Repository: fmu2/PyTorch-MAML Branch: master Commit: eee3fe3da538 Files: 40 Total size: 83.4 KB Directory structure: gitextract_xgy6p1l2/ ├── README.md ├── configs/ │ └── convnet4/ │ ├── mini-imagenet/ │ │ ├── 5_way_1_shot/ │ │ │ ├── test_reproduce.yaml │ │ │ ├── test_template.yaml │ │ │ ├── train_reproduce.yaml │ │ │ └── train_template.yaml │ │ └── 5_way_5_shot/ │ │ ├── test_reproduce.yaml │ │ ├── test_template.yaml │ │ ├── train_reproduce.yaml │ │ └── train_template.yaml │ └── tiered-imagenet/ │ ├── 5_way_1_shot/ │ │ ├── test_reproduce.yaml │ │ ├── test_template.yaml │ │ ├── train_reproduce.yaml │ │ └── train_template.yaml │ └── 5_way_5_shot/ │ ├── test_reproduce.yaml │ ├── test_template.yaml │ ├── train_reproduce.yaml │ └── train_template.yaml ├── datasets/ │ ├── __init__.py │ ├── cifar100.py │ ├── cub200.py │ ├── datasets.py │ ├── inatural.py │ ├── mini_imagenet.py │ ├── tiered_imagenet.py │ └── transforms.py ├── models/ │ ├── __init__.py │ ├── classifiers/ │ │ ├── __init__.py │ │ ├── classifiers.py │ │ └── logistic.py │ ├── encoders/ │ │ ├── __init__.py │ │ ├── convnet4.py │ │ ├── encoders.py │ │ ├── resnet12.py │ │ └── resnet18.py │ ├── maml.py │ └── modules.py ├── test.py ├── train.py └── utils/ ├── __init__.py └── optimizers.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # MAML in PyTorch - Re-implementation and Beyond A PyTorch implementation of [Model Agnostic Meta-Learning (MAML)](https://arxiv.org/abs/1703.03400). We faithfully reproduce [the official Tensorflow implementation](https://github.com/cbfinn/maml) while incorporating a number of additional features that may ease further study of this very high-profile meta-learning framework. ## Overview This repository contains code for training and evaluating MAML on the mini-ImageNet and tiered-ImageNet datasets most commonly used for few-shot image classification. To the best of our knowledge, this is the only PyTorch implementation of MAML to date that **fully reproduces the results in the original paper** without applying tricks such as data augmentation, evaluation on multiple crops, and ensemble of multiple models. Other existing PyTorch implementations typically see a ~3% gap in accuracy for the 5-way-1-shot and 5-way-5-shot classification tasks on mini-ImageNet. Beyond reproducing the results, our implementation comes with a few extra bits that we believe can be helpful for further development of the framework. We highlight the improvements we have built into our code, and discuss our observations that warrent some attention. ## Implementation Highlights - **Batch normalization with per-episode running statistics.** Our implementation provides flexibility of tracking global and/or per-episode running statistics, hence supporting both transductive and inductive inference. - **Better data pre-processing.** The official implementation does not normalize and augment data. We support data normalization and a variety of data augmentation techniques. We also implement data batching and support/query-set splitting more efficiently. - **More datasets.** We support mini-ImageNet, tiered-ImageNet and more. - **More options for outer-loop optimization.** We support mutiple optimizers and learning-rate schedulers for the outer-loop optimization. - **More powerful inner-loop optimization.** The official implementation uses vanilla gradient descent in the inner loop. We support momentum and weight decay. - **More options for encoder architecture.** We support the standard four-layer ConvNet as well as ResNet-12 and ResNet-18 as the encoder. - **Easy layer freezing.** We provide an interface for layer freezing experiments. One may freeze an arbitrary set of layers or blocks during inner-loop adaptation. - **Meta-learning with zero-initialized classifier head.** The official implementation learns a meta-initialization for both the encoder and the classifier head. This prevents one from varying the number of categories at training or test time. With our implementation, one may opt to learn a meta-initialization for the encoder while initializing the classifier head at zero. - **Distributed training and gradient checkpointing.** MAML is very memory-intensive because it buffers all tensors generated throughout the inner-loop adaptation steps. Gradient checkpointing trades compute for memory, effectively bringing the memory cost from O(N) down to O(1), where N is the number of inner-loop steps. In our experiments, gradient checkpointing saved up to 80% of GPU memory at the cost of running the forward pass more than once (a moderate 20% increase in running time). ## Transductive or Inductive? The official implementation assumes transductive learning. The batch normalization layers do not track running statistics at training time, and they use mini-batch statistics at test time. The implicit assumption here is that test data come in mini-batches and are perhaps balanced across categories. This is a very restrictive assumption and does not land MAML directly comparable with the vast majority of meta-learning and few-shot learning methods. Unfortunately, this is not immediately obvious from the paper, and our findings suggest that the performance of MAML is hugely overestimated. - **Accuracy is very sensitive to the size of query set in the transductive setting.** For example, the result for 5-way-1-shot classification on miniImageNet from the paper (48.70%) was obtained on five queries, one per category. We found that the accuracy dropped by ~1.5% given five queries per category, and by ~2.5% given 15 queries per category. - The paper reports mean accuracy over 600 independently sampled tasks, or trials. We found that **600 trials, again in the transductive setting, are insufficient for an unbiased estimate of model performance**. The mean accuracy from 6,000 trials is more stable, and is always ~2% lower than that from the first 600 trials. We conjecture that the distribution of per-trial accuracy is highly skewed towards the high end. - We found that **MAML performs a lot worse in the inductive setting**. Given the same model configuration, inductive accuracy is always much lower (~4%) than the *corrected* transductive accuracy, which is already a few percentage points behind the reported number. Hence, one should be extremely cautious when comparing MAML with its competitors as is evident from the discussion above. ## FOMAML and layer freezing Unfortunately, some insights discussed in the original paper and its follow-up works do not appear to hold in the inductive setting. - FOMAML (i.e. the first-order approximation of MAML) performs as well as MAML in transductive learning, but fails completely in the inductive setting. - Completely freezing the encoder during inner-loop adaption as was done in [this work](https://arxiv.org/abs/1909.09157) results in dramatic decrease in accuracy. ## BatchNorm and TaskNorm [A recent work](https://arxiv.org/abs/2003.03284) proposes TaskNorm, a test-time enhancement of batch normalization, noting that the small batch sizes during training may leave batch normalization less effective. We did not have much success with this method. We observed marginal improvement most of the time, and found that it hurts performance occationally. That said, we do believe that batch normalization is hard to deal with in MAML. TaskNorm attempts to attack the problem of small batch sizes, which we conjecture is just one among the three main causes (i.e., extremely scarse training data, extremely small batch sizes, and extremely small number of inner-loop updates) of the ineffectiveness of batch normalization in MAML. ## Quick Start ### 0. Preliminaries **Environment** - Python 3.6.8 (or any Python 3 distribution) - PyTorch 1.3.1 (or any PyTorch > 1.0) - tensorboardX **Datasets** Please follow the download links [here](https://github.com/cyvius96/few-shot-meta-baseline). Please modify the file names accordingly so that they can be recognized by the data loaders. **Configurations** Template configuration files as well as those for reproducing the results in the original paper can be found in `configs/`. The hyperparameters are self-explanatory. ### 1. Training MAML Here is the command for single-GPU training of MAML with ConvNet4 backbone for 5-way-1-shot classification on mini-ImageNet to reproduce the result in the original paper. ``` python train.py --config=configs/convnet4/mini-imagenet/train_reproduce.yaml ``` Use `-gpu` to specify available GPUs for multi-GPU training. For example, ``` python train.py --config=configs/convnet4/mini-imagenet/train_reproduce.yaml --gpu=0,1 ``` Add `-efficient` to enable gradient checkpointing. This aggressively saves GPU memory while slightly increases running time. ``` python train.py --config=configs/convnet4/mini-imagenet/train_reproduce.yaml --efficient ``` Use `-tag` to customize the name of the directory where the checkpoints and log files are saved. ### 2. Testing MAML Here is how one would test MAML for 5-way-1-shot classification on mini-ImageNet to reproduce the result in the original paper. Please confirm the loading path first. ``` python test.py --config=configs/convnet4/mini-imagenet/test_reproduce.yaml ``` The `-gpu` and `-efficient` tags function similarly as in training. ## Contact [Xinchan Zhu](https://www.linkedin.com/in/xinchan-zhu-66673b106) (zhuxinchan@gmail.com) ## Cite our Repository ``` @misc{pytorch_maml, title={maml in pytorch - re-implementation and beyond}, author={Zhu, Xinchan}, howpublished={\url{https://github.com/shirleyzhu233/PyTorch-MAML}}, year={2020} } ``` ## Related Code Repositories Our implementation is inspired by the following repositories. * maml (the official implementation) * MAML-Pytorch * HowToTrainYourMAMLPytorch * memory-efficient-maml ## References ``` @inproceedings{finn2017model, title={Model-agnostic meta-learning for fast adaptation of deep networks}, author={Finn, Chelsea and Abbeel, Pieter and Levine, Sergey}, booktitle={International Conference on Machine Learning (ICML)}, year={2017} } @inproceedings{raghu2019rapid, title={Rapid learning or feature reuse? towards understanding the effectiveness of maml}, author={Raghu, Aniruddh and Raghu, Maithra and Bengio, Samy and Vinyals, Oriol}, booktitle={International Conference on Learning Representations (ICLR)}, year={2019} } @article{Bronskill2020tasknorm, title={Tasknorm: rethinking batch normalization for meta-learning}, author={Bronskill, John and Gordon, Jonathan and Requeima, James and Nowozin, Sebastian and Turner, Richard E.}, journal={arXiv preprint arXiv:2003.03284}, year={2020} } ``` ================================================ FILE: configs/convnet4/mini-imagenet/5_way_1_shot/test_reproduce.yaml ================================================ dataset: meta-mini-imagenet test: split: meta-test image_size: 84 normalization: False transform: null n_batch: 150 n_episode: 4 n_way: 5 n_shot: 1 n_query: 1 load: ./save/convnet4_mini-imagenet_5_way_1_shot/max-va.pth inner_args: n_step: 10 encoder_lr: 0.01 classifier_lr: 0.01 first_order: False # set to True for FOMAML frozen: - bn epoch: 1 ================================================ FILE: configs/convnet4/mini-imagenet/5_way_1_shot/test_template.yaml ================================================ dataset: meta-mini-imagenet test: split: meta-test image_size: 84 normalization: True transform: flip n_batch: 200 n_episode: 4 n_way: 5 n_shot: 1 n_query: 15 load: ./save/convnet4_mini-imagenet_5_way_1_shot/max-va.pth inner_args: reset_classifier: True n_step: 5 encoder_lr: 0.01 classifier_lr: 0.01 momentum: 0.9 weight_decay: 5.e-4 first_order: False epoch: 10 ================================================ FILE: configs/convnet4/mini-imagenet/5_way_1_shot/train_reproduce.yaml ================================================ dataset: meta-mini-imagenet train: split: meta-train image_size: 84 normalization: False transform: null n_batch: 200 n_episode: 4 n_way: 5 n_shot: 1 n_query: 15 val: split: meta-val image_size: 84 normalization: False transform: null n_batch: 200 n_episode: 4 n_way: 5 n_shot: 1 n_query: 15 encoder: convnet4 encoder_args: bn_args: track_running_stats: False classifier: logistic inner_args: n_step: 5 encoder_lr: 0.01 classifier_lr: 0.01 first_order: False # set to True for FOMAML frozen: - bn optimizer: adam optimizer_args: lr: 0.001 epoch: 300 ================================================ FILE: configs/convnet4/mini-imagenet/5_way_1_shot/train_template.yaml ================================================ dataset: meta-mini-imagenet train: split: meta-train image_size: 84 normalization: True transform: flip n_batch: 200 n_episode: 4 n_way: 5 n_shot: 1 n_query: 15 val: split: meta-val image_size: 84 normalization: True transform: flip n_batch: 200 n_episode: 4 n_way: 5 n_shot: 1 n_query: 15 encoder: convnet4 encoder_args: bn_args: track_running_stats: True episodic: - conv1 - conv2 - conv3 - conv4 classifier: logistic inner_args: reset_classifier: True n_step: 5 encoder_lr: 0.01 classifier_lr: 0.01 momentum: 0.9 weight_decay: 5.e-4 first_order: False optimizer: sgd optimizer_args: lr: 0.01 weight_decay: 5.e-4 schedule: step milestones: - 120 - 140 epoch: 150 ================================================ FILE: configs/convnet4/mini-imagenet/5_way_5_shot/test_reproduce.yaml ================================================ dataset: meta-mini-imagenet test: split: meta-test image_size: 84 normalization: False transform: null n_batch: 150 n_episode: 4 n_way: 5 n_shot: 5 n_query: 5 load: ./save/convnet4_mini-imagenet_5_way_5_shot/max-va.pth inner_args: n_step: 10 encoder_lr: 0.01 classifier_lr: 0.01 first_order: False frozen: - bn epoch: 1 ================================================ FILE: configs/convnet4/mini-imagenet/5_way_5_shot/test_template.yaml ================================================ dataset: meta-mini-imagenet test: split: meta-test image_size: 84 normalization: True transform: flip n_batch: 200 n_episode: 4 n_way: 5 n_shot: 5 n_query: 15 load: ./save/convnet4_mini-imagenet_5_way_5_shot/max-va.pth inner_args: reset_classifier: True n_step: 5 encoder_lr: 0.01 classifier_lr: 0.01 momentum: 0.9 weight_decay: 5.e-4 first_order: False epoch: 10 ================================================ FILE: configs/convnet4/mini-imagenet/5_way_5_shot/train_reproduce.yaml ================================================ dataset: meta-mini-imagenet train: split: meta-train image_size: 84 normalization: False transform: null n_batch: 200 n_episode: 4 n_way: 5 n_shot: 5 n_query: 15 val: split: meta-val image_size: 84 normalization: False transform: null n_batch: 200 n_episode: 4 n_way: 5 n_shot: 5 n_query: 15 encoder: convnet4 encoder_args: bn_args: track_running_stats: False classifier: logistic inner_args: n_step: 5 encoder_lr: 0.01 classifier_lr: 0.01 first_order: False frozen: - bn optimizer: adam optimizer_args: lr: 0.001 epoch: 300 ================================================ FILE: configs/convnet4/mini-imagenet/5_way_5_shot/train_template.yaml ================================================ dataset: meta-mini-imagenet train: split: meta-train image_size: 84 normalization: True transform: flip n_batch: 200 n_episode: 4 n_way: 5 n_shot: 5 n_query: 15 val: split: meta-val image_size: 84 normalization: True transform: flip n_batch: 200 n_episode: 4 n_way: 5 n_shot: 5 n_query: 15 encoder: convnet4 encoder_args: bn_args: track_running_stats: True episodic: - conv1 - conv2 - conv3 - conv4 classifier: logistic inner_args: reset_classifier: True n_step: 5 encoder_lr: 0.01 classifier_lr: 0.01 momentum: 0.9 weight_decay: 5.e-4 first_order: False optimizer: sgd optimizer_args: lr: 0.01 weight_decay: 5.e-4 schedule: step milestones: - 120 - 140 epoch: 150 ================================================ FILE: configs/convnet4/tiered-imagenet/5_way_1_shot/test_reproduce.yaml ================================================ dataset: meta-tiered-imagenet test: split: meta-test image_size: 84 normalization: False transform: null n_batch: 150 n_episode: 4 n_way: 5 n_shot: 1 n_query: 1 load: ./save/wide-convnet4_tiered-imagenet_5_way_1_shot/max-va.pth inner_args: n_step: 10 encoder_lr: 0.01 classifier_lr: 0.01 first_order: False # set to True for FOMAML frozen: - bn epoch: 1 ================================================ FILE: configs/convnet4/tiered-imagenet/5_way_1_shot/test_template.yaml ================================================ dataset: meta-tiered-imagenet test: split: meta-test image_size: 84 normalization: True transform: flip n_batch: 200 n_episode: 4 n_way: 5 n_shot: 1 n_query: 15 load: ./save/wide-convnet4_tiered-imagenet_5_way_1_shot/max-va.pth inner_args: reset_classifier: True n_step: 5 encoder_lr: 0.01 classifier_lr: 0.01 momentum: 0.9 weight_decay: 5.e-4 first_order: False epoch: 10 ================================================ FILE: configs/convnet4/tiered-imagenet/5_way_1_shot/train_reproduce.yaml ================================================ dataset: meta-tiered-imagenet train: split: meta-train image_size: 84 normalization: False transform: null n_batch: 200 n_episode: 4 n_way: 5 n_shot: 1 n_query: 15 val: split: meta-val image_size: 84 normalization: False transform: null n_batch: 200 n_episode: 4 n_way: 5 n_shot: 1 n_query: 15 encoder: wide-convnet4 encoder_args: bn_args: track_running_stats: False classifier: logistic inner_args: n_step: 5 encoder_lr: 0.01 classifier_lr: 0.01 first_order: False # set to True for FOMAML frozen: - bn optimizer: adam optimizer_args: lr: 0.001 epoch: 300 ================================================ FILE: configs/convnet4/tiered-imagenet/5_way_1_shot/train_template.yaml ================================================ dataset: meta-tiered-imagenet train: split: meta-train image_size: 84 normalization: True transform: flip n_batch: 200 n_episode: 4 n_way: 5 n_shot: 1 n_query: 15 val: split: meta-val image_size: 84 normalization: True transform: flip n_batch: 200 n_episode: 4 n_way: 5 n_shot: 1 n_query: 15 encoder: wide-convnet4 encoder_args: bn_args: track_running_stats: True episodic: - conv1 - conv2 - conv3 - conv4 classifier: logistic inner_args: reset_classifier: True n_step: 5 encoder_lr: 0.01 classifier_lr: 0.01 momentum: 0.9 weight_decay: 5.e-4 first_order: False optimizer: sgd optimizer_args: lr: 0.01 weight_decay: 5.e-4 schedule: step milestones: - 120 - 140 epoch: 150 ================================================ FILE: configs/convnet4/tiered-imagenet/5_way_5_shot/test_reproduce.yaml ================================================ dataset: meta-tiered-imagenet test: split: meta-test image_size: 84 normalization: False transform: null n_batch: 150 n_episode: 4 n_way: 5 n_shot: 5 n_query: 5 load: ./save/wide-convnet4_tiered-imagenet_5_way_5_shot/max-va.pth inner_args: n_step: 10 encoder_lr: 0.01 classifier_lr: 0.01 first_order: False # set to True for FOMAML frozen: - bn epoch: 1 ================================================ FILE: configs/convnet4/tiered-imagenet/5_way_5_shot/test_template.yaml ================================================ dataset: meta-tiered-imagenet test: split: meta-test image_size: 84 normalization: True transform: flip n_batch: 200 n_episode: 4 n_way: 5 n_shot: 5 n_query: 15 load: ./save/wide-convnet4_tiered-imagenet_5_way_5_shot/max-va.pth inner_args: reset_classifier: True n_step: 5 encoder_lr: 0.01 classifier_lr: 0.01 momentum: 0.9 weight_decay: 5.e-4 first_order: False epoch: 10 ================================================ FILE: configs/convnet4/tiered-imagenet/5_way_5_shot/train_reproduce.yaml ================================================ dataset: meta-tiered-imagenet train: split: meta-train image_size: 84 normalization: False transform: null n_batch: 200 n_episode: 4 n_way: 5 n_shot: 5 n_query: 15 val: split: meta-val image_size: 84 normalization: False transform: null n_batch: 200 n_episode: 4 n_way: 5 n_shot: 5 n_query: 15 encoder: wide-convnet4 encoder_args: bn_args: track_running_stats: False classifier: logistic inner_args: n_step: 5 encoder_lr: 0.01 classifier_lr: 0.01 first_order: False # set to True for FOMAML frozen: - bn optimizer: adam optimizer_args: lr: 0.001 epoch: 300 ================================================ FILE: configs/convnet4/tiered-imagenet/5_way_5_shot/train_template.yaml ================================================ dataset: meta-tiered-imagenet train: split: meta-train image_size: 84 normalization: True transform: flip n_batch: 200 n_episode: 4 n_way: 5 n_shot: 5 n_query: 15 val: split: meta-val image_size: 84 normalization: True transform: flip n_batch: 200 n_episode: 4 n_way: 5 n_shot: 5 n_query: 15 encoder: wide-convnet4 encoder_args: bn_args: track_running_stats: True episodic: - conv1 - conv2 - conv3 - conv4 classifier: logistic inner_args: reset_classifier: True n_step: 5 encoder_lr: 0.01 classifier_lr: 0.01 momentum: 0.9 weight_decay: 5.e-4 first_order: False optimizer: sgd optimizer_args: lr: 0.01 weight_decay: 5.e-4 schedule: step milestones: - 120 - 140 epoch: 150 ================================================ FILE: datasets/__init__.py ================================================ from .datasets import make, collate_fn from . import mini_imagenet from . import tiered_imagenet from . import cifar100 from . import cub200 from . import inatural from . import transforms ================================================ FILE: datasets/cifar100.py ================================================ import os import pickle import torch from torch.utils.data import Dataset import numpy as np from PIL import Image from .datasets import register from .transforms import get_transform class Cifar100(Dataset): def __init__(self, root_path, split='train', image_size=32, normalization=True, transform=None): super(Cifar100, self).__init__() split_dict = {'train': 'train', # standard train 'trainval': 'trainval', # standard train + val 'meta-train': 'train', # meta-train 'meta-val': 'val', # meta-val 'meta-trainval': 'trainval', # meta-train + meta-val 'meta-test': 'test', # meta-test } split_tag = split_dict[split] split_file = os.path.join(root_path, split_tag + '.pickle') assert os.path.isfile(split_file) with open(split_file, 'rb') as f: pack = pickle.load(f, encoding='latin1') data, label = pack['data'], pack['labels'] data = [Image.fromarray(x) for x in data] label = np.array(label) label_key = sorted(np.unique(label)) label_map = dict(zip(label_key, range(len(label_key)))) new_label = np.array([label_map[x] for x in label]) self.root_path = root_path self.split_tag = split_tag self.image_size = image_size self.data = data self.label = new_label self.n_classes = len(label_key) if normalization: self.norm_params = {'mean': [0.5071, 0.4867, 0.4408], 'std': [0.2675, 0.2565, 0.2761]} else: self.norm_params = {'mean': [0., 0., 0.], 'std': [1., 1., 1.]} self.transform = get_transform(transform, image_size, self.norm_params) def convert_raw(x): mean = torch.tensor(self.norm_params['mean']).view(3, 1, 1).type_as(x) std = torch.tensor(self.norm_params['std']).view(3, 1, 1).type_as(x) return x * std + mean self.convert_raw = convert_raw def __len__(self): return len(self.data) def __getitem__(self, index): image = self.transform(self.data[index]) label = self.label[index] return image, label class MetaCifar100(Cifar100): def __init__(self, root_path, split='train', image_size=32, normalization=True, transform=None, val_transform=None, n_batch=200, n_episode=4, n_way=5, n_shot=1, n_query=15): super(MetaCifar100, self).__init__(root_path, split, image_size, normalization, transform) self.n_batch = n_batch self.n_episode = n_episode self.n_way = n_way self.n_shot = n_shot self.n_query = n_query self.catlocs = tuple() for cat in range(self.n_classes): self.catlocs += (np.argwhere(self.label == cat).reshape(-1),) self.val_transform = get_transform( val_transform, image_size, self.norm_params) def __len__(self): return self.n_batch * self.n_episode def __getitem__(self, index): shot, query = [], [] cats = np.random.choice(self.n_classes, self.n_way, replace=False) for c in cats: c_shot, c_query = [], [] idx_list = np.random.choice( self.catlocs[c], self.n_shot + self.n_query, replace=False) shot_idx, query_idx = idx_list[:self.n_shot], idx_list[-self.n_query:] for idx in shot_idx: c_shot.append(self.transform(self.data[idx])) for idx in query_idx: c_query.append(self.val_transform(self.data[idx])) shot.append(torch.stack(c_shot)) query.append(torch.stack(c_query)) shot = torch.cat(shot, dim=0) # [n_way * n_shot, C, H, W] query = torch.cat(query, dim=0) # [n_way * n_query, C, H, W] cls = torch.arange(self.n_way)[:, None] shot_labels = cls.repeat(1, self.n_shot).flatten() # [n_way * n_shot] query_labels = cls.repeat(1, self.n_query).flatten() # [n_way * n_query] return shot, query, shot_labels, query_labels @register('cifar-fs') class CifarFS(Cifar100): def __init__(self, *args): super(CifarFS, self).__init__(*args) @register('meta-cifar-fs') class MetaCifarFS(MetaCifar100): def __init__(self, *args): super(MetaCifarFS, self).__init__(*args) @register('fc100') class FC100(Cifar100): def __init__(self, *args): super(FC100, self).__init__(*args) @register('meta-fc100') class MetaFC100(MetaCifar100): def __init__(self, *args): super(MetaFC100, self).__init__(*args) ================================================ FILE: datasets/cub200.py ================================================ import os import torch from torch.utils.data import Dataset import numpy as np from PIL import Image from .datasets import register from .transforms import get_transform @register('cub200') class CUB200(Dataset): def __init__(self, root_path, split='train', image_size=84, normalization=True, transform=None): super(CUB200, self).__init__() split_dict = {'train': 'train', # standard train 'meta-train': 'train', # meta-train 'meta-val': 'val', # meta-val 'meta-test': 'test', # meta-test } split_tag = split_dict[split] split_file = os.path.join(root_path, 'fs-splits', split_tag + '.csv') assert os.path.isfile(split_file) with open(split_file, 'r') as f: pairs = [x.strip().split(',') for x in f.readlines() if x.strip() != ''] data, label = [x[0] for x in pairs], [int(x[1]) for x in pairs] label = np.array(label) label_key = sorted(np.unique(label)) label_map = dict(zip(label_key, range(len(label_key)))) new_label = np.array([label_map[x] for x in label]) self.root_path = root_path self.split_tag = split_tag self.image_size = image_size self.data = data self.label = new_label self.n_classes = len(label_key) if normalization: self.norm_params = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]} # ImageNet statistics else: self.norm_params = {'mean': [0., 0., 0.], 'std': [1., 1., 1.]} self.transform = get_transform(transform, image_size, self.norm_params) def convert_raw(x): mean = torch.tensor(self.norm_params['mean']).view(3, 1, 1).type_as(x) std = torch.tensor(self.norm_params['std']).view(3, 1, 1).type_as(x) return x * std + mean self.convert_raw = convert_raw def _load_image(self, index): image_path = os.path.join(self.root_path, 'images', self.data[index]) assert os.path.isfile(image_path) image = Image.open(image_path).convert('RGB') return image def __len__(self): return len(self.label) def __getitem__(self, index): image = self.transform(self._load_image(index)) label = self.label[index] return image, label @register('meta-cub200') class MetaCUB200(CUB200): def __init__(self, root_path, split='train', image_size=84, normalization=True, transform=None, val_transform=None, n_batch=200, n_episode=4, n_way=5, n_shot=1, n_query=15): super(MetaCUB200, self).__init__(root_path, split, image_size, normalization, transform) self.n_batch = n_batch self.n_episode = n_episode self.n_way = n_way self.n_shot = n_shot self.n_query = n_query self.catlocs = tuple() for cat in range(self.n_classes): self.catlocs += (np.argwhere(self.label == cat).reshape(-1),) self.val_transform = get_transform( val_transform, image_size, self.norm_params) def __len__(self): return self.n_batch * self.n_episode def __getitem__(self, index): shot, query = [], [] cats = np.random.choice(self.n_classes, self.n_way, replace=False) for c in cats: c_shot, c_query = [], [] idx_list = np.random.choice( self.catlocs[c], self.n_shot + self.n_query, replace=False) shot_idx, query_idx = idx_list[:self.n_shot], idx_list[-self.n_query:] for idx in shot_idx: c_shot.append(self.transform(self._load_image(idx))) for idx in query_idx: c_query.append(self.val_transform(self._load_image(idx))) shot.append(torch.stack(c_shot)) query.append(torch.stack(c_query)) shot = torch.cat(shot, dim=0) # [n_way * n_shot, C, H, W] query = torch.cat(query, dim=0) # [n_way * n_query, C, H, W] cls = torch.arange(self.n_way)[:, None] shot_labels = cls.repeat(1, self.n_shot).flatten() # [n_way * n_shot] query_labels = cls.repeat(1, self.n_query).flatten() # [n_way * n_query] return shot, query, shot_labels, query_labels ================================================ FILE: datasets/datasets.py ================================================ import os import torch DEFAULT_ROOT = './materials' datasets = {} def register(name): def decorator(cls): datasets[name] = cls return cls return decorator def make(name, **kwargs): if kwargs.get('root_path') is None: kwargs['root_path'] = os.path.join(DEFAULT_ROOT, name.replace('meta-', '')) dataset = datasets[name](**kwargs) return dataset def collate_fn(batch): shot, query, shot_label, query_label = [], [], [], [] for s, q, sl, ql in batch: shot.append(s) query.append(q) shot_label.append(sl) query_label.append(ql) shot = torch.stack(shot) # [n_ep, n_way * n_shot, C, H, W] query = torch.stack(query) # [n_ep, n_way * n_query, C, H, W] shot_label = torch.stack(shot_label) # [n_ep, n_way * n_shot] query_label = torch.stack(query_label) # [n_ep, n_way * n_query] return shot, query, shot_label, query_label ================================================ FILE: datasets/inatural.py ================================================ import os import torch from torch.utils.data import Dataset import numpy as np from PIL import Image from .datasets import register from .transforms import get_transform @register('inatural') class INat2017(Dataset): def __init__(self, root_path, split='train', image_size=84, normalization=True, transform=None): super(INat2017, self).__init__() split_dict = {'train': 'train', # standard train 'meta-train': 'train', # meta-train 'meta-test': 'test', # meta-test } split_tag = split_dict[split] split_file = os.path.join(root_path, 'fs-splits', split_tag + '.csv') assert os.path.isfile(split_file) with open(split_file, 'r') as f: pairs = [x.strip().split(',') for x in f.readlines() if x.strip() != ''] data, label = [x[0] for x in pairs], [int(x[1]) for x in pairs] label = np.array(label) label_key = sorted(np.unique(label)) label_map = dict(zip(label_key, range(len(label_key)))) new_label = np.array([label_map[x] for x in label]) self.root_path = root_path self.split_tag = split_tag self.image_size = image_size self.data = data self.label = new_label self.n_classes = len(label_key) if normalization: self.norm_params = {'mean': [0.4905, 0.4961, 0.4330], 'std': [0.1737, 0.1713, 0.1779]} else: self.norm_params = {'mean': [0., 0., 0.], 'std': [1., 1., 1.]} self.transform = get_transform(transform, image_size, self.norm_params) def convert_raw(x): mean = torch.tensor(self.norm_params['mean']).view(3, 1, 1).type_as(x) std = torch.tensor(self.norm_params['std']).view(3, 1, 1).type_as(x) return x * std + mean self.convert_raw = convert_raw def _load_image(self, index): image_path = os.path.join(self.root_path, 'images', self.data[index]) assert os.path.isfile(image_path) image = Image.open(image_path).convert('RGB') return image def __len__(self): return len(self.label) def __getitem__(self, index): image = self.transform(self._load_image(index)) label = self.label[index] return image, label @register('meta-inatural') class MetaINat2017(INat2017): def __init__(self, root_path, split='train', image_size=84, normalization=True, transform=None, val_transform=None, n_batch=200, n_episode=4, n_way=5, n_shot=1, n_query=15): super(MetaINat2017, self).__init__(root_path, split, image_size, normalization, transform) self.n_batch = n_batch self.n_episode = n_episode self.n_way = n_way self.n_shot = n_shot self.n_query = n_query self.catlocs = tuple() for cat in range(self.n_classes): self.catlocs += (np.argwhere(self.label == cat).reshape(-1),) self.val_transform = get_transform( val_transform, image_size, self.norm_params) def __len__(self): return self.n_batch * self.n_episode def __getitem__(self, index): shot, query = [], [] cats = np.random.choice(self.n_classes, self.n_way, replace=False) for c in cats: c_shot, c_query = [], [] idx_list = np.random.choice( self.catlocs[c], self.n_shot + self.n_query, replace=False) shot_idx, query_idx = idx_list[:self.n_shot], idx_list[-self.n_query:] for idx in shot_idx: c_shot.append(self.transform(self._load_image(idx))) for idx in query_idx: c_query.append(self.val_transform(self._load_image(idx))) shot.append(torch.stack(c_shot)) query.append(torch.stack(c_query)) shot = torch.cat(shot, dim=0) # [n_way * n_shot, C, H, W] query = torch.cat(query, dim=0) # [n_way * n_query, C, H, W] cls = torch.arange(self.n_way)[:, None] shot_labels = cls.repeat(1, self.n_shot).flatten() # [n_way * n_shot] query_labels = cls.repeat(1, self.n_query).flatten() # [n_way * n_query] return shot, query, shot_labels, query_labels ================================================ FILE: datasets/mini_imagenet.py ================================================ import os import pickle import torch from torch.utils.data import Dataset import numpy as np from PIL import Image from .datasets import register from .transforms import get_transform @register('mini-imagenet') class MiniImageNet(Dataset): def __init__(self, root_path, split='train', image_size=84, normalization=True, transform=None): super(MiniImageNet, self).__init__() split_dict = {'train': 'train_phase_train', # standard train 'val': 'train_phase_val', # standard val 'trainval': 'train_phase_trainval', # standard train and val 'test': 'train_phase_test', # standard test 'meta-train': 'train_phase_train', # meta-train 'meta-val': 'val', # meta-val 'meta-test': 'test', # meta-test } split_tag = split_dict[split] split_file = os.path.join(root_path, split_tag + '.pickle') assert os.path.isfile(split_file) with open(split_file, 'rb') as f: pack = pickle.load(f, encoding='latin1') data, label = pack['data'], pack['labels'] data = [Image.fromarray(x) for x in data] label = np.array(label) label_key = sorted(np.unique(label)) label_map = dict(zip(label_key, range(len(label_key)))) new_label = np.array([label_map[x] for x in label]) self.root_path = root_path self.split_tag = split_tag self.image_size = image_size self.data = data self.label = new_label self.n_classes = len(label_key) if normalization: self.norm_params = {'mean': [0.471, 0.450, 0.403], 'std': [0.278, 0.268, 0.284]} else: self.norm_params = {'mean': [0., 0., 0.], 'std': [1., 1., 1.]} self.transform = get_transform(transform, image_size, self.norm_params) def convert_raw(x): mean = torch.tensor(self.norm_params['mean']).view(3, 1, 1).type_as(x) std = torch.tensor(self.norm_params['std']).view(3, 1, 1).type_as(x) return x * std + mean self.convert_raw = convert_raw def __len__(self): return len(self.data) def __getitem__(self, index): image = self.transform(self.data[index]) label = self.label[index] return image, label @register('meta-mini-imagenet') class MetaMiniImageNet(MiniImageNet): def __init__(self, root_path, split='train', image_size=84, normalization=True, transform=None, val_transform=None, n_batch=200, n_episode=4, n_way=5, n_shot=1, n_query=15): super(MetaMiniImageNet, self).__init__(root_path, split, image_size, normalization, transform) self.n_batch = n_batch self.n_episode = n_episode self.n_way = n_way self.n_shot = n_shot self.n_query = n_query self.catlocs = tuple() for cat in range(self.n_classes): self.catlocs += (np.argwhere(self.label == cat).reshape(-1),) self.val_transform = get_transform( val_transform, image_size, self.norm_params) def __len__(self): return self.n_batch * self.n_episode def __getitem__(self, index): shot, query = [], [] cats = np.random.choice(self.n_classes, self.n_way, replace=False) for c in cats: c_shot, c_query = [], [] idx_list = np.random.choice( self.catlocs[c], self.n_shot + self.n_query, replace=False) shot_idx, query_idx = idx_list[:self.n_shot], idx_list[-self.n_query:] for idx in shot_idx: c_shot.append(self.transform(self.data[idx])) for idx in query_idx: c_query.append(self.val_transform(self.data[idx])) shot.append(torch.stack(c_shot)) query.append(torch.stack(c_query)) shot = torch.cat(shot, dim=0) # [n_way * n_shot, C, H, W] query = torch.cat(query, dim=0) # [n_way * n_query, C, H, W] cls = torch.arange(self.n_way)[:, None] shot_labels = cls.repeat(1, self.n_shot).flatten() # [n_way * n_shot] query_labels = cls.repeat(1, self.n_query).flatten() # [n_way * n_query] return shot, query, shot_labels, query_labels ================================================ FILE: datasets/tiered_imagenet.py ================================================ import os import pickle import torch from torch.utils.data import Dataset import numpy as np from PIL import Image from .datasets import register from .transforms import get_transform @register('tiered-imagenet') class TieredImageNet(Dataset): def __init__(self, root_path, split='train', image_size=84, normalization=True, transform=None): super(TieredImageNet, self).__init__() split_dict = {'train': 'train', # standard train 'val': 'train_phase_val', # standard val 'meta-train': 'train', # meta-train 'meta-val': 'val', # meta-val 'meta-test': 'test', # meta-test } split_tag = split_dict[split] split_file = os.path.join(root_path, split_tag + '_images.npz') label_file = os.path.join(root_path, split_tag + '_labels.pkl') assert os.path.isfile(split_file) assert os.path.isfile(label_file) data = np.load(split_file, allow_pickle=True)['images'] data = data[:, :, :, ::-1] with open(label_file, 'rb') as f: label = pickle.load(f)['labels'] data = [Image.fromarray(x) for x in data] label = np.array(label) label_key = sorted(np.unique(label)) label_map = dict(zip(label_key, range(len(label_key)))) new_label = np.array([label_map[x] for x in label]) self.root_path = root_path self.split_tag = split_tag self.image_size = image_size self.data = data self.label = new_label self.n_classes = len(label_key) if normalization: self.norm_params = {'mean': [0.478, 0.456, 0.410], 'std': [0.279, 0.274, 0.286]} else: self.norm_params = {'mean': [0., 0., 0.], 'std': [1., 1., 1.]} self.transform = get_transform(transform, image_size, self.norm_params) def convert_raw(x): mean = torch.tensor(self.norm_params['mean']).view(3, 1, 1).type_as(x) std = torch.tensor(self.norm_params['std']).view(3, 1, 1).type_as(x) return x * std + mean self.convert_raw = convert_raw def __len__(self): return len(self.data) def __getitem__(self, index): image = self.transform(self.data[index]) label = self.label[index] return image, label @register('meta-tiered-imagenet') class MetaTieredImageNet(TieredImageNet): def __init__(self, root_path, split='train', image_size=84, normalization=True, transform=None, val_transform=None, n_batch=200, n_episode=4, n_way=5, n_shot=1, n_query=15): super(MetaTieredImageNet, self).__init__(root_path, split, image_size, normalization, transform) self.n_batch = n_batch self.n_episode = n_episode self.n_way = n_way self.n_shot = n_shot self.n_query = n_query self.catlocs = tuple() for cat in range(self.n_classes): self.catlocs += (np.argwhere(self.label == cat).reshape(-1),) self.val_transform = get_transform( val_transform, image_size, self.norm_params) def __len__(self): return self.n_batch * self.n_episode def __getitem__(self, index): shot, query = [], [] cats = np.random.choice(self.n_classes, self.n_way, replace=False) for c in cats: c_shot, c_query = [], [] idx_list = np.random.choice( self.catlocs[c], self.n_shot + self.n_query, replace=False) shot_idx, query_idx = idx_list[:self.n_shot], idx_list[-self.n_query:] for idx in shot_idx: c_shot.append(self.transform(self.data[idx])) for idx in query_idx: c_query.append(self.val_transform(self.data[idx])) shot.append(torch.stack(c_shot)) query.append(torch.stack(c_query)) shot = torch.cat(shot, dim=0) # [n_way * n_shot, C, H, W] query = torch.cat(query, dim=0) # [n_way * n_query, C, H, W] cls = torch.arange(self.n_way)[:, None] shot_labels = cls.repeat(1, self.n_shot).flatten() # [n_way * n_shot] query_labels = cls.repeat(1, self.n_query).flatten() # [n_way * n_query] return shot, query, shot_labels, query_labels ================================================ FILE: datasets/transforms.py ================================================ import torchvision.transforms as transforms __all__ = ['get_transform'] def get_transform(name, image_size, norm_params): if name == 'resize': return transforms.Compose([ transforms.RandomResizedCrop(image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(**norm_params), ]) elif name == 'crop': return transforms.Compose([ transforms.Resize(image_size), transforms.RandomCrop(image_size, padding=8), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(**norm_params), ]) elif name == 'color': return transforms.Compose([ transforms.Resize(image_size), transforms.RandomCrop(image_size, padding=8), transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(**norm_params), ]) elif name == 'flip': return transforms.Compose([ transforms.Resize(image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(**norm_params), ]) elif name == 'enlarge': return transforms.Compose([ transforms.Resize(int(image_size * 256 / 224)), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize(**norm_params), ]) elif name is None: return transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(**norm_params), ]) else: raise ValueError('invalid transformation') ================================================ FILE: models/__init__.py ================================================ from .maml import make from .maml import load ================================================ FILE: models/classifiers/__init__.py ================================================ from .classifiers import make, load from . import logistic ================================================ FILE: models/classifiers/classifiers.py ================================================ import torch __all__ = ['make', 'load'] models = {} def register(name): def decorator(cls): models[name] = cls return cls return decorator def make(name, **kwargs): if name is None: return None model = models[name](**kwargs) if torch.cuda.is_available(): model.cuda() return model def load(ckpt): model = make(ckpt['classifier'], **ckpt['classifier_args']) model.load_state_dict(ckpt['classifier_state_dict']) return model ================================================ FILE: models/classifiers/logistic.py ================================================ import torch import torch.nn as nn from .classifiers import register from ..modules import * __all__ = ['LogisticClassifier'] @register('logistic') class LogisticClassifier(Module): def __init__(self, in_dim, n_way, temp=1., learn_temp=False): super(LogisticClassifier, self).__init__() self.in_dim = in_dim self.n_way = n_way self.temp = temp self.learn_temp = learn_temp self.linear = Linear(in_dim, n_way) if self.learn_temp: self.temp = nn.Parameter(torch.tensor(temp)) def reset_parameters(self): nn.init.zeros_(self.linear.weight) nn.init.zeros_(self.linear.bias) def forward(self, x_shot, params=None): assert x_shot.dim() == 2 logits = self.linear(x_shot, get_child_dict(params, 'linear')) logits = logits * self.temp return logits ================================================ FILE: models/encoders/__init__.py ================================================ from .encoders import make, load from . import convnet4 from . import resnet12 from . import resnet18 ================================================ FILE: models/encoders/convnet4.py ================================================ from collections import OrderedDict import torch.nn as nn from .encoders import register from ..modules import * __all__ = ['convnet4', 'wide_convnet4'] class ConvBlock(Module): def __init__(self, in_channels, out_channels, bn_args): super(ConvBlock, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.conv = Conv2d(in_channels, out_channels, 3, 1, padding=1) self.bn = BatchNorm2d(out_channels, **bn_args) self.relu = nn.ReLU(inplace=True) self.pool = nn.MaxPool2d(2) def forward(self, x, params=None, episode=None): out = self.conv(x, get_child_dict(params, 'conv')) out = self.bn(out, get_child_dict(params, 'bn'), episode) out = self.pool(self.relu(out)) return out class ConvNet4(Module): def __init__(self, hid_dim, out_dim, bn_args): super(ConvNet4, self).__init__() self.hid_dim = hid_dim self.out_dim = out_dim episodic = bn_args.get('episodic') or [] bn_args_ep, bn_args_no_ep = bn_args.copy(), bn_args.copy() bn_args_ep['episodic'] = True bn_args_no_ep['episodic'] = False bn_args_dict = dict() for i in [1, 2, 3, 4]: if 'conv%d' % i in episodic: bn_args_dict[i] = bn_args_ep else: bn_args_dict[i] = bn_args_no_ep self.encoder = Sequential(OrderedDict([ ('conv1', ConvBlock(3, hid_dim, bn_args_dict[1])), ('conv2', ConvBlock(hid_dim, hid_dim, bn_args_dict[2])), ('conv3', ConvBlock(hid_dim, hid_dim, bn_args_dict[3])), ('conv4', ConvBlock(hid_dim, out_dim, bn_args_dict[4])), ])) def get_out_dim(self, scale=25): return self.out_dim * scale def forward(self, x, params=None, episode=None): out = self.encoder(x, get_child_dict(params, 'encoder'), episode) out = out.view(out.shape[0], -1) return out @register('convnet4') def convnet4(bn_args=dict()): return ConvNet4(32, 32, bn_args) @register('wide-convnet4') def wide_convnet4(bn_args=dict()): return ConvNet4(64, 64, bn_args) ================================================ FILE: models/encoders/encoders.py ================================================ import torch models = {} def register(name): def decorator(cls): models[name] = cls return cls return decorator def make(name, **kwargs): if name is None: return None model = models[name](**kwargs) if torch.cuda.is_available(): model.cuda() return model def load(ckpt): model = make(ckpt['encoder'], **ckpt['encoder_args']) if model is not None: model.load_state_dict(ckpt['encoder_state_dict']) return model ================================================ FILE: models/encoders/resnet12.py ================================================ from collections import OrderedDict import torch.nn as nn from .encoders import register from ..modules import * __all__ = ['resnet12', 'wide_resnet12'] def conv3x3(in_channels, out_channels): return Conv2d(in_channels, out_channels, 3, 1, padding=1, bias=False) def conv1x1(in_channels, out_channels): return Conv2d(in_channels, out_channels, 1, 1, padding=0, bias=False) class Block(Module): def __init__(self, in_planes, planes, bn_args): super(Block, self).__init__() self.in_planes = in_planes self.planes = planes self.conv1 = conv3x3(in_planes, planes) self.bn1 = BatchNorm2d(planes, **bn_args) self.conv2 = conv3x3(planes, planes) self.bn2 = BatchNorm2d(planes, **bn_args) self.conv3 = conv3x3(planes, planes) self.bn3 = BatchNorm2d(planes, **bn_args) self.res_conv = Sequential(OrderedDict([ ('conv', conv1x1(in_planes, planes)), ('bn', BatchNorm2d(planes, **bn_args)), ])) self.relu = nn.LeakyReLU(0.1, inplace=True) self.pool = nn.MaxPool2d(2) def forward(self, x, params=None, episode=None): out = self.conv1(x, get_child_dict(params, 'conv1')) out = self.bn1(out, get_child_dict(params, 'bn1'), episode) out = self.relu(out) out = self.conv2(out, get_child_dict(params, 'conv2')) out = self.bn2(out, get_child_dict(params, 'bn2'), episode) out = self.relu(out) out = self.conv3(out, get_child_dict(params, 'conv3')) out = self.bn3(out, get_child_dict(params, 'bn3'), episode) x = self.res_conv(x, get_child_dict(params, 'res_conv'), episode) out = self.pool(self.relu(out + x)) return out class ResNet12(Module): def __init__(self, channels, bn_args): super(ResNet12, self).__init__() self.channels = channels episodic = bn_args.get('episodic') or [] bn_args_ep, bn_args_no_ep = bn_args.copy(), bn_args.copy() bn_args_ep['episodic'] = True bn_args_no_ep['episodic'] = False bn_args_dict = dict() for i in [1, 2, 3, 4]: if 'layer%d' % i in episodic: bn_args_dict[i] = bn_args_ep else: bn_args_dict[i] = bn_args_no_ep self.layer1 = Block(3, channels[0], bn_args_dict[1]) self.layer2 = Block(channels[0], channels[1], bn_args_dict[2]) self.layer3 = Block(channels[1], channels[2], bn_args_dict[3]) self.layer4 = Block(channels[2], channels[3], bn_args_dict[4]) self.pool = nn.AdaptiveAvgPool2d(1) self.out_dim = channels[3] for m in self.modules(): if isinstance(m, Conv2d): nn.init.kaiming_normal_( m.weight, mode='fan_out', nonlinearity='leaky_relu') elif isinstance(m, BatchNorm2d): nn.init.constant_(m.weight, 1.) nn.init.constant_(m.bias, 0.) def get_out_dim(self): return self.out_dim def forward(self, x, params=None, episode=None): out = self.layer1(x, get_child_dict(params, 'layer1'), episode) out = self.layer2(out, get_child_dict(params, 'layer2'), episode) out = self.layer3(out, get_child_dict(params, 'layer3'), episode) out = self.layer4(out, get_child_dict(params, 'layer4'), episode) out = self.pool(out).flatten(1) return out @register('resnet12') def resnet12(bn_args=dict()): return ResNet12([64, 128, 256, 512], bn_args) @register('wide-resnet12') def wide_resnet12(bn_args=dict()): return ResNet12([64, 160, 320, 640], bn_args) ================================================ FILE: models/encoders/resnet18.py ================================================ from collections import OrderedDict import torch.nn as nn from .encoders import register from ..modules import * __all__ = ['resnet18', 'wide_resnet18'] def conv3x3(in_channels, out_channels, stride=1): return Conv2d(in_channels, out_channels, 3, stride, padding=1, bias=False) def conv1x1(in_channels, out_channels, stride=1): return Conv2d(in_channels, out_channels, 1, stride, padding=0, bias=False) class Block(Module): def __init__(self, in_planes, planes, stride, bn_args): super(Block, self).__init__() self.in_planes = in_planes self.planes = planes self.stride = stride self.conv1 = conv3x3(in_planes, planes, stride) self.bn1 = BatchNorm2d(planes, **bn_args) self.conv2 = conv3x3(planes, planes) self.bn2 = BatchNorm2d(planes, **bn_args) if stride > 1: self.res_conv = Sequential(OrderedDict([ ('conv', conv1x1(in_planes, planes)), ('bn', BatchNorm2d(planes, **bn_args)), ])) self.relu = nn.ReLU(inplace=True) def forward(self, x, params=None, episode=None): out = self.conv1(x, get_child_dict(params, 'conv1')) out = self.bn1(out, get_child_dict(params, 'bn1'), episode) out = self.relu(out) out = self.conv2(out, get_child_dict(params, 'conv2')) out = self.bn2(out, get_child_dict(params, 'bn2'), episode) if self.stride > 1: x = self.res_conv(x, get_child_dict(params, 'res_conv'), episode) out = self.relu(out + x) return out class ResNet18(Module): def __init__(self, channels, bn_args): super(ResNet18, self).__init__() self.channels = channels episodic = bn_args.get('episodic') or [] bn_args_ep, bn_args_no_ep = bn_args.copy(), bn_args.copy() bn_args_ep['episodic'] = True bn_args_no_ep['episodic'] = False bn_args_dict = dict() for i in [0, 1, 2, 3, 4]: if 'layer%d' % i in episodic: bn_args_dict[i] = bn_args_ep else: bn_args_dict[i] = bn_args_no_ep self.layer0 = Sequential(OrderedDict([ ('conv', conv3x3(3, 64)), ('bn', BatchNorm2d(64, **bn_args_dict[0])), ])) self.relu = nn.ReLU(inplace=True) self.layer1 = Block(64, channels[0], 1, bn_args_dict[1]) self.layer2 = Block(channels[0], channels[1], 2, bn_args_dict[2]) self.layer3 = Block(channels[1], channels[2], 2, bn_args_dict[3]) self.layer4 = Block(channels[2], channels[3], 2, bn_args_dict[4]) self.pool = nn.AdaptiveAvgPool2d(1) self.out_dim = channels[3] for m in self.modules(): if isinstance(m, Conv2d): nn.init.kaiming_normal_( m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, BatchNorm2d): nn.init.constant_(m.weight, 1.) nn.init.constant_(m.bias, 0.) def get_out_dim(self, scale=1): return self.out_dim * scale def forward(self, x, params=None, episode=None): out = self.layer0(x, get_child_dict(params, 'layer0'), episode) out = self.relu(out) out = self.layer1(out, get_child_dict(params, 'layer1'), episode) out = self.layer2(out, get_child_dict(params, 'layer2'), episode) out = self.layer3(out, get_child_dict(params, 'layer3'), episode) out = self.layer4(out, get_child_dict(params, 'layer4'), episode) out = self.pool(out).flatten(1) return out @register('resnet18') def resnet18(bn_args=dict()): return ResNet18([64, 128, 256, 512], bn_args) @register('wide-resnet18') def wide_resnet18(bn_args=dict()): return ResNet18([64, 160, 320, 640], bn_args) ================================================ FILE: models/maml.py ================================================ from collections import OrderedDict import torch import torch.nn.functional as F import torch.autograd as autograd import torch.utils.checkpoint as cp from . import encoders from . import classifiers from .modules import get_child_dict, Module, BatchNorm2d def make(enc_name, enc_args, clf_name, clf_args): """ Initializes a random meta model. Args: enc_name (str): name of the encoder (e.g., 'resnet12'). enc_args (dict): arguments for the encoder. clf_name (str): name of the classifier (e.g., 'meta-nn'). clf_args (dict): arguments for the classifier. Returns: model (MAML): a meta classifier with a random encoder. """ enc = encoders.make(enc_name, **enc_args) clf_args['in_dim'] = enc.get_out_dim() clf = classifiers.make(clf_name, **clf_args) model = MAML(enc, clf) return model def load(ckpt, load_clf=False, clf_name=None, clf_args=None): """ Initializes a meta model with a pre-trained encoder. Args: ckpt (dict): a checkpoint from which a pre-trained encoder is restored. load_clf (bool, optional): if True, loads a pre-trained classifier. Default: False (in which case the classifier is randomly initialized) clf_name (str, optional): name of the classifier (e.g., 'meta-nn') clf_args (dict, optional): arguments for the classifier. (The last two arguments are ignored if load_clf=True.) Returns: model (MAML): a meta model with a pre-trained encoder. """ enc = encoders.load(ckpt) if load_clf: clf = classifiers.load(ckpt) else: if clf_name is None and clf_args is None: clf = classifiers.make(ckpt['classifier'], **ckpt['classifier_args']) else: clf_args['in_dim'] = enc.get_out_dim() clf = classifiers.make(clf_name, **clf_args) model = MAML(enc, clf) return model class MAML(Module): def __init__(self, encoder, classifier): super(MAML, self).__init__() self.encoder = encoder self.classifier = classifier def reset_classifier(self): self.classifier.reset_parameters() def _inner_forward(self, x, params, episode): """ Forward pass for the inner loop. """ feat = self.encoder(x, get_child_dict(params, 'encoder'), episode) logits = self.classifier(feat, get_child_dict(params, 'classifier')) return logits def _inner_iter(self, x, y, params, mom_buffer, episode, inner_args, detach): """ Performs one inner-loop iteration of MAML including the forward and backward passes and the parameter update. Args: x (float tensor, [n_way * n_shot, C, H, W]): per-episode support set. y (int tensor, [n_way * n_shot]): per-episode support set labels. params (dict): the model parameters BEFORE the update. mom_buffer (dict): the momentum buffer BEFORE the update. episode (int): the current episode index. inner_args (dict): inner-loop optimization hyperparameters. detach (bool): if True, detachs the graph for the current iteration. Returns: updated_params (dict): the model parameters AFTER the update. mom_buffer (dict): the momentum buffer AFTER the update. """ with torch.enable_grad(): # forward pass logits = self._inner_forward(x, params, episode) loss = F.cross_entropy(logits, y) # backward pass grads = autograd.grad(loss, params.values(), create_graph=(not detach and not inner_args['first_order']), only_inputs=True, allow_unused=True) # parameter update updated_params = OrderedDict() for (name, param), grad in zip(params.items(), grads): if grad is None: updated_param = param else: if inner_args['weight_decay'] > 0: grad = grad + inner_args['weight_decay'] * param if inner_args['momentum'] > 0: grad = grad + inner_args['momentum'] * mom_buffer[name] mom_buffer[name] = grad if 'encoder' in name: lr = inner_args['encoder_lr'] elif 'classifier' in name: lr = inner_args['classifier_lr'] else: raise ValueError('invalid parameter name') updated_param = param - lr * grad if detach: updated_param = updated_param.detach().requires_grad_(True) updated_params[name] = updated_param return updated_params, mom_buffer def _adapt(self, x, y, params, episode, inner_args, meta_train): """ Performs inner-loop adaptation in MAML. Args: x (float tensor, [n_way * n_shot, C, H, W]): per-episode support set. (T: transforms, C: channels, H: height, W: width) y (int tensor, [n_way * n_shot]): per-episode support set labels. params (dict): a dictionary of parameters at meta-initialization. episode (int): the current episode index. inner_args (dict): inner-loop optimization hyperparameters. meta_train (bool): if True, the model is in meta-training. Returns: params (dict): model paramters AFTER inner-loop adaptation. """ assert x.dim() == 4 and y.dim() == 1 assert x.size(0) == y.size(0) # Initializes a dictionary of momentum buffer for gradient descent in the # inner loop. It has the same set of keys as the parameter dictionary. mom_buffer = OrderedDict() if inner_args['momentum'] > 0: for name, param in params.items(): mom_buffer[name] = torch.zeros_like(param) params_keys = tuple(params.keys()) mom_buffer_keys = tuple(mom_buffer.keys()) for m in self.modules(): if isinstance(m, BatchNorm2d) and m.is_episodic(): m.reset_episodic_running_stats(episode) def _inner_iter_cp(episode, *state): """ Performs one inner-loop iteration when checkpointing is enabled. The code is executed twice: - 1st time with torch.no_grad() for creating checkpoints. - 2nd time with torch.enable_grad() for computing gradients. """ params = OrderedDict(zip(params_keys, state[:len(params_keys)])) mom_buffer = OrderedDict( zip(mom_buffer_keys, state[-len(mom_buffer_keys):])) detach = not torch.is_grad_enabled() # detach graph in the first pass self.is_first_pass(detach) params, mom_buffer = self._inner_iter( x, y, params, mom_buffer, int(episode), inner_args, detach) state = tuple(t if t.requires_grad else t.clone().requires_grad_(True) for t in tuple(params.values()) + tuple(mom_buffer.values())) return state for step in range(inner_args['n_step']): if self.efficient: # checkpointing state = tuple(params.values()) + tuple(mom_buffer.values()) state = cp.checkpoint(_inner_iter_cp, torch.as_tensor(episode), *state) params = OrderedDict(zip(params_keys, state[:len(params_keys)])) mom_buffer = OrderedDict( zip(mom_buffer_keys, state[-len(mom_buffer_keys):])) else: params, mom_buffer = self._inner_iter( x, y, params, mom_buffer, episode, inner_args, not meta_train) return params def forward(self, x_shot, x_query, y_shot, inner_args, meta_train): """ Args: x_shot (float tensor, [n_episode, n_way * n_shot, C, H, W]): support sets. x_query (float tensor, [n_episode, n_way * n_query, C, H, W]): query sets. (T: transforms, C: channels, H: height, W: width) y_shot (int tensor, [n_episode, n_way * n_shot]): support set labels. inner_args (dict, optional): inner-loop hyperparameters. meta_train (bool): if True, the model is in meta-training. Returns: logits (float tensor, [n_episode, n_way * n_shot, n_way]): predicted logits. """ assert self.encoder is not None assert self.classifier is not None assert x_shot.dim() == 5 and x_query.dim() == 5 assert x_shot.size(0) == x_query.size(0) # a dictionary of parameters that will be updated in the inner loop params = OrderedDict(self.named_parameters()) for name in list(params.keys()): if not params[name].requires_grad or \ any(s in name for s in inner_args['frozen'] + ['temp']): params.pop(name) logits = [] for ep in range(x_shot.size(0)): # inner-loop training self.train() if not meta_train: for m in self.modules(): if isinstance(m, BatchNorm2d) and not m.is_episodic(): m.eval() updated_params = self._adapt( x_shot[ep], y_shot[ep], params, ep, inner_args, meta_train) # inner-loop validation with torch.set_grad_enabled(meta_train): self.eval() logits_ep = self._inner_forward(x_query[ep], updated_params, ep) logits.append(logits_ep) self.train(meta_train) logits = torch.stack(logits) return logits ================================================ FILE: models/modules.py ================================================ import re from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F __all__ = ['Module', 'Conv2d', 'Linear', 'BatchNorm2d', 'Sequential', 'get_child_dict'] def get_child_dict(params, key=None): """ Constructs parameter dictionary for a network module. Args: params (dict): a parent dictionary of named parameters. key (str, optional): a key that specifies the root of the child dictionary. Returns: child_dict (dict): a child dictionary of model parameters. """ if params is None: return None if key is None or (isinstance(key, str) and key == ''): return params key_re = re.compile(r'^{0}\.(.+)'.format(re.escape(key))) if not any(filter(key_re.match, params.keys())): # handles nn.DataParallel key_re = re.compile(r'^module\.{0}\.(.+)'.format(re.escape(key))) child_dict = OrderedDict( (key_re.sub(r'\1', k), value) for (k, value) in params.items() if key_re.match(k) is not None) return child_dict class Module(nn.Module): def __init__(self): super(Module, self).__init__() self.efficient = False self.first_pass = True def go_efficient(self, mode=True): """ Switches on / off gradient checkpointing. """ self.efficient = mode for m in self.children(): if isinstance(m, Module): m.go_efficient(mode) def is_first_pass(self, mode=True): """ Tracks the progress of forward and backward pass when gradient checkpointing is enabled. """ self.first_pass = mode for m in self.children(): if isinstance(m, Module): m.is_first_pass(mode) class Conv2d(nn.Conv2d, Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, bias=bias) def forward(self, x, params=None, episode=None): if params is None: x = super(Conv2d, self).forward(x) else: weight, bias = params.get('weight'), params.get('bias') if weight is None: weight = self.weight if bias is None: bias = self.bias x = F.conv2d(x, weight, bias, self.stride, self.padding) return x class Linear(nn.Linear, Module): def __init__(self, in_features, out_features, bias=True): super(Linear, self).__init__(in_features, out_features, bias=bias) def forward(self, x, params=None, episode=None): if params is None: x = super(Linear, self).forward(x) else: weight, bias = params.get('weight'), params.get('bias') if weight is None: weight = self.weight if bias is None: bias = self.bias x = F.linear(x, weight, bias) return x class BatchNorm2d(nn.BatchNorm2d, Module): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, episodic=False, n_episode=4, alpha=False): """ Args: episodic (bool, optional): if True, maintains running statistics for each episode separately. It is ignored if track_running_stats=False. Default: True n_episode (int, optional): number of episodes per mini-batch. It is ignored if episodic=False. alpha (bool, optional): if True, learns to interpolate between batch statistics computed over the support set and instance statistics from a query at validation time. Default: True (It is ignored if track_running_stats=False or meta_learn=False) """ super(BatchNorm2d, self).__init__(num_features, eps, momentum, affine, track_running_stats) self.episodic = episodic self.n_episode = n_episode self.alpha = alpha if self.track_running_stats: if self.episodic: for ep in range(n_episode): self.register_buffer( 'running_mean_%d' % ep, torch.zeros(num_features)) self.register_buffer( 'running_var_%d' % ep, torch.ones(num_features)) self.register_buffer( 'num_batches_tracked_%d' % ep, torch.tensor(0, dtype=torch.int)) if self.alpha: self.register_buffer('batch_size', torch.tensor(0, dtype=torch.int)) self.alpha_scale = nn.Parameter(torch.tensor(0.)) self.alpha_offset = nn.Parameter(torch.tensor(0.)) def is_episodic(self): return self.episodic def _batch_norm(self, x, mean, var, weight=None, bias=None): if self.affine: assert weight is not None and bias is not None weight = weight.view(1, -1, 1, 1) bias = bias.view(1, -1, 1, 1) x = weight * (x - mean) / (var + self.eps) ** .5 + bias else: x = (x - mean) / (var + self.eps) ** .5 return x def reset_episodic_running_stats(self, episode): if self.episodic: getattr(self, 'running_mean_%d' % episode).zero_() getattr(self, 'running_var_%d' % episode).fill_(1.) getattr(self, 'num_batches_tracked_%d' % episode).zero_() def forward(self, x, params=None, episode=None): self._check_input_dim(x) if params is not None: weight, bias = params.get('weight'), params.get('bias') if weight is None: weight = self.weight if bias is None: bias = self.bias else: weight, bias = self.weight, self.bias if self.track_running_stats: if self.episodic: assert episode is not None and episode < self.n_episode running_mean = getattr(self, 'running_mean_%d' % episode) running_var = getattr(self, 'running_var_%d' % episode) num_batches_tracked = getattr(self, 'num_batches_tracked_%d' % episode) else: running_mean, running_var = self.running_mean, self.running_var num_batches_tracked = self.num_batches_tracked if self.training: exp_avg_factor = 0. if self.first_pass: # only updates statistics in the first pass if self.alpha: self.batch_size = x.size(0) num_batches_tracked += 1 if self.momentum is None: exp_avg_factor = 1. / float(num_batches_tracked) else: exp_avg_factor = self.momentum return F.batch_norm(x, running_mean, running_var, weight, bias, True, exp_avg_factor, self.eps) else: if self.alpha: assert self.batch_size > 0 alpha = torch.sigmoid( self.alpha_scale * self.batch_size + self.alpha_offset) # exponentially moving-averaged training statistics running_mean = running_mean.view(1, -1, 1, 1) running_var = running_var.view(1, -1, 1, 1) # per-sample statistics sample_mean = torch.mean(x, dim=(2, 3), keepdim=True) sample_var = torch.var(x, dim=(2, 3), unbiased=False, keepdim=True) # interpolated statistics mean = alpha * running_mean + (1 - alpha) * sample_mean var = alpha * running_var + (1 - alpha) * sample_var + \ alpha * (1 - alpha) * (sample_mean - running_mean) ** 2 return self._batch_norm(x, mean, var, weight, bias) else: return F.batch_norm(x, running_mean, running_var, weight, bias, False, 0., self.eps) else: return F.batch_norm(x, None, None, weight, bias, True, 0., self.eps) class Sequential(nn.Sequential, Module): def __init__(self, *args): super(Sequential, self).__init__(*args) def forward(self, x, params=None, episode=None): if params is None: for module in self: x = module(x, None, episode) else: for name, module in self._modules.items(): x = module(x, get_child_dict(params, name), episode) return x ================================================ FILE: test.py ================================================ import argparse import random import yaml import torch import torch.nn as nn import numpy as np from tqdm import tqdm from torch.utils.data import DataLoader import datasets import models import utils def main(config): random.seed(0) np.random.seed(0) torch.manual_seed(0) torch.cuda.manual_seed(0) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False ##### Dataset ##### dataset = datasets.make(config['dataset'], **config['test']) utils.log('meta-test set: {} (x{}), {}'.format( dataset[0][0].shape, len(dataset), dataset.n_classes)) loader = DataLoader(dataset, config['test']['n_episode'], collate_fn=datasets.collate_fn, num_workers=1, pin_memory=True) ##### Model ##### ckpt = torch.load(config['load']) inner_args = utils.config_inner_args(config.get('inner_args')) model = models.load(ckpt, load_clf=(not inner_args['reset_classifier'])) if args.efficient: model.go_efficient() if config.get('_parallel'): model = nn.DataParallel(model) utils.log('num params: {}'.format(utils.compute_n_params(model))) ##### Evaluation ##### model.eval() aves_va = utils.AverageMeter() va_lst = [] for epoch in range(1, config['epoch'] + 1): for data in tqdm(loader, leave=False): x_shot, x_query, y_shot, y_query = data x_shot, y_shot = x_shot.cuda(), y_shot.cuda() x_query, y_query = x_query.cuda(), y_query.cuda() if inner_args['reset_classifier']: if config.get('_parallel'): model.module.reset_classifier() else: model.reset_classifier() logits = model(x_shot, x_query, y_shot, inner_args, meta_train=False) logits = logits.view(-1, config['test']['n_way']) labels = y_query.view(-1) pred = torch.argmax(logits, dim=1) acc = utils.compute_acc(pred, labels) aves_va.update(acc, 1) va_lst.append(acc) print('test epoch {}: acc={:.2f} +- {:.2f} (%)'.format( epoch, aves_va.item() * 100, utils.mean_confidence_interval(va_lst) * 100)) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--config', help='configuration file') parser.add_argument('--gpu', help='gpu device number', type=str, default='0') parser.add_argument('--efficient', help='if True, enables gradient checkpointing', action='store_true') args = parser.parse_args() config = yaml.load(open(args.config, 'r'), Loader=yaml.FullLoader) if len(args.gpu.split(',')) > 1: config['_parallel'] = True config['_gpu'] = args.gpu utils.set_gpu(args.gpu) main(config) ================================================ FILE: train.py ================================================ import argparse import os import random from collections import OrderedDict import yaml import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from tqdm import tqdm from torch.utils.data import DataLoader from tensorboardX import SummaryWriter import datasets import models import utils import utils.optimizers as optimizers def main(config): random.seed(0) np.random.seed(0) torch.manual_seed(0) torch.cuda.manual_seed(0) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False ckpt_name = args.name if ckpt_name is None: ckpt_name = config['encoder'] ckpt_name += '_' + config['dataset'].replace('meta-', '') ckpt_name += '_{}_way_{}_shot'.format( config['train']['n_way'], config['train']['n_shot']) if args.tag is not None: ckpt_name += '_' + args.tag ckpt_path = os.path.join('./save', ckpt_name) utils.ensure_path(ckpt_path) utils.set_log_path(ckpt_path) writer = SummaryWriter(os.path.join(ckpt_path, 'tensorboard')) yaml.dump(config, open(os.path.join(ckpt_path, 'config.yaml'), 'w')) ##### Dataset ##### # meta-train train_set = datasets.make(config['dataset'], **config['train']) utils.log('meta-train set: {} (x{}), {}'.format( train_set[0][0].shape, len(train_set), train_set.n_classes)) train_loader = DataLoader( train_set, config['train']['n_episode'], collate_fn=datasets.collate_fn, num_workers=1, pin_memory=True) # meta-val eval_val = False if config.get('val'): eval_val = True val_set = datasets.make(config['dataset'], **config['val']) utils.log('meta-val set: {} (x{}), {}'.format( val_set[0][0].shape, len(val_set), val_set.n_classes)) val_loader = DataLoader( val_set, config['val']['n_episode'], collate_fn=datasets.collate_fn, num_workers=1, pin_memory=True) ##### Model and Optimizer ##### inner_args = utils.config_inner_args(config.get('inner_args')) if config.get('load'): ckpt = torch.load(config['load']) config['encoder'] = ckpt['encoder'] config['encoder_args'] = ckpt['encoder_args'] config['classifier'] = ckpt['classifier'] config['classifier_args'] = ckpt['classifier_args'] model = models.load(ckpt, load_clf=(not inner_args['reset_classifier'])) optimizer, lr_scheduler = optimizers.load(ckpt, model.parameters()) start_epoch = ckpt['training']['epoch'] + 1 max_va = ckpt['training']['max_va'] else: config['encoder_args'] = config.get('encoder_args') or dict() config['classifier_args'] = config.get('classifier_args') or dict() config['encoder_args']['bn_args']['n_episode'] = config['train']['n_episode'] config['classifier_args']['n_way'] = config['train']['n_way'] model = models.make(config['encoder'], config['encoder_args'], config['classifier'], config['classifier_args']) optimizer, lr_scheduler = optimizers.make( config['optimizer'], model.parameters(), **config['optimizer_args']) start_epoch = 1 max_va = 0. if args.efficient: model.go_efficient() if config.get('_parallel'): model = nn.DataParallel(model) utils.log('num params: {}'.format(utils.compute_n_params(model))) timer_elapsed, timer_epoch = utils.Timer(), utils.Timer() ##### Training and evaluation ##### # 'tl': meta-train loss # 'ta': meta-train accuracy # 'vl': meta-val loss # 'va': meta-val accuracy aves_keys = ['tl', 'ta', 'vl', 'va'] trlog = dict() for k in aves_keys: trlog[k] = [] for epoch in range(start_epoch, config['epoch'] + 1): timer_epoch.start() aves = {k: utils.AverageMeter() for k in aves_keys} # meta-train model.train() writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch) np.random.seed(epoch) for data in tqdm(train_loader, desc='meta-train', leave=False): x_shot, x_query, y_shot, y_query = data x_shot, y_shot = x_shot.cuda(), y_shot.cuda() x_query, y_query = x_query.cuda(), y_query.cuda() if inner_args['reset_classifier']: if config.get('_parallel'): model.module.reset_classifier() else: model.reset_classifier() logits = model(x_shot, x_query, y_shot, inner_args, meta_train=True) logits = logits.flatten(0, 1) labels = y_query.flatten() pred = torch.argmax(logits, dim=-1) acc = utils.compute_acc(pred, labels) loss = F.cross_entropy(logits, labels) aves['tl'].update(loss.item(), 1) aves['ta'].update(acc, 1) optimizer.zero_grad() loss.backward() for param in optimizer.param_groups[0]['params']: nn.utils.clip_grad_value_(param, 10) optimizer.step() # meta-val if eval_val: model.eval() np.random.seed(0) for data in tqdm(val_loader, desc='meta-val', leave=False): x_shot, x_query, y_shot, y_query = data x_shot, y_shot = x_shot.cuda(), y_shot.cuda() x_query, y_query = x_query.cuda(), y_query.cuda() if inner_args['reset_classifier']: if config.get('_parallel'): model.module.reset_classifier() else: model.reset_classifier() logits = model(x_shot, x_query, y_shot, inner_args, meta_train=False) logits = logits.flatten(0, 1) labels = y_query.flatten() pred = torch.argmax(logits, dim=-1) acc = utils.compute_acc(pred, labels) loss = F.cross_entropy(logits, labels) aves['vl'].update(loss.item(), 1) aves['va'].update(acc, 1) if lr_scheduler is not None: lr_scheduler.step() for k, avg in aves.items(): aves[k] = avg.item() trlog[k].append(aves[k]) t_epoch = utils.time_str(timer_epoch.end()) t_elapsed = utils.time_str(timer_elapsed.end()) t_estimate = utils.time_str(timer_elapsed.end() / (epoch - start_epoch + 1) * (config['epoch'] - start_epoch + 1)) # formats output log_str = 'epoch {}, meta-train {:.4f}|{:.4f}'.format( str(epoch), aves['tl'], aves['ta']) writer.add_scalars('loss', {'meta-train': aves['tl']}, epoch) writer.add_scalars('acc', {'meta-train': aves['ta']}, epoch) if eval_val: log_str += ', meta-val {:.4f}|{:.4f}'.format(aves['vl'], aves['va']) writer.add_scalars('loss', {'meta-val': aves['vl']}, epoch) writer.add_scalars('acc', {'meta-val': aves['va']}, epoch) log_str += ', {} {}/{}'.format(t_epoch, t_elapsed, t_estimate) utils.log(log_str) # saves model and meta-data if config.get('_parallel'): model_ = model.module else: model_ = model training = { 'epoch': epoch, 'max_va': max(max_va, aves['va']), 'optimizer': config['optimizer'], 'optimizer_args': config['optimizer_args'], 'optimizer_state_dict': optimizer.state_dict(), 'lr_scheduler_state_dict': lr_scheduler.state_dict() if lr_scheduler is not None else None, } ckpt = { 'file': __file__, 'config': config, 'encoder': config['encoder'], 'encoder_args': config['encoder_args'], 'encoder_state_dict': model_.encoder.state_dict(), 'classifier': config['classifier'], 'classifier_args': config['classifier_args'], 'classifier_state_dict': model_.classifier.state_dict(), 'training': training, } # 'epoch-last.pth': saved at the latest epoch # 'max-va.pth': saved when validation accuracy is at its maximum torch.save(ckpt, os.path.join(ckpt_path, 'epoch-last.pth')) torch.save(trlog, os.path.join(ckpt_path, 'trlog.pth')) if aves['va'] > max_va: max_va = aves['va'] torch.save(ckpt, os.path.join(ckpt_path, 'max-va.pth')) writer.flush() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--config', help='configuration file') parser.add_argument('--name', help='model name', type=str, default=None) parser.add_argument('--tag', help='auxiliary information', type=str, default=None) parser.add_argument('--gpu', help='gpu device number', type=str, default='0') parser.add_argument('--efficient', help='if True, enables gradient checkpointing', action='store_true') args = parser.parse_args() config = yaml.load(open(args.config, 'r'), Loader=yaml.FullLoader) if len(args.gpu.split(',')) > 1: config['_parallel'] = True config['_gpu'] = args.gpu utils.set_gpu(args.gpu) main(config) ================================================ FILE: utils/__init__.py ================================================ import os import shutil import time import numpy as np import scipy.stats as stats _log_path = None def set_log_path(path): global _log_path _log_path = path def log(obj, filename='log.txt'): print(obj) if _log_path is not None: with open(os.path.join(_log_path, filename), 'a') as f: print(obj, file=f) class AverageMeter(object): def __init__(self): self.reset() def reset(self): self.val = 0. self.avg = 0. self.sum = 0. self.count = 0. def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def item(self): return self.avg class Timer(object): def __init__(self): self.start() def start(self): self.v = time.time() def end(self): return time.time() - self.v def set_gpu(gpu): print('set gpu:', gpu) os.environ['CUDA_VISIBLE_DEVICES'] = gpu def ensure_path(path, remove=True): basename = os.path.basename(path.rstrip('/')) if os.path.exists(path): if remove and (basename.startswith('_') or input('{} exists, remove? ([y]/n): '.format(path)) != 'n'): shutil.rmtree(path) os.makedirs(path) else: os.makedirs(path) def time_str(t): if t >= 3600: return '{:.1f}h'.format(t / 3600) if t >= 60: return '{:.1f}m'.format(t / 60) return '{:.1f}s'.format(t) def compute_acc(pred, label, reduction='mean'): result = (pred == label).float() if reduction == 'none': return result.detach() elif reduction == 'mean': return result.mean().item() def compute_n_params(model, return_str=True): n_params = 0 for p in model.parameters(): n_params += p.numel() if return_str: if n_params >= 1e6: return '{:.1f}M'.format(n_params / 1e6) else: return '{:.1f}K'.format(n_params / 1e3) else: return n_params def mean_confidence_interval(data, confidence=0.95): a = 1.0 * np.array(data) stderr = stats.sem(a) h = stderr * stats.t.ppf((1 + confidence) / 2., len(a) - 1) return h def config_inner_args(inner_args): if inner_args is None: inner_args = dict() inner_args['reset_classifier'] = inner_args.get('reset_classifier') or False inner_args['n_step'] = inner_args.get('n_step') or 5 inner_args['encoder_lr'] = inner_args.get('encoder_lr') or 0.01 inner_args['classifier_lr'] = inner_args.get('classifier_lr') or 0.01 inner_args['momentum'] = inner_args.get('momentum') or 0. inner_args['weight_decay'] = inner_args.get('weight_decay') or 0. inner_args['first_order'] = inner_args.get('first_order') or False inner_args['frozen'] = inner_args.get('frozen') or [] return inner_args ================================================ FILE: utils/optimizers.py ================================================ from torch.optim import SGD, RMSprop, Adam from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR def make(name, params, lr, weight_decay=0., schedule='step', milestones=None, gamma=0.1): """ Prepares an optimizer and its learning-rate scheduler. Args: name (str): name of the optimizer. Options: 'sgd', 'rmsprop', 'adam' params (iterable): parameters to optimize. lr (float): initial learning rate. weight_decay (float, optional): weight decay. Default: 0. schedule (str, optional): type of learning-rate schedule. Default: 'step' Options: 'step', 'cosine' (This argument is ignored if milestones=None.) milestones (int list, optional): a list of epoches when learning rate is altered. Default: None gamma (float, optional): multiplicative factor of learning rate decay. Default: 0.1 """ if name == 'sgd': optimizer = SGD(params, lr, momentum=0.9, weight_decay=weight_decay) elif name == 'rmsprop': optimizer = RMSprop(params, lr, weight_decay=weight_decay) elif name == 'adam': optimizer = Adam(params, lr, weight_decay=weight_decay) else: raise ValueError('invalid optimizer') if milestones is not None: if schedule == 'step': lr_scheduler = MultiStepLR(optimizer, milestones, gamma) elif schedule == 'cosine': lr_scheduler = CosineAnnealingLR(optimizer, milestones[-1]) else: lr_scheduler = None return optimizer, lr_scheduler def load(ckpt, params): train = ckpt['training'] optimizer, lr_scheduler = make( train['optimizer'], params, **train['optimizer_args']) optimizer.load_state_dict(train['optimizer_state_dict']) if lr_scheduler is not None: lr_scheduler.load_state_dict(train['lr_scheduler_state_dict']) return optimizer, lr_scheduler