Repository: dongzelian/SSF Branch: main Commit: e94e0e704a4e Files: 75 Total size: 315.5 KB Directory structure: gitextract_ti2b9cvr/ ├── LICENSE ├── README.md ├── data/ │ ├── __init__.py │ ├── cub2011.py │ ├── dataset_factory.py │ ├── loader.py │ ├── nabirds.py │ ├── stanford_dogs.py │ ├── transforms_factory.py │ └── vtab.py ├── log/ │ ├── README.md │ └── cifar100.csv ├── models/ │ ├── as_mlp.py │ ├── convnext.py │ ├── swin_transformer.py │ └── vision_transformer.py ├── optim_factory.py ├── requirements.txt ├── train.py ├── train_scripts/ │ ├── asmlp/ │ │ └── cifar_100/ │ │ ├── train_full.sh │ │ ├── train_linear_probe.sh │ │ └── train_ssf.sh │ ├── convnext/ │ │ ├── cifar_100/ │ │ │ ├── train_full.sh │ │ │ ├── train_linear_probe.sh │ │ │ └── train_ssf.sh │ │ └── imagenet_1k/ │ │ ├── train_full.sh │ │ ├── train_linear_probe.sh │ │ └── train_ssf.sh │ ├── swin/ │ │ ├── cifar_100/ │ │ │ ├── train_full.sh │ │ │ ├── train_linear_probe.sh │ │ │ └── train_ssf.sh │ │ └── imagenet_1k/ │ │ ├── train_full.sh │ │ ├── train_linear_probe.sh │ │ └── train_ssf.sh │ └── vit/ │ ├── cifar_100/ │ │ ├── eval_ssf.sh │ │ ├── train_full.sh │ │ ├── train_linear_probe.sh │ │ └── train_ssf.sh │ ├── fgvc/ │ │ ├── cub_2011/ │ │ │ └── train_ssf.sh │ │ ├── nabirds/ │ │ │ └── train_ssf.sh │ │ ├── oxford_flowers/ │ │ │ └── train_ssf.sh │ │ ├── stanford_cars/ │ │ │ └── train_ssf.sh │ │ └── stanford_dogs/ │ │ └── train_ssf.sh │ ├── imagenet_1k/ │ │ ├── train_full.sh │ │ ├── train_linear_probe.sh │ │ └── train_ssf.sh │ ├── imagenet_a/ │ │ └── eval_ssf.sh │ ├── imagenet_c/ │ │ └── eval_ssf.sh │ ├── imagenet_r/ │ │ └── eval_ssf.sh │ └── vtab/ │ ├── caltech101/ │ │ └── train_ssf.sh │ ├── cifar_100/ │ │ └── train_ssf.sh │ ├── clevr_count/ │ │ └── train_ssf.sh │ ├── clevr_dist/ │ │ └── train_ssf.sh │ ├── diabetic_retinopathy/ │ │ └── train_ssf.sh │ ├── dmlab/ │ │ └── train_ssf.sh │ ├── dsprites_loc/ │ │ └── train_ssf.sh │ ├── dsprites_ori/ │ │ └── train_ssf.sh │ ├── dtd/ │ │ └── train_ssf.sh │ ├── eurosat/ │ │ └── train_ssf.sh │ ├── flowers102/ │ │ └── train_ssf.sh │ ├── kitti/ │ │ └── train_ssf.sh │ ├── patch_camelyon/ │ │ └── train_ssf.sh │ ├── pets/ │ │ └── train_ssf.sh │ ├── resisc45/ │ │ └── train_ssf.sh │ ├── smallnorb_azi/ │ │ └── train_ssf.sh │ ├── smallnorb_ele/ │ │ └── train_ssf.sh │ ├── sun397/ │ │ └── train_ssf.sh │ └── svhn/ │ └── train_ssf.sh ├── utils/ │ ├── __init__.py │ ├── imagenet_a.py │ ├── imagenet_r.py │ ├── mce_utils.py │ ├── scaler.py │ └── utils.py └── validate_ood.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2022 dongzelian Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # SSF for Efficient Model Tuning This repo is the official implementation of our NeurIPS2022 paper "Scaling & Shifting Your Features: A New Baseline for Efficient Model Tuning" ([arXiv](https://arxiv.org/abs/2210.08823)). ## Usage ### Install - Clone this repo: ```bash git clone https://github.com/dongzelian/SSF.git cd SSF ``` - Create a conda virtual environment and activate it: ```bash conda create -n ssf python=3.7 -y conda activate ssf ``` - Install `CUDA==10.1` with `cudnn7` following the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) - Install `PyTorch==1.7.1` and `torchvision==0.8.2` with `CUDA==10.1`: ```bash conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch ``` - Install `timm==0.6.5`: ```bash pip install timm==0.6.5 ``` - Install other requirements: ```bash pip install -r requirements.txt ``` ### Data preparation - FGVC & vtab-1k You can follow [VPT](https://github.com/KMnP/vpt) to download them. Since the original [vtab dataset](https://github.com/google-research/task_adaptation/tree/master/task_adaptation/data) is processed with tensorflow scripts and the processing of some datasets is tricky, we also upload the extracted vtab-1k dataset in [onedrive](https://shanghaitecheducn-my.sharepoint.com/:f:/g/personal/liandz_shanghaitech_edu_cn/EnV6eYPVCPZKhbqi-WSJIO8BOcyQwDwRk6dAThqonQ1Ycw?e=J884Fp) for your convenience. You can download from here and then use them with our [vtab.py](https://github.com/dongzelian/SSF/blob/main/data/vtab.py) directly. (Note that the license is in [vtab dataset](https://github.com/google-research/task_adaptation/tree/master/task_adaptation/data)). - CIFAR-100 ```bash wget https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz ``` - For ImageNet-1K, download it from http://image-net.org/, and move validation images to labeled sub-folders. The file structure should look like: ```bash $ tree data imagenet ├── train │ ├── class1 │ │ ├── img1.jpeg │ │ ├── img2.jpeg │ │ └── ... │ ├── class2 │ │ ├── img3.jpeg │ │ └── ... │ └── ... └── val ├── class1 │ ├── img4.jpeg │ ├── img5.jpeg │ └── ... ├── class2 │ ├── img6.jpeg │ └── ... └── ... ``` - Robustness & OOD datasets Prepare [ImageNet-A](https://github.com/hendrycks/natural-adv-examples), [ImageNet-R](https://github.com/hendrycks/imagenet-r) and [ImageNet-C](https://zenodo.org/record/2235448#.Y04cBOxByFw) for evaluation. ### Pre-trained model preparation - For pre-trained ViT-B/16, Swin-B, and ConvNext-B models on ImageNet-21K, the model weights will be automatically downloaded when you fine-tune a pre-trained model via `SSF`. You can also manually download them from [ViT](https://github.com/google-research/vision_transformer),[Swin Transformer](https://github.com/microsoft/Swin-Transformer), and [ConvNext](https://github.com/facebookresearch/ConvNeXt). - For pre-trained AS-MLP-B model on ImageNet-1K, you can manually download them from [AS-MLP](https://github.com/svip-lab/AS-MLP). ### Fine-tuning a pre-trained model via SSF To fine-tune a pre-trained ViT model via `SSF` on CIFAR-100 or ImageNet-1K, run: ```bash bash train_scripts/vit/cifar_100/train_ssf.sh ``` or ```bash bash train_scripts/vit/imagenet_1k/train_ssf.sh ``` You can also find the similar scripts for Swin, ConvNext, and AS-MLP models. You can easily reproduce our results. Enjoy! ### Robustness & OOD To evaluate the performance of fine-tuned model via SSF on Robustness & OOD, run: ```bash bash train_scripts/vit/imagenet_a(r, c)/eval_ssf.sh ``` ### Citation If this project is helpful for you, you can cite our paper: ``` @InProceedings{Lian_2022_SSF, title={Scaling \& Shifting Your Features: A New Baseline for Efficient Model Tuning}, author={Lian, Dongze and Zhou, Daquan and Feng, Jiashi and Wang, Xinchao}, booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, year={2022} } ``` ### Acknowledgement The code is built upon [timm](https://github.com/rwightman/pytorch-image-models). The processing of the vtab-1k dataset refers to [vpt](https://github.com/KMnP/vpt), [vtab github repo](https://github.com/google-research/task_adaptation/tree/master/task_adaptation/data), and [NOAH](https://github.com/ZhangYuanhan-AI/NOAH). ================================================ FILE: data/__init__.py ================================================ from .loader import create_loader from .dataset_factory import create_dataset ================================================ FILE: data/cub2011.py ================================================ import os import pandas as pd from torchvision.datasets import VisionDataset from torchvision.datasets.folder import default_loader from torchvision.datasets.utils import download_file_from_google_drive class Cub2011(VisionDataset): """`CUB-200-2011 `_ Dataset. Args: root (string): Root directory of the dataset. train (bool, optional): If True, creates dataset from training set, otherwise creates from test set. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ base_folder = 'CUB_200_2011/images' # url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' file_id = '1hbzc_P1FuxMkcabkgn9ZKinBwW683j45' filename = 'CUB_200_2011.tgz' tgz_md5 = '97eceeb196236b17998738112f37df78' def __init__(self, root, train=True, transform=None, target_transform=None, download=False): super(Cub2011, self).__init__(root, transform=transform, target_transform=target_transform) self.loader = default_loader self.train = train if download: self._download() if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted. You can use download=True to download it') def _load_metadata(self): images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', names=['img_id', 'filepath']) image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), sep=' ', names=['img_id', 'target']) train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), sep=' ', names=['img_id', 'is_training_img']) data = images.merge(image_class_labels, on='img_id') self.data = data.merge(train_test_split, on='img_id') class_names = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'classes.txt'), sep=' ', names=['class_name'], usecols=[1]) self.class_names = class_names['class_name'].to_list() if self.train: self.data = self.data[self.data.is_training_img == 1] else: self.data = self.data[self.data.is_training_img == 0] def _check_integrity(self): try: self._load_metadata() except Exception: return False for index, row in self.data.iterrows(): filepath = os.path.join(self.root, self.base_folder, row.filepath) if not os.path.isfile(filepath): print(filepath) return False return True def _download(self): import tarfile if self._check_integrity(): print('Files already downloaded and verified') return download_file_from_google_drive(self.file_id, self.root, self.filename, self.tgz_md5) with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: tar.extractall(path=self.root) def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data.iloc[idx] path = os.path.join(self.root, self.base_folder, sample.filepath) target = sample.target - 1 # Targets start at 1 by default, so shift to 0 img = self.loader(path) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target if __name__ == '__main__': train_dataset = Cub2011('./cub2011', train=True, download=False) test_dataset = Cub2011('./cub2011', train=False, download=False) ================================================ FILE: data/dataset_factory.py ================================================ """ Dataset Factory Hacked together by / Copyright 2021, Ross Wightman """ import os #import hub from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder try: from torchvision.datasets import Places365 has_places365 = True except ImportError: has_places365 = False try: from torchvision.datasets import INaturalist has_inaturalist = True except ImportError: has_inaturalist = False from timm.data.dataset import IterableImageDataset, ImageDataset # my datasets from .stanford_dogs import dogs from .nabirds import NABirds from .cub2011 import Cub2011 from .vtab import VTAB _TORCH_BASIC_DS = dict( cifar10=CIFAR10, cifar100=CIFAR100, mnist=MNIST, qmist=QMNIST, kmnist=KMNIST, fashion_mnist=FashionMNIST, ) _TRAIN_SYNONYM = {'train', 'training'} _EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'} _VTAB_DATASET = ['caltech101', 'clevr_count', 'dmlab', 'dsprites_ori', 'eurosat', 'flowers102', 'patch_camelyon', 'smallnorb_azi', 'svhn', 'cifar100', 'clevr_dist', 'dsprites_loc', 'dtd', 'kitti', 'pets', 'resisc45', 'smallnorb_ele', 'sun397', 'diabetic_retinopathy'] def _search_split(root, split): # look for sub-folder with name of split in root and use that if it exists split_name = split.split('[')[0] try_root = os.path.join(root, split_name) if os.path.exists(try_root): return try_root def _try(syn): for s in syn: try_root = os.path.join(root, s) if os.path.exists(try_root): return try_root return root if split_name in _TRAIN_SYNONYM: root = _try(_TRAIN_SYNONYM) elif split_name in _EVAL_SYNONYM: root = _try(_EVAL_SYNONYM) return root def create_dataset( name, root, split='validation', search_split=True, class_map=None, load_bytes=False, is_training=False, download=False, batch_size=None, repeats=0, **kwargs ): """ Dataset factory method In parenthesis after each arg are the type of dataset supported for each arg, one of: * folder - default, timm folder (or tar) based ImageDataset * torch - torchvision based datasets * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset * all - any of the above Args: name: dataset name, empty is okay for folder based datasets root: root folder of dataset (all) split: dataset split (all) search_split: search for split specific child fold from root so one can specify `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder) class_map: specify class -> index mapping via text file or dict (folder) load_bytes: load data, return images as undecoded bytes (folder) download: download dataset if not present and supported (TFDS, torch) is_training: create dataset in train mode, this is different from the split. For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS) batch_size: batch size hint for (TFDS) repeats: dataset repeats per iteration i.e. epoch (TFDS) **kwargs: other args to pass to dataset Returns: Dataset object """ name = name.lower() if name.startswith('torch/'): name = name.split('/', 2)[-1] torch_kwargs = dict(root=root, download=download, **kwargs) if name in _TORCH_BASIC_DS: ds_class = _TORCH_BASIC_DS[name] use_train = split in _TRAIN_SYNONYM ds = ds_class(train=use_train, **torch_kwargs) elif name == 'inaturalist' or name == 'inat': assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist' target_type = 'full' split_split = split.split('/') if len(split_split) > 1: target_type = split_split[0].split('_') if len(target_type) == 1: target_type = target_type[0] split = split_split[-1] if split in _TRAIN_SYNONYM: split = '2021_train' elif split in _EVAL_SYNONYM: split = '2021_valid' ds = INaturalist(version=split, target_type=target_type, **torch_kwargs) elif name == 'places365': assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.' if split in _TRAIN_SYNONYM: split = 'train-standard' elif split in _EVAL_SYNONYM: split = 'val' ds = Places365(split=split, **torch_kwargs) elif name == 'imagenet': if split in _EVAL_SYNONYM: split = 'val' ds = ImageNet(split=split, **torch_kwargs) elif name == 'image_folder' or name == 'folder': # in case torchvision ImageFolder is preferred over timm ImageDataset for some reason if search_split and os.path.isdir(root): # look for split specific sub-folder in root root = _search_split(root, split) ds = ImageFolder(root, **kwargs) else: assert False, f"Unknown torchvision dataset {name}" elif name.startswith('tfds/'): ds = IterableImageDataset( root, parser=name, split=split, is_training=is_training, download=download, batch_size=batch_size, repeats=repeats, **kwargs) else: # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future # define my datasets if name == 'stanford_dogs': ds = dogs(root=root, train=is_training, **kwargs) elif name == 'nabirds': ds = NABirds(root=root, train=is_training, **kwargs) elif name == 'cub2011': ds = Cub2011(root=root, train=is_training, **kwargs) elif name in _VTAB_DATASET: ds = VTAB(root=root, train=is_training, **kwargs) else: if os.path.isdir(os.path.join(root, split)): root = os.path.join(root, split) else: if search_split and os.path.isdir(root): root = _search_split(root, split) ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs) return ds ================================================ FILE: data/loader.py ================================================ """ Loader Factory, Fast Collate, CUDA Prefetcher Prefetcher and Fast Collate inspired by NVIDIA APEX example at https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf Hacked together by / Copyright 2019, Ross Wightman """ import random from functools import partial from itertools import repeat from typing import Callable import torch.utils.data import numpy as np from .transforms_factory import create_transform from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data.distributed_sampler import OrderedDistributedSampler, RepeatAugSampler from timm.data.random_erasing import RandomErasing from timm.data.mixup import FastCollateMixup def fast_collate(batch): """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)""" assert isinstance(batch[0], tuple) batch_size = len(batch) if isinstance(batch[0][0], tuple): # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position inner_tuple_size = len(batch[0][0]) flattened_batch_size = batch_size * inner_tuple_size targets = torch.zeros(flattened_batch_size, dtype=torch.int64) tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8) for i in range(batch_size): assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length for j in range(inner_tuple_size): targets[i + j * batch_size] = batch[i][1] tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j]) return tensor, targets elif isinstance(batch[0][0], np.ndarray): targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) assert len(targets) == batch_size tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) for i in range(batch_size): tensor[i] += torch.from_numpy(batch[i][0]) return tensor, targets elif isinstance(batch[0][0], torch.Tensor): targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) assert len(targets) == batch_size tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) for i in range(batch_size): tensor[i].copy_(batch[i][0]) return tensor, targets else: assert False def expand_to_chs(x, n): if not isinstance(x, (tuple, list)): x = tuple(repeat(x, n)) elif len(x) == 1: x = x * n else: assert len(x) == n, 'normalization stats must match image channels' return x class PrefetchLoader: def __init__( self, loader, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, channels=3, fp16=False, re_prob=0., re_mode='const', re_count=1, re_num_splits=0): mean = expand_to_chs(mean, channels) std = expand_to_chs(std, channels) normalization_shape = (1, channels, 1, 1) self.loader = loader self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(normalization_shape) self.std = torch.tensor([x * 255 for x in std]).cuda().view(normalization_shape) self.fp16 = fp16 if fp16: self.mean = self.mean.half() self.std = self.std.half() if re_prob > 0.: self.random_erasing = RandomErasing( probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) else: self.random_erasing = None def __iter__(self): stream = torch.cuda.Stream() first = True for next_input, next_target in self.loader: with torch.cuda.stream(stream): next_input = next_input.cuda(non_blocking=True) next_target = next_target.cuda(non_blocking=True) if self.fp16: next_input = next_input.half().sub_(self.mean).div_(self.std) else: next_input = next_input.float().sub_(self.mean).div_(self.std) if self.random_erasing is not None: next_input = self.random_erasing(next_input) if not first: yield input, target else: first = False torch.cuda.current_stream().wait_stream(stream) input = next_input target = next_target yield input, target def __len__(self): return len(self.loader) @property def sampler(self): return self.loader.sampler @property def dataset(self): return self.loader.dataset @property def mixup_enabled(self): if isinstance(self.loader.collate_fn, FastCollateMixup): return self.loader.collate_fn.mixup_enabled else: return False @mixup_enabled.setter def mixup_enabled(self, x): if isinstance(self.loader.collate_fn, FastCollateMixup): self.loader.collate_fn.mixup_enabled = x def _worker_init(worker_id, worker_seeding='all'): worker_info = torch.utils.data.get_worker_info() assert worker_info.id == worker_id if isinstance(worker_seeding, Callable): seed = worker_seeding(worker_info) random.seed(seed) torch.manual_seed(seed) np.random.seed(seed % (2 ** 32 - 1)) else: assert worker_seeding in ('all', 'part') # random / torch seed already called in dataloader iter class w/ worker_info.seed # to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed) if worker_seeding == 'all': np.random.seed(worker_info.seed % (2 ** 32 - 1)) def create_loader( dataset, input_size, batch_size, is_training=False, use_prefetcher=True, no_aug=False, simple_aug=False, direct_resize=False, re_prob=0., re_mode='const', re_count=1, re_split=False, scale=None, ratio=None, hflip=0.5, vflip=0., color_jitter=0.4, auto_augment=None, num_aug_repeats=0, num_aug_splits=0, interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_workers=1, distributed=False, crop_pct=None, collate_fn=None, pin_memory=False, fp16=False, tf_preprocessing=False, use_multi_epochs_loader=False, persistent_workers=True, worker_seeding='all', ): re_num_splits = 0 if re_split: # apply RE to second half of batch if no aug split otherwise line up with aug split re_num_splits = num_aug_splits or 2 dataset.transform = create_transform( input_size, is_training=is_training, use_prefetcher=use_prefetcher, no_aug=no_aug, simple_aug=simple_aug, direct_resize=direct_resize, scale=scale, ratio=ratio, hflip=hflip, vflip=vflip, color_jitter=color_jitter, auto_augment=auto_augment, interpolation=interpolation, mean=mean, std=std, crop_pct=crop_pct, tf_preprocessing=tf_preprocessing, re_prob=re_prob, re_mode=re_mode, re_count=re_count, re_num_splits=re_num_splits, separate=num_aug_splits > 0, ) sampler = None if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): if is_training: if num_aug_repeats: sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats) else: sampler = torch.utils.data.distributed.DistributedSampler(dataset) else: # This will add extra duplicate entries to result in equal num # of samples per-process, will slightly alter validation results sampler = OrderedDistributedSampler(dataset) else: assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use" if collate_fn is None: collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate loader_class = torch.utils.data.DataLoader if use_multi_epochs_loader: loader_class = MultiEpochsDataLoader loader_args = dict( batch_size=batch_size, shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training, num_workers=num_workers, sampler=sampler, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=is_training, worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding), persistent_workers=persistent_workers ) try: loader = loader_class(dataset, **loader_args) except TypeError as e: loader_args.pop('persistent_workers') # only in Pytorch 1.7+ loader = loader_class(dataset, **loader_args) if use_prefetcher: prefetch_re_prob = re_prob if is_training and not no_aug else 0. loader = PrefetchLoader( loader, mean=mean, std=std, channels=input_size[0], fp16=fp16, re_prob=prefetch_re_prob, re_mode=re_mode, re_count=re_count, re_num_splits=re_num_splits ) return loader class MultiEpochsDataLoader(torch.utils.data.DataLoader): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._DataLoader__initialized = False self.batch_sampler = _RepeatSampler(self.batch_sampler) self._DataLoader__initialized = True self.iterator = super().__iter__() def __len__(self): return len(self.batch_sampler.sampler) def __iter__(self): for i in range(len(self)): yield next(self.iterator) class _RepeatSampler(object): """ Sampler that repeats forever. Args: sampler (Sampler) """ def __init__(self, sampler): self.sampler = sampler def __iter__(self): while True: yield from iter(self.sampler) ================================================ FILE: data/nabirds.py ================================================ import os import pandas as pd import warnings import numpy as np import torch from PIL import Image from torchvision.datasets import VisionDataset from torchvision.datasets.folder import default_loader from torchvision.datasets.utils import check_integrity, extract_archive from torch.utils.data import DataLoader, Dataset class NABirds(Dataset): """`NABirds `_ Dataset. Args: root (string): Root directory of the dataset. train (bool, optional): If True, creates dataset from training set, otherwise creates from test set. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ base_folder = 'nabirds/images' def __init__(self, root, train=True, transform=None): dataset_path = os.path.join(root, 'nabirds') self.root = root self.loader = default_loader self.train = train self.transform = transform image_paths = pd.read_csv(os.path.join(dataset_path, 'images.txt'), sep=' ', names=['img_id', 'filepath']) image_class_labels = pd.read_csv(os.path.join(dataset_path, 'image_class_labels.txt'), sep=' ', names=['img_id', 'target']) # Since the raw labels are non-continuous, map them to new ones self.label_map = get_continuous_class_map(image_class_labels['target']) train_test_split = pd.read_csv(os.path.join(dataset_path, 'train_test_split.txt'), sep=' ', names=['img_id', 'is_training_img']) data = image_paths.merge(image_class_labels, on='img_id') self.data = data.merge(train_test_split, on='img_id') # Load in the train / test split if self.train: self.data = self.data[self.data.is_training_img == 1] else: self.data = self.data[self.data.is_training_img == 0] # Load in the class data self.class_names = load_class_names(dataset_path) self.class_hierarchy = load_hierarchy(dataset_path) def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data.iloc[idx] path = os.path.join(self.root, self.base_folder, sample.filepath) target = self.label_map[sample.target] img = self.loader(path) if self.transform is not None: img = self.transform(img) return img, target def get_continuous_class_map(class_labels): label_set = set(class_labels) return {k: i for i, k in enumerate(label_set)} def load_class_names(dataset_path=''): names = {} with open(os.path.join(dataset_path, 'classes.txt')) as f: for line in f: pieces = line.strip().split() class_id = pieces[0] names[class_id] = ' '.join(pieces[1:]) return names def load_hierarchy(dataset_path=''): parents = {} with open(os.path.join(dataset_path, 'hierarchy.txt')) as f: for line in f: pieces = line.strip().split() child_id, parent_id = pieces parents[child_id] = parent_id return parents ================================================ FILE: data/stanford_dogs.py ================================================ from __future__ import print_function from PIL import Image from os.path import join import os import scipy.io import torch.utils.data as data from torchvision.datasets.utils import download_url, list_dir, list_files class dogs(data.Dataset): """`Stanford Dogs `_ Dataset. Args: root (string): Root directory of dataset where directory ``omniglot-py`` exists. cropped (bool, optional): If true, the images will be cropped into the bounding box specified in the annotations transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset tar files from the internet and puts it in root directory. If the tar files are already downloaded, they are not downloaded again. """ #folder = 'StanfordDogs' folder = '' download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs' def __init__(self, root, train=True, cropped=False, transform=None, target_transform=None, download=False): self.root = join(os.path.expanduser(root), self.folder) self.train = train self.cropped = cropped self.transform = transform self.target_transform = target_transform if download: self.download() split = self.load_split() self.images_folder = join(self.root, 'Images') self.annotations_folder = join(self.root, 'Annotation') self._breeds = list_dir(self.images_folder) if self.cropped: self._breed_annotations = [[(annotation, box, idx) for box in self.get_boxes(join(self.annotations_folder, annotation))] for annotation, idx in split] self._flat_breed_annotations = sum(self._breed_annotations, []) self._flat_breed_images = [(annotation+'.jpg', idx) for annotation, box, idx in self._flat_breed_annotations] else: self._breed_images = [(annotation+'.jpg', idx) for annotation, idx in split] self._flat_breed_images = self._breed_images self.classes = ["Chihuaha", "Japanese Spaniel", "Maltese Dog", "Pekinese", "Shih-Tzu", "Blenheim Spaniel", "Papillon", "Toy Terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick", "Black-and-tan Coonhound", "Walker Hound", "English Foxhound", "Redbone", "Borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizian Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bullterrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wirehaired Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale", "Cairn", "Australian Terrier", "Dandi Dinmont", "Boston Bull", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scotch Terrier", "Tibetan Terrier", "Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa", "Flat-coated Retriever", "Curly-coater Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Short-haired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany", "Clumber", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael", "Malinois", "Briard", "Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "Collie", "Border Collie", "Bouvier des Flandres", "Rottweiler", "German Shepard", "Doberman", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller", "EntleBucher", "Boxer", "Bull Mastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "Saint Bernard", "Eskimo Dog", "Malamute", "Siberian Husky", "Affenpinscher", "Basenji", "Pug", "Leonberg", "Newfoundland", "Great Pyrenees", "Samoyed", "Pomeranian", "Chow", "Keeshond", "Brabancon Griffon", "Pembroke", "Cardigan", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican Hairless", "Dingo", "Dhole", "African Hunting Dog"] def __len__(self): return len(self._flat_breed_images) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target character class. """ image_name, target_class = self._flat_breed_images[index] image_path = join(self.images_folder, image_name) image = Image.open(image_path).convert('RGB') if self.cropped: image = image.crop(self._flat_breed_annotations[index][1]) if self.transform: image = self.transform(image) if self.target_transform: target_class = self.target_transform(target_class) return image, target_class def download(self): import tarfile if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')): if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120: print('Files already downloaded and verified') return for filename in ['images', 'annotation', 'lists']: tar_filename = filename + '.tar' url = self.download_url_prefix + '/' + tar_filename download_url(url, self.root, tar_filename, None) print('Extracting downloaded file: ' + join(self.root, tar_filename)) with tarfile.open(join(self.root, tar_filename), 'r') as tar_file: tar_file.extractall(self.root) os.remove(join(self.root, tar_filename)) @staticmethod def get_boxes(path): import xml.etree.ElementTree e = xml.etree.ElementTree.parse(path).getroot() boxes = [] for objs in e.iter('object'): boxes.append([int(objs.find('bndbox').find('xmin').text), int(objs.find('bndbox').find('ymin').text), int(objs.find('bndbox').find('xmax').text), int(objs.find('bndbox').find('ymax').text)]) return boxes def load_split(self): if self.train: split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list'] labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels'] else: split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list'] labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels'] split = [item[0][0] for item in split] labels = [item[0]-1 for item in labels] return list(zip(split, labels)) def stats(self): counts = {} for index in range(len(self._flat_breed_images)): image_name, target_class = self._flat_breed_images[index] if target_class not in counts.keys(): counts[target_class] = 1 else: counts[target_class] += 1 print("%d samples spanning %d classes (avg %f per class)"%(len(self._flat_breed_images), len(counts.keys()), float(len(self._flat_breed_images))/float(len(counts.keys())))) return counts ================================================ FILE: data/transforms_factory.py ================================================ """ Transforms Factory Factory methods for building image transforms for use with TIMM (PyTorch Image Models) Hacked together by / Copyright 2019, Ross Wightman """ import math import torch from torchvision import transforms from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, ToNumpy from timm.data.random_erasing import RandomErasing def transforms_direct_resize( img_size=224, interpolation='bilinear', use_prefetcher=False, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, ): if interpolation == 'random': # random interpolation not supported with no-aug interpolation = 'bilinear' tfl = [ transforms.Resize(img_size, interpolation=str_to_interp_mode(interpolation)), transforms.CenterCrop(img_size) ] if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm tfl += [ToNumpy()] else: tfl += [ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std)) ] return transforms.Compose(tfl) def transforms_simpleaug_train( img_size=224, scale=None, ratio=None, hflip=0.5, interpolation='bilinear', use_prefetcher=False, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, ): tfl = [ RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation), transforms.RandomHorizontalFlip(p=hflip) ] if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm tfl += [ToNumpy()] else: tfl += [ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std)) ] return transforms.Compose(tfl) def transforms_imagenet_train( img_size=224, scale=None, ratio=None, hflip=0.5, vflip=0., color_jitter=0.4, auto_augment=None, interpolation='random', use_prefetcher=False, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, re_prob=0., re_mode='const', re_count=1, re_num_splits=0, separate=False, ): """ If separate==True, the transforms are returned as a tuple of 3 separate transforms for use in a mixing dataset that passes * all data through the first (primary) transform, called the 'clean' data * a portion of the data through the secondary transform * normalizes and converts the branches above with the third, final transform """ scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range primary_tfl = [ RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)] if hflip > 0.: primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] if vflip > 0.: primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] secondary_tfl = [] if auto_augment: assert isinstance(auto_augment, str) if isinstance(img_size, (tuple, list)): img_size_min = min(img_size) else: img_size_min = img_size aa_params = dict( translate_const=int(img_size_min * 0.45), img_mean=tuple([min(255, round(255 * x)) for x in mean]), ) if interpolation and interpolation != 'random': aa_params['interpolation'] = str_to_pil_interp(interpolation) if auto_augment.startswith('rand'): secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] elif auto_augment.startswith('augmix'): aa_params['translate_pct'] = 0.3 secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] else: secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] elif color_jitter is not None: # color jitter is enabled when not using AA if isinstance(color_jitter, (list, tuple)): # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation # or 4 if also augmenting hue assert len(color_jitter) in (3, 4) else: # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue color_jitter = (float(color_jitter),) * 3 secondary_tfl += [transforms.ColorJitter(*color_jitter)] final_tfl = [] if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm final_tfl += [ToNumpy()] else: final_tfl += [ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std)) ] if re_prob > 0.: final_tfl.append( RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu')) if separate: return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) else: return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) def transforms_imagenet_eval( img_size=224, crop_pct=None, interpolation='bilinear', use_prefetcher=False, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD): crop_pct = crop_pct or DEFAULT_CROP_PCT if isinstance(img_size, (tuple, list)): assert len(img_size) == 2 if img_size[-1] == img_size[-2]: # fall-back to older behaviour so Resize scales to shortest edge if target is square scale_size = int(math.floor(img_size[0] / crop_pct)) else: scale_size = tuple([int(x / crop_pct) for x in img_size]) else: scale_size = int(math.floor(img_size / crop_pct)) tfl = [ transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)), transforms.CenterCrop(img_size), ] if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm tfl += [ToNumpy()] else: tfl += [ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std)) ] return transforms.Compose(tfl) def create_transform( input_size, is_training=False, use_prefetcher=False, no_aug=False, simple_aug=False, direct_resize=False, scale=None, ratio=None, hflip=0.5, vflip=0., color_jitter=0.4, auto_augment=None, interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, re_prob=0., re_mode='const', re_count=1, re_num_splits=0, crop_pct=None, tf_preprocessing=False, separate=False): if isinstance(input_size, (tuple, list)): img_size = input_size[-2:] else: img_size = input_size if tf_preprocessing and use_prefetcher: assert not separate, "Separate transforms not supported for TF preprocessing" from timm.data.tf_preprocessing import TfPreprocessTransform transform = TfPreprocessTransform( is_training=is_training, size=img_size, interpolation=interpolation) else: if is_training: if no_aug: assert not separate, "Cannot perform split augmentation with no_aug" transform = transforms_direct_resize( img_size, interpolation=interpolation, use_prefetcher=use_prefetcher, mean=mean, std=std) elif simple_aug: transform = transforms_simpleaug_train( img_size, scale=scale, ratio=ratio, hflip=hflip, interpolation=interpolation, use_prefetcher=use_prefetcher, mean=mean, std=std) else: transform = transforms_imagenet_train( img_size, scale=scale, ratio=ratio, hflip=hflip, vflip=vflip, color_jitter=color_jitter, auto_augment=auto_augment, interpolation=interpolation, use_prefetcher=use_prefetcher, mean=mean, std=std, re_prob=re_prob, re_mode=re_mode, re_count=re_count, re_num_splits=re_num_splits, separate=separate) else: if direct_resize: #print('direct_resize') transform = transforms_direct_resize( img_size, interpolation=interpolation, use_prefetcher=use_prefetcher, mean=mean, std=std) else: assert not separate, "Separate transforms not supported for validation preprocessing" transform = transforms_imagenet_eval( img_size, interpolation=interpolation, use_prefetcher=use_prefetcher, mean=mean, std=std, crop_pct=crop_pct) return transform ================================================ FILE: data/vtab.py ================================================ import os from torchvision.datasets.folder import ImageFolder, default_loader class VTAB(ImageFolder): def __init__(self, root, train=True, transform=None, target_transform=None, mode=None,is_individual_prompt=False,**kwargs): self.dataset_root = root self.loader = default_loader self.target_transform = None self.transform = transform train_list_path = os.path.join(self.dataset_root, 'train800val200.txt') test_list_path = os.path.join(self.dataset_root, 'test.txt') # train_list_path = os.path.join(self.dataset_root, 'train800.txt') # test_list_path = os.path.join(self.dataset_root, 'val200.txt') self.samples = [] if train: with open(train_list_path, 'r') as f: for line in f: img_name = line.split(' ')[0] label = int(line.split(' ')[1]) self.samples.append((os.path.join(root,img_name), label)) else: with open(test_list_path, 'r') as f: for line in f: img_name = line.split(' ')[0] label = int(line.split(' ')[1]) self.samples.append((os.path.join(root,img_name), label)) ================================================ FILE: log/README.md ================================================ ================================================ FILE: log/cifar100.csv ================================================ epoch,train_loss,eval_loss,eval_top1,eval_top5 0,5.603768242730035,6.0231625,0.96,4.89 1,4.440675179163615,5.8859375,1.02,5.18 2,3.186616155836317,5.63614375,1.43,6.43 3,2.827679475148519,5.345575,1.92,8.46 4,2.7690945731268988,5.04094375,2.7,11.41 5,2.6746084690093994,4.7265375,4.07,15.29 6,2.5769188139173718,4.404325,6.44,20.77 7,2.658121665318807,4.07631875,10.12,28.31 8,2.5106263955434165,3.74056875,15.35,37.6 9,2.638317664464315,3.398590625,23.04,49.38 10,2.4921720822652182,3.05421875,32.87,61.79 11,2.5904562208387585,2.714753125,44.93,72.77 12,2.521847221586439,2.38726875,57.29,81.89 13,2.555907726287842,2.0806125,67.52,88.52 14,2.678542561001248,1.8002078125,75.47,92.47 15,2.5564871629079184,1.549046875,81.04,94.85 16,2.5498899353875055,1.329515625,84.56,96.42 17,2.555638392766317,1.1425953125,86.4,97.29 18,2.5207514233059354,0.9856625,87.97,97.84 19,2.5135546260409884,0.8570640625,88.97,98.17 20,2.417558749516805,0.75346328125,89.95,98.5 21,2.4377992947896323,0.669475,90.59,98.76 22,2.5416789849599204,0.6010578125,91.05,98.88 23,2.4467749065823026,0.5459734375,91.42,98.99 24,2.3756895065307617,0.5007609375,91.64,99.1 25,2.4439392619662814,0.46489609375,91.92,99.19 26,2.4532913896772595,0.43501953125,92.09,99.26 27,2.508685827255249,0.410413671875,92.19,99.3 28,2.489635467529297,0.390225,92.39,99.31 29,2.4972835646735296,0.37323125,92.44,99.36 30,2.4116059409247503,0.358880078125,92.54,99.42 31,2.4685837162865534,0.346959765625,92.7,99.46 32,2.3165384001202054,0.336875,92.71,99.47 33,2.4255740377638073,0.32844765625,92.71,99.46 34,2.4002869658999972,0.321278125,92.76,99.45 35,2.4275462097591824,0.314995703125,92.8,99.47 36,2.407806317011515,0.309953125,92.91,99.49 37,2.402532418568929,0.30562421875,92.98,99.51 38,2.4551497830284967,0.30197578125,93.06,99.51 39,2.4293878608279758,0.2986171875,93.17,99.51 40,2.4136725531684027,0.296045703125,93.24,99.51 41,2.4232251379224987,0.29378125,93.32,99.52 42,2.2722683482699924,0.2918,93.31,99.52 43,2.397341330846151,0.290190625,93.35,99.52 44,2.406024138132731,0.289073828125,93.39,99.53 45,2.3778079880608454,0.288259765625,93.42,99.53 46,2.4535727500915527,0.28774140625,93.47,99.53 47,2.4934064282311335,0.287232421875,93.56,99.53 48,2.325168079800076,0.287040234375,93.59,99.52 49,2.396822929382324,0.286894140625,93.58,99.55 50,2.3157892756991916,0.287205859375,93.61,99.55 51,2.4792808956570096,0.287432421875,93.63,99.56 52,2.366891860961914,0.287920703125,93.65,99.56 53,2.345345550113254,0.2883359375,93.67,99.56 54,2.303606006834242,0.288706640625,93.66,99.56 55,2.398844109641181,0.28921875,93.68,99.57 56,2.3735866281721325,0.2897921875,93.68,99.57 57,2.4807073805067272,0.2904640625,93.69,99.58 58,2.349070734447903,0.291125,93.76,99.58 59,2.375280910068088,0.292005859375,93.8,99.58 60,2.33055845896403,0.293062109375,93.82,99.57 61,2.4362279574076333,0.29396484375,93.79,99.57 62,2.317348506715563,0.29523125,93.8,99.57 63,2.4144566191567316,0.29618203125,93.8,99.57 64,2.3383621904585095,0.297141796875,93.79,99.56 65,2.4115795029534235,0.298030078125,93.79,99.56 66,2.3216844929589167,0.299159375,93.82,99.57 67,2.318764951494005,0.30013984375,93.77,99.56 68,2.269577423731486,0.30101484375,93.75,99.56 69,2.369996494717068,0.30192890625,93.79,99.55 70,2.3662983311547174,0.303022265625,93.81,99.56 71,2.303777880138821,0.303908984375,93.85,99.56 72,2.255165616671244,0.30502734375,93.82,99.56 73,2.2920398712158203,0.305877734375,93.82,99.56 74,2.310059520933363,0.30675234375,93.87,99.56 75,2.325947019788954,0.307698046875,93.82,99.57 76,2.2926743825276694,0.308682421875,93.82,99.55 77,2.317892154057821,0.3095203125,93.86,99.55 78,2.3864409658643932,0.310484375,93.86,99.55 79,2.2767338487837048,0.311702734375,93.88,99.55 80,2.459476047092014,0.31251640625,93.89,99.55 81,2.3905074066585965,0.313465234375,93.87,99.55 82,2.380774630440606,0.314390625,93.88,99.54 83,2.191994031270345,0.31534296875,93.91,99.54 84,2.316111962000529,0.31621796875,93.9,99.54 85,2.307388676537408,0.31720625,93.89,99.54 86,2.27423980500963,0.318103125,93.9,99.54 87,2.3265264564090304,0.31895,93.91,99.54 88,2.265656683180067,0.31984140625,93.94,99.54 89,2.3482420444488525,0.320740625,93.97,99.54 90,2.3093814849853516,0.321472265625,93.97,99.54 91,2.3022206094529896,0.322198046875,93.97,99.54 92,2.3547442489200168,0.3228609375,93.99,99.54 93,2.246518611907959,0.3234953125,93.97,99.54 94,2.3851645787556968,0.324027734375,93.98,99.56 95,2.3422129816479154,0.324636328125,93.93,99.56 96,2.2714282936520047,0.325282421875,93.93,99.56 97,2.374925719367133,0.32581875,93.93,99.55 98,2.3505734867519803,0.326328125,93.94,99.54 99,2.418484025531345,0.326844140625,93.95,99.55 100,2.293912728627523,0.32743125,93.94,99.55 101,2.446575509177314,0.327946875,93.95,99.55 102,2.309349775314331,0.32843125,93.96,99.55 103,2.316555658976237,0.3289484375,93.95,99.55 104,2.3285930156707764,0.329358203125,93.94,99.56 105,2.2936308648851185,0.32980234375,93.93,99.56 106,2.3283233642578125,0.330234765625,93.95,99.57 107,2.3618712955050998,0.33079453125,93.95,99.57 108,2.3460081683264837,0.3311765625,93.95,99.57 109,2.352536598841349,0.33160625,93.96,99.57 ================================================ FILE: models/as_mlp.py ================================================ # -------------------------------------------------------- # AS-MLP # Licensed under The MIT License [see LICENSE for details] # Written by Zehao Yu and Dongze Lian (AS-MLP) # -------------------------------------------------------- import logging import math from copy import deepcopy from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.fx_features import register_notrace_function from timm.models.helpers import build_model_with_cfg from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.models.layers import _assert from timm.models.registry import register_model from timm.models.vision_transformer import checkpoint_filter_fn _logger = logging.getLogger(__name__) def _cfg(url='', file='', **kwargs): return { 'url': url, 'file': file, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } default_cfgs = { 'as_base_patch4_window7_224': _cfg( file='/path/to/asmlp_base_patch4_shift5_224.pth' ), } class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., tuning_mode=None): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1) self.drop = nn.Dropout(drop) self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features) self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features) def forward(self, x): x = self.fc1(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) x = self.act(x) x = self.drop(x) x = self.fc2(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) x = self.drop(x) return x class AxialShift(nn.Module): r""" Axial shift Args: dim (int): Number of input channels. shift_size (int): shift size . as_bias (bool, optional): If True, add a learnable bias to as mlp. Default: True proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, shift_size, as_bias=True, proj_drop=0., tuning_mode=None): super().__init__() self.dim = dim self.shift_size = shift_size self.pad = shift_size // 2 self.conv1 = nn.Conv2d(dim, dim, 1, 1, 0, groups=1, bias=as_bias) self.conv2_1 = nn.Conv2d(dim, dim, 1, 1, 0, groups=1, bias=as_bias) self.conv2_2 = nn.Conv2d(dim, dim, 1, 1, 0, groups=1, bias=as_bias) self.conv3 = nn.Conv2d(dim, dim, 1, 1, 0, groups=1, bias=as_bias) self.actn = nn.GELU() self.norm1 = MyNorm(dim) self.norm2 = MyNorm(dim) self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim) self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) self.ssf_scale_3, self.ssf_shift_3 = init_ssf_scale_shift(dim) self.ssf_scale_4, self.ssf_shift_4 = init_ssf_scale_shift(dim) def forward(self, x): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, C, H, W = x.shape x = self.conv1(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) x = self.norm1(x) x = self.actn(x) x = F.pad(x, (self.pad, self.pad, self.pad, self.pad) , "constant", 0) xs = torch.chunk(x, self.shift_size, 1) def shift(dim): x_shift = [ torch.roll(x_c, shift, dim) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))] x_cat = torch.cat(x_shift, 1) x_cat = torch.narrow(x_cat, 2, self.pad, H) x_cat = torch.narrow(x_cat, 3, self.pad, W) return x_cat x_shift_lr = shift(3) x_shift_td = shift(2) if self.tuning_mode == 'ssf': x_lr = ssf_ada(self.conv2_1(x_shift_lr), self.ssf_scale_2, self.ssf_shift_2) x_td = ssf_ada(self.conv2_2(x_shift_td), self.ssf_scale_3, self.ssf_shift_3) else: x_lr = self.conv2_1(x_shift_lr) x_td = self.conv2_2(x_shift_td) x_lr = self.actn(x_lr) x_td = self.actn(x_td) x = x_lr + x_td x = self.norm2(x) x = self.conv3(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_4, self.ssf_shift_4) return x def extra_repr(self) -> str: return f'dim={self.dim}, shift_size={self.shift_size}' def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # conv1 flops += N * self.dim * self.dim # norm 1 flops += N * self.dim # conv2_1 conv2_2 flops += N * self.dim * self.dim * 2 # x_lr + x_td flops += N * self.dim # norm2 flops += N * self.dim # norm3 flops += N * self.dim * self.dim return flops class AxialShiftedBlock(nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. shift_size (int): Shift size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. as_bias (bool, optional): If True, add a learnable bias to Axial Mlp. Default: True drop (float, optional): Dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, input_resolution, shift_size=7, mlp_ratio=4., as_bias=True, drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, tuning_mode=None): super().__init__() self.dim = dim self.input_resolution = input_resolution self.shift_size = shift_size self.mlp_ratio = mlp_ratio self.norm1 = norm_layer(dim) self.axial_shift = AxialShift(dim, shift_size=shift_size, as_bias=as_bias, proj_drop=drop, tuning_mode=tuning_mode) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, tuning_mode=tuning_mode) self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim) self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) def forward(self, x): B, C, H, W = x.shape shortcut = x x = self.norm1(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) # axial shift block x = self.axial_shift(x) # B, C, H, W # FFN x = shortcut + self.drop_path(x) if self.tuning_mode == 'ssf': x = x + self.drop_path(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2))) else: x = x + self.drop_path(self.mlp(self.norm2(x))) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, " \ f"shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # shift mlp flops += self.axial_shift.flops(H * W) # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, tuning_mode=None): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Conv2d(4 * dim, 2 * dim, 1, 1, bias=False) self.norm = norm_layer(4 * dim) self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(4 * dim) def forward(self, x): """ x: B, H*W, C """ B, C, H, W = x.shape #assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x = x.view(B, C, H, W) x0 = x[:, :, 0::2, 0::2] # B C H/2 W/2 x1 = x[:, :, 1::2, 0::2] # B C H/2 W/2 x2 = x[:, :, 0::2, 1::2] # B C H/2 W/2 x3 = x[:, :, 1::2, 1::2] # B C H/2 W/2 x = torch.cat([x0, x1, x2, x3], 1) # B 4*C H/2 W/2 x = self.norm(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) x = self.reduction(x) return x def extra_repr(self) -> str: return f"input_resolution={self.input_resolution}, dim={self.dim}" def flops(self): H, W = self.input_resolution flops = H * W * self.dim flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim return flops class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, dim, input_resolution, depth, shift_size, mlp_ratio=4., as_bias=True, drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, tuning_mode=None): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ AxialShiftedBlock(dim=dim, input_resolution=input_resolution, shift_size=shift_size, mlp_ratio=mlp_ratio, as_bias=as_bias, drop=drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, tuning_mode=tuning_mode[i]) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer, tuning_mode=tuning_mode) else: self.downsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): flops = 0 for blk in self.blocks: flops += blk.flops() if self.downsample is not None: flops += self.downsample.flops() return flops class PatchEmbed(nn.Module): r""" Image to Patch Embedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, tuning_mode=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim) if norm_layer: self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(embed_dim) def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x)#.flatten(2).transpose(1, 2) # B Ph*Pw C if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) if self.norm is not None: x = self.norm(x) x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) else: if self.norm is not None: x = self.norm(x) return x def flops(self): Ho, Wo = self.patches_resolution flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops def MyNorm(dim): return nn.GroupNorm(1, dim) def init_ssf_scale_shift(dim): scale = nn.Parameter(torch.ones(dim)) shift = nn.Parameter(torch.zeros(dim)) nn.init.normal_(scale, mean=1, std=.02) nn.init.normal_(shift, std=.02) return scale, shift def ssf_ada(x, scale, shift): assert scale.shape == shift.shape if x.shape[-1] == scale.shape[0]: return x * scale + shift elif x.shape[1] == scale.shape[0]: return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1) else: raise ValueError('the input tensor shape does not match the shape of the scale factor.') class AS_MLP(nn.Module): r""" AS-MLP A PyTorch impl of : `AS-MLP: An Axial Shifted MLP Architecture for Vision` - https://arxiv.org/pdf/xxx.xxx Args: img_size (int | tuple(int)): Input image size. Default 224 patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each AS-MLP layer. window_size (int): shift size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 as_bias (bool): If True, add a learnable bias to as-mlp block. Default: True drop_rate (float): Dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.GroupNorm with group=1. patch_norm (bool): If True, add normalization after patch embedding. Default: True use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], shift_size=5, mlp_ratio=4., as_bias=True, drop_rate=0., drop_path_rate=0.1, norm_layer=MyNorm, patch_norm=True, use_checkpoint=False, tuning_mode=None, **kwargs): super().__init__() self.num_classes = num_classes self.num_layers = len(depths) self.embed_dim = embed_dim self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio # split image into non-overlapping patches self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule self.tuning_mode = tuning_mode tuning_mode_list = [[tuning_mode] * depths[i_layer] for i_layer in range(self.num_layers)] if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(self.num_features) # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], shift_size=shift_size, mlp_ratio=self.mlp_ratio, as_bias=as_bias, drop=drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, tuning_mode=tuning_mode_list[i_layer]) self.layers.append(layer) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool2d(1) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() #self.apply(self._init_weights) @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} def get_classifier(self): return self.head def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.norm(x) # B L C if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) x = self.avgpool(x) # B C 1 1 x = torch.flatten(x, 1) return x def forward(self, x): x = self.forward_features(x) x = self.head(x) return x def flops(self): flops = 0 flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) flops += self.num_features * self.num_classes return flops def _create_as_mlp(variant, pretrained=False, **kwargs): model = build_model_with_cfg( AS_MLP, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, **kwargs) return model @register_model def as_base_patch4_window7_224(pretrained=False, **kwargs): """ AS-MLP-B @ 224x224, pretrained ImageNet-1k """ model_kwargs = dict( patch_size=4, shift_size=5, embed_dim=128, depths=(2, 2, 18, 2), **kwargs) return _create_as_mlp('as_base_patch4_window7_224', pretrained=pretrained, **model_kwargs) ================================================ FILE: models/convnext.py ================================================ """ ConvNeXt Paper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman """ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the MIT license import math from collections import OrderedDict from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.fx_features import register_notrace_module from timm.models.helpers import named_apply, build_model_with_cfg, checkpoint_seq from timm.models.layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, to_2tuple from timm.models.registry import register_model __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'stem.0', 'classifier': 'head.fc', **kwargs } default_cfgs = dict( convnext_tiny=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth"), convnext_small=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth"), convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"), convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), convnext_tiny_hnf=_cfg(url=''), convnext_base_in22ft1k=_cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'), convnext_large_in22ft1k=_cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth'), convnext_xlarge_in22ft1k=_cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth'), convnext_base_384_in22ft1k=_cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth', input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), convnext_large_384_in22ft1k=_cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth', input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), convnext_xlarge_384_in22ft1k=_cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth', input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), convnext_base_in22k=_cfg( url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841), convnext_large_in22k=_cfg( url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841), convnext_xlarge_in22k=_cfg( url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841), ) def _is_contiguous(tensor: torch.Tensor) -> bool: # jit is oh so lovely :/ # if torch.jit.is_tracing(): # return True if torch.jit.is_scripting(): return tensor.is_contiguous() else: return tensor.is_contiguous(memory_format=torch.contiguous_format) @register_notrace_module class LayerNorm2d(nn.LayerNorm): r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). """ def __init__(self, normalized_shape, eps=1e-6): super().__init__(normalized_shape, eps=eps) def forward(self, x) -> torch.Tensor: if _is_contiguous(x): return F.layer_norm( x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) else: s, u = torch.var_mean(x, dim=1, keepdim=True) x = (x - u) * torch.rsqrt(s + self.eps) x = x * self.weight[:, None, None] + self.bias[:, None, None] return x class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0., tuning_mode=None): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features) self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features) def forward(self, x): x = self.fc1(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) x = self.act(x) x = self.drop1(x) x = self.fc2(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) x = self.drop2(x) return x class ConvNeXtBlock(nn.Module): """ ConvNeXt Block There are two equivalent implementations: (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW. Args: dim (int): Number of input channels. drop_path (float): Stochastic depth rate. Default: 0.0 ls_init_value (float): Init value for Layer Scale. Default: 1e-6. """ def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=False, mlp_ratio=4, norm_layer=None, tuning_mode=None): super().__init__() if not norm_layer: norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) mlp_layer = ConvMlp if conv_mlp else Mlp self.use_conv_mlp = conv_mlp self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv self.norm = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, tuning_mode=tuning_mode) self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim) self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) def forward(self, x): shortcut = x x = self.conv_dw(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) if self.use_conv_mlp: x = self.norm(x) x = self.mlp(x) else: x = x.permute(0, 2, 3, 1) x = self.norm(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) x = self.mlp(x) x = x.permute(0, 3, 1, 2) if self.gamma is not None: x = x.mul(self.gamma.reshape(1, -1, 1, 1)) x = self.drop_path(x) + shortcut return x class Downsample(nn.Module): """ 2D Image to Downsample """ def __init__(self, dim, out_dim, kernel_size, stride, norm_layer=None, tuning_mode=None): super().__init__() self.norm = norm_layer(dim) self.proj = nn.Conv2d(dim, out_dim, kernel_size=stride, stride=stride) self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim) self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_dim) def forward(self, x): if self.tuning_mode == 'ssf': x = ssf_ada(self.norm(x), self.ssf_scale_1, self.ssf_shift_1) x = ssf_ada(self.proj(x), self.ssf_scale_2, self.ssf_shift_2) else: x = self.norm(x) x = self.proj(x) return x class ConvNeXtStage(nn.Module): def __init__( self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=False, norm_layer=None, cl_norm_layer=None, cross_stage=False, tuning_mode=None): super().__init__() self.grad_checkpointing = False if in_chs != out_chs or stride > 1: self.downsample = Downsample(dim=in_chs, out_dim=out_chs, kernel_size=stride, stride=stride, norm_layer=norm_layer, tuning_mode=tuning_mode) else: self.downsample = nn.Identity() dp_rates = dp_rates or [0.] * depth self.blocks = nn.Sequential(*[ConvNeXtBlock( dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp, norm_layer=norm_layer if conv_mlp else cl_norm_layer, tuning_mode=tuning_mode[j]) for j in range(depth)] ) def forward(self, x): x = self.downsample(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) return x class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, tuning_mode=None): super().__init__() self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim) self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(embed_dim) def forward(self, x): if self.tuning_mode == 'ssf': x = ssf_ada(self.proj(x), self.ssf_scale_1, self.ssf_shift_1) x = ssf_ada(self.norm(x), self.ssf_scale_2, self.ssf_shift_2) else: x = self.proj(x) x = self.norm(x) return x def init_ssf_scale_shift(dim): scale = nn.Parameter(torch.ones(dim)) shift = nn.Parameter(torch.zeros(dim)) nn.init.normal_(scale, mean=1, std=.02) nn.init.normal_(shift, std=.02) return scale, shift def ssf_ada(x, scale, shift): assert scale.shape == shift.shape if x.shape[-1] == scale.shape[0]: return x * scale + shift elif x.shape[1] == scale.shape[0]: return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1) else: raise ValueError('the input tensor shape does not match the shape of the scale factor.') class ConvNeXt(nn.Module): r""" ConvNeXt A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf Args: in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768] drop_rate (float): Head dropout rate drop_path_rate (float): Stochastic depth rate. Default: 0. ls_init_value (float): Init value for Layer Scale. Default: 1e-6. head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. """ def __init__( self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4, depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=False, head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0., tuning_mode=None ): super().__init__() assert output_stride == 32 if norm_layer is None: norm_layer = partial(LayerNorm2d, eps=1e-6) cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) else: assert conv_mlp,\ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' cl_norm_layer = norm_layer self.num_classes = num_classes self.drop_rate = drop_rate self.feature_info = [] # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 self.stem = PatchEmbed(patch_size=4, in_chans=3, embed_dim=dims[0], norm_layer=norm_layer, tuning_mode=tuning_mode) self.tuning_mode = tuning_mode tuning_mode_list = [[tuning_mode] * depths[i_layer] for i_layer in range(len(depths))] if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dims[3]) self.stages = nn.Sequential() dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] curr_stride = patch_size prev_chs = dims[0] stages = [] # 4 feature resolution stages, each consisting of multiple residual blocks for i in range(4): stride = 2 if i > 0 else 1 # FIXME support dilation / output_stride curr_stride *= stride out_chs = dims[i] stages.append(ConvNeXtStage( prev_chs, out_chs, stride=stride, depth=depths[i], dp_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, norm_layer=norm_layer, cl_norm_layer=cl_norm_layer, tuning_mode=tuning_mode_list[i]) ) prev_chs = out_chs # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) self.num_features = prev_chs if head_norm_first: # norm -> global pool -> fc ordering, like most other nets (not compat with FB weights) self.norm_pre = norm_layer(self.num_features) # final norm layer, before pooling self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) else: # pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) self.norm_pre = nn.Identity() self.head = nn.Sequential(OrderedDict([ ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), ('norm', norm_layer(self.num_features)), ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), ('drop', nn.Dropout(self.drop_rate)), ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) ])) named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): for s in self.stages: s.grad_checkpointing = enable def get_classifier(self): return self.head.fc def reset_classifier(self, num_classes=0, global_pool='avg'): if isinstance(self.head, ClassifierHead): # norm -> global pool -> fc self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) else: # pool -> norm -> fc self.head = nn.Sequential(OrderedDict([ ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), ('norm', self.head.norm), ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), ('drop', nn.Dropout(self.drop_rate)), ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) ])) def forward_features(self, x): x = self.stem(x) x = self.stages(x) x = self.norm_pre(x) if self.tuning_mode == 'ssf': x = ssf_ada(self.norm_pre(x), self.ssf_scale_1, self.ssf_shift_1) return x def forward(self, x): x = self.forward_features(x) x = self.head(x) return x def _init_weights(module, name=None, head_init_scale=1.0): if isinstance(module, nn.Conv2d): trunc_normal_(module.weight, std=.02) nn.init.constant_(module.bias, 0) elif isinstance(module, nn.Linear): trunc_normal_(module.weight, std=.02) nn.init.constant_(module.bias, 0) if name and 'head.' in name: module.weight.data.mul_(head_init_scale) module.bias.data.mul_(head_init_scale) def checkpoint_filter_fn(state_dict, model): """ Remap FB checkpoints -> timm """ #ipdb.set_trace() if 'model' in state_dict: state_dict = state_dict['model'] out_dict = {} import re for k, v in state_dict.items(): k = k.replace('downsample_layers.0.0.', 'stem.proj.') k = k.replace('downsample_layers.0.1.', 'stem.norm.') k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) k = re.sub(r'downsample_layers.([0-9]+).([0]+)', r'stages.\1.downsample.norm', k) k = re.sub(r'downsample_layers.([0-9]+).([1]+)', r'stages.\1.downsample.proj', k) k = k.replace('dwconv', 'conv_dw') k = k.replace('pwconv', 'mlp.fc') k = k.replace('head.', 'head.fc.') if k.startswith('norm.'): k = k.replace('norm', 'head.norm') if v.ndim == 2 and 'head' not in k: model_shape = model.state_dict()[k].shape v = v.reshape(model_shape) out_dict[k] = v return out_dict def _create_convnext(variant, pretrained=False, **kwargs): model = build_model_with_cfg( ConvNeXt, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), **kwargs) return model @register_model def convnext_tiny(pretrained=False, **kwargs): model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args) return model @register_model def convnext_tiny_hnf(pretrained=False, **kwargs): model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, **kwargs) model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args) return model @register_model def convnext_small(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) model = _create_convnext('convnext_small', pretrained=pretrained, **model_args) return model @register_model def convnext_base(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) model = _create_convnext('convnext_base', pretrained=pretrained, **model_args) return model @register_model def convnext_large(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) model = _create_convnext('convnext_large', pretrained=pretrained, **model_args) return model @register_model def convnext_base_in22ft1k(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) model = _create_convnext('convnext_base_in22ft1k', pretrained=pretrained, **model_args) return model @register_model def convnext_large_in22ft1k(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) model = _create_convnext('convnext_large_in22ft1k', pretrained=pretrained, **model_args) return model @register_model def convnext_xlarge_in22ft1k(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) model = _create_convnext('convnext_xlarge_in22ft1k', pretrained=pretrained, **model_args) return model @register_model def convnext_base_384_in22ft1k(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) model = _create_convnext('convnext_base_384_in22ft1k', pretrained=pretrained, **model_args) return model @register_model def convnext_large_384_in22ft1k(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) model = _create_convnext('convnext_large_384_in22ft1k', pretrained=pretrained, **model_args) return model @register_model def convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) model = _create_convnext('convnext_xlarge_384_in22ft1k', pretrained=pretrained, **model_args) return model @register_model def convnext_base_in22k(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) model = _create_convnext('convnext_base_in22k', pretrained=pretrained, **model_args) return model @register_model def convnext_large_in22k(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) model = _create_convnext('convnext_large_in22k', pretrained=pretrained, **model_args) return model @register_model def convnext_xlarge_in22k(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args) return model ================================================ FILE: models/swin_transformer.py ================================================ """ Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman """ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # -------------------------------------------------------- import logging import math from copy import deepcopy from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.fx_features import register_notrace_function from timm.models.helpers import build_model_with_cfg, named_apply from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.models.layers import _assert from timm.models.registry import register_model from timm.models.vision_transformer import checkpoint_filter_fn, get_init_weights_vit import ipdb _logger = logging.getLogger(__name__) def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } default_cfgs = { # patch models (my experiments) 'swin_base_patch4_window12_384': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth', input_size=(3, 384, 384), crop_pct=1.0), 'swin_base_patch4_window7_224': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth', ), 'swin_large_patch4_window12_384': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth', input_size=(3, 384, 384), crop_pct=1.0), 'swin_large_patch4_window7_224': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth', ), 'swin_small_patch4_window7_224': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', ), 'swin_tiny_patch4_window7_224': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', ), 'swin_base_patch4_window12_384_in22k': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), 'swin_base_patch4_window7_224_in22k': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', num_classes=21841), 'swin_large_patch4_window12_384_in22k': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth', input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), 'swin_large_patch4_window7_224_in22k': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth', num_classes=21841), } class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0., tuning_mode=None): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features) self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features) def forward(self, x): x = self.fc1(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) x = self.act(x) x = self.drop1(x) x = self.fc2(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) x = self.drop2(x) return x def window_partition(x, window_size: int): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows @register_notrace_function # reason: int argument is a Proxy def window_reverse(windows, window_size: int, H: int, W: int): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., tuning_mode=None): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim * 3) self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) def forward(self, x, mask: Optional[torch.Tensor] = None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape if self.tuning_mode == 'ssf': #qkv = (self.qkv(x) * self.ssf_scale_1 + self.ssf_shift_1).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = (ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1)).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) else: qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) x = self.proj_drop(x) return x class SwinTransformerBlock(nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, tuning_mode=None): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, tuning_mode=tuning_mode) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, tuning_mode=tuning_mode) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None self.register_buffer("attn_mask", attn_mask) self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim) self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape shortcut = x x = self.norm1(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) if self.tuning_mode == 'ssf': x = x + self.drop_path(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2))) else: x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, tuning_mode=None): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(4 * dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, L, C = x.shape _assert(L == H * W, "input feature has wrong size") _assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.") x = x.view(B, H, W, C) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) x = self.reduction(x) return x def extra_repr(self) -> str: return f"input_resolution={self.input_resolution}, dim={self.dim}" def flops(self): H, W = self.input_resolution flops = H * W * self.dim flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim return flops class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, tuning_mode=None): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock( dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, tuning_mode=tuning_mode[i]) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer, tuning_mode=tuning_mode) else: self.downsample = None def forward(self, x): for blk in self.blocks: if not torch.jit.is_scripting() and self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, tuning_mode=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.norm_layer = norm_layer self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim) if norm_layer: self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(embed_dim) def forward(self, x): B, C, H, W = x.shape _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) if self.norm_layer: x = ssf_ada(self.norm(x), self.ssf_scale_2, self.ssf_shift_2) else: x = self.norm(x) else: x = self.norm(x) return x def init_ssf_scale_shift(dim): scale = nn.Parameter(torch.ones(dim)) shift = nn.Parameter(torch.zeros(dim)) nn.init.normal_(scale, mean=1, std=.02) nn.init.normal_(shift, std=.02) return scale, shift def ssf_ada(x, scale, shift): assert scale.shape == shift.shape if x.shape[-1] == scale.shape[0]: return x * scale + shift elif x.shape[1] == scale.shape[0]: return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1) else: raise ValueError('the input tensor shape does not match the shape of the scale factor.') class SwinTransformer(nn.Module): r""" Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: img_size (int | tuple(int)): Input image size. Default 224 patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. Default: True use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), window_size=7, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, weight_init='', tuning_mode=None, **kwargs): super().__init__() self.num_classes = num_classes self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio # split image into non-overlapping patches self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None, tuning_mode=tuning_mode) num_patches = self.patch_embed.num_patches self.patch_grid = self.patch_embed.grid_size # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) else: self.absolute_pos_embed = None self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule self.tuning_mode = tuning_mode tuning_mode_list = [[tuning_mode] * depths[i_layer] for i_layer in range(self.num_layers)] if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(self.num_features) # build layers layers = [] for i_layer in range(self.num_layers): layers += [BasicLayer( dim=int(embed_dim * 2 ** i_layer), input_resolution=(self.patch_grid[0] // (2 ** i_layer), self.patch_grid[1] // (2 ** i_layer)), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, tuning_mode=tuning_mode_list[i_layer]) ] self.layers = nn.Sequential(*layers) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() if weight_init != 'skip': self.init_weights(weight_init) @torch.jit.ignore def init_weights(self, mode=''): assert mode in ('jax', 'jax_nlhb', 'moco', '') if self.absolute_pos_embed is not None: trunc_normal_(self.absolute_pos_embed, std=.02) head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. named_apply(get_init_weights_vit(mode, head_bias=head_bias), self) @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} def get_classifier(self): return self.head def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) if self.absolute_pos_embed is not None: x = x + self.absolute_pos_embed x = self.pos_drop(x) x = self.layers(x) x = self.norm(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) x = self.avgpool(x.transpose(1, 2)) # B C 1 x = torch.flatten(x, 1) return x def forward(self, x): x = self.forward_features(x) x = self.head(x) return x def _create_swin_transformer(variant, pretrained=False, **kwargs): model = build_model_with_cfg( SwinTransformer, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, **kwargs) return model @register_model def swin_base_patch4_window12_384(pretrained=False, **kwargs): """ Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k """ model_kwargs = dict( patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs) @register_model def swin_base_patch4_window7_224(pretrained=False, **kwargs): """ Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k """ model_kwargs = dict( patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs) @register_model def swin_large_patch4_window12_384(pretrained=False, **kwargs): """ Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k """ model_kwargs = dict( patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs) @register_model def swin_large_patch4_window7_224(pretrained=False, **kwargs): """ Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k """ model_kwargs = dict( patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs) @register_model def swin_small_patch4_window7_224(pretrained=False, **kwargs): """ Swin-S @ 224x224, trained ImageNet-1k """ model_kwargs = dict( patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs) return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs) @register_model def swin_tiny_patch4_window7_224(pretrained=False, **kwargs): """ Swin-T @ 224x224, trained ImageNet-1k """ model_kwargs = dict( patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs) return _create_swin_transformer('swin_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs) @register_model def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs): """ Swin-B @ 384x384, trained ImageNet-22k """ model_kwargs = dict( patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs) @register_model def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs): """ Swin-B @ 224x224, trained ImageNet-22k """ model_kwargs = dict( patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs) @register_model def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs): """ Swin-L @ 384x384, trained ImageNet-22k """ model_kwargs = dict( patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs) @register_model def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs): """ Swin-L @ 224x224, trained ImageNet-22k """ model_kwargs = dict( patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs) ================================================ FILE: models/vision_transformer.py ================================================ """ Vision Transformer (ViT) in PyTorch A PyTorch implement of Vision Transformers as described in: 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` - https://arxiv.org/abs/2106.10270 The official jax code is released and available at https://github.com/google-research/vision_transformer Acknowledgments: * The paper authors for releasing code and weights, thanks! * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out for some einops/einsum fun * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT * Bert reference code checks against Huggingface Transformers and Tensorflow Bert Hacked together by / Copyright 2020, Ross Wightman """ import math import logging from functools import partial from collections import OrderedDict from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv, resolve_pretrained_cfg, checkpoint_seq from timm.models.layers import DropPath, trunc_normal_, lecun_normal_, _assert from timm.models.layers.helpers import to_2tuple from timm.models.registry import register_model import ipdb _logger = logging.getLogger(__name__) def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } default_cfgs = { # patch models (weights from official Google JAX impl) 'vit_tiny_patch16_224': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 'vit_tiny_patch16_384': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), 'vit_small_patch32_224': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 'vit_small_patch32_384': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), 'vit_small_patch16_224': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 'vit_small_patch16_384': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch32_224': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 'vit_base_patch32_384': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch16_224': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 'vit_base_patch16_384': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch8_224': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 'vit_large_patch32_224': _cfg( url='', # no official model weights for this combo, only for in21k ), 'vit_large_patch32_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', input_size=(3, 384, 384), crop_pct=1.0), 'vit_large_patch16_224': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 'vit_large_patch16_384': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), # patch models, imagenet21k (weights from official Google JAX impl) 'vit_tiny_patch16_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', num_classes=21843), 'vit_small_patch16_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', num_classes=21843), 'vit_base_patch16_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', num_classes=21843), 'vit_large_patch16_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', num_classes=21843), } class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0., tuning_mode=None): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features) self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features) def forward(self, x): x = self.fc1(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) x = self.act(x) x = self.drop1(x) x = self.fc2(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) x = self.drop2(x) return x class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., tuning_mode=None): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim * 3) self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) def forward(self, x): B, N, C = x.shape if self.tuning_mode == 'ssf': qkv = (ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1)).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) else: qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) x = self.proj_drop(x) return x class LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=False): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x): return x.mul_(self.gamma) if self.inplace else x * self.gamma class Block(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, tuning_mode=None): super().__init__() self.dim = dim self.norm1 = norm_layer(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, tuning_mode=tuning_mode) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop, tuning_mode=tuning_mode) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim) self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) def forward(self, x): if self.tuning_mode == 'ssf': x = x + self.drop_path1(self.ls1(self.attn(ssf_ada(self.norm1(x), self.ssf_scale_1, self.ssf_shift_1)))) x = x + self.drop_path2(self.ls2(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2)))) else: x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x class ResPostBlock(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.init_values = init_values self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.norm1 = norm_layer(dim) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.norm2 = norm_layer(dim) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.init_weights() def init_weights(self): # NOTE this init overrides that base model init with specific changes for the block type if self.init_values is not None: nn.init.constant_(self.norm1.weight, self.init_values) nn.init.constant_(self.norm2.weight, self.init_values) def forward(self, x): x = x + self.drop_path1(self.norm1(self.attn(x))) x = x + self.drop_path2(self.norm2(self.mlp(x))) return x class ParallelBlock(nn.Module): def __init__( self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.num_parallel = num_parallel self.attns = nn.ModuleList() self.ffns = nn.ModuleList() for _ in range(num_parallel): self.attns.append(nn.Sequential(OrderedDict([ ('norm', norm_layer(dim)), ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ]))) self.ffns.append(nn.Sequential(OrderedDict([ ('norm', norm_layer(dim)), ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ]))) def _forward_jit(self, x): x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) return x @torch.jit.ignore def _forward(self, x): x = x + sum(attn(x) for attn in self.attns) x = x + sum(ffn(x) for ffn in self.ffns) return x def forward(self, x): if torch.jit.is_scripting() or torch.jit.is_tracing(): return self._forward_jit(x) else: return self._forward(x) class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, tuning_mode=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.norm_layer = norm_layer self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() self.tuning_mode = tuning_mode if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim) if norm_layer: self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(embed_dim) def forward(self, x): B, C, H, W = x.shape _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) if self.norm_layer: x = ssf_ada(self.norm(x), self.ssf_scale_2, self.ssf_shift_2) else: x = self.norm(x) else: x = self.norm(x) return x def init_ssf_scale_shift(dim): scale = nn.Parameter(torch.ones(dim)) shift = nn.Parameter(torch.zeros(dim)) nn.init.normal_(scale, mean=1, std=.02) nn.init.normal_(shift, std=.02) return scale, shift def ssf_ada(x, scale, shift): assert scale.shape == shift.shape if x.shape[-1] == scale.shape[0]: return x * scale + shift elif x.shape[1] == scale.shape[0]: return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1) else: raise ValueError('the input tensor shape does not match the shape of the scale factor.') class VisionTransformer(nn.Module): """ Vision Transformer A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 """ def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, tuning_mode=None): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size in_chans (int): number of input channels num_classes (int): number of classes for classification head global_pool (str): type of global pooling for final sequence (default: 'token') embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True init_values: (float): layer-scale init values class_token (bool): use class token fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate weight_init (str): weight init scheme embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer act_layer: (nn.Module): MLP activation layer """ super().__init__() assert global_pool in ('', 'avg', 'token') assert class_token or global_pool != 'token' use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_tokens = 1 if class_token else 0 self.grad_checkpointing = False self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, tuning_mode=tuning_mode) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.tuning_mode = tuning_mode tuning_mode_list = [tuning_mode] * depth if tuning_mode == 'ssf': self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(self.num_features) self.blocks = nn.Sequential(*[ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, tuning_mode=tuning_mode_list[i]) for i in range(depth)]) self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() # Classifier Head self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() if weight_init != 'skip': self.init_weights(weight_init) def init_weights(self, mode=''): assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. trunc_normal_(self.pos_embed, std=.02) if self.cls_token is not None: nn.init.normal_(self.cls_token, std=1e-6) named_apply(get_init_weights_vit(mode, head_bias), self) def _init_weights(self, m): # this fn left here for compat with downstream users init_weights_vit_timm(m) @torch.jit.ignore() def load_pretrained(self, checkpoint_path, prefix=''): _load_weights(self, checkpoint_path, prefix) @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token', 'dist_token'} @torch.jit.ignore def group_matcher(self, coarse=False): return dict( stem=r'^cls_token|pos_embed|patch_embed', # stem and embed blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] ) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self): return self.head def reset_classifier(self, num_classes: int, global_pool=None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg', 'token') self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = self.pos_drop(x + self.pos_embed) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) x = self.norm(x) if self.tuning_mode == 'ssf': x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) return x def forward_head(self, x, pre_logits: bool = False): if self.global_pool: x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) return x if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) x = self.forward_head(x) return x def init_weights_vit_timm(module: nn.Module, name: str = ''): """ ViT weight initialization, original timm impl (for reproducibility) """ if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=.02) if module.bias is not None: nn.init.zeros_(module.bias) elif hasattr(module, 'init_weights'): module.init_weights() def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): """ ViT weight initialization, matching JAX (Flax) impl """ if isinstance(module, nn.Linear): if name.startswith('head'): nn.init.zeros_(module.weight) nn.init.constant_(module.bias, head_bias) else: nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) elif isinstance(module, nn.Conv2d): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif hasattr(module, 'init_weights'): module.init_weights() def init_weights_vit_moco(module: nn.Module, name: str = ''): """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ if isinstance(module, nn.Linear): if 'qkv' in name: # treat the weights of Q, K, V separately val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) nn.init.uniform_(module.weight, -val, val) else: nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif hasattr(module, 'init_weights'): module.init_weights() def get_init_weights_vit(mode='jax', head_bias: float = 0.): if 'jax' in mode: return partial(init_weights_vit_jax, head_bias=head_bias) elif 'moco' in mode: return init_weights_vit_moco else: return init_weights_vit_timm @torch.no_grad() def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): """ Load weights from .npz checkpoints for official Google Brain Flax implementation """ import numpy as np def _n2p(w, t=True): if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: w = w.flatten() if t: if w.ndim == 4: w = w.transpose([3, 2, 0, 1]) elif w.ndim == 3: w = w.transpose([2, 0, 1]) elif w.ndim == 2: w = w.transpose([1, 0]) return torch.from_numpy(w) w = np.load(checkpoint_path) if not prefix and 'opt/target/embedding/kernel' in w: prefix = 'opt/target/' if hasattr(model.patch_embed, 'backbone'): # hybrid backbone = model.patch_embed.backbone stem_only = not hasattr(backbone, 'stem') stem = backbone if stem_only else backbone.stem stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) if not stem_only: for i, stage in enumerate(backbone.stages): for j, block in enumerate(stage.blocks): bp = f'{prefix}block{i + 1}/unit{j + 1}/' for r in range(3): getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) if block.downsample is not None: block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) else: embed_conv_w = adapt_input_conv( model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) model.patch_embed.proj.weight.copy_(embed_conv_w) model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) if pos_embed_w.shape != model.pos_embed.shape: pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) model.pos_embed.copy_(pos_embed_w) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) for i, block in enumerate(model.blocks.children()): block_prefix = f'{prefix}Transformer/encoderblock_{i}/' mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) block.attn.qkv.weight.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) block.attn.qkv.bias.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) for r in range(2): getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) ntok_new = posemb_new.shape[1] if num_tokens: posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] ntok_new -= num_tokens else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] gs_old = int(math.sqrt(len(posemb_grid))) if not len(gs_new): # backwards compatibility gs_new = [int(math.sqrt(ntok_new))] * 2 assert len(gs_new) >= 2 _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) posemb = torch.cat([posemb_tok, posemb_grid], dim=1) return posemb def checkpoint_filter_fn(state_dict, model): """ convert patch embedding weight from manual patchify + linear proj to conv""" out_dict = {} if 'model' in state_dict: # For deit models state_dict = state_dict['model'] for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k and len(v.shape) < 4: # For old models that I trained prior to conv based patchification O, I, H, W = model.patch_embed.proj.weight.shape v = v.reshape(O, -1, H, W) elif k == 'pos_embed' and v.shape != model.pos_embed.shape: # To resize pos embedding when using model at different size from pretrained weights v = resize_pos_embed( v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) elif 'pre_logits' in k: # NOTE representation layer removed as not used in latest 21k/1k pretrained weights continue out_dict[k] = v return out_dict def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) model = build_model_with_cfg( VisionTransformer, variant, pretrained, pretrained_cfg=pretrained_cfg, pretrained_filter_fn=checkpoint_filter_fn, pretrained_custom_load='npz' in pretrained_cfg['url'], **kwargs) return model @register_model def vit_tiny_patch16_224(pretrained=False, **kwargs): """ ViT-Tiny (Vit-Ti/16) """ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_tiny_patch16_384(pretrained=False, **kwargs): """ ViT-Tiny (Vit-Ti/16) @ 384x384. """ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_small_patch16_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_small_patch16_384(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch16_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch16_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_large_patch16_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_large_patch16_384(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): """ ViT-Tiny (Vit-Ti/16). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @register_model def vit_small_patch16_224_in21k(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch16_224_in21k(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @register_model def vit_large_patch16_224_in21k(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model ================================================ FILE: optim_factory.py ================================================ """ Optimizer Factory w/ Custom Weight Decay Hacked together by / Copyright 2021 Ross Wightman """ import json from itertools import islice from typing import Optional, Callable, Tuple, Dict, Union import torch import torch.nn as nn import torch.optim as optim #from timm.models.helpers import group_parameters from timm.optim.adabelief import AdaBelief from timm.optim.adafactor import Adafactor from timm.optim.adahessian import Adahessian from timm.optim.adamp import AdamP from timm.optim.lamb import Lamb from timm.optim.lars import Lars from timm.optim.lookahead import Lookahead from timm.optim.madgrad import MADGRAD from timm.optim.nadam import Nadam from timm.optim.nvnovograd import NvNovoGrad from timm.optim.radam import RAdam from timm.optim.rmsprop_tf import RMSpropTF from timm.optim.sgdp import SGDP try: from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD has_apex = True except ImportError: has_apex = False def param_groups_weight_decay( model: nn.Module, weight_decay=1e-5, no_weight_decay_list=() ): no_weight_decay_list = set(no_weight_decay_list) decay = [] no_decay = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list: no_decay.append(param) else: decay.append(param) return [ {'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': weight_decay}] def _group(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) def _layer_map(model, layers_per_group=12, num_groups=None): def _in_head(n, hp): if not hp: return True elif isinstance(hp, (tuple, list)): return any([n.startswith(hpi) for hpi in hp]) else: return n.startswith(hp) head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None) names_trunk = [] names_head = [] for n, _ in model.named_parameters(): names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n) # group non-head layers num_trunk_layers = len(names_trunk) if num_groups is not None: layers_per_group = -(num_trunk_layers // -num_groups) names_trunk = list(_group(names_trunk, layers_per_group)) num_trunk_groups = len(names_trunk) layer_map = {n: i for i, l in enumerate(names_trunk) for n in l} layer_map.update({n: num_trunk_groups for n in names_head}) return layer_map def group_with_matcher( named_objects, group_matcher: Union[Dict, Callable], output_values: bool = False, reverse: bool = False ): if isinstance(group_matcher, dict): # dictionary matcher contains a dict of raw-string regex expr that must be compiled compiled = [] for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()): if mspec is None: continue # map all matching specifications into 3-tuple (compiled re, prefix, suffix) if isinstance(mspec, (tuple, list)): # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix) for sspec in mspec: compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])] else: compiled += [(re.compile(mspec), (group_ordinal,), None)] group_matcher = compiled def _get_grouping(name): if isinstance(group_matcher, (list, tuple)): for match_fn, prefix, suffix in group_matcher: r = match_fn.match(name) if r: parts = (prefix, r.groups(), suffix) # map all tuple elem to int for numeric sort, filter out None entries return tuple(map(float, chain.from_iterable(filter(None, parts)))) return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal else: ord = group_matcher(name) if not isinstance(ord, collections.abc.Iterable): return ord, return tuple(ord) # map layers into groups via ordinals (ints or tuples of ints) from matcher grouping = defaultdict(list) for k, v in named_objects: grouping[_get_grouping(k)].append(v if output_values else k) # remap to integers layer_id_to_param = defaultdict(list) lid = -1 for k in sorted(filter(lambda x: x is not None, grouping.keys())): if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]: lid += 1 layer_id_to_param[lid].extend(grouping[k]) if reverse: assert not output_values, "reverse mapping only sensible for name output" # output reverse mapping param_to_layer_id = {} for lid, lm in layer_id_to_param.items(): for n in lm: param_to_layer_id[n] = lid return param_to_layer_id return layer_id_to_param def group_parameters( module: nn.Module, group_matcher, output_values=False, reverse=False, ): return group_with_matcher( module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse) def group_modules( module: nn.Module, group_matcher, output_values=False, reverse=False, ): return group_with_matcher( named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse) def param_groups_layer_decay( model: nn.Module, weight_decay: float = 0.05, no_weight_decay_list: Tuple[str] = (), layer_decay: float = .75, end_layer_decay: Optional[float] = None, ): """ Parameter groups for layer-wise lr decay & weight decay Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 """ no_weight_decay_list = set(no_weight_decay_list) param_group_names = {} # NOTE for debugging param_groups = {} if hasattr(model, 'group_matcher'): # FIXME interface needs more work layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True) else: # fallback layer_map = _layer_map(model) num_layers = max(layer_map.values()) + 1 layer_max = num_layers - 1 layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers)) for name, param in model.named_parameters(): if not param.requires_grad: continue # no decay: all 1D parameters and model specific ones if param.ndim == 1 or name in no_weight_decay_list: g_decay = "no_decay" this_decay = 0. else: g_decay = "decay" this_decay = weight_decay layer_id = layer_map.get(name, layer_max) group_name = "layer_%d_%s" % (layer_id, g_decay) if group_name not in param_groups: this_scale = layer_scales[layer_id] param_group_names[group_name] = { "lr_scale": this_scale, "weight_decay": this_decay, "param_names": [], } param_groups[group_name] = { "lr_scale": this_scale, "weight_decay": this_decay, "params": [], } param_group_names[group_name]["param_names"].append(name) param_groups[group_name]["params"].append(param) # FIXME temporary output to debug new feature print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) return list(param_groups.values()) def optimizer_kwargs(cfg): """ cfg/argparse to kwargs helper Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn. """ kwargs = dict( opt=cfg.opt, lr=cfg.lr, weight_decay=cfg.weight_decay, momentum=cfg.momentum, tuning_mode=cfg.tuning_mode) if getattr(cfg, 'opt_eps', None) is not None: kwargs['eps'] = cfg.opt_eps if getattr(cfg, 'opt_betas', None) is not None: kwargs['betas'] = cfg.opt_betas if getattr(cfg, 'layer_decay', None) is not None: kwargs['layer_decay'] = cfg.layer_decay if getattr(cfg, 'opt_args', None) is not None: kwargs.update(cfg.opt_args) return kwargs def create_optimizer(args, model, filter_bias_and_bn=True): """ Legacy optimizer factory for backwards compatibility. NOTE: Use create_optimizer_v2 for new code. """ return create_optimizer_v2( model, **optimizer_kwargs(cfg=args), filter_bias_and_bn=filter_bias_and_bn, ) def create_optimizer_v2( model_or_params, opt: str = 'sgd', lr: Optional[float] = None, weight_decay: float = 0., momentum: float = 0.9, tuning_mode: str = None, filter_bias_and_bn: bool = True, layer_decay: Optional[float] = None, param_group_fn: Optional[Callable] = None, **kwargs): """ Create an optimizer. TODO currently the model is passed in and all parameters are selected for optimization. For more general use an interface that allows selection of parameters to optimize and lr groups, one of: * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion * expose the parameters interface and leave it up to caller Args: model_or_params (nn.Module): model containing parameters to optimize opt: name of optimizer to create lr: initial learning rate weight_decay: weight decay to apply in optimizer momentum: momentum for momentum based optimizers (others may use betas via kwargs) filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay **kwargs: extra optimizer specific kwargs to pass through Returns: Optimizer """ if isinstance(model_or_params, nn.Module): # TODO: for fine-tuning if tuning_mode: for name, param in model_or_params.named_parameters(): if tuning_mode == 'linear_probe': if "head." not in name: param.requires_grad = False elif tuning_mode == 'ssf': if "head." not in name and "ssf_scale" not in name and "ssf_shift_" not in name: param.requires_grad = False if param.requires_grad == True: print(name) print('freezing parameters finished!') # a model was passed in, extract parameters and add weight decays to appropriate layers no_weight_decay = {} if hasattr(model_or_params, 'no_weight_decay'): no_weight_decay = model_or_params.no_weight_decay() if param_group_fn: parameters = param_group_fn(model_or_params) elif layer_decay is not None: parameters = param_groups_layer_decay( model_or_params, weight_decay=weight_decay, layer_decay=layer_decay, no_weight_decay_list=no_weight_decay) weight_decay = 0. elif weight_decay and filter_bias_and_bn: parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay) weight_decay = 0. else: parameters = model_or_params.parameters() else: # iterable of parameters or param groups passed in parameters = model_or_params opt_lower = opt.lower() opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' opt_args = dict(weight_decay=weight_decay, **kwargs) if lr is not None: opt_args.setdefault('lr', lr) # basic SGD & related if opt_lower == 'sgd' or opt_lower == 'nesterov': # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'momentum': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args) elif opt_lower == 'sgdp': optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args) # adaptive elif opt_lower == 'adam': optimizer = optim.Adam(parameters, **opt_args) elif opt_lower == 'adamw': optimizer = optim.AdamW(parameters, **opt_args) elif opt_lower == 'adamp': optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) elif opt_lower == 'nadam': try: # NOTE PyTorch >= 1.10 should have native NAdam optimizer = optim.Nadam(parameters, **opt_args) except AttributeError: optimizer = Nadam(parameters, **opt_args) elif opt_lower == 'radam': optimizer = RAdam(parameters, **opt_args) elif opt_lower == 'adamax': optimizer = optim.Adamax(parameters, **opt_args) elif opt_lower == 'adabelief': optimizer = AdaBelief(parameters, rectify=False, **opt_args) elif opt_lower == 'radabelief': optimizer = AdaBelief(parameters, rectify=True, **opt_args) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, **opt_args) elif opt_lower == 'adagrad': opt_args.setdefault('eps', 1e-8) optimizer = optim.Adagrad(parameters, **opt_args) elif opt_lower == 'adafactor': optimizer = Adafactor(parameters, **opt_args) elif opt_lower == 'lamb': optimizer = Lamb(parameters, **opt_args) elif opt_lower == 'lambc': optimizer = Lamb(parameters, trust_clip=True, **opt_args) elif opt_lower == 'larc': optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args) elif opt_lower == 'lars': optimizer = Lars(parameters, momentum=momentum, **opt_args) elif opt_lower == 'nlarc': optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args) elif opt_lower == 'nlars': optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'madgrad': optimizer = MADGRAD(parameters, momentum=momentum, **opt_args) elif opt_lower == 'madgradw': optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args) elif opt_lower == 'novograd' or opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, **opt_args) elif opt_lower == 'rmsprop': optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args) elif opt_lower == 'rmsproptf': optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args) # second order elif opt_lower == 'adahessian': optimizer = Adahessian(parameters, **opt_args) # NVIDIA fused optimizers, require APEX to be installed elif opt_lower == 'fusedsgd': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'fusedmomentum': opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) elif opt_lower == 'fusedadamw': optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) elif opt_lower == 'fusedlamb': optimizer = FusedLAMB(parameters, **opt_args) elif opt_lower == 'fusednovograd': opt_args.setdefault('betas', (0.95, 0.98)) optimizer = FusedNovoGrad(parameters, **opt_args) else: assert False and "Invalid optimizer" raise ValueError if len(opt_split) > 1: if opt_split[0] == 'lookahead': optimizer = Lookahead(optimizer) return optimizer ================================================ FILE: requirements.txt ================================================ pyyaml scipy pandas ipdb ================================================ FILE: train.py ================================================ #!/usr/bin/env python3 """ ImageNet Training Script This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet training results with some of the latest networks and training techniques. It favours canonical PyTorch and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit. This script was started from an early version of the PyTorch ImageNet example (https://github.com/pytorch/examples/tree/master/imagenet) NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples (https://github.com/NVIDIA/apex/tree/master/examples/imagenet) Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ import argparse import time import yaml import os import logging import numpy as np from collections import OrderedDict from contextlib import suppress from datetime import datetime import torch import torch.nn as nn import torchvision.utils from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ convert_splitbn_model, model_parameters from timm.utils import * from timm.loss import * from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler from data import create_loader, create_dataset from optim_factory import create_optimizer_v2, optimizer_kwargs from models import vision_transformer, swin_transformer, convnext, as_mlp import ipdb try: from apex import amp from apex.parallel import DistributedDataParallel as ApexDDP from apex.parallel import convert_syncbn_model has_apex = True except ImportError: has_apex = False has_native_amp = False try: if getattr(torch.cuda.amp, 'autocast') is not None: has_native_amp = True except AttributeError: pass try: import wandb has_wandb = True except ImportError: has_wandb = False torch.backends.cudnn.benchmark = True _logger = logging.getLogger('train') # The first arg parser parses out only the --config argument, this argument is used to # load a yaml file containing key-values that override the defaults for the main parser below config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', help='YAML config file specifying default arguments') parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') # Dataset parameters parser.add_argument('data_dir', metavar='DIR', help='path to dataset') parser.add_argument('--dataset', '-d', metavar='NAME', default='', help='dataset type (default: ImageFolder/ImageTar if empty)') parser.add_argument('--train-split', metavar='NAME', default='train', help='dataset train split (default: train)') parser.add_argument('--val-split', metavar='NAME', default='validation', help='dataset validation split (default: validation)') parser.add_argument('--dataset-download', action='store_true', default=False, help='Allow download of dataset for torch/ and tfds/ datasets that support it.') parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', help='path to class to idx mapping file (default: "")') # Model parameters parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL', help='Name of model to train (default: "resnet50"') parser.add_argument('--pretrained', action='store_true', default=False, help='Start with pretrained version of specified network (if avail)') parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', help='Initialize model from this checkpoint (default: none)') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='Resume full model and optimizer state from checkpoint (default: none)') parser.add_argument('--no-resume-opt', action='store_true', default=False, help='prevent resume of optimizer state when resuming model') parser.add_argument('--num-classes', type=int, default=None, metavar='N', help='number of label classes (Model default if None)') parser.add_argument('--gp', default=None, type=str, metavar='POOL', help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') parser.add_argument('--img-size', type=int, default=None, metavar='N', help='Image patch size (default: None => model default)') parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') parser.add_argument('--crop-pct', default=None, type=float, metavar='N', help='Input image center crop percent (for validation only)') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', help='Override std deviation of dataset') parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', help='Input batch size for training (default: 128)') parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', help='Validation batch size override (default: None)') parser.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='torch.jit.script the full model') parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--grad-checkpointing', action='store_true', default=False, help='Enable gradient checkpointing through model blocks/stages') # finetuning parser.add_argument('--tuning-mode', default=None, type=str, help='Method of fine-tuning (default: None') parser.add_argument('--evaluate', action='store_true', default=False, help='evaluate') # Optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', help='Optimizer (default: "sgd"') parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: None, use opt default)') parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: None, use opt default)') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='Optimizer momentum (default: 0.9)') parser.add_argument('--weight-decay', type=float, default=2e-5, help='weight decay (default: 2e-5)') parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') parser.add_argument('--clip-mode', type=str, default='norm', help='Gradient clipping mode. One of ("norm", "value", "agc")') parser.add_argument('--layer-decay', type=float, default=None, help='layer-wise learning rate decay (default: None)') # Learning rate schedule parameters parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', help='LR scheduler (default: "step"') parser.add_argument('--lr', type=float, default=0.05, metavar='LR', help='learning rate (default: 0.05)') parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', help='learning rate noise limit percent (default: 0.67)') parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', help='learning rate cycle len multiplier (default: 1.0)') parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', help='amount to decay each learning rate cycle (default: 0.5)') parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', help='learning rate cycle limit, cycles enabled if > 1') parser.add_argument('--lr-k-decay', type=float, default=1.0, help='learning rate k-decay for cosine/poly (default: 1.0)') parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', help='warmup learning rate (default: 0.0001)') parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') parser.add_argument('--epochs', type=int, default=300, metavar='N', help='number of epochs to train (default: 300)') parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N', help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') parser.add_argument('--start-epoch', default=None, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('--decay-milestones', default=[30, 60], type=int, nargs='+', metavar="MILESTONES", help='list of decay epoch indices for multistep lr. must be increasing') parser.add_argument('--decay-epochs', type=float, default=100, metavar='N', help='epoch interval to decay LR') parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', help='epochs to warmup LR, if scheduler supports') parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends') parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10') parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') # Augmentation & regularization parameters parser.add_argument('--no-aug', action='store_true', default=False, help='Disable all training augmentation, override other train aug args') parser.add_argument('--simple-aug', action='store_true', default=False, help='Only randomresize and flip training augmentation, override other train aug args') parser.add_argument('--direct-resize', action='store_true', default=False, help='Direct resize image in validation') parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', help='Random resize scale (default: 0.08 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--hflip', type=float, default=0.5, help='Horizontal flip training aug probability') parser.add_argument('--vflip', type=float, default=0., help='Vertical flip training aug probability') parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)') parser.add_argument('--aug-repeats', type=float, default=0, help='Number of augmentation repetitions (distributed training only) (default: 0)') parser.add_argument('--aug-splits', type=int, default=0, help='Number of augmentation splits (default: 0, valid: 0 or >=2)') parser.add_argument('--jsd-loss', action='store_true', default=False, help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') parser.add_argument('--bce-loss', action='store_true', default=False, help='Enable BCE loss w/ Mixup/CutMix use.') parser.add_argument('--bce-target-thresh', type=float, default=None, help='Threshold for binarizing softened BCE targets (default: None, disabled)') parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode (default: "pixel")') parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)') parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split') parser.add_argument('--mixup', type=float, default=0.8, help='mixup alpha, mixup enabled if > 0. (default: 0.)') parser.add_argument('--cutmix', type=float, default=1.0, help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') parser.add_argument('--mixup-prob', type=float, default=1.0, help='Probability of performing mixup or cutmix when either/both is enabled') parser.add_argument('--mixup-switch-prob', type=float, default=0.5, help='Probability of switching to cutmix when both mixup and cutmix enabled') parser.add_argument('--mixup-mode', type=str, default='batch', help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', help='Turn off mixup after this epoch, disabled if 0 (default: 0)') parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') parser.add_argument('--train-interpolation', type=str, default='random', help='Training interpolation (random, bilinear, bicubic default: "random")') parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)') parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', help='Drop connect rate, DEPRECATED, use drop-path (default: None)') parser.add_argument('--drop-path', type=float, default=None, metavar='PCT', help='Drop path rate (default: None)') parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', help='Drop block rate (default: None)') # Batch norm parameters (only works with gen_efficientnet based models currently) parser.add_argument('--bn-momentum', type=float, default=None, help='BatchNorm momentum override (if not None)') parser.add_argument('--bn-eps', type=float, default=None, help='BatchNorm epsilon override (if not None)') parser.add_argument('--sync-bn', action='store_true', help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') parser.add_argument('--dist-bn', type=str, default='reduce', help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') parser.add_argument('--split-bn', action='store_true', help='Enable separate BN layers per augmentation split.') # Model Exponential Moving Average parser.add_argument('--model-ema', action='store_true', default=False, help='Enable tracking moving average of model weights') parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') parser.add_argument('--model-ema-decay', type=float, default=0.9998, help='decay factor for model weights moving average (default: 0.9998)') # Misc parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') parser.add_argument('--worker-seeding', type=str, default='all', help='worker seed mode (default: all)') parser.add_argument('--log-interval', type=int, default=50, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', help='how many batches to wait before writing recovery checkpoint') parser.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', help='number of checkpoints to keep (default: 10)') parser.add_argument('-j', '--workers', type=int, default=8, metavar='N', help='how many training processes to use (default: 4)') parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, help='use NVIDIA Apex AMP or Native AMP for mixed precision training') parser.add_argument('--apex-amp', action='store_true', default=False, help='Use NVIDIA Apex AMP mixed precision') parser.add_argument('--native-amp', action='store_true', default=False, help='Use Native Torch AMP mixed precision') parser.add_argument('--no-ddp-bb', action='store_true', default=False, help='Force broadcast buffers for native DDP to off.') parser.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--output', default='', type=str, metavar='PATH', help='path to output folder (default: none, current dir)') parser.add_argument('--experiment', default='', type=str, metavar='NAME', help='name of train experiment, name of sub-folder for output') parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', help='Best metric (default: "top1"') parser.add_argument('--tta', type=int, default=0, metavar='N', help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') parser.add_argument("--local_rank", default=0, type=int) parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, help='use the multi-epochs-loader to save time at the beginning of every epoch') parser.add_argument('--log-wandb', action='store_true', default=False, help='log training and validation metrics to wandb') def _parse_args(): # Do we have a config file to parse? args_config, remaining = config_parser.parse_known_args() if args_config.config: with open(args_config.config, 'r') as f: cfg = yaml.safe_load(f) parser.set_defaults(**cfg) # The main arg parser parses the rest of the args, the usual # defaults will have been overridden if config file specified. args = parser.parse_args(remaining) # Cache the args as a text string to save them in the output dir later args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) return args, args_text def main(): setup_default_logging() args, args_text = _parse_args() if args.log_wandb: if has_wandb: wandb.init(project=args.experiment, config=args) else: _logger.warning("You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # `--amp` chooses native amp before apex (APEX ver not actively maintained) if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning("Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") random_seed(args.seed, args.rank) if args.fuser: set_jit_fuser(args.fuser) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint, tuning_mode=args.tuning_mode) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly if args.grad_checkpointing: model.set_grad_checkpointing(enable=True) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp == 'apex': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) if args.local_rank == 0: _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') _logger.info(f"number of params for requires grad: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info('Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp == 'apex': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # create the train and eval datasets dataset_train = create_dataset( args.dataset, root=args.data_dir, split=args.train_split, is_training=True, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size, repeats=args.epoch_repeats) dataset_eval = create_dataset( args.dataset, root=args.data_dir, split=args.val_split, is_training=False, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, simple_aug=args.simple_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_repeats=args.aug_repeats, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader, worker_seeding=args.worker_seeding, ) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size or args.batch_size, is_training=False, use_prefetcher=args.prefetcher, direct_resize=args.direct_resize, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # setup loss function if args.jsd_loss: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) elif mixup_active: # smoothing is handled with mixup target transform which outputs sparse, soft targets if args.bce_loss: train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh) else: train_loss_fn = SoftTargetCrossEntropy() elif args.smoothing: if args.bce_loss: train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh) else: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = None if args.rank == 0: if args.experiment: exp_name = args.experiment else: exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), safe_model_name(args.model), str(data_config['input_size'][-1]) ]) output_dir = get_outdir(args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver( model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) if args.evaluate: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint(start_epoch, metric=save_metric) return try: for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if output_dir is not None: update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) def train_one_epoch( epoch, model, loader, optimizer, loss_fn, args, lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress, loss_scaler=None, model_ema=None, mixup_fn=None): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: loader.mixup_enabled = False elif mixup_fn is not None: mixup_fn.mixup_enabled = False second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order batch_time_m = AverageMeter() data_time_m = AverageMeter() losses_m = AverageMeter() model.train() end = time.time() last_idx = len(loader) - 1 num_updates = epoch * len(loader) for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx data_time_m.update(time.time() - end) if not args.prefetcher: input, target = input.cuda(), target.cuda() if mixup_fn is not None: input, target = mixup_fn(input, target) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) with amp_autocast(): output = model(input) loss = loss_fn(output, target) if not args.distributed: losses_m.update(loss.item(), input.size(0)) optimizer.zero_grad() if loss_scaler is not None: loss_scaler( loss, optimizer, clip_grad=args.clip_grad, clip_mode=args.clip_mode, parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), create_graph=second_order) else: loss.backward(create_graph=second_order) if args.clip_grad is not None: dispatch_clip_grad( model_parameters(model, exclude_head='agc' in args.clip_mode), value=args.clip_grad, mode=args.clip_mode) optimizer.step() if model_ema is not None: model_ema.update(model) torch.cuda.synchronize() num_updates += 1 batch_time_m.update(time.time() - end) if last_batch or batch_idx % args.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size) losses_m.update(reduced_loss.item(), input.size(0)) if args.local_rank == 0: _logger.info( 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) ' 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'LR: {lr:.3e} ' 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( epoch, batch_idx, len(loader), 100. * batch_idx / last_idx, loss=losses_m, batch_time=batch_time_m, rate=input.size(0) * args.world_size / batch_time_m.val, rate_avg=input.size(0) * args.world_size / batch_time_m.avg, lr=lr, data_time=data_time_m)) if args.save_images and output_dir: torchvision.utils.save_image( input, os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), padding=0, normalize=True) if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): saver.save_recovery(epoch, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) end = time.time() # end for if hasattr(optimizer, 'sync_lookahead'): optimizer.sync_lookahead() return OrderedDict([('loss', losses_m.avg)]) def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): batch_time_m = AverageMeter() losses_m = AverageMeter() top1_m = AverageMeter() top5_m = AverageMeter() model.eval() end = time.time() last_idx = len(loader) - 1 with torch.no_grad(): for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx if not args.prefetcher: input = input.cuda() target = target.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) with amp_autocast(): output = model(input) if isinstance(output, (tuple, list)): output = output[0] # augmentation reduction reduce_factor = args.tta if reduce_factor > 1: output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) target = target[0:target.size(0):reduce_factor] loss = loss_fn(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size) acc1 = reduce_tensor(acc1, args.world_size) acc5 = reduce_tensor(acc5, args.world_size) else: reduced_loss = loss.data torch.cuda.synchronize() losses_m.update(reduced_loss.item(), input.size(0)) top1_m.update(acc1.item(), output.size(0)) top5_m.update(acc5.item(), output.size(0)) batch_time_m.update(time.time() - end) end = time.time() if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): log_name = 'Test' + log_suffix _logger.info( '{0}: [{1:>4d}/{2}] ' 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( log_name, batch_idx, last_idx, batch_time=batch_time_m, loss=losses_m, top1=top1_m, top5=top5_m)) metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) return metrics if __name__ == '__main__': main() ================================================ FILE: train_scripts/asmlp/cifar_100/train_full.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3, python -m torch.distributed.launch --nproc_per_node=4 --master_port=33518 \ train.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model as_base_patch4_window7_224 \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 5e-7 --warmup-epochs 10 \ --lr 5e-5 --min-lr 5e-8 \ --drop-path 0 --img-size 224 \ --output output/as_base_patch4_window7_224/cifar_100/full \ --amp --pretrained \ ================================================ FILE: train_scripts/asmlp/cifar_100/train_linear_probe.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=33518 \ train.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model as_base_patch4_window7_224 \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 5e-7 --warmup-epochs 10 \ --lr 1e-3 --min-lr 5e-8 \ --drop-path 0 --img-size 224 \ --output output/as_base_patch4_window7_224/cifar_100/linear_probe \ --amp --tuning-mode linear_probe --pretrained \ ================================================ FILE: train_scripts/asmlp/cifar_100/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3, python -m torch.distributed.launch --nproc_per_node=4 --master_port=33518 \ train.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model as_base_patch4_window7_224 \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 1e-3 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --output output/as_base_patch4_window7_224/cifar_100/ssf \ --amp --tuning-mode ssf --pretrained \ ================================================ FILE: train_scripts/convnext/cifar_100/train_full.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3, python -m torch.distributed.launch --nproc_per_node=4 --master_port=33518 \ train.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model convnext_base_in22k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 5e-7 --warmup-epochs 10 \ --lr 5e-5 --min-lr 5e-8 \ --drop-path 0.2 --img-size 224 \ --model-ema --model-ema-decay 0.99992 \ --output output/convnext_base_in22k/cifar_100/full \ --amp --pretrained \ ================================================ FILE: train_scripts/convnext/cifar_100/train_linear_probe.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3, python -m torch.distributed.launch --nproc_per_node=4 --master_port=33518 \ train.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model convnext_base_in22k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 5e-7 --warmup-epochs 10 \ --lr 1e-3 --min-lr 5e-8 \ --drop-path 0 --img-size 224 \ --model-ema --model-ema-decay 0.99992 \ --output output/convnext_base_in22k/cifar_100/linear_probe \ --amp --tuning-mode linear_probe --pretrained \ ================================================ FILE: train_scripts/convnext/cifar_100/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3, python -m torch.distributed.launch --nproc_per_node=4 --master_port=27524 \ train.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model convnext_base_in22k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 1e-3 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --model-ema --model-ema-decay 0.99992 \ --output output/convnext_base_in22k/cifar_100/ssf \ --amp --tuning-mode ssf --pretrained \ ================================================ FILE: train_scripts/convnext/imagenet_1k/train_full.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7, python -m torch.distributed.launch --nproc_per_node=8 --master_port=33518 \ train.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model convnext_base_in22k \ --batch-size 32 --epochs 30 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 5e-7 --warmup-epochs 5 \ --lr 5e-5 --min-lr 5e-8 \ --drop-path 0.2 --img-size 224 \ --model-ema --model-ema-decay 0.99992 \ --output output/convnext_base_in22k/imagenet_1k/full \ --amp --pretrained \ ================================================ FILE: train_scripts/convnext/imagenet_1k/train_linear_probe.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7, python -m torch.distributed.launch --nproc_per_node=8 --master_port=33518 \ train.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model convnext_base_in22k \ --batch-size 32 --epochs 30 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 5e-7 --warmup-epochs 5 \ --lr 1e-3 --min-lr 5e-8 \ --drop-path 0.1 --img-size 224 \ --model-ema --model-ema-decay 0.99992 \ --output output/convnext_base_in22k/imagenet_1k/linear_probe \ --amp --tuning-mode linear_probe --pretrained ================================================ FILE: train_scripts/convnext/imagenet_1k/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7, python -m torch.distributed.launch --nproc_per_node=8 --master_port=33518 \ train.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model convnext_base_in22k \ --batch-size 32 --epochs 30 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 1e-7 --warmup-epochs 5 \ --lr 1e-3 --min-lr 1e-8 \ --drop-path 0.1 --img-size 224 \ --model-ema --model-ema-decay 0.99992 \ --output output/convnext_base_in22k/imagenet_1k/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/swin/cifar_100/train_full.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3, python -m torch.distributed.launch --nproc_per_node=4 --master_port=33518 \ train.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model swin_base_patch4_window7_224_in22k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 5e-7 --warmup-epochs 10 \ --lr 5e-5 --min-lr 5e-8 \ --drop-path 0.1 --img-size 224 \ --output output/swin_base_patch4_window7_224_in22k/cifar_100/full \ --amp --pretrained \ ================================================ FILE: train_scripts/swin/cifar_100/train_linear_probe.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3, python -m torch.distributed.launch --nproc_per_node=4 --master_port=33518 \ train.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model swin_base_patch4_window7_224_in22k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 5e-7 --warmup-epochs 10 \ --lr 1e-3 --min-lr 5e-8 \ --drop-path 0 --img-size 224 \ --output output/swin_base_patch4_window7_224_in22k/cifar_100/linear_probe \ --amp --tuning-mode linear_probe --pretrained \ ================================================ FILE: train_scripts/swin/cifar_100/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3, python -m torch.distributed.launch --nproc_per_node=4 --master_port=33518 \ train.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model swin_base_patch4_window7_224_in22k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 5e-7 --warmup-epochs 10 \ --lr 5e-3 --min-lr 5e-8 \ --drop-path 0 --img-size 224 \ --output output/swin_base_patch4_window7_224_in22k/cifar_100/ssf \ --amp --tuning-mode ssf --pretrained \ ================================================ FILE: train_scripts/swin/imagenet_1k/train_full.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7, python -m torch.distributed.launch --nproc_per_node=8 --master_port=33518 \ train.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model swin_base_patch4_window7_224_in22k \ --batch-size 32 --epochs 30 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 5e-7 --warmup-epochs 5 \ --lr 5e-5 --min-lr 5e-8 \ --drop-path 0.2 --img-size 224 \ --output output/swin_base_patch4_window7_224_in22k/imagenet_1k/full \ --amp --pretrained \ ================================================ FILE: train_scripts/swin/imagenet_1k/train_linear_probe.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7, python -m torch.distributed.launch --nproc_per_node=8 --master_port=33518 \ train.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model swin_base_patch4_window7_224_in22k \ --batch-size 32 --epochs 30 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 5e-7 --warmup-epochs 5 \ --lr 5e-3 --min-lr 5e-8 \ --drop-path 0.1 --img-size 224 \ --output output/swin_base_patch4_window7_224_in22k/imagenet_1k/linear_probe \ --amp --tuning-mode linear_probe --pretrained ================================================ FILE: train_scripts/swin/imagenet_1k/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7, python -m torch.distributed.launch --nproc_per_node=8 --master_port=33518 \ train.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model swin_base_patch4_window7_224_in22k \ --batch-size 32 --epochs 30 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 5e-7 --warmup-epochs 5 \ --lr 5e-3 --min-lr 5e-8 \ --drop-path 0.1 --img-size 224 \ --output output/swin_base_patch4_window7_224_in22k/imagenet_1k/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/vit/cifar_100/eval_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0, python -m torch.distributed.launch --nproc_per_node=1 --master_port=17346 \ train.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 1e-3 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --model-ema --model-ema-decay 0.99992 \ --output output/vit_base_patch16_224_in21k/cifar_100/ssf/eval \ --amp --tuning-mode ssf --pretrained \ --evaluate \ --checkpoint /path/to/vit_base_patch16_224_in21k/cifar_100/ssf/model_best.pth.tar \ ================================================ FILE: train_scripts/vit/cifar_100/train_full.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3, python -m torch.distributed.launch --nproc_per_node=4 --master_port=12346 \ train.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 5e-5 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --model-ema --model-ema-decay 0.99992 \ --output output/vit_base_patch16_224_in21k/cifar_100/full \ --amp --pretrained \ ================================================ FILE: train_scripts/vit/cifar_100/train_linear_probe.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3, python -m torch.distributed.launch --nproc_per_node=4 --master_port=12346 \ train.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 1e-3 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --model-ema --model-ema-decay 0.99992 \ --output output/vit_base_patch16_224_in21k/cifar_100/linear_probe \ --amp --tuning-mode linear_probe --pretrained \ ================================================ FILE: train_scripts/vit/cifar_100/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3, python -m torch.distributed.launch --nproc_per_node=4 --master_port=12346 \ train.py /path/to/cifar100 --dataset torch/cifar100 --num-classes 100 --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 1e-3 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --model-ema --model-ema-decay 0.99992 \ --output output/vit_base_patch16_224_in21k/cifar_100/ssf \ --amp --tuning-mode ssf --pretrained \ ================================================ FILE: train_scripts/vit/fgvc/cub_2011/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=14655 \ train.py /path/to/CUB_200_2011 --dataset cub2011 --num-classes 200 --simple-aug --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-2 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 1e-2 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --model-ema --model-ema-decay 0.9998 \ --output output/vit_base_patch16_224_in21k/fgvc/cub2011/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/vit/fgvc/nabirds/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=14222 \ train.py /path/to/nabirds --dataset nabirds --num-classes 555 --simple-aug --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 2e-4 --min-lr 1e-8 \ --drop-path 0.1 --img-size 224 \ --model-ema --model-ema-decay 0.9998 \ --output output/vit_base_patch16_224_in21k/nabirds/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/vit/fgvc/oxford_flowers/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=12341 \ train.py /path/to/oxford_flowers --dataset oxford_flowers --num-classes 102 --val-split val --simple-aug --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 1e-2 --min-lr 1e-8 \ --drop-path 0.1 --img-size 224 \ --model-ema --model-ema-decay 0.999 \ --output output/vit_base_patch16_224_in21k/oxford_flowers/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/vit/fgvc/stanford_cars/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=12349 \ train.py /path/to/stanford_cars --dataset stanford_cars --num-classes 196 --val-split val --simple-aug --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 2e-2 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --model-ema --model-ema-decay 0.9998 \ --output output/vit_base_patch16_224_in21k/stanford_cars/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/vit/fgvc/stanford_dogs/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=12319 \ train.py /path/to/stanford_dogs --dataset stanford_dogs --num-classes 120 --simple-aug --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 2.5e-4 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --model-ema --model-ema-decay 0.9998 \ --output output/vit_base_patch16_224_in21k/stanford_dogs/ssf \ --amp --tuning-mode ssf --pretrained \ ================================================ FILE: train_scripts/vit/imagenet_1k/train_full.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7, python -m torch.distributed.launch --nproc_per_node=8 --master_port=33518 \ train.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 30 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 1e-7 --warmup-epochs 5 \ --lr 1e-4 --min-lr 1e-8 \ --drop-path 0.2 --img-size 224 \ --model-ema --model-ema-decay 0.99992 \ --output output/vit_base_patch16_224_in21k/imagenet_1k/full \ --amp --pretrained \ ================================================ FILE: train_scripts/vit/imagenet_1k/train_linear_probe.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7, python -m torch.distributed.launch --nproc_per_node=8 --master_port=33518 \ train.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 30 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 1e-7 --warmup-epochs 5 \ --lr 1e-4 --min-lr 1e-8 \ --drop-path 0.1 --img-size 224 \ --model-ema --model-ema-decay 0.99992 \ --output output/vit_base_patch16_224_in21k/imagenet_1k/linear_probe \ --amp --tuning-mode linear_probe --pretrained ================================================ FILE: train_scripts/vit/imagenet_1k/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7, python -m torch.distributed.launch --nproc_per_node=8 --master_port=33518 \ train.py /path/to/imagenet_1k --dataset imagenet --num-classes 1000 --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 30 \ --opt adamw --weight-decay 0.05 \ --warmup-lr 1e-7 --warmup-epochs 5 \ --lr 1e-3 --min-lr 1e-8 \ --drop-path 0.1 --img-size 224 \ --model-ema --model-ema-decay 0.99992 \ --output output/vit_base_patch16_224_in21k/imagenet_1k/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/vit/imagenet_a/eval_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0, python validate_ood.py \ /path/to/imagenet-a \ --num-classes 1000 \ --model vit_base_patch16_224_in21k \ --batch-size 64 \ --no-test-pool --imagenet_a \ --results-file output/vit_base_patch16_224_in21k/imagenet_a/ssf \ --tuning-mode ssf \ --checkpoint /path/to/vit_base_patch16_224_in21k/imagenet_1k/ssf/model_best.pth.tar ================================================ FILE: train_scripts/vit/imagenet_c/eval_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0, python validate_ood.py \ /path/to/imagenet-c/ \ --num-classes 1000 \ --model vit_base_patch16_224_in21k \ --batch-size 64 \ --imagenet_c \ --results-file output/vit_base_patch16_224_in21k/imagenet_c/ssf \ --tuning-mode ssf \ --checkpoint /path/to/vit_base_patch16_224_in21k/imagenet_1k/ssf/model_best.pth.tar ================================================ FILE: train_scripts/vit/imagenet_r/eval_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0, python validate_ood.py \ /path/to/imagenet-r \ --num-classes 1000 \ --model vit_base_patch16_224_in21k \ --batch-size 64 \ --imagenet_r \ --results-file output/vit_base_patch16_224_in21k/imagenet_r/ssf \ --tuning-mode ssf \ --checkpoint /path/to/vit_base_patch16_224_in21k/imagenet_1k/ssf/model_best.pth.tar ================================================ FILE: train_scripts/vit/vtab/caltech101/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=14337 \ train.py /path/to/vtab-1k/caltech101 --dataset caltech101 --num-classes 102 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-2 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 1e-3 --min-lr 1e-8 \ --drop-path 0.1 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/caltech101/ssf \ --amp --tuning-mode ssf --pretrained \ ================================================ FILE: train_scripts/vit/vtab/cifar_100/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=19547 \ train.py /path/to/vtab-1k/cifar --dataset cifar100 --num-classes 100 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-5 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 5e-3 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/cifar_100/ssf \ --amp --tuning-mode ssf --pretrained \ ================================================ FILE: train_scripts/vit/vtab/clevr_count/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=14332 \ train.py /path/to/vtab-1k/clevr_count --dataset clevr_count --num-classes 8 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-2 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 2e-3 --min-lr 1e-8 \ --drop-path 0.1 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/clevr_count/ssf \ --amp --tuning-mode ssf --pretrained \ ================================================ FILE: train_scripts/vit/vtab/clevr_dist/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=10032 \ train.py /path/to/vtab-1k/clevr_dist --dataset clevr_dist --num-classes 6 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-2 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 5e-2 --min-lr 1e-8 \ --drop-path 0.1 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/clevr_dist/ssf \ --amp --tuning-mode ssf --pretrained \ ================================================ FILE: train_scripts/vit/vtab/diabetic_retinopathy/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=26662 \ train.py /path/to/vtab-1k/diabetic_retinopathy --dataset diabetic_retinopathy --num-classes 5 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-5 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 5e-3 --min-lr 1e-8 \ --drop-path 0.2 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/diabetic_retinopathy/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/vit/vtab/dmlab/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=13002 \ train.py /path/to/vtab-1k/dmlab --dataset dmlab --num-classes 6 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-5 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 5e-3 --min-lr 1e-8 \ --drop-path 0.1 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/dmlab/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/vit/vtab/dsprites_loc/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=12102 \ train.py /path/to/vtab-1k/dsprites_loc --dataset dsprites_loc --num-classes 16 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-5 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 1e-2 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/dsprites_loc/ssf \ --amp --tuning-mode ssf --pretrained \ ================================================ FILE: train_scripts/vit/vtab/dsprites_ori/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=12002 \ train.py /path/to/vtab-1k/dsprites_ori --dataset dsprites_ori --num-classes 16 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-5 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 5e-3 --min-lr 1e-8 \ --drop-path 0.2 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/dsprites_ori/ssf \ --amp --tuning-mode ssf --pretrained \ ================================================ FILE: train_scripts/vit/vtab/dtd/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=14312 \ train.py /path/to/vtab-1k/dtd --dataset dtd --num-classes 47 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-5 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 5e-3 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --output output/vit_base_patch16_224_in21k/vtab/dtd/ssf \ --amp --tuning-mode ssf --pretrained \ --mixup 0 --cutmix 0 --smoothing 0 ================================================ FILE: train_scripts/vit/vtab/eurosat/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=14112 \ train.py /path/to/vtab-1k/eurosat --dataset eurosat --num-classes 10 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-2 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 3e-3 --min-lr 1e-8 \ --drop-path 0.2 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/eurosat/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/vit/vtab/flowers102/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=14222 \ train.py /path/to/vtab-1k/oxford_flowers102 --dataset flowers102 --num-classes 102 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-5 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 5e-3 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/flowers102/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/vit/vtab/kitti/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=14332 \ train.py /path/to/vtab-1k/kitti --dataset kitti --num-classes 4 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-5 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 1e-2 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/kitti/ssf \ --amp --tuning-mode ssf --pretrained \ ================================================ FILE: train_scripts/vit/vtab/patch_camelyon/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=14332 \ train.py /path/to/vtab-1k/patch_camelyon --dataset patch_camelyon --num-classes 2 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-5 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 5e-3 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/patch_camelyon/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/vit/vtab/pets/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=14332 \ train.py /path/to/vtab-1k/oxford_iiit_pet --dataset pets --num-classes 37 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-5 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 5e-3 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/pets/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/vit/vtab/resisc45/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=11222 \ train.py /path/to/vtab-1k/resisc45 --dataset resisc45 --num-classes 45 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-5 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 2e-3 --min-lr 1e-8 \ --drop-path 0.1 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/resisc45/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/vit/vtab/smallnorb_azi/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=14882 \ train.py /path/to/vtab-1k/smallnorb_azi --dataset smallnorb_azi --num-classes 18 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-5 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 2e-2 --min-lr 1e-8 \ --drop-path 0.1 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/smallnorb_azi/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/vit/vtab/smallnorb_ele/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=24332 \ train.py /path/to/vtab-1k/smallnorb_ele --dataset smallnorb_ele --num-classes 9 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-2 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 5e-3 --min-lr 1e-8 \ --drop-path 0.2 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/smallnorb_ele/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/vit/vtab/sun397/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=14192 \ train.py /path/to/vtab-1k/sun397 --dataset sun397 --num-classes 397 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-5 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 5e-3 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/sun397/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: train_scripts/vit/vtab/svhn/train_ssf.sh ================================================ CUDA_VISIBLE_DEVICES=0,1, python -m torch.distributed.launch --nproc_per_node=2 --master_port=14332 \ train.py /path/to/vtab-1k/svhn --dataset svhn --num-classes 10 --no-aug --direct-resize --model vit_base_patch16_224_in21k \ --batch-size 32 --epochs 100 \ --opt adamw --weight-decay 5e-5 \ --warmup-lr 1e-7 --warmup-epochs 10 \ --lr 1e-2 --min-lr 1e-8 \ --drop-path 0 --img-size 224 \ --mixup 0 --cutmix 0 --smoothing 0 \ --output output/vit_base_patch16_224_in21k/vtab/svhn/ssf \ --amp --tuning-mode ssf --pretrained ================================================ FILE: utils/__init__.py ================================================ from .utils import load_for_transfer_learning, load_for_probing from .scaler import ApexScaler_SAM from .mce_utils import * ================================================ FILE: utils/imagenet_a.py ================================================ thousand_k_to_200 = {0: -1, 1: -1, 2: -1, 3: -1, 4: -1, 5: -1, 6: 0, 7: -1, 8: -1, 9: -1, 10: -1, 11: 1, 12: -1, 13: 2, 14: -1, 15: 3, 16: -1, 17: 4, 18: -1, 19: -1, 20: -1, 21: -1, 22: 5, 23: 6, 24: -1, 25: -1, 26: -1, 27: 7, 28: -1, 29: -1, 30: 8, 31: -1, 32: -1, 33: -1, 34: -1, 35: -1, 36: -1, 37: 9, 38: -1, 39: 10, 40: -1, 41: -1, 42: 11, 43: -1, 44: -1, 45: -1, 46: -1, 47: 12, 48: -1, 49: -1, 50: 13, 51: -1, 52: -1, 53: -1, 54: -1, 55: -1, 56: -1, 57: 14, 58: -1, 59: -1, 60: -1, 61: -1, 62: -1, 63: -1, 64: -1, 65: -1, 66: -1, 67: -1, 68: -1, 69: -1, 70: 15, 71: 16, 72: -1, 73: -1, 74: -1, 75: -1, 76: 17, 77: -1, 78: -1, 79: 18, 80: -1, 81: -1, 82: -1, 83: -1, 84: -1, 85: -1, 86: -1, 87: -1, 88: -1, 89: 19, 90: 20, 91: -1, 92: -1, 93: -1, 94: 21, 95: -1, 96: 22, 97: 23, 98: -1, 99: 24, 100: -1, 101: -1, 102: -1, 103: -1, 104: -1, 105: 25, 106: -1, 107: 26, 108: 27, 109: -1, 110: 28, 111: -1, 112: -1, 113: 29, 114: -1, 115: -1, 116: -1, 117: -1, 118: -1, 119: -1, 120: -1, 121: -1, 122: -1, 123: -1, 124: 30, 125: 31, 126: -1, 127: -1, 128: -1, 129: -1, 130: 32, 131: -1, 132: 33, 133: -1, 134: -1, 135: -1, 136: -1, 137: -1, 138: -1, 139: -1, 140: -1, 141: -1, 142: -1, 143: 34, 144: 35, 145: -1, 146: -1, 147: -1, 148: -1, 149: -1, 150: 36, 151: 37, 152: -1, 153: -1, 154: -1, 155: -1, 156: -1, 157: -1, 158: -1, 159: -1, 160: -1, 161: -1, 162: -1, 163: -1, 164: -1, 165: -1, 166: -1, 167: -1, 168: -1, 169: -1, 170: -1, 171: -1, 172: -1, 173: -1, 174: -1, 175: -1, 176: -1, 177: -1, 178: -1, 179: -1, 180: -1, 181: -1, 182: -1, 183: -1, 184: -1, 185: -1, 186: -1, 187: -1, 188: -1, 189: -1, 190: -1, 191: -1, 192: -1, 193: -1, 194: -1, 195: -1, 196: -1, 197: -1, 198: -1, 199: -1, 200: -1, 201: -1, 202: -1, 203: -1, 204: -1, 205: -1, 206: -1, 207: 38, 208: -1, 209: -1, 210: -1, 211: -1, 212: -1, 213: -1, 214: -1, 215: -1, 216: -1, 217: -1, 218: -1, 219: -1, 220: -1, 221: -1, 222: -1, 223: -1, 224: -1, 225: -1, 226: -1, 227: -1, 228: -1, 229: -1, 230: -1, 231: -1, 232: -1, 233: -1, 234: 39, 235: 40, 236: -1, 237: -1, 238: -1, 239: -1, 240: -1, 241: -1, 242: -1, 243: -1, 244: -1, 245: -1, 246: -1, 247: -1, 248: -1, 249: -1, 250: -1, 251: -1, 252: -1, 253: -1, 254: 41, 255: -1, 256: -1, 257: -1, 258: -1, 259: -1, 260: -1, 261: -1, 262: -1, 263: -1, 264: -1, 265: -1, 266: -1, 267: -1, 268: -1, 269: -1, 270: -1, 271: -1, 272: -1, 273: -1, 274: -1, 275: -1, 276: -1, 277: 42, 278: -1, 279: -1, 280: -1, 281: -1, 282: -1, 283: 43, 284: -1, 285: -1, 286: -1, 287: 44, 288: -1, 289: -1, 290: -1, 291: 45, 292: -1, 293: -1, 294: -1, 295: 46, 296: -1, 297: -1, 298: 47, 299: -1, 300: -1, 301: 48, 302: -1, 303: -1, 304: -1, 305: -1, 306: 49, 307: 50, 308: 51, 309: 52, 310: 53, 311: 54, 312: -1, 313: 55, 314: 56, 315: 57, 316: -1, 317: 58, 318: -1, 319: 59, 320: -1, 321: -1, 322: -1, 323: 60, 324: 61, 325: -1, 326: 62, 327: 63, 328: -1, 329: -1, 330: 64, 331: -1, 332: -1, 333: -1, 334: 65, 335: 66, 336: 67, 337: -1, 338: -1, 339: -1, 340: -1, 341: -1, 342: -1, 343: -1, 344: -1, 345: -1, 346: -1, 347: 68, 348: -1, 349: -1, 350: -1, 351: -1, 352: -1, 353: -1, 354: -1, 355: -1, 356: -1, 357: -1, 358: -1, 359: -1, 360: -1, 361: 69, 362: -1, 363: 70, 364: -1, 365: -1, 366: -1, 367: -1, 368: -1, 369: -1, 370: -1, 371: -1, 372: 71, 373: -1, 374: -1, 375: -1, 376: -1, 377: -1, 378: 72, 379: -1, 380: -1, 381: -1, 382: -1, 383: -1, 384: -1, 385: -1, 386: 73, 387: -1, 388: -1, 389: -1, 390: -1, 391: -1, 392: -1, 393: -1, 394: -1, 395: -1, 396: -1, 397: 74, 398: -1, 399: -1, 400: 75, 401: 76, 402: 77, 403: -1, 404: 78, 405: -1, 406: -1, 407: 79, 408: -1, 409: -1, 410: -1, 411: 80, 412: -1, 413: -1, 414: -1, 415: -1, 416: 81, 417: 82, 418: -1, 419: -1, 420: 83, 421: -1, 422: -1, 423: -1, 424: -1, 425: 84, 426: -1, 427: -1, 428: 85, 429: -1, 430: 86, 431: -1, 432: -1, 433: -1, 434: -1, 435: -1, 436: -1, 437: 87, 438: 88, 439: -1, 440: -1, 441: -1, 442: -1, 443: -1, 444: -1, 445: 89, 446: -1, 447: -1, 448: -1, 449: -1, 450: -1, 451: -1, 452: -1, 453: -1, 454: -1, 455: -1, 456: 90, 457: 91, 458: -1, 459: -1, 460: -1, 461: 92, 462: 93, 463: -1, 464: -1, 465: -1, 466: -1, 467: -1, 468: -1, 469: -1, 470: 94, 471: -1, 472: 95, 473: -1, 474: -1, 475: -1, 476: -1, 477: -1, 478: -1, 479: -1, 480: -1, 481: -1, 482: -1, 483: 96, 484: -1, 485: -1, 486: 97, 487: -1, 488: 98, 489: -1, 490: -1, 491: -1, 492: 99, 493: -1, 494: -1, 495: -1, 496: 100, 497: -1, 498: -1, 499: -1, 500: -1, 501: -1, 502: -1, 503: -1, 504: -1, 505: -1, 506: -1, 507: -1, 508: -1, 509: -1, 510: -1, 511: -1, 512: -1, 513: -1, 514: 101, 515: -1, 516: 102, 517: -1, 518: -1, 519: -1, 520: -1, 521: -1, 522: -1, 523: -1, 524: -1, 525: -1, 526: -1, 527: -1, 528: 103, 529: -1, 530: 104, 531: -1, 532: -1, 533: -1, 534: -1, 535: -1, 536: -1, 537: -1, 538: -1, 539: 105, 540: -1, 541: -1, 542: 106, 543: 107, 544: -1, 545: -1, 546: -1, 547: -1, 548: -1, 549: 108, 550: -1, 551: -1, 552: 109, 553: -1, 554: -1, 555: -1, 556: -1, 557: 110, 558: -1, 559: -1, 560: -1, 561: 111, 562: 112, 563: -1, 564: -1, 565: -1, 566: -1, 567: -1, 568: -1, 569: 113, 570: -1, 571: -1, 572: 114, 573: 115, 574: -1, 575: 116, 576: -1, 577: -1, 578: -1, 579: 117, 580: -1, 581: -1, 582: -1, 583: -1, 584: -1, 585: -1, 586: -1, 587: -1, 588: -1, 589: 118, 590: -1, 591: -1, 592: -1, 593: -1, 594: -1, 595: -1, 596: -1, 597: -1, 598: -1, 599: -1, 600: -1, 601: -1, 602: -1, 603: -1, 604: -1, 605: -1, 606: 119, 607: 120, 608: -1, 609: 121, 610: -1, 611: -1, 612: -1, 613: -1, 614: 122, 615: -1, 616: -1, 617: -1, 618: -1, 619: -1, 620: -1, 621: -1, 622: -1, 623: -1, 624: -1, 625: -1, 626: 123, 627: 124, 628: -1, 629: -1, 630: -1, 631: -1, 632: -1, 633: -1, 634: -1, 635: -1, 636: -1, 637: -1, 638: -1, 639: -1, 640: 125, 641: 126, 642: 127, 643: 128, 644: -1, 645: -1, 646: -1, 647: -1, 648: -1, 649: -1, 650: -1, 651: -1, 652: -1, 653: -1, 654: -1, 655: -1, 656: -1, 657: -1, 658: 129, 659: -1, 660: -1, 661: -1, 662: -1, 663: -1, 664: -1, 665: -1, 666: -1, 667: -1, 668: 130, 669: -1, 670: -1, 671: -1, 672: -1, 673: -1, 674: -1, 675: -1, 676: -1, 677: 131, 678: -1, 679: -1, 680: -1, 681: -1, 682: 132, 683: -1, 684: 133, 685: -1, 686: -1, 687: 134, 688: -1, 689: -1, 690: -1, 691: -1, 692: -1, 693: -1, 694: -1, 695: -1, 696: -1, 697: -1, 698: -1, 699: -1, 700: -1, 701: 135, 702: -1, 703: -1, 704: 136, 705: -1, 706: -1, 707: -1, 708: -1, 709: -1, 710: -1, 711: -1, 712: -1, 713: -1, 714: -1, 715: -1, 716: -1, 717: -1, 718: -1, 719: 137, 720: -1, 721: -1, 722: -1, 723: -1, 724: -1, 725: -1, 726: -1, 727: -1, 728: -1, 729: -1, 730: -1, 731: -1, 732: -1, 733: -1, 734: -1, 735: -1, 736: 138, 737: -1, 738: -1, 739: -1, 740: -1, 741: -1, 742: -1, 743: -1, 744: -1, 745: -1, 746: 139, 747: -1, 748: -1, 749: 140, 750: -1, 751: -1, 752: 141, 753: -1, 754: -1, 755: -1, 756: -1, 757: -1, 758: 142, 759: -1, 760: -1, 761: -1, 762: -1, 763: 143, 764: -1, 765: 144, 766: -1, 767: -1, 768: 145, 769: -1, 770: -1, 771: -1, 772: -1, 773: 146, 774: 147, 775: -1, 776: 148, 777: -1, 778: -1, 779: 149, 780: 150, 781: -1, 782: -1, 783: -1, 784: -1, 785: -1, 786: 151, 787: -1, 788: -1, 789: -1, 790: -1, 791: -1, 792: 152, 793: -1, 794: -1, 795: -1, 796: -1, 797: 153, 798: -1, 799: -1, 800: -1, 801: -1, 802: 154, 803: 155, 804: 156, 805: -1, 806: -1, 807: -1, 808: -1, 809: -1, 810: -1, 811: -1, 812: -1, 813: 157, 814: -1, 815: 158, 816: -1, 817: -1, 818: -1, 819: -1, 820: 159, 821: -1, 822: -1, 823: 160, 824: -1, 825: -1, 826: -1, 827: -1, 828: -1, 829: -1, 830: -1, 831: 161, 832: -1, 833: 162, 834: -1, 835: 163, 836: -1, 837: -1, 838: -1, 839: 164, 840: -1, 841: -1, 842: -1, 843: -1, 844: -1, 845: 165, 846: -1, 847: 166, 848: -1, 849: -1, 850: 167, 851: -1, 852: -1, 853: -1, 854: -1, 855: -1, 856: -1, 857: -1, 858: -1, 859: 168, 860: -1, 861: -1, 862: 169, 863: -1, 864: -1, 865: -1, 866: -1, 867: -1, 868: -1, 869: -1, 870: 170, 871: -1, 872: -1, 873: -1, 874: -1, 875: -1, 876: -1, 877: -1, 878: -1, 879: 171, 880: 172, 881: -1, 882: -1, 883: -1, 884: -1, 885: -1, 886: -1, 887: -1, 888: 173, 889: -1, 890: 174, 891: -1, 892: -1, 893: -1, 894: -1, 895: -1, 896: -1, 897: 175, 898: -1, 899: -1, 900: 176, 901: -1, 902: -1, 903: -1, 904: -1, 905: -1, 906: -1, 907: 177, 908: -1, 909: -1, 910: -1, 911: -1, 912: -1, 913: 178, 914: -1, 915: -1, 916: -1, 917: -1, 918: -1, 919: -1, 920: -1, 921: -1, 922: -1, 923: -1, 924: 179, 925: -1, 926: -1, 927: -1, 928: -1, 929: -1, 930: -1, 931: -1, 932: 180, 933: 181, 934: 182, 935: -1, 936: -1, 937: 183, 938: -1, 939: -1, 940: -1, 941: -1, 942: -1, 943: 184, 944: -1, 945: 185, 946: -1, 947: 186, 948: -1, 949: -1, 950: -1, 951: 187, 952: -1, 953: -1, 954: 188, 955: -1, 956: 189, 957: 190, 958: -1, 959: 191, 960: -1, 961: -1, 962: -1, 963: -1, 964: -1, 965: -1, 966: -1, 967: -1, 968: -1, 969: -1, 970: -1, 971: 192, 972: 193, 973: -1, 974: -1, 975: -1, 976: -1, 977: -1, 978: -1, 979: -1, 980: 194, 981: 195, 982: -1, 983: -1, 984: 196, 985: -1, 986: 197, 987: 198, 988: 199, 989: -1, 990: -1, 991: -1, 992: -1, 993: -1, 994: -1, 995: -1, 996: -1, 997: -1, 998: -1, 999: -1} indices_in_1k = [k for k in thousand_k_to_200 if thousand_k_to_200[k] != -1] ================================================ FILE: utils/imagenet_r.py ================================================ all_wnids = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02699494', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02776631', 'n02777292', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02788148', 'n02790996', 'n02791124', 'n02791270', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02807133', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02825657', 'n02834397', 'n02835271', 'n02837789', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02871525', 'n02877765', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02894605', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02927161', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966193', 'n02966687', 'n02971356', 'n02974003', 'n02977058', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02988304', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000247', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03032252', 'n03041632', 'n03042490', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03065424', 'n03075370', 'n03085013', 'n03089624', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03160309', 'n03179701', 'n03180011', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03208938', 'n03216828', 'n03218198', 'n03220513', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03290653', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388043', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03457902', 'n03459775', 'n03461385', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03529860', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03637318', 'n03642806', 'n03649909', 'n03657121', 'n03658185', 'n03661043', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03697007', 'n03706229', 'n03709823', 'n03710193', 'n03710637', 'n03710721', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03871628', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03887697', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03899768', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03967562', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04008634', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04081281', 'n04086273', 'n04090263', 'n04099969', 'n04111531', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04152593', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04200800', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04239074', 'n04243546', 'n04251144', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04264628', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04296562', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04356056', 'n04357314', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04380533', 'n04389033', 'n04392985', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04443257', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04462240', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04493381', 'n04501370', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04523525', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04592741', 'n04596742', 'n04597913', 'n04599235', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06359193', 'n06596364', 'n06785654', 'n06794110', 'n06874185', 'n07248320', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07590611', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07836838', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07892512', 'n07920052', 'n07930864', 'n07932039', 'n09193705', 'n09229709', 'n09246464', 'n09256479', 'n09288635', 'n09332890', 'n09399592', 'n09421951', 'n09428293', 'n09468604', 'n09472597', 'n09835506', 'n10148035', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141'] imagenet_r_wnids = {'n01443537', 'n01484850', 'n01494475', 'n01498041', 'n01514859', 'n01518878', 'n01531178', 'n01534433', 'n01614925', 'n01616318', 'n01630670', 'n01632777', 'n01644373', 'n01677366', 'n01694178', 'n01748264', 'n01770393', 'n01774750', 'n01784675', 'n01806143', 'n01820546', 'n01833805', 'n01843383', 'n01847000', 'n01855672', 'n01860187', 'n01882714', 'n01910747', 'n01944390', 'n01983481', 'n01986214', 'n02007558', 'n02009912', 'n02051845', 'n02056570', 'n02066245', 'n02071294', 'n02077923', 'n02085620', 'n02086240', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02091032', 'n02091134', 'n02092339', 'n02094433', 'n02096585', 'n02097298', 'n02098286', 'n02099601', 'n02099712', 'n02102318', 'n02106030', 'n02106166', 'n02106550', 'n02106662', 'n02108089', 'n02108915', 'n02109525', 'n02110185', 'n02110341', 'n02110958', 'n02112018', 'n02112137', 'n02113023', 'n02113624', 'n02113799', 'n02114367', 'n02117135', 'n02119022', 'n02123045', 'n02128385', 'n02128757', 'n02129165', 'n02129604', 'n02130308', 'n02134084', 'n02138441', 'n02165456', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02233338', 'n02236044', 'n02268443', 'n02279972', 'n02317335', 'n02325366', 'n02346627', 'n02356798', 'n02363005', 'n02364673', 'n02391049', 'n02395406', 'n02398521', 'n02410509', 'n02423022', 'n02437616', 'n02445715', 'n02447366', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02486410', 'n02510455', 'n02526121', 'n02607072', 'n02655020', 'n02672831', 'n02701002', 'n02749479', 'n02769748', 'n02793495', 'n02797295', 'n02802426', 'n02808440', 'n02814860', 'n02823750', 'n02841315', 'n02843684', 'n02883205', 'n02906734', 'n02909870', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02966193', 'n02980441', 'n02992529', 'n03124170', 'n03272010', 'n03345487', 'n03372029', 'n03424325', 'n03452741', 'n03467068', 'n03481172', 'n03494278', 'n03495258', 'n03498962', 'n03594945', 'n03602883', 'n03630383', 'n03649909', 'n03676483', 'n03710193', 'n03773504', 'n03775071', 'n03888257', 'n03930630', 'n03947888', 'n04086273', 'n04118538', 'n04133789', 'n04141076', 'n04146614', 'n04147183', 'n04192698', 'n04254680', 'n04266014', 'n04275548', 'n04310018', 'n04325704', 'n04347754', 'n04389033', 'n04409515', 'n04465501', 'n04487394', 'n04522168', 'n04536866', 'n04552348', 'n04591713', 'n07614500', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07714571', 'n07714990', 'n07718472', 'n07720875', 'n07734744', 'n07742313', 'n07745940', 'n07749582', 'n07753275', 'n07753592', 'n07768694', 'n07873807', 'n07880968', 'n07920052', 'n09472597', 'n09835506', 'n10565667', 'n12267677'} imagenet_r_mask = [wnid in imagenet_r_wnids for wnid in all_wnids] # imagenet_r_indices = [i for i in range(1000) if imagenet_r_mask[i] is True] # [1, 2, 4, 6, 8, 9, 11, 13, 22, 23, 26, 29, 31, 39, 47, 63, 71, 76, 79, 84, 90, 94, 96, 97, 99, 100, 105, 107, 113, 122, 125, 130, 132, 144, 145, 147, 148, 150, 151, 155, 160, 161, 162, 163, 171, 172, 178, 187, 195, 199, 203, 207, 208, 219, 231, 232, 234, 235, 242, 245, 247, 250, 251, 254, 259, 260, 263, 265, 267, 269, 276, 277, 281, 288, 289, 291, 292, 293, 296, 299, 301, 308, 309, 310, 311, 314, 315, 319, 323, 327, 330, 334, 335, 337, 338, 340, 341, 344, 347, 353, 355, 361, 362, 365, 366, 367, 368, 372, 388, 390, 393, 397, 401, 407, 413, 414, 425, 428, 430, 435, 437, 441, 447, 448, 457, 462, 463, 469, 470, 471, 472, 476, 483, 487, 515, 546, 555, 558, 570, 579, 583, 587, 593, 594, 596, 609, 613, 617, 621, 629, 637, 657, 658, 701, 717, 724, 763, 768, 774, 776, 779, 780, 787, 805, 812, 815, 820, 824, 833, 847, 852, 866, 875, 883, 889, 895, 907, 928, 931, 932, 933, 934, 936, 937, 943, 945, 947, 948, 949, 951, 953, 954, 957, 963, 965, 967, 980, 981, 983, 988] imagenet_o_wnids = ['n01443537', 'n01704323', 'n01770081', 'n01784675', 'n01819313', 'n01820546', 'n01910747', 'n01917289', 'n01968897', 'n02074367', 'n02317335', 'n02319095', 'n02395406', 'n02454379', 'n02606052', 'n02655020', 'n02666196', 'n02672831', 'n02730930', 'n02777292', 'n02783161', 'n02786058', 'n02787622', 'n02791270', 'n02808304', 'n02817516', 'n02841315', 'n02865351', 'n02877765', 'n02892767', 'n02906734', 'n02910353', 'n02916936', 'n02948072', 'n02965783', 'n03000134', 'n03000684', 'n03017168', 'n03026506', 'n03032252', 'n03075370', 'n03109150', 'n03126707', 'n03134739', 'n03160309', 'n03196217', 'n03207743', 'n03218198', 'n03223299', 'n03240683', 'n03271574', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03344393', 'n03347037', 'n03372029', 'n03376595', 'n03388043', 'n03388183', 'n03400231', 'n03445777', 'n03457902', 'n03467068', 'n03482405', 'n03483316', 'n03494278', 'n03530642', 'n03544143', 'n03584829', 'n03590841', 'n03598930', 'n03602883', 'n03649909', 'n03661043', 'n03666591', 'n03676483', 'n03692522', 'n03706229', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03742115', 'n03786901', 'n03788365', 'n03794056', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03840681', 'n03843555', 'n03854065', 'n03857828', 'n03868863', 'n03874293', 'n03884397', 'n03891251', 'n03908714', 'n03920288', 'n03929660', 'n03930313', 'n03937543', 'n03942813', 'n03944341', 'n03961711', 'n03970156', 'n03982430', 'n03991062', 'n03995372', 'n03998194', 'n04005630', 'n04023962', 'n04033901', 'n04040759', 'n04067472', 'n04074963', 'n04116512', 'n04118776', 'n04125021', 'n04127249', 'n04131690', 'n04141975', 'n04153751', 'n04154565', 'n04201297', 'n04204347', 'n04209133', 'n04209239', 'n04228054', 'n04235860', 'n04243546', 'n04252077', 'n04254120', 'n04258138', 'n04265275', 'n04270147', 'n04275548', 'n04330267', 'n04332243', 'n04336792', 'n04347754', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04429376', 'n04435653', 'n04442312', 'n04482393', 'n04501370', 'n04507155', 'n04525305', 'n04542943', 'n04554684', 'n04557648', 'n04562935', 'n04579432', 'n04591157', 'n04597913', 'n04599235', 'n06785654', 'n06874185', 'n07615774', 'n07693725', 'n07695742', 'n07697537', 'n07711569', 'n07714990', 'n07715103', 'n07716358', 'n07717410', 'n07718472', 'n07720875', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753275', 'n07753592', 'n07754684', 'n07768694', 'n07836838', 'n07871810', 'n07873807', 'n07880968', 'n09229709', 'n09472597', 'n12144580', 'n12267677', 'n13052670'] imagenet_o_mask = [wnid in set(imagenet_o_wnids) for wnid in all_wnids] ================================================ FILE: utils/mce_utils.py ================================================ # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, visit # https://github.com/NVlabs/FAN/blob/main/LICENSE # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. """ Misc functions, including distributed helpers. Mostly copy-paste from torchvision references. """ # Modified by: Daquan import io import os import time from collections import defaultdict, deque import datetime import torch import torch.distributed as dist data_loaders_names = { 'Brightness': 'brightness', 'Contrast': 'contrast', 'Defocus Blur': 'defocus_blur', 'Elastic Transform': 'elastic_transform', 'Fog': 'fog', 'Frost': 'frost', 'Gaussian Noise': 'gaussian_noise', 'Glass Blur': 'glass_blur', 'Impulse Noise': 'impulse_noise', 'JPEG Compression': 'jpeg_compression', 'Motion Blur': 'motion_blur', 'Pixelate': 'pixelate', 'Shot Noise': 'shot_noise', 'Snow': 'snow', 'Zoom Blur': 'zoom_blur' } def get_ce_alexnet(): """Returns Corruption Error values for AlexNet""" ce_alexnet = dict() ce_alexnet['gaussian_noise'] = 0.886428 ce_alexnet['shot_noise'] = 0.894468 ce_alexnet['impulse_noise'] = 0.922640 ce_alexnet['defocus_blur'] = 0.819880 ce_alexnet['glass_blur'] = 0.826268 ce_alexnet['motion_blur'] = 0.785948 ce_alexnet['zoom_blur'] = 0.798360 ce_alexnet['snow'] = 0.866816 ce_alexnet['frost'] = 0.826572 ce_alexnet['fog'] = 0.819324 ce_alexnet['brightness'] = 0.564592 ce_alexnet['contrast'] = 0.853204 ce_alexnet['elastic_transform'] = 0.646056 ce_alexnet['pixelate'] = 0.717840 ce_alexnet['jpeg_compression'] = 0.606500 return ce_alexnet def get_mce_from_accuracy(accuracy, error_alexnet): """Computes mean Corruption Error from accuracy""" error = 100. - accuracy ce = error / (error_alexnet * 100.) return ce class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=20, fmt=None): if fmt is None: fmt = "{median:.4f} ({global_avg:.4f})" self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt def update(self, value, n=1): self.deque.append(value) self.count += n self.total += value * n def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ if not is_dist_avail_and_initialized(): return t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') dist.barrier() dist.all_reduce(t) t = t.tolist() self.count = int(t[0]) self.total = t[1] @property def median(self): d = torch.tensor(list(self.deque)) return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque), dtype=torch.float32) return d.mean().item() @property def global_avg(self): return self.total / self.count @property def max(self): return max(self.deque) @property def value(self): return self.deque[-1] def __str__(self): return self.fmt.format( median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) class MetricLogger(object): def __init__(self, delimiter="\t"): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter def update(self, **kwargs): for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): loss_str.append( "{}: {}".format(name, str(meter)) ) return self.delimiter.join(loss_str) def synchronize_between_processes(self): for meter in self.meters.values(): meter.synchronize_between_processes() def add_meter(self, name, meter): self.meters[name] = meter def log_every(self, iterable, print_freq, header=None): i = 0 if not header: header = '' start_time = time.time() end = time.time() iter_time = SmoothedValue(fmt='{avg:.4f}') data_time = SmoothedValue(fmt='{avg:.4f}') space_fmt = ':' + str(len(str(len(iterable)))) + 'd' log_msg = [ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}' ] if torch.cuda.is_available(): log_msg.append('max mem: {memory:.0f}') log_msg = self.delimiter.join(log_msg) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) yield obj iter_time.update(time.time() - end) if i % print_freq == 0 or i == len(iterable) - 1: eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB)) else: print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time))) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('{} Total time: {} ({:.4f} s / it)'.format( header, total_time_str, total_time / len(iterable))) def _load_checkpoint_for_ema(model_ema, checkpoint): """ Workaround for ModelEma._load_checkpoint to accept an already-loaded object """ mem_file = io.BytesIO() torch.save(checkpoint, mem_file) mem_file.seek(0) model_ema._load_checkpoint(mem_file) def setup_for_distributed(is_master): """ This function disables printing when not in master process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop('force', False) if is_master or force: builtin_print(*args, **kwargs) __builtin__.print = print def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def save_on_master(*args, **kwargs): if is_main_process(): torch.save(*args, **kwargs) ================================================ FILE: utils/scaler.py ================================================ import torch from timm.utils import ApexScaler, NativeScaler try: from apex import amp has_apex = True except ImportError: amp = None has_apex = False class ApexScaler_SAM(ApexScaler): def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, step=0, rho=0.05): with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward(create_graph=create_graph) if step==0 or step==2: if clip_grad is not None: dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) optimizer.step() elif step==1: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), rho, norm_type=2.0) optimizer.step() ================================================ FILE: utils/utils.py ================================================ # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, visit # https://github.com/NVlabs/FAN/blob/main/LICENSE # Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd. # # This source code is licensed under the Clear BSD License # LICENSE file in the root directory of this file # All rights reserved. # Modified by: Daquan Zhou ''' - resize_pos_embed: resize position embedding - load_for_transfer_learning: load pretrained paramters to model in transfer learning - get_mean_and_std: calculate the mean and std value of dataset. - msr_init: net parameter initialization. - progress_bar: progress bar mimic xlua.progress. ''' import os import sys import time import torch import math import torch.nn as nn import torch.nn.init as init import logging import os from collections import OrderedDict import torch.nn.functional as F _logger = logging.getLogger(__name__) def resize_pos_embed(posemb, posemb_new): # example: 224:(14x14+1)-> 384: (24x24+1) # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 ntok_new = posemb_new.shape[1] if True: posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] # posemb_tok is for cls token, posemb_grid for the following tokens ntok_new -= 1 else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] gs_old = int(math.sqrt(len(posemb_grid))) # 14 gs_new = int(math.sqrt(ntok_new)) # 24 _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) # [1, 196, dim]->[1, 14, 14, dim]->[1, dim, 14, 14] posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bicubic') # [1, dim, 14, 14] -> [1, dim, 24, 24] posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) # [1, dim, 24, 24] -> [1, 24*24, dim] posemb = torch.cat([posemb_tok, posemb_grid], dim=1) # [1, 24*24+1, dim] return posemb def resize_pos_embed_cait(posemb, posemb_new): # example: 224:(14x14+1)-> 384: (24x24+1) # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 ntok_new = posemb_new.shape[1] posemb_grid = posemb[0] gs_old = int(math.sqrt(len(posemb_grid))) # 14 gs_new = int(math.sqrt(ntok_new)) # 24 _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) # [1, 196, dim]->[1, 14, 14, dim]->[1, dim, 14, 14] posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bicubic') # [1, dim, 14, 14] -> [1, dim, 24, 24] posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) # [1, dim, 24, 24] -> [1, 24*24, dim] return posemb_grid def resize_pos_embed_nocls(posemb, posemb_new): # example: 224:(14x14+1)-> 384: (24x24+1) # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 gs_old = posemb.shape[1] # 14 gs_new = posemb_new.shape[1] # 24 _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) posemb_grid = posemb posemb_grid = posemb_grid.permute(0, 3, 1, 2) # [1, 14, 14, dim]->[1, dim, 14, 14] posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bicubic') # [1, dim, 14, 14] -> [1, dim, 24, 24] posemb_grid = posemb_grid.permute(0, 2, 3, 1) # [1, dim, 24, 24]->[1, 24, 24, dim] return posemb_grid def load_state_dict(checkpoint_path,model, use_ema=False, num_classes=1000, no_pos_embed=False): if checkpoint_path and os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') state_dict_key = 'state_dict' if isinstance(checkpoint, dict): if use_ema and 'state_dict_ema' in checkpoint: state_dict_key = 'state_dict_ema' if state_dict_key and state_dict_key in checkpoint: new_state_dict = OrderedDict() for k, v in checkpoint[state_dict_key].items(): # strip `module.` prefix name = k[7:] if k.startswith('module') else k new_state_dict[name] = v state_dict = new_state_dict else: state_dict = checkpoint _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) if num_classes != 1000: # completely discard fully connected for all other differences between pretrained and created model del state_dict['head' + '.weight'] del state_dict['head' + '.bias'] old_aux_head_weight = state_dict.pop('aux_head.weight', None) old_aux_head_bias = state_dict.pop('aux_head.bias', None) if not no_pos_embed: old_posemb = state_dict['pos_embed'] if model.pos_embed.shape != old_posemb.shape: # need resize the position embedding by interpolate if len(old_posemb.shape)==3: if int(math.sqrt(old_posemb.shape[1]))**2==old_posemb.shape[1]: new_posemb = resize_pos_embed_cait(old_posemb, model.pos_embed) else: new_posemb = resize_pos_embed(old_posemb, model.pos_embed) elif len(old_posemb.shape)==4: new_posemb = resize_pos_embed_nocls(old_posemb, model.pos_embed) state_dict['pos_embed'] = new_posemb return state_dict else: _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() def load_for_transfer_learning(model, checkpoint_path, use_ema=False, strict=True, num_classes=1000): state_dict = load_state_dict(checkpoint_path, model, use_ema, num_classes) model.load_state_dict(state_dict, strict=strict) def load_for_probing(model, checkpoint_path, use_ema=False, strict=False, num_classes=19167): state_dict = load_state_dict(checkpoint_path, model, use_ema, num_classes=19167, no_pos_embed=True) info=model.load_state_dict(state_dict, strict=strict) print(info) def get_mean_and_std(dataset): '''Compute the mean and std value of dataset.''' dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) mean = torch.zeros(3) std = torch.zeros(3) print('==> Computing mean and std..') for inputs, targets in dataloader: for i in range(3): mean[i] += inputs[:,i,:,:].mean() std[i] += inputs[:,i,:,:].std() mean.div_(len(dataset)) std.div_(len(dataset)) return mean, std def init_params(net): '''Init layer parameters.''' for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal(m.weight, mode='fan_out') if m.bias: init.constant(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant(m.weight, 1) init.constant(m.bias, 0) elif isinstance(m, nn.Linear): init.normal(m.weight, std=1e-3) if m.bias: init.constant(m.bias, 0) ================================================ FILE: validate_ood.py ================================================ #!/usr/bin/env python3 # Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, visit # https://github.com/NVlabs/FAN/blob/main/LICENSE """ ImageNet Validation Script This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit. Hacked together by Ross Wightman (https://github.com/rwightman) """ import argparse import errno import os import csv import glob import time import logging import torch import torch.nn as nn import torch.nn.parallel from collections import OrderedDict from contextlib import suppress from timm.models import create_model, apply_test_time_pool, resume_checkpoint, load_checkpoint, is_model, list_models from timm.data import resolve_data_config, RealLabelsImagenet from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy import numpy as np from utils.imagenet_a import indices_in_1k from utils.imagenet_r import imagenet_r_mask from utils.mce_utils import get_ce_alexnet, get_mce_from_accuracy from data import create_loader, create_dataset from optim_factory import create_optimizer_v2, optimizer_kwargs from models import vision_transformer, swin_transformer, convnext has_apex = False try: from apex import amp has_apex = True except ImportError: pass has_native_amp = False try: if getattr(torch.cuda.amp, 'autocast') is not None: has_native_amp = True except AttributeError: pass torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') parser.add_argument('data', metavar='DIR', help='path to dataset') parser.add_argument('--dataset', '-d', metavar='NAME', default='', help='dataset type (default: ImageFolder/ImageTar if empty)') parser.add_argument('--split', metavar='NAME', default='validation', help='dataset split (default: validation)') parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', help='model architecture (default: dpn92)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') parser.add_argument('--img-size', default=None, type=int, metavar='N', help='Input image dimension, uses model default if empty') parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') parser.add_argument('--crop-pct', default=None, type=float, metavar='N', help='Input image center crop pct') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', help='Override std deviation of of dataset') parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('--num-classes', type=int, default=None, help='Number classes in dataset') parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', help='path to class to idx mapping file (default: "")') parser.add_argument('--gp', default=None, type=str, metavar='POOL', help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') parser.add_argument('--log-freq', default=50, type=int, metavar='N', help='batch logging frequency (default: 10)') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') parser.add_argument('--num-gpu', type=int, default=1, help='Number of GPUS to use') parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', help='disable test time pool') parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') parser.add_argument('--amp', action='store_true', default=False, help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') parser.add_argument('--apex-amp', action='store_true', default=False, help='Use NVIDIA Apex AMP mixed precision') parser.add_argument('--native-amp', action='store_true', default=False, help='Use Native Torch AMP mixed precision') parser.add_argument('--tf-preprocessing', action='store_true', default=False, help='Use Tensorflow preprocessing pipeline (require CPU TF installed') parser.add_argument('--use-ema', dest='use_ema', action='store_true', help='use ema version of weights if present') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true', help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance') parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', help='Output csv file for validation results (summary)') parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', help='Real labels JSON file for imagenet evaluation') parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', help='Valid label indices txt file for validation of partial label space') # finetuning parser.add_argument('--tuning-mode', default=None, type=str, help='Method of fine-tuning (default: None') parser.add_argument('--num-vpt', default=None, type=int, help='The number of prompts in VPT') parser.add_argument('--imagenet_a', action='store_true', default=False, help='replace labels from 1k to 200') parser.add_argument('--imagenet_r', action='store_true', default=False, help='replace labels from 1k to imagenet-r indices') parser.add_argument('--imagenet_c', action='store_true', default=False, help='use corrupted dataset for evaluation') def validate(args): args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher amp_autocast = suppress # do nothing if args.amp: if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True else: _logger.warning("Neither APEX or Native Torch AMP is available.") assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." if args.native_amp: amp_autocast = torch.cuda.amp.autocast _logger.info('Validating in mixed precision with native PyTorch AMP.') elif args.apex_amp: _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') else: _logger.info('Validating in float32. AMP not enabled.') if args.legacy_jit: set_jit_legacy() model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, global_pool=args.gp, scriptable=args.torchscript, tuning_mode=args.tuning_mode, num_vpt=args.num_vpt ) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes if args.checkpoint: resume_epoch = resume_checkpoint( model, args.checkpoint ) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model, use_test_size=True) test_time_pool = False if not args.no_test_pool: model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) model = model.cuda() if args.apex_amp: model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() print(args.data) dataset = create_dataset( root=args.data, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() with torch.no_grad(): input = torch.randn((args.batch_size,) + data_config['input_size']).cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output with amp_autocast(): output = model(input) if args.imagenet_a: output = output[:, indices_in_1k] if args.imagenet_r: output = output[:, imagenet_r_mask] if isinstance(output, (tuple, list)): output = output[0] if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) if real_labels is not None: # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) else: top1a, top5a = top1.avg, top5.avg results = OrderedDict( top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results def main(): setup_default_logging() args = parser.parse_args() if not args.imagenet_c: if args.imagenet_a or args.imagenet_r: validate(args) else: print('Please specify an OOD dataset.') return else: results_file = args.results_file or './results-all.csv' os.makedirs(results_file, exist_ok=True) blur_list = ['gaussian_blur', 'motion_blur', 'glass_blur', 'defocus_blur'] noise_list = ['gaussian_noise', 'shot_noise', 'speckle_noise', 'impulse_noise'] digital_list = ['contrast', 'jpeg_compression', 'saturate', 'pixelate'] weather_list = ['snow', 'fog', 'frost', 'spatter', 'brightness'] extra = ['zoom_blur', 'elastic_transform'] name_list = noise_list + extra + blur_list + digital_list + weather_list ce_alexnet = get_ce_alexnet() mCE = 0 counter = 0 average_acc = {} base_dir = args.data for noise_name in name_list: res_sum = 0 root = base_dir + noise_name + '/' results = [] for i in range(0, 5): args.data = root + str(i+1) print('validating dir:', args.data) res = validate(args) results.append(res['top1']) res_sum += res['top1'] if noise_name in ce_alexnet.keys(): CE = get_mce_from_accuracy(res['top1'], ce_alexnet[noise_name]) mCE += CE counter += 1 results.append(res_sum/(i+1)) average_acc[noise_name] = res_sum/(i+1) np.savetxt(results_file + noise_name + '_' + '%.2f' % (res_sum/(i+1)) + '.csv', results) print('average score is:', res_sum / (i+1)) print('current mCE is: ', mCE/counter) np.savetxt(results_file + 'mCE' + '_' + '%.2f' % (mCE/counter) + '.csv', results) print('all average score is:', average_acc) print('mCE is: ', mCE/counter) def write_results(results_file, results): with open(results_file, mode='w') as cf: dw = csv.DictWriter(cf, fieldnames=results[0].keys()) dw.writeheader() for r in results: dw.writerow(r) cf.flush() if __name__ == '__main__': main()