Repository: fmu2/PyTorch-MAML
Branch: master
Commit: eee3fe3da538
Files: 40
Total size: 83.4 KB
Directory structure:
gitextract_xgy6p1l2/
├── README.md
├── configs/
│ └── convnet4/
│ ├── mini-imagenet/
│ │ ├── 5_way_1_shot/
│ │ │ ├── test_reproduce.yaml
│ │ │ ├── test_template.yaml
│ │ │ ├── train_reproduce.yaml
│ │ │ └── train_template.yaml
│ │ └── 5_way_5_shot/
│ │ ├── test_reproduce.yaml
│ │ ├── test_template.yaml
│ │ ├── train_reproduce.yaml
│ │ └── train_template.yaml
│ └── tiered-imagenet/
│ ├── 5_way_1_shot/
│ │ ├── test_reproduce.yaml
│ │ ├── test_template.yaml
│ │ ├── train_reproduce.yaml
│ │ └── train_template.yaml
│ └── 5_way_5_shot/
│ ├── test_reproduce.yaml
│ ├── test_template.yaml
│ ├── train_reproduce.yaml
│ └── train_template.yaml
├── datasets/
│ ├── __init__.py
│ ├── cifar100.py
│ ├── cub200.py
│ ├── datasets.py
│ ├── inatural.py
│ ├── mini_imagenet.py
│ ├── tiered_imagenet.py
│ └── transforms.py
├── models/
│ ├── __init__.py
│ ├── classifiers/
│ │ ├── __init__.py
│ │ ├── classifiers.py
│ │ └── logistic.py
│ ├── encoders/
│ │ ├── __init__.py
│ │ ├── convnet4.py
│ │ ├── encoders.py
│ │ ├── resnet12.py
│ │ └── resnet18.py
│ ├── maml.py
│ └── modules.py
├── test.py
├── train.py
└── utils/
├── __init__.py
└── optimizers.py
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
# MAML in PyTorch - Re-implementation and Beyond
A PyTorch implementation of [Model Agnostic Meta-Learning (MAML)](https://arxiv.org/abs/1703.03400). We faithfully reproduce [the official Tensorflow implementation](https://github.com/cbfinn/maml) while incorporating a number of additional features that may ease further study of this very high-profile meta-learning framework.
## Overview
This repository contains code for training and evaluating MAML on the mini-ImageNet and tiered-ImageNet datasets most commonly used for few-shot image classification. To the best of our knowledge, this is the only PyTorch implementation of MAML to date that **fully reproduces the results in the original paper** without applying tricks such as data augmentation, evaluation on multiple crops, and ensemble of multiple models. Other existing PyTorch implementations typically see a ~3% gap in accuracy for the 5-way-1-shot and 5-way-5-shot classification tasks on mini-ImageNet.
Beyond reproducing the results, our implementation comes with a few extra bits that we believe can be helpful for further development of the framework. We highlight the improvements we have built into our code, and discuss our observations that warrent some attention.
## Implementation Highlights
- **Batch normalization with per-episode running statistics.** Our implementation provides flexibility of tracking global and/or per-episode running statistics, hence supporting both transductive and inductive inference.
- **Better data pre-processing.** The official implementation does not normalize and augment data. We support data normalization and a variety of data augmentation techniques. We also implement data batching and support/query-set splitting more efficiently.
- **More datasets.** We support mini-ImageNet, tiered-ImageNet and more.
- **More options for outer-loop optimization.** We support mutiple optimizers and learning-rate schedulers for the outer-loop optimization.
- **More powerful inner-loop optimization.** The official implementation uses vanilla gradient descent in the inner loop. We support momentum and weight decay.
- **More options for encoder architecture.** We support the standard four-layer ConvNet as well as ResNet-12 and ResNet-18 as the encoder.
- **Easy layer freezing.** We provide an interface for layer freezing experiments. One may freeze an arbitrary set of layers or blocks during inner-loop adaptation.
- **Meta-learning with zero-initialized classifier head.** The official implementation learns a meta-initialization for both the encoder and the classifier head. This prevents one from varying the number of categories at training or test time. With our implementation, one may opt to learn a meta-initialization for the encoder while initializing the classifier head at zero.
- **Distributed training and gradient checkpointing.** MAML is very memory-intensive because it buffers all tensors generated throughout the inner-loop adaptation steps. Gradient checkpointing trades compute for memory, effectively bringing the memory cost from O(N) down to O(1), where N is the number of inner-loop steps. In our experiments, gradient checkpointing saved up to 80% of GPU memory at the cost of running the forward pass more than once (a moderate 20% increase in running time).
## Transductive or Inductive?
The official implementation assumes transductive learning. The batch normalization layers do not track running statistics at training time, and they use mini-batch statistics at test time. The implicit assumption here is that test data come in mini-batches and are perhaps balanced across categories. This is a very restrictive assumption and does not land MAML directly comparable with the vast majority of meta-learning and few-shot learning methods. Unfortunately, this is not immediately obvious from the paper, and our findings suggest that the performance of MAML is hugely overestimated.
- **Accuracy is very sensitive to the size of query set in the transductive setting.** For example, the result for 5-way-1-shot classification on miniImageNet from the paper (48.70%) was obtained on five queries, one per category. We found that the accuracy dropped by ~1.5% given five queries per category, and by ~2.5% given 15 queries per category.
- The paper reports mean accuracy over 600 independently sampled tasks, or trials. We found that **600 trials, again in the transductive setting, are insufficient for an unbiased estimate of model performance**. The mean accuracy from 6,000 trials is more stable, and is always ~2% lower than that from the first 600 trials. We conjecture that the distribution of per-trial accuracy is highly skewed towards the high end.
- We found that **MAML performs a lot worse in the inductive setting**. Given the same model configuration, inductive accuracy is always much lower (~4%) than the *corrected* transductive accuracy, which is already a few percentage points behind the reported number.
Hence, one should be extremely cautious when comparing MAML with its competitors as is evident from the discussion above.
## FOMAML and layer freezing
Unfortunately, some insights discussed in the original paper and its follow-up works do not appear to hold in the inductive setting.
- FOMAML (i.e. the first-order approximation of MAML) performs as well as MAML in transductive learning, but fails completely in the inductive setting.
- Completely freezing the encoder during inner-loop adaption as was done in [this work](https://arxiv.org/abs/1909.09157) results in dramatic decrease in accuracy.
## BatchNorm and TaskNorm
[A recent work](https://arxiv.org/abs/2003.03284) proposes TaskNorm, a test-time enhancement of batch normalization, noting that the small batch sizes during training may leave batch normalization less effective. We did not have much success with this method. We observed marginal improvement most of the time, and found that it hurts performance occationally. That said, we do believe that batch normalization is hard to deal with in MAML. TaskNorm attempts to attack the problem of small batch sizes, which we conjecture is just one among the three main causes (i.e., extremely scarse training data, extremely small batch sizes, and extremely small number of inner-loop updates) of the ineffectiveness of batch normalization in MAML.
## Quick Start
### 0. Preliminaries
**Environment**
- Python 3.6.8 (or any Python 3 distribution)
- PyTorch 1.3.1 (or any PyTorch > 1.0)
- tensorboardX
**Datasets**
Please follow the download links [here](https://github.com/cyvius96/few-shot-meta-baseline). Please modify the file names accordingly so that they can be recognized by the data loaders.
**Configurations**
Template configuration files as well as those for reproducing the results in the original paper can be found in `configs/`. The hyperparameters are self-explanatory.
### 1. Training MAML
Here is the command for single-GPU training of MAML with ConvNet4 backbone for 5-way-1-shot classification on mini-ImageNet to reproduce the result in the original paper.
```
python train.py --config=configs/convnet4/mini-imagenet/train_reproduce.yaml
```
Use `-gpu` to specify available GPUs for multi-GPU training. For example,
```
python train.py --config=configs/convnet4/mini-imagenet/train_reproduce.yaml --gpu=0,1
```
Add `-efficient` to enable gradient checkpointing. This aggressively saves GPU memory while slightly increases running time.
```
python train.py --config=configs/convnet4/mini-imagenet/train_reproduce.yaml --efficient
```
Use `-tag` to customize the name of the directory where the checkpoints and log files are saved.
### 2. Testing MAML
Here is how one would test MAML for 5-way-1-shot classification on mini-ImageNet to reproduce the result in the original paper. Please confirm the loading path first.
```
python test.py --config=configs/convnet4/mini-imagenet/test_reproduce.yaml
```
The `-gpu` and `-efficient` tags function similarly as in training.
## Contact
[Xinchan Zhu](https://www.linkedin.com/in/xinchan-zhu-66673b106) (zhuxinchan@gmail.com)
## Cite our Repository
```
@misc{pytorch_maml,
title={maml in pytorch - re-implementation and beyond},
author={Zhu, Xinchan},
howpublished={\url{https://github.com/shirleyzhu233/PyTorch-MAML}},
year={2020}
}
```
## Related Code Repositories
Our implementation is inspired by the following repositories.
* maml (the official implementation) <https://github.com/cbfinn/maml>
* MAML-Pytorch <https://github.com/dragen1860/MAML-Pytorch>
* HowToTrainYourMAMLPytorch <https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch>
* memory-efficient-maml <https://github.com/dbaranchuk/memory-efficient-maml>
## References
```
@inproceedings{finn2017model,
title={Model-agnostic meta-learning for fast adaptation of deep networks},
author={Finn, Chelsea and Abbeel, Pieter and Levine, Sergey},
booktitle={International Conference on Machine Learning (ICML)},
year={2017}
}
@inproceedings{raghu2019rapid,
title={Rapid learning or feature reuse? towards understanding the effectiveness of maml},
author={Raghu, Aniruddh and Raghu, Maithra and Bengio, Samy and Vinyals, Oriol},
booktitle={International Conference on Learning Representations (ICLR)},
year={2019}
}
@article{Bronskill2020tasknorm,
title={Tasknorm: rethinking batch normalization for meta-learning},
author={Bronskill, John and Gordon, Jonathan and Requeima, James and Nowozin, Sebastian and Turner, Richard E.},
journal={arXiv preprint arXiv:2003.03284},
year={2020}
}
```
================================================
FILE: configs/convnet4/mini-imagenet/5_way_1_shot/test_reproduce.yaml
================================================
dataset: meta-mini-imagenet
test:
split: meta-test
image_size: 84
normalization: False
transform: null
n_batch: 150
n_episode: 4
n_way: 5
n_shot: 1
n_query: 1
load: ./save/convnet4_mini-imagenet_5_way_1_shot/max-va.pth
inner_args:
n_step: 10
encoder_lr: 0.01
classifier_lr: 0.01
first_order: False # set to True for FOMAML
frozen:
- bn
epoch: 1
================================================
FILE: configs/convnet4/mini-imagenet/5_way_1_shot/test_template.yaml
================================================
dataset: meta-mini-imagenet
test:
split: meta-test
image_size: 84
normalization: True
transform: flip
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 1
n_query: 15
load: ./save/convnet4_mini-imagenet_5_way_1_shot/max-va.pth
inner_args:
reset_classifier: True
n_step: 5
encoder_lr: 0.01
classifier_lr: 0.01
momentum: 0.9
weight_decay: 5.e-4
first_order: False
epoch: 10
================================================
FILE: configs/convnet4/mini-imagenet/5_way_1_shot/train_reproduce.yaml
================================================
dataset: meta-mini-imagenet
train:
split: meta-train
image_size: 84
normalization: False
transform: null
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 1
n_query: 15
val:
split: meta-val
image_size: 84
normalization: False
transform: null
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 1
n_query: 15
encoder: convnet4
encoder_args:
bn_args:
track_running_stats: False
classifier: logistic
inner_args:
n_step: 5
encoder_lr: 0.01
classifier_lr: 0.01
first_order: False # set to True for FOMAML
frozen:
- bn
optimizer: adam
optimizer_args:
lr: 0.001
epoch: 300
================================================
FILE: configs/convnet4/mini-imagenet/5_way_1_shot/train_template.yaml
================================================
dataset: meta-mini-imagenet
train:
split: meta-train
image_size: 84
normalization: True
transform: flip
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 1
n_query: 15
val:
split: meta-val
image_size: 84
normalization: True
transform: flip
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 1
n_query: 15
encoder: convnet4
encoder_args:
bn_args:
track_running_stats: True
episodic:
- conv1
- conv2
- conv3
- conv4
classifier: logistic
inner_args:
reset_classifier: True
n_step: 5
encoder_lr: 0.01
classifier_lr: 0.01
momentum: 0.9
weight_decay: 5.e-4
first_order: False
optimizer: sgd
optimizer_args:
lr: 0.01
weight_decay: 5.e-4
schedule: step
milestones:
- 120
- 140
epoch: 150
================================================
FILE: configs/convnet4/mini-imagenet/5_way_5_shot/test_reproduce.yaml
================================================
dataset: meta-mini-imagenet
test:
split: meta-test
image_size: 84
normalization: False
transform: null
n_batch: 150
n_episode: 4
n_way: 5
n_shot: 5
n_query: 5
load: ./save/convnet4_mini-imagenet_5_way_5_shot/max-va.pth
inner_args:
n_step: 10
encoder_lr: 0.01
classifier_lr: 0.01
first_order: False
frozen:
- bn
epoch: 1
================================================
FILE: configs/convnet4/mini-imagenet/5_way_5_shot/test_template.yaml
================================================
dataset: meta-mini-imagenet
test:
split: meta-test
image_size: 84
normalization: True
transform: flip
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 5
n_query: 15
load: ./save/convnet4_mini-imagenet_5_way_5_shot/max-va.pth
inner_args:
reset_classifier: True
n_step: 5
encoder_lr: 0.01
classifier_lr: 0.01
momentum: 0.9
weight_decay: 5.e-4
first_order: False
epoch: 10
================================================
FILE: configs/convnet4/mini-imagenet/5_way_5_shot/train_reproduce.yaml
================================================
dataset: meta-mini-imagenet
train:
split: meta-train
image_size: 84
normalization: False
transform: null
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 5
n_query: 15
val:
split: meta-val
image_size: 84
normalization: False
transform: null
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 5
n_query: 15
encoder: convnet4
encoder_args:
bn_args:
track_running_stats: False
classifier: logistic
inner_args:
n_step: 5
encoder_lr: 0.01
classifier_lr: 0.01
first_order: False
frozen:
- bn
optimizer: adam
optimizer_args:
lr: 0.001
epoch: 300
================================================
FILE: configs/convnet4/mini-imagenet/5_way_5_shot/train_template.yaml
================================================
dataset: meta-mini-imagenet
train:
split: meta-train
image_size: 84
normalization: True
transform: flip
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 5
n_query: 15
val:
split: meta-val
image_size: 84
normalization: True
transform: flip
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 5
n_query: 15
encoder: convnet4
encoder_args:
bn_args:
track_running_stats: True
episodic:
- conv1
- conv2
- conv3
- conv4
classifier: logistic
inner_args:
reset_classifier: True
n_step: 5
encoder_lr: 0.01
classifier_lr: 0.01
momentum: 0.9
weight_decay: 5.e-4
first_order: False
optimizer: sgd
optimizer_args:
lr: 0.01
weight_decay: 5.e-4
schedule: step
milestones:
- 120
- 140
epoch: 150
================================================
FILE: configs/convnet4/tiered-imagenet/5_way_1_shot/test_reproduce.yaml
================================================
dataset: meta-tiered-imagenet
test:
split: meta-test
image_size: 84
normalization: False
transform: null
n_batch: 150
n_episode: 4
n_way: 5
n_shot: 1
n_query: 1
load: ./save/wide-convnet4_tiered-imagenet_5_way_1_shot/max-va.pth
inner_args:
n_step: 10
encoder_lr: 0.01
classifier_lr: 0.01
first_order: False # set to True for FOMAML
frozen:
- bn
epoch: 1
================================================
FILE: configs/convnet4/tiered-imagenet/5_way_1_shot/test_template.yaml
================================================
dataset: meta-tiered-imagenet
test:
split: meta-test
image_size: 84
normalization: True
transform: flip
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 1
n_query: 15
load: ./save/wide-convnet4_tiered-imagenet_5_way_1_shot/max-va.pth
inner_args:
reset_classifier: True
n_step: 5
encoder_lr: 0.01
classifier_lr: 0.01
momentum: 0.9
weight_decay: 5.e-4
first_order: False
epoch: 10
================================================
FILE: configs/convnet4/tiered-imagenet/5_way_1_shot/train_reproduce.yaml
================================================
dataset: meta-tiered-imagenet
train:
split: meta-train
image_size: 84
normalization: False
transform: null
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 1
n_query: 15
val:
split: meta-val
image_size: 84
normalization: False
transform: null
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 1
n_query: 15
encoder: wide-convnet4
encoder_args:
bn_args:
track_running_stats: False
classifier: logistic
inner_args:
n_step: 5
encoder_lr: 0.01
classifier_lr: 0.01
first_order: False # set to True for FOMAML
frozen:
- bn
optimizer: adam
optimizer_args:
lr: 0.001
epoch: 300
================================================
FILE: configs/convnet4/tiered-imagenet/5_way_1_shot/train_template.yaml
================================================
dataset: meta-tiered-imagenet
train:
split: meta-train
image_size: 84
normalization: True
transform: flip
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 1
n_query: 15
val:
split: meta-val
image_size: 84
normalization: True
transform: flip
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 1
n_query: 15
encoder: wide-convnet4
encoder_args:
bn_args:
track_running_stats: True
episodic:
- conv1
- conv2
- conv3
- conv4
classifier: logistic
inner_args:
reset_classifier: True
n_step: 5
encoder_lr: 0.01
classifier_lr: 0.01
momentum: 0.9
weight_decay: 5.e-4
first_order: False
optimizer: sgd
optimizer_args:
lr: 0.01
weight_decay: 5.e-4
schedule: step
milestones:
- 120
- 140
epoch: 150
================================================
FILE: configs/convnet4/tiered-imagenet/5_way_5_shot/test_reproduce.yaml
================================================
dataset: meta-tiered-imagenet
test:
split: meta-test
image_size: 84
normalization: False
transform: null
n_batch: 150
n_episode: 4
n_way: 5
n_shot: 5
n_query: 5
load: ./save/wide-convnet4_tiered-imagenet_5_way_5_shot/max-va.pth
inner_args:
n_step: 10
encoder_lr: 0.01
classifier_lr: 0.01
first_order: False # set to True for FOMAML
frozen:
- bn
epoch: 1
================================================
FILE: configs/convnet4/tiered-imagenet/5_way_5_shot/test_template.yaml
================================================
dataset: meta-tiered-imagenet
test:
split: meta-test
image_size: 84
normalization: True
transform: flip
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 5
n_query: 15
load: ./save/wide-convnet4_tiered-imagenet_5_way_5_shot/max-va.pth
inner_args:
reset_classifier: True
n_step: 5
encoder_lr: 0.01
classifier_lr: 0.01
momentum: 0.9
weight_decay: 5.e-4
first_order: False
epoch: 10
================================================
FILE: configs/convnet4/tiered-imagenet/5_way_5_shot/train_reproduce.yaml
================================================
dataset: meta-tiered-imagenet
train:
split: meta-train
image_size: 84
normalization: False
transform: null
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 5
n_query: 15
val:
split: meta-val
image_size: 84
normalization: False
transform: null
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 5
n_query: 15
encoder: wide-convnet4
encoder_args:
bn_args:
track_running_stats: False
classifier: logistic
inner_args:
n_step: 5
encoder_lr: 0.01
classifier_lr: 0.01
first_order: False # set to True for FOMAML
frozen:
- bn
optimizer: adam
optimizer_args:
lr: 0.001
epoch: 300
================================================
FILE: configs/convnet4/tiered-imagenet/5_way_5_shot/train_template.yaml
================================================
dataset: meta-tiered-imagenet
train:
split: meta-train
image_size: 84
normalization: True
transform: flip
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 5
n_query: 15
val:
split: meta-val
image_size: 84
normalization: True
transform: flip
n_batch: 200
n_episode: 4
n_way: 5
n_shot: 5
n_query: 15
encoder: wide-convnet4
encoder_args:
bn_args:
track_running_stats: True
episodic:
- conv1
- conv2
- conv3
- conv4
classifier: logistic
inner_args:
reset_classifier: True
n_step: 5
encoder_lr: 0.01
classifier_lr: 0.01
momentum: 0.9
weight_decay: 5.e-4
first_order: False
optimizer: sgd
optimizer_args:
lr: 0.01
weight_decay: 5.e-4
schedule: step
milestones:
- 120
- 140
epoch: 150
================================================
FILE: datasets/__init__.py
================================================
from .datasets import make, collate_fn
from . import mini_imagenet
from . import tiered_imagenet
from . import cifar100
from . import cub200
from . import inatural
from . import transforms
================================================
FILE: datasets/cifar100.py
================================================
import os
import pickle
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
from .datasets import register
from .transforms import get_transform
class Cifar100(Dataset):
def __init__(self, root_path, split='train', image_size=32,
normalization=True, transform=None):
super(Cifar100, self).__init__()
split_dict = {'train': 'train', # standard train
'trainval': 'trainval', # standard train + val
'meta-train': 'train', # meta-train
'meta-val': 'val', # meta-val
'meta-trainval': 'trainval', # meta-train + meta-val
'meta-test': 'test', # meta-test
}
split_tag = split_dict[split]
split_file = os.path.join(root_path, split_tag + '.pickle')
assert os.path.isfile(split_file)
with open(split_file, 'rb') as f:
pack = pickle.load(f, encoding='latin1')
data, label = pack['data'], pack['labels']
data = [Image.fromarray(x) for x in data]
label = np.array(label)
label_key = sorted(np.unique(label))
label_map = dict(zip(label_key, range(len(label_key))))
new_label = np.array([label_map[x] for x in label])
self.root_path = root_path
self.split_tag = split_tag
self.image_size = image_size
self.data = data
self.label = new_label
self.n_classes = len(label_key)
if normalization:
self.norm_params = {'mean': [0.5071, 0.4867, 0.4408],
'std': [0.2675, 0.2565, 0.2761]}
else:
self.norm_params = {'mean': [0., 0., 0.],
'std': [1., 1., 1.]}
self.transform = get_transform(transform, image_size, self.norm_params)
def convert_raw(x):
mean = torch.tensor(self.norm_params['mean']).view(3, 1, 1).type_as(x)
std = torch.tensor(self.norm_params['std']).view(3, 1, 1).type_as(x)
return x * std + mean
self.convert_raw = convert_raw
def __len__(self):
return len(self.data)
def __getitem__(self, index):
image = self.transform(self.data[index])
label = self.label[index]
return image, label
class MetaCifar100(Cifar100):
def __init__(self, root_path, split='train', image_size=32,
normalization=True, transform=None, val_transform=None,
n_batch=200, n_episode=4, n_way=5, n_shot=1, n_query=15):
super(MetaCifar100, self).__init__(root_path, split, image_size,
normalization, transform)
self.n_batch = n_batch
self.n_episode = n_episode
self.n_way = n_way
self.n_shot = n_shot
self.n_query = n_query
self.catlocs = tuple()
for cat in range(self.n_classes):
self.catlocs += (np.argwhere(self.label == cat).reshape(-1),)
self.val_transform = get_transform(
val_transform, image_size, self.norm_params)
def __len__(self):
return self.n_batch * self.n_episode
def __getitem__(self, index):
shot, query = [], []
cats = np.random.choice(self.n_classes, self.n_way, replace=False)
for c in cats:
c_shot, c_query = [], []
idx_list = np.random.choice(
self.catlocs[c], self.n_shot + self.n_query, replace=False)
shot_idx, query_idx = idx_list[:self.n_shot], idx_list[-self.n_query:]
for idx in shot_idx:
c_shot.append(self.transform(self.data[idx]))
for idx in query_idx:
c_query.append(self.val_transform(self.data[idx]))
shot.append(torch.stack(c_shot))
query.append(torch.stack(c_query))
shot = torch.cat(shot, dim=0) # [n_way * n_shot, C, H, W]
query = torch.cat(query, dim=0) # [n_way * n_query, C, H, W]
cls = torch.arange(self.n_way)[:, None]
shot_labels = cls.repeat(1, self.n_shot).flatten() # [n_way * n_shot]
query_labels = cls.repeat(1, self.n_query).flatten() # [n_way * n_query]
return shot, query, shot_labels, query_labels
@register('cifar-fs')
class CifarFS(Cifar100):
def __init__(self, *args):
super(CifarFS, self).__init__(*args)
@register('meta-cifar-fs')
class MetaCifarFS(MetaCifar100):
def __init__(self, *args):
super(MetaCifarFS, self).__init__(*args)
@register('fc100')
class FC100(Cifar100):
def __init__(self, *args):
super(FC100, self).__init__(*args)
@register('meta-fc100')
class MetaFC100(MetaCifar100):
def __init__(self, *args):
super(MetaFC100, self).__init__(*args)
================================================
FILE: datasets/cub200.py
================================================
import os
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
from .datasets import register
from .transforms import get_transform
@register('cub200')
class CUB200(Dataset):
def __init__(self, root_path, split='train', image_size=84,
normalization=True, transform=None):
super(CUB200, self).__init__()
split_dict = {'train': 'train', # standard train
'meta-train': 'train', # meta-train
'meta-val': 'val', # meta-val
'meta-test': 'test', # meta-test
}
split_tag = split_dict[split]
split_file = os.path.join(root_path, 'fs-splits', split_tag + '.csv')
assert os.path.isfile(split_file)
with open(split_file, 'r') as f:
pairs = [x.strip().split(',')
for x in f.readlines() if x.strip() != '']
data, label = [x[0] for x in pairs], [int(x[1]) for x in pairs]
label = np.array(label)
label_key = sorted(np.unique(label))
label_map = dict(zip(label_key, range(len(label_key))))
new_label = np.array([label_map[x] for x in label])
self.root_path = root_path
self.split_tag = split_tag
self.image_size = image_size
self.data = data
self.label = new_label
self.n_classes = len(label_key)
if normalization:
self.norm_params = {'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225]} # ImageNet statistics
else:
self.norm_params = {'mean': [0., 0., 0.],
'std': [1., 1., 1.]}
self.transform = get_transform(transform, image_size, self.norm_params)
def convert_raw(x):
mean = torch.tensor(self.norm_params['mean']).view(3, 1, 1).type_as(x)
std = torch.tensor(self.norm_params['std']).view(3, 1, 1).type_as(x)
return x * std + mean
self.convert_raw = convert_raw
def _load_image(self, index):
image_path = os.path.join(self.root_path, 'images', self.data[index])
assert os.path.isfile(image_path)
image = Image.open(image_path).convert('RGB')
return image
def __len__(self):
return len(self.label)
def __getitem__(self, index):
image = self.transform(self._load_image(index))
label = self.label[index]
return image, label
@register('meta-cub200')
class MetaCUB200(CUB200):
def __init__(self, root_path, split='train', image_size=84,
normalization=True, transform=None, val_transform=None,
n_batch=200, n_episode=4, n_way=5, n_shot=1, n_query=15):
super(MetaCUB200, self).__init__(root_path, split, image_size,
normalization, transform)
self.n_batch = n_batch
self.n_episode = n_episode
self.n_way = n_way
self.n_shot = n_shot
self.n_query = n_query
self.catlocs = tuple()
for cat in range(self.n_classes):
self.catlocs += (np.argwhere(self.label == cat).reshape(-1),)
self.val_transform = get_transform(
val_transform, image_size, self.norm_params)
def __len__(self):
return self.n_batch * self.n_episode
def __getitem__(self, index):
shot, query = [], []
cats = np.random.choice(self.n_classes, self.n_way, replace=False)
for c in cats:
c_shot, c_query = [], []
idx_list = np.random.choice(
self.catlocs[c], self.n_shot + self.n_query, replace=False)
shot_idx, query_idx = idx_list[:self.n_shot], idx_list[-self.n_query:]
for idx in shot_idx:
c_shot.append(self.transform(self._load_image(idx)))
for idx in query_idx:
c_query.append(self.val_transform(self._load_image(idx)))
shot.append(torch.stack(c_shot))
query.append(torch.stack(c_query))
shot = torch.cat(shot, dim=0) # [n_way * n_shot, C, H, W]
query = torch.cat(query, dim=0) # [n_way * n_query, C, H, W]
cls = torch.arange(self.n_way)[:, None]
shot_labels = cls.repeat(1, self.n_shot).flatten() # [n_way * n_shot]
query_labels = cls.repeat(1, self.n_query).flatten() # [n_way * n_query]
return shot, query, shot_labels, query_labels
================================================
FILE: datasets/datasets.py
================================================
import os
import torch
DEFAULT_ROOT = './materials'
datasets = {}
def register(name):
def decorator(cls):
datasets[name] = cls
return cls
return decorator
def make(name, **kwargs):
if kwargs.get('root_path') is None:
kwargs['root_path'] = os.path.join(DEFAULT_ROOT, name.replace('meta-', ''))
dataset = datasets[name](**kwargs)
return dataset
def collate_fn(batch):
shot, query, shot_label, query_label = [], [], [], []
for s, q, sl, ql in batch:
shot.append(s)
query.append(q)
shot_label.append(sl)
query_label.append(ql)
shot = torch.stack(shot) # [n_ep, n_way * n_shot, C, H, W]
query = torch.stack(query) # [n_ep, n_way * n_query, C, H, W]
shot_label = torch.stack(shot_label) # [n_ep, n_way * n_shot]
query_label = torch.stack(query_label) # [n_ep, n_way * n_query]
return shot, query, shot_label, query_label
================================================
FILE: datasets/inatural.py
================================================
import os
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
from .datasets import register
from .transforms import get_transform
@register('inatural')
class INat2017(Dataset):
def __init__(self, root_path, split='train', image_size=84,
normalization=True, transform=None):
super(INat2017, self).__init__()
split_dict = {'train': 'train', # standard train
'meta-train': 'train', # meta-train
'meta-test': 'test', # meta-test
}
split_tag = split_dict[split]
split_file = os.path.join(root_path, 'fs-splits', split_tag + '.csv')
assert os.path.isfile(split_file)
with open(split_file, 'r') as f:
pairs = [x.strip().split(',')
for x in f.readlines() if x.strip() != '']
data, label = [x[0] for x in pairs], [int(x[1]) for x in pairs]
label = np.array(label)
label_key = sorted(np.unique(label))
label_map = dict(zip(label_key, range(len(label_key))))
new_label = np.array([label_map[x] for x in label])
self.root_path = root_path
self.split_tag = split_tag
self.image_size = image_size
self.data = data
self.label = new_label
self.n_classes = len(label_key)
if normalization:
self.norm_params = {'mean': [0.4905, 0.4961, 0.4330],
'std': [0.1737, 0.1713, 0.1779]}
else:
self.norm_params = {'mean': [0., 0., 0.],
'std': [1., 1., 1.]}
self.transform = get_transform(transform, image_size, self.norm_params)
def convert_raw(x):
mean = torch.tensor(self.norm_params['mean']).view(3, 1, 1).type_as(x)
std = torch.tensor(self.norm_params['std']).view(3, 1, 1).type_as(x)
return x * std + mean
self.convert_raw = convert_raw
def _load_image(self, index):
image_path = os.path.join(self.root_path, 'images', self.data[index])
assert os.path.isfile(image_path)
image = Image.open(image_path).convert('RGB')
return image
def __len__(self):
return len(self.label)
def __getitem__(self, index):
image = self.transform(self._load_image(index))
label = self.label[index]
return image, label
@register('meta-inatural')
class MetaINat2017(INat2017):
def __init__(self, root_path, split='train', image_size=84,
normalization=True, transform=None, val_transform=None,
n_batch=200, n_episode=4, n_way=5, n_shot=1, n_query=15):
super(MetaINat2017, self).__init__(root_path, split, image_size,
normalization, transform)
self.n_batch = n_batch
self.n_episode = n_episode
self.n_way = n_way
self.n_shot = n_shot
self.n_query = n_query
self.catlocs = tuple()
for cat in range(self.n_classes):
self.catlocs += (np.argwhere(self.label == cat).reshape(-1),)
self.val_transform = get_transform(
val_transform, image_size, self.norm_params)
def __len__(self):
return self.n_batch * self.n_episode
def __getitem__(self, index):
shot, query = [], []
cats = np.random.choice(self.n_classes, self.n_way, replace=False)
for c in cats:
c_shot, c_query = [], []
idx_list = np.random.choice(
self.catlocs[c], self.n_shot + self.n_query, replace=False)
shot_idx, query_idx = idx_list[:self.n_shot], idx_list[-self.n_query:]
for idx in shot_idx:
c_shot.append(self.transform(self._load_image(idx)))
for idx in query_idx:
c_query.append(self.val_transform(self._load_image(idx)))
shot.append(torch.stack(c_shot))
query.append(torch.stack(c_query))
shot = torch.cat(shot, dim=0) # [n_way * n_shot, C, H, W]
query = torch.cat(query, dim=0) # [n_way * n_query, C, H, W]
cls = torch.arange(self.n_way)[:, None]
shot_labels = cls.repeat(1, self.n_shot).flatten() # [n_way * n_shot]
query_labels = cls.repeat(1, self.n_query).flatten() # [n_way * n_query]
return shot, query, shot_labels, query_labels
================================================
FILE: datasets/mini_imagenet.py
================================================
import os
import pickle
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
from .datasets import register
from .transforms import get_transform
@register('mini-imagenet')
class MiniImageNet(Dataset):
def __init__(self, root_path, split='train', image_size=84,
normalization=True, transform=None):
super(MiniImageNet, self).__init__()
split_dict = {'train': 'train_phase_train', # standard train
'val': 'train_phase_val', # standard val
'trainval': 'train_phase_trainval', # standard train and val
'test': 'train_phase_test', # standard test
'meta-train': 'train_phase_train', # meta-train
'meta-val': 'val', # meta-val
'meta-test': 'test', # meta-test
}
split_tag = split_dict[split]
split_file = os.path.join(root_path, split_tag + '.pickle')
assert os.path.isfile(split_file)
with open(split_file, 'rb') as f:
pack = pickle.load(f, encoding='latin1')
data, label = pack['data'], pack['labels']
data = [Image.fromarray(x) for x in data]
label = np.array(label)
label_key = sorted(np.unique(label))
label_map = dict(zip(label_key, range(len(label_key))))
new_label = np.array([label_map[x] for x in label])
self.root_path = root_path
self.split_tag = split_tag
self.image_size = image_size
self.data = data
self.label = new_label
self.n_classes = len(label_key)
if normalization:
self.norm_params = {'mean': [0.471, 0.450, 0.403],
'std': [0.278, 0.268, 0.284]}
else:
self.norm_params = {'mean': [0., 0., 0.],
'std': [1., 1., 1.]}
self.transform = get_transform(transform, image_size, self.norm_params)
def convert_raw(x):
mean = torch.tensor(self.norm_params['mean']).view(3, 1, 1).type_as(x)
std = torch.tensor(self.norm_params['std']).view(3, 1, 1).type_as(x)
return x * std + mean
self.convert_raw = convert_raw
def __len__(self):
return len(self.data)
def __getitem__(self, index):
image = self.transform(self.data[index])
label = self.label[index]
return image, label
@register('meta-mini-imagenet')
class MetaMiniImageNet(MiniImageNet):
def __init__(self, root_path, split='train', image_size=84,
normalization=True, transform=None, val_transform=None,
n_batch=200, n_episode=4, n_way=5, n_shot=1, n_query=15):
super(MetaMiniImageNet, self).__init__(root_path, split, image_size,
normalization, transform)
self.n_batch = n_batch
self.n_episode = n_episode
self.n_way = n_way
self.n_shot = n_shot
self.n_query = n_query
self.catlocs = tuple()
for cat in range(self.n_classes):
self.catlocs += (np.argwhere(self.label == cat).reshape(-1),)
self.val_transform = get_transform(
val_transform, image_size, self.norm_params)
def __len__(self):
return self.n_batch * self.n_episode
def __getitem__(self, index):
shot, query = [], []
cats = np.random.choice(self.n_classes, self.n_way, replace=False)
for c in cats:
c_shot, c_query = [], []
idx_list = np.random.choice(
self.catlocs[c], self.n_shot + self.n_query, replace=False)
shot_idx, query_idx = idx_list[:self.n_shot], idx_list[-self.n_query:]
for idx in shot_idx:
c_shot.append(self.transform(self.data[idx]))
for idx in query_idx:
c_query.append(self.val_transform(self.data[idx]))
shot.append(torch.stack(c_shot))
query.append(torch.stack(c_query))
shot = torch.cat(shot, dim=0) # [n_way * n_shot, C, H, W]
query = torch.cat(query, dim=0) # [n_way * n_query, C, H, W]
cls = torch.arange(self.n_way)[:, None]
shot_labels = cls.repeat(1, self.n_shot).flatten() # [n_way * n_shot]
query_labels = cls.repeat(1, self.n_query).flatten() # [n_way * n_query]
return shot, query, shot_labels, query_labels
================================================
FILE: datasets/tiered_imagenet.py
================================================
import os
import pickle
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
from .datasets import register
from .transforms import get_transform
@register('tiered-imagenet')
class TieredImageNet(Dataset):
def __init__(self, root_path, split='train', image_size=84,
normalization=True, transform=None):
super(TieredImageNet, self).__init__()
split_dict = {'train': 'train', # standard train
'val': 'train_phase_val', # standard val
'meta-train': 'train', # meta-train
'meta-val': 'val', # meta-val
'meta-test': 'test', # meta-test
}
split_tag = split_dict[split]
split_file = os.path.join(root_path, split_tag + '_images.npz')
label_file = os.path.join(root_path, split_tag + '_labels.pkl')
assert os.path.isfile(split_file)
assert os.path.isfile(label_file)
data = np.load(split_file, allow_pickle=True)['images']
data = data[:, :, :, ::-1]
with open(label_file, 'rb') as f:
label = pickle.load(f)['labels']
data = [Image.fromarray(x) for x in data]
label = np.array(label)
label_key = sorted(np.unique(label))
label_map = dict(zip(label_key, range(len(label_key))))
new_label = np.array([label_map[x] for x in label])
self.root_path = root_path
self.split_tag = split_tag
self.image_size = image_size
self.data = data
self.label = new_label
self.n_classes = len(label_key)
if normalization:
self.norm_params = {'mean': [0.478, 0.456, 0.410],
'std': [0.279, 0.274, 0.286]}
else:
self.norm_params = {'mean': [0., 0., 0.],
'std': [1., 1., 1.]}
self.transform = get_transform(transform, image_size, self.norm_params)
def convert_raw(x):
mean = torch.tensor(self.norm_params['mean']).view(3, 1, 1).type_as(x)
std = torch.tensor(self.norm_params['std']).view(3, 1, 1).type_as(x)
return x * std + mean
self.convert_raw = convert_raw
def __len__(self):
return len(self.data)
def __getitem__(self, index):
image = self.transform(self.data[index])
label = self.label[index]
return image, label
@register('meta-tiered-imagenet')
class MetaTieredImageNet(TieredImageNet):
def __init__(self, root_path, split='train', image_size=84,
normalization=True, transform=None, val_transform=None,
n_batch=200, n_episode=4, n_way=5, n_shot=1, n_query=15):
super(MetaTieredImageNet, self).__init__(root_path, split, image_size,
normalization, transform)
self.n_batch = n_batch
self.n_episode = n_episode
self.n_way = n_way
self.n_shot = n_shot
self.n_query = n_query
self.catlocs = tuple()
for cat in range(self.n_classes):
self.catlocs += (np.argwhere(self.label == cat).reshape(-1),)
self.val_transform = get_transform(
val_transform, image_size, self.norm_params)
def __len__(self):
return self.n_batch * self.n_episode
def __getitem__(self, index):
shot, query = [], []
cats = np.random.choice(self.n_classes, self.n_way, replace=False)
for c in cats:
c_shot, c_query = [], []
idx_list = np.random.choice(
self.catlocs[c], self.n_shot + self.n_query, replace=False)
shot_idx, query_idx = idx_list[:self.n_shot], idx_list[-self.n_query:]
for idx in shot_idx:
c_shot.append(self.transform(self.data[idx]))
for idx in query_idx:
c_query.append(self.val_transform(self.data[idx]))
shot.append(torch.stack(c_shot))
query.append(torch.stack(c_query))
shot = torch.cat(shot, dim=0) # [n_way * n_shot, C, H, W]
query = torch.cat(query, dim=0) # [n_way * n_query, C, H, W]
cls = torch.arange(self.n_way)[:, None]
shot_labels = cls.repeat(1, self.n_shot).flatten() # [n_way * n_shot]
query_labels = cls.repeat(1, self.n_query).flatten() # [n_way * n_query]
return shot, query, shot_labels, query_labels
================================================
FILE: datasets/transforms.py
================================================
import torchvision.transforms as transforms
__all__ = ['get_transform']
def get_transform(name, image_size, norm_params):
if name == 'resize':
return transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_params),
])
elif name == 'crop':
return transforms.Compose([
transforms.Resize(image_size),
transforms.RandomCrop(image_size, padding=8),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_params),
])
elif name == 'color':
return transforms.Compose([
transforms.Resize(image_size),
transforms.RandomCrop(image_size, padding=8),
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_params),
])
elif name == 'flip':
return transforms.Compose([
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_params),
])
elif name == 'enlarge':
return transforms.Compose([
transforms.Resize(int(image_size * 256 / 224)),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(**norm_params),
])
elif name is None:
return transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize(**norm_params),
])
else:
raise ValueError('invalid transformation')
================================================
FILE: models/__init__.py
================================================
from .maml import make
from .maml import load
================================================
FILE: models/classifiers/__init__.py
================================================
from .classifiers import make, load
from . import logistic
================================================
FILE: models/classifiers/classifiers.py
================================================
import torch
__all__ = ['make', 'load']
models = {}
def register(name):
def decorator(cls):
models[name] = cls
return cls
return decorator
def make(name, **kwargs):
if name is None:
return None
model = models[name](**kwargs)
if torch.cuda.is_available():
model.cuda()
return model
def load(ckpt):
model = make(ckpt['classifier'], **ckpt['classifier_args'])
model.load_state_dict(ckpt['classifier_state_dict'])
return model
================================================
FILE: models/classifiers/logistic.py
================================================
import torch
import torch.nn as nn
from .classifiers import register
from ..modules import *
__all__ = ['LogisticClassifier']
@register('logistic')
class LogisticClassifier(Module):
def __init__(self, in_dim, n_way, temp=1., learn_temp=False):
super(LogisticClassifier, self).__init__()
self.in_dim = in_dim
self.n_way = n_way
self.temp = temp
self.learn_temp = learn_temp
self.linear = Linear(in_dim, n_way)
if self.learn_temp:
self.temp = nn.Parameter(torch.tensor(temp))
def reset_parameters(self):
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x_shot, params=None):
assert x_shot.dim() == 2
logits = self.linear(x_shot, get_child_dict(params, 'linear'))
logits = logits * self.temp
return logits
================================================
FILE: models/encoders/__init__.py
================================================
from .encoders import make, load
from . import convnet4
from . import resnet12
from . import resnet18
================================================
FILE: models/encoders/convnet4.py
================================================
from collections import OrderedDict
import torch.nn as nn
from .encoders import register
from ..modules import *
__all__ = ['convnet4', 'wide_convnet4']
class ConvBlock(Module):
def __init__(self, in_channels, out_channels, bn_args):
super(ConvBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = Conv2d(in_channels, out_channels, 3, 1, padding=1)
self.bn = BatchNorm2d(out_channels, **bn_args)
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(2)
def forward(self, x, params=None, episode=None):
out = self.conv(x, get_child_dict(params, 'conv'))
out = self.bn(out, get_child_dict(params, 'bn'), episode)
out = self.pool(self.relu(out))
return out
class ConvNet4(Module):
def __init__(self, hid_dim, out_dim, bn_args):
super(ConvNet4, self).__init__()
self.hid_dim = hid_dim
self.out_dim = out_dim
episodic = bn_args.get('episodic') or []
bn_args_ep, bn_args_no_ep = bn_args.copy(), bn_args.copy()
bn_args_ep['episodic'] = True
bn_args_no_ep['episodic'] = False
bn_args_dict = dict()
for i in [1, 2, 3, 4]:
if 'conv%d' % i in episodic:
bn_args_dict[i] = bn_args_ep
else:
bn_args_dict[i] = bn_args_no_ep
self.encoder = Sequential(OrderedDict([
('conv1', ConvBlock(3, hid_dim, bn_args_dict[1])),
('conv2', ConvBlock(hid_dim, hid_dim, bn_args_dict[2])),
('conv3', ConvBlock(hid_dim, hid_dim, bn_args_dict[3])),
('conv4', ConvBlock(hid_dim, out_dim, bn_args_dict[4])),
]))
def get_out_dim(self, scale=25):
return self.out_dim * scale
def forward(self, x, params=None, episode=None):
out = self.encoder(x, get_child_dict(params, 'encoder'), episode)
out = out.view(out.shape[0], -1)
return out
@register('convnet4')
def convnet4(bn_args=dict()):
return ConvNet4(32, 32, bn_args)
@register('wide-convnet4')
def wide_convnet4(bn_args=dict()):
return ConvNet4(64, 64, bn_args)
================================================
FILE: models/encoders/encoders.py
================================================
import torch
models = {}
def register(name):
def decorator(cls):
models[name] = cls
return cls
return decorator
def make(name, **kwargs):
if name is None:
return None
model = models[name](**kwargs)
if torch.cuda.is_available():
model.cuda()
return model
def load(ckpt):
model = make(ckpt['encoder'], **ckpt['encoder_args'])
if model is not None:
model.load_state_dict(ckpt['encoder_state_dict'])
return model
================================================
FILE: models/encoders/resnet12.py
================================================
from collections import OrderedDict
import torch.nn as nn
from .encoders import register
from ..modules import *
__all__ = ['resnet12', 'wide_resnet12']
def conv3x3(in_channels, out_channels):
return Conv2d(in_channels, out_channels, 3, 1, padding=1, bias=False)
def conv1x1(in_channels, out_channels):
return Conv2d(in_channels, out_channels, 1, 1, padding=0, bias=False)
class Block(Module):
def __init__(self, in_planes, planes, bn_args):
super(Block, self).__init__()
self.in_planes = in_planes
self.planes = planes
self.conv1 = conv3x3(in_planes, planes)
self.bn1 = BatchNorm2d(planes, **bn_args)
self.conv2 = conv3x3(planes, planes)
self.bn2 = BatchNorm2d(planes, **bn_args)
self.conv3 = conv3x3(planes, planes)
self.bn3 = BatchNorm2d(planes, **bn_args)
self.res_conv = Sequential(OrderedDict([
('conv', conv1x1(in_planes, planes)),
('bn', BatchNorm2d(planes, **bn_args)),
]))
self.relu = nn.LeakyReLU(0.1, inplace=True)
self.pool = nn.MaxPool2d(2)
def forward(self, x, params=None, episode=None):
out = self.conv1(x, get_child_dict(params, 'conv1'))
out = self.bn1(out, get_child_dict(params, 'bn1'), episode)
out = self.relu(out)
out = self.conv2(out, get_child_dict(params, 'conv2'))
out = self.bn2(out, get_child_dict(params, 'bn2'), episode)
out = self.relu(out)
out = self.conv3(out, get_child_dict(params, 'conv3'))
out = self.bn3(out, get_child_dict(params, 'bn3'), episode)
x = self.res_conv(x, get_child_dict(params, 'res_conv'), episode)
out = self.pool(self.relu(out + x))
return out
class ResNet12(Module):
def __init__(self, channels, bn_args):
super(ResNet12, self).__init__()
self.channels = channels
episodic = bn_args.get('episodic') or []
bn_args_ep, bn_args_no_ep = bn_args.copy(), bn_args.copy()
bn_args_ep['episodic'] = True
bn_args_no_ep['episodic'] = False
bn_args_dict = dict()
for i in [1, 2, 3, 4]:
if 'layer%d' % i in episodic:
bn_args_dict[i] = bn_args_ep
else:
bn_args_dict[i] = bn_args_no_ep
self.layer1 = Block(3, channels[0], bn_args_dict[1])
self.layer2 = Block(channels[0], channels[1], bn_args_dict[2])
self.layer3 = Block(channels[1], channels[2], bn_args_dict[3])
self.layer4 = Block(channels[2], channels[3], bn_args_dict[4])
self.pool = nn.AdaptiveAvgPool2d(1)
self.out_dim = channels[3]
for m in self.modules():
if isinstance(m, Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='leaky_relu')
elif isinstance(m, BatchNorm2d):
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
def get_out_dim(self):
return self.out_dim
def forward(self, x, params=None, episode=None):
out = self.layer1(x, get_child_dict(params, 'layer1'), episode)
out = self.layer2(out, get_child_dict(params, 'layer2'), episode)
out = self.layer3(out, get_child_dict(params, 'layer3'), episode)
out = self.layer4(out, get_child_dict(params, 'layer4'), episode)
out = self.pool(out).flatten(1)
return out
@register('resnet12')
def resnet12(bn_args=dict()):
return ResNet12([64, 128, 256, 512], bn_args)
@register('wide-resnet12')
def wide_resnet12(bn_args=dict()):
return ResNet12([64, 160, 320, 640], bn_args)
================================================
FILE: models/encoders/resnet18.py
================================================
from collections import OrderedDict
import torch.nn as nn
from .encoders import register
from ..modules import *
__all__ = ['resnet18', 'wide_resnet18']
def conv3x3(in_channels, out_channels, stride=1):
return Conv2d(in_channels, out_channels, 3, stride, padding=1, bias=False)
def conv1x1(in_channels, out_channels, stride=1):
return Conv2d(in_channels, out_channels, 1, stride, padding=0, bias=False)
class Block(Module):
def __init__(self, in_planes, planes, stride, bn_args):
super(Block, self).__init__()
self.in_planes = in_planes
self.planes = planes
self.stride = stride
self.conv1 = conv3x3(in_planes, planes, stride)
self.bn1 = BatchNorm2d(planes, **bn_args)
self.conv2 = conv3x3(planes, planes)
self.bn2 = BatchNorm2d(planes, **bn_args)
if stride > 1:
self.res_conv = Sequential(OrderedDict([
('conv', conv1x1(in_planes, planes)),
('bn', BatchNorm2d(planes, **bn_args)),
]))
self.relu = nn.ReLU(inplace=True)
def forward(self, x, params=None, episode=None):
out = self.conv1(x, get_child_dict(params, 'conv1'))
out = self.bn1(out, get_child_dict(params, 'bn1'), episode)
out = self.relu(out)
out = self.conv2(out, get_child_dict(params, 'conv2'))
out = self.bn2(out, get_child_dict(params, 'bn2'), episode)
if self.stride > 1:
x = self.res_conv(x, get_child_dict(params, 'res_conv'), episode)
out = self.relu(out + x)
return out
class ResNet18(Module):
def __init__(self, channels, bn_args):
super(ResNet18, self).__init__()
self.channels = channels
episodic = bn_args.get('episodic') or []
bn_args_ep, bn_args_no_ep = bn_args.copy(), bn_args.copy()
bn_args_ep['episodic'] = True
bn_args_no_ep['episodic'] = False
bn_args_dict = dict()
for i in [0, 1, 2, 3, 4]:
if 'layer%d' % i in episodic:
bn_args_dict[i] = bn_args_ep
else:
bn_args_dict[i] = bn_args_no_ep
self.layer0 = Sequential(OrderedDict([
('conv', conv3x3(3, 64)),
('bn', BatchNorm2d(64, **bn_args_dict[0])),
]))
self.relu = nn.ReLU(inplace=True)
self.layer1 = Block(64, channels[0], 1, bn_args_dict[1])
self.layer2 = Block(channels[0], channels[1], 2, bn_args_dict[2])
self.layer3 = Block(channels[1], channels[2], 2, bn_args_dict[3])
self.layer4 = Block(channels[2], channels[3], 2, bn_args_dict[4])
self.pool = nn.AdaptiveAvgPool2d(1)
self.out_dim = channels[3]
for m in self.modules():
if isinstance(m, Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, BatchNorm2d):
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
def get_out_dim(self, scale=1):
return self.out_dim * scale
def forward(self, x, params=None, episode=None):
out = self.layer0(x, get_child_dict(params, 'layer0'), episode)
out = self.relu(out)
out = self.layer1(out, get_child_dict(params, 'layer1'), episode)
out = self.layer2(out, get_child_dict(params, 'layer2'), episode)
out = self.layer3(out, get_child_dict(params, 'layer3'), episode)
out = self.layer4(out, get_child_dict(params, 'layer4'), episode)
out = self.pool(out).flatten(1)
return out
@register('resnet18')
def resnet18(bn_args=dict()):
return ResNet18([64, 128, 256, 512], bn_args)
@register('wide-resnet18')
def wide_resnet18(bn_args=dict()):
return ResNet18([64, 160, 320, 640], bn_args)
================================================
FILE: models/maml.py
================================================
from collections import OrderedDict
import torch
import torch.nn.functional as F
import torch.autograd as autograd
import torch.utils.checkpoint as cp
from . import encoders
from . import classifiers
from .modules import get_child_dict, Module, BatchNorm2d
def make(enc_name, enc_args, clf_name, clf_args):
"""
Initializes a random meta model.
Args:
enc_name (str): name of the encoder (e.g., 'resnet12').
enc_args (dict): arguments for the encoder.
clf_name (str): name of the classifier (e.g., 'meta-nn').
clf_args (dict): arguments for the classifier.
Returns:
model (MAML): a meta classifier with a random encoder.
"""
enc = encoders.make(enc_name, **enc_args)
clf_args['in_dim'] = enc.get_out_dim()
clf = classifiers.make(clf_name, **clf_args)
model = MAML(enc, clf)
return model
def load(ckpt, load_clf=False, clf_name=None, clf_args=None):
"""
Initializes a meta model with a pre-trained encoder.
Args:
ckpt (dict): a checkpoint from which a pre-trained encoder is restored.
load_clf (bool, optional): if True, loads a pre-trained classifier.
Default: False (in which case the classifier is randomly initialized)
clf_name (str, optional): name of the classifier (e.g., 'meta-nn')
clf_args (dict, optional): arguments for the classifier.
(The last two arguments are ignored if load_clf=True.)
Returns:
model (MAML): a meta model with a pre-trained encoder.
"""
enc = encoders.load(ckpt)
if load_clf:
clf = classifiers.load(ckpt)
else:
if clf_name is None and clf_args is None:
clf = classifiers.make(ckpt['classifier'], **ckpt['classifier_args'])
else:
clf_args['in_dim'] = enc.get_out_dim()
clf = classifiers.make(clf_name, **clf_args)
model = MAML(enc, clf)
return model
class MAML(Module):
def __init__(self, encoder, classifier):
super(MAML, self).__init__()
self.encoder = encoder
self.classifier = classifier
def reset_classifier(self):
self.classifier.reset_parameters()
def _inner_forward(self, x, params, episode):
""" Forward pass for the inner loop. """
feat = self.encoder(x, get_child_dict(params, 'encoder'), episode)
logits = self.classifier(feat, get_child_dict(params, 'classifier'))
return logits
def _inner_iter(self, x, y, params, mom_buffer, episode, inner_args, detach):
"""
Performs one inner-loop iteration of MAML including the forward and
backward passes and the parameter update.
Args:
x (float tensor, [n_way * n_shot, C, H, W]): per-episode support set.
y (int tensor, [n_way * n_shot]): per-episode support set labels.
params (dict): the model parameters BEFORE the update.
mom_buffer (dict): the momentum buffer BEFORE the update.
episode (int): the current episode index.
inner_args (dict): inner-loop optimization hyperparameters.
detach (bool): if True, detachs the graph for the current iteration.
Returns:
updated_params (dict): the model parameters AFTER the update.
mom_buffer (dict): the momentum buffer AFTER the update.
"""
with torch.enable_grad():
# forward pass
logits = self._inner_forward(x, params, episode)
loss = F.cross_entropy(logits, y)
# backward pass
grads = autograd.grad(loss, params.values(),
create_graph=(not detach and not inner_args['first_order']),
only_inputs=True, allow_unused=True)
# parameter update
updated_params = OrderedDict()
for (name, param), grad in zip(params.items(), grads):
if grad is None:
updated_param = param
else:
if inner_args['weight_decay'] > 0:
grad = grad + inner_args['weight_decay'] * param
if inner_args['momentum'] > 0:
grad = grad + inner_args['momentum'] * mom_buffer[name]
mom_buffer[name] = grad
if 'encoder' in name:
lr = inner_args['encoder_lr']
elif 'classifier' in name:
lr = inner_args['classifier_lr']
else:
raise ValueError('invalid parameter name')
updated_param = param - lr * grad
if detach:
updated_param = updated_param.detach().requires_grad_(True)
updated_params[name] = updated_param
return updated_params, mom_buffer
def _adapt(self, x, y, params, episode, inner_args, meta_train):
"""
Performs inner-loop adaptation in MAML.
Args:
x (float tensor, [n_way * n_shot, C, H, W]): per-episode support set.
(T: transforms, C: channels, H: height, W: width)
y (int tensor, [n_way * n_shot]): per-episode support set labels.
params (dict): a dictionary of parameters at meta-initialization.
episode (int): the current episode index.
inner_args (dict): inner-loop optimization hyperparameters.
meta_train (bool): if True, the model is in meta-training.
Returns:
params (dict): model paramters AFTER inner-loop adaptation.
"""
assert x.dim() == 4 and y.dim() == 1
assert x.size(0) == y.size(0)
# Initializes a dictionary of momentum buffer for gradient descent in the
# inner loop. It has the same set of keys as the parameter dictionary.
mom_buffer = OrderedDict()
if inner_args['momentum'] > 0:
for name, param in params.items():
mom_buffer[name] = torch.zeros_like(param)
params_keys = tuple(params.keys())
mom_buffer_keys = tuple(mom_buffer.keys())
for m in self.modules():
if isinstance(m, BatchNorm2d) and m.is_episodic():
m.reset_episodic_running_stats(episode)
def _inner_iter_cp(episode, *state):
"""
Performs one inner-loop iteration when checkpointing is enabled.
The code is executed twice:
- 1st time with torch.no_grad() for creating checkpoints.
- 2nd time with torch.enable_grad() for computing gradients.
"""
params = OrderedDict(zip(params_keys, state[:len(params_keys)]))
mom_buffer = OrderedDict(
zip(mom_buffer_keys, state[-len(mom_buffer_keys):]))
detach = not torch.is_grad_enabled() # detach graph in the first pass
self.is_first_pass(detach)
params, mom_buffer = self._inner_iter(
x, y, params, mom_buffer, int(episode), inner_args, detach)
state = tuple(t if t.requires_grad else t.clone().requires_grad_(True)
for t in tuple(params.values()) + tuple(mom_buffer.values()))
return state
for step in range(inner_args['n_step']):
if self.efficient: # checkpointing
state = tuple(params.values()) + tuple(mom_buffer.values())
state = cp.checkpoint(_inner_iter_cp, torch.as_tensor(episode), *state)
params = OrderedDict(zip(params_keys, state[:len(params_keys)]))
mom_buffer = OrderedDict(
zip(mom_buffer_keys, state[-len(mom_buffer_keys):]))
else:
params, mom_buffer = self._inner_iter(
x, y, params, mom_buffer, episode, inner_args, not meta_train)
return params
def forward(self, x_shot, x_query, y_shot, inner_args, meta_train):
"""
Args:
x_shot (float tensor, [n_episode, n_way * n_shot, C, H, W]): support sets.
x_query (float tensor, [n_episode, n_way * n_query, C, H, W]): query sets.
(T: transforms, C: channels, H: height, W: width)
y_shot (int tensor, [n_episode, n_way * n_shot]): support set labels.
inner_args (dict, optional): inner-loop hyperparameters.
meta_train (bool): if True, the model is in meta-training.
Returns:
logits (float tensor, [n_episode, n_way * n_shot, n_way]): predicted logits.
"""
assert self.encoder is not None
assert self.classifier is not None
assert x_shot.dim() == 5 and x_query.dim() == 5
assert x_shot.size(0) == x_query.size(0)
# a dictionary of parameters that will be updated in the inner loop
params = OrderedDict(self.named_parameters())
for name in list(params.keys()):
if not params[name].requires_grad or \
any(s in name for s in inner_args['frozen'] + ['temp']):
params.pop(name)
logits = []
for ep in range(x_shot.size(0)):
# inner-loop training
self.train()
if not meta_train:
for m in self.modules():
if isinstance(m, BatchNorm2d) and not m.is_episodic():
m.eval()
updated_params = self._adapt(
x_shot[ep], y_shot[ep], params, ep, inner_args, meta_train)
# inner-loop validation
with torch.set_grad_enabled(meta_train):
self.eval()
logits_ep = self._inner_forward(x_query[ep], updated_params, ep)
logits.append(logits_ep)
self.train(meta_train)
logits = torch.stack(logits)
return logits
================================================
FILE: models/modules.py
================================================
import re
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['Module', 'Conv2d', 'Linear', 'BatchNorm2d', 'Sequential',
'get_child_dict']
def get_child_dict(params, key=None):
"""
Constructs parameter dictionary for a network module.
Args:
params (dict): a parent dictionary of named parameters.
key (str, optional): a key that specifies the root of the child dictionary.
Returns:
child_dict (dict): a child dictionary of model parameters.
"""
if params is None:
return None
if key is None or (isinstance(key, str) and key == ''):
return params
key_re = re.compile(r'^{0}\.(.+)'.format(re.escape(key)))
if not any(filter(key_re.match, params.keys())): # handles nn.DataParallel
key_re = re.compile(r'^module\.{0}\.(.+)'.format(re.escape(key)))
child_dict = OrderedDict(
(key_re.sub(r'\1', k), value) for (k, value)
in params.items() if key_re.match(k) is not None)
return child_dict
class Module(nn.Module):
def __init__(self):
super(Module, self).__init__()
self.efficient = False
self.first_pass = True
def go_efficient(self, mode=True):
""" Switches on / off gradient checkpointing. """
self.efficient = mode
for m in self.children():
if isinstance(m, Module):
m.go_efficient(mode)
def is_first_pass(self, mode=True):
""" Tracks the progress of forward and backward pass when gradient
checkpointing is enabled. """
self.first_pass = mode
for m in self.children():
if isinstance(m, Module):
m.is_first_pass(mode)
class Conv2d(nn.Conv2d, Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, bias=True):
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size,
stride, padding, bias=bias)
def forward(self, x, params=None, episode=None):
if params is None:
x = super(Conv2d, self).forward(x)
else:
weight, bias = params.get('weight'), params.get('bias')
if weight is None:
weight = self.weight
if bias is None:
bias = self.bias
x = F.conv2d(x, weight, bias, self.stride, self.padding)
return x
class Linear(nn.Linear, Module):
def __init__(self, in_features, out_features, bias=True):
super(Linear, self).__init__(in_features, out_features, bias=bias)
def forward(self, x, params=None, episode=None):
if params is None:
x = super(Linear, self).forward(x)
else:
weight, bias = params.get('weight'), params.get('bias')
if weight is None:
weight = self.weight
if bias is None:
bias = self.bias
x = F.linear(x, weight, bias)
return x
class BatchNorm2d(nn.BatchNorm2d, Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True, episodic=False, n_episode=4,
alpha=False):
"""
Args:
episodic (bool, optional): if True, maintains running statistics for
each episode separately. It is ignored if track_running_stats=False.
Default: True
n_episode (int, optional): number of episodes per mini-batch. It is
ignored if episodic=False.
alpha (bool, optional): if True, learns to interpolate between batch
statistics computed over the support set and instance statistics from
a query at validation time. Default: True
(It is ignored if track_running_stats=False or meta_learn=False)
"""
super(BatchNorm2d, self).__init__(num_features, eps, momentum, affine,
track_running_stats)
self.episodic = episodic
self.n_episode = n_episode
self.alpha = alpha
if self.track_running_stats:
if self.episodic:
for ep in range(n_episode):
self.register_buffer(
'running_mean_%d' % ep, torch.zeros(num_features))
self.register_buffer(
'running_var_%d' % ep, torch.ones(num_features))
self.register_buffer(
'num_batches_tracked_%d' % ep, torch.tensor(0, dtype=torch.int))
if self.alpha:
self.register_buffer('batch_size', torch.tensor(0, dtype=torch.int))
self.alpha_scale = nn.Parameter(torch.tensor(0.))
self.alpha_offset = nn.Parameter(torch.tensor(0.))
def is_episodic(self):
return self.episodic
def _batch_norm(self, x, mean, var, weight=None, bias=None):
if self.affine:
assert weight is not None and bias is not None
weight = weight.view(1, -1, 1, 1)
bias = bias.view(1, -1, 1, 1)
x = weight * (x - mean) / (var + self.eps) ** .5 + bias
else:
x = (x - mean) / (var + self.eps) ** .5
return x
def reset_episodic_running_stats(self, episode):
if self.episodic:
getattr(self, 'running_mean_%d' % episode).zero_()
getattr(self, 'running_var_%d' % episode).fill_(1.)
getattr(self, 'num_batches_tracked_%d' % episode).zero_()
def forward(self, x, params=None, episode=None):
self._check_input_dim(x)
if params is not None:
weight, bias = params.get('weight'), params.get('bias')
if weight is None:
weight = self.weight
if bias is None:
bias = self.bias
else:
weight, bias = self.weight, self.bias
if self.track_running_stats:
if self.episodic:
assert episode is not None and episode < self.n_episode
running_mean = getattr(self, 'running_mean_%d' % episode)
running_var = getattr(self, 'running_var_%d' % episode)
num_batches_tracked = getattr(self, 'num_batches_tracked_%d' % episode)
else:
running_mean, running_var = self.running_mean, self.running_var
num_batches_tracked = self.num_batches_tracked
if self.training:
exp_avg_factor = 0.
if self.first_pass: # only updates statistics in the first pass
if self.alpha:
self.batch_size = x.size(0)
num_batches_tracked += 1
if self.momentum is None:
exp_avg_factor = 1. / float(num_batches_tracked)
else:
exp_avg_factor = self.momentum
return F.batch_norm(x, running_mean, running_var, weight, bias,
True, exp_avg_factor, self.eps)
else:
if self.alpha:
assert self.batch_size > 0
alpha = torch.sigmoid(
self.alpha_scale * self.batch_size + self.alpha_offset)
# exponentially moving-averaged training statistics
running_mean = running_mean.view(1, -1, 1, 1)
running_var = running_var.view(1, -1, 1, 1)
# per-sample statistics
sample_mean = torch.mean(x, dim=(2, 3), keepdim=True)
sample_var = torch.var(x, dim=(2, 3), unbiased=False, keepdim=True)
# interpolated statistics
mean = alpha * running_mean + (1 - alpha) * sample_mean
var = alpha * running_var + (1 - alpha) * sample_var + \
alpha * (1 - alpha) * (sample_mean - running_mean) ** 2
return self._batch_norm(x, mean, var, weight, bias)
else:
return F.batch_norm(x, running_mean, running_var, weight, bias,
False, 0., self.eps)
else:
return F.batch_norm(x, None, None, weight, bias, True, 0., self.eps)
class Sequential(nn.Sequential, Module):
def __init__(self, *args):
super(Sequential, self).__init__(*args)
def forward(self, x, params=None, episode=None):
if params is None:
for module in self:
x = module(x, None, episode)
else:
for name, module in self._modules.items():
x = module(x, get_child_dict(params, name), episode)
return x
================================================
FILE: test.py
================================================
import argparse
import random
import yaml
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
import datasets
import models
import utils
def main(config):
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
##### Dataset #####
dataset = datasets.make(config['dataset'], **config['test'])
utils.log('meta-test set: {} (x{}), {}'.format(
dataset[0][0].shape, len(dataset), dataset.n_classes))
loader = DataLoader(dataset, config['test']['n_episode'],
collate_fn=datasets.collate_fn, num_workers=1, pin_memory=True)
##### Model #####
ckpt = torch.load(config['load'])
inner_args = utils.config_inner_args(config.get('inner_args'))
model = models.load(ckpt, load_clf=(not inner_args['reset_classifier']))
if args.efficient:
model.go_efficient()
if config.get('_parallel'):
model = nn.DataParallel(model)
utils.log('num params: {}'.format(utils.compute_n_params(model)))
##### Evaluation #####
model.eval()
aves_va = utils.AverageMeter()
va_lst = []
for epoch in range(1, config['epoch'] + 1):
for data in tqdm(loader, leave=False):
x_shot, x_query, y_shot, y_query = data
x_shot, y_shot = x_shot.cuda(), y_shot.cuda()
x_query, y_query = x_query.cuda(), y_query.cuda()
if inner_args['reset_classifier']:
if config.get('_parallel'):
model.module.reset_classifier()
else:
model.reset_classifier()
logits = model(x_shot, x_query, y_shot, inner_args, meta_train=False)
logits = logits.view(-1, config['test']['n_way'])
labels = y_query.view(-1)
pred = torch.argmax(logits, dim=1)
acc = utils.compute_acc(pred, labels)
aves_va.update(acc, 1)
va_lst.append(acc)
print('test epoch {}: acc={:.2f} +- {:.2f} (%)'.format(
epoch, aves_va.item() * 100,
utils.mean_confidence_interval(va_lst) * 100))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config',
help='configuration file')
parser.add_argument('--gpu',
help='gpu device number',
type=str, default='0')
parser.add_argument('--efficient',
help='if True, enables gradient checkpointing',
action='store_true')
args = parser.parse_args()
config = yaml.load(open(args.config, 'r'), Loader=yaml.FullLoader)
if len(args.gpu.split(',')) > 1:
config['_parallel'] = True
config['_gpu'] = args.gpu
utils.set_gpu(args.gpu)
main(config)
================================================
FILE: train.py
================================================
import argparse
import os
import random
from collections import OrderedDict
import yaml
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
import datasets
import models
import utils
import utils.optimizers as optimizers
def main(config):
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
ckpt_name = args.name
if ckpt_name is None:
ckpt_name = config['encoder']
ckpt_name += '_' + config['dataset'].replace('meta-', '')
ckpt_name += '_{}_way_{}_shot'.format(
config['train']['n_way'], config['train']['n_shot'])
if args.tag is not None:
ckpt_name += '_' + args.tag
ckpt_path = os.path.join('./save', ckpt_name)
utils.ensure_path(ckpt_path)
utils.set_log_path(ckpt_path)
writer = SummaryWriter(os.path.join(ckpt_path, 'tensorboard'))
yaml.dump(config, open(os.path.join(ckpt_path, 'config.yaml'), 'w'))
##### Dataset #####
# meta-train
train_set = datasets.make(config['dataset'], **config['train'])
utils.log('meta-train set: {} (x{}), {}'.format(
train_set[0][0].shape, len(train_set), train_set.n_classes))
train_loader = DataLoader(
train_set, config['train']['n_episode'],
collate_fn=datasets.collate_fn, num_workers=1, pin_memory=True)
# meta-val
eval_val = False
if config.get('val'):
eval_val = True
val_set = datasets.make(config['dataset'], **config['val'])
utils.log('meta-val set: {} (x{}), {}'.format(
val_set[0][0].shape, len(val_set), val_set.n_classes))
val_loader = DataLoader(
val_set, config['val']['n_episode'],
collate_fn=datasets.collate_fn, num_workers=1, pin_memory=True)
##### Model and Optimizer #####
inner_args = utils.config_inner_args(config.get('inner_args'))
if config.get('load'):
ckpt = torch.load(config['load'])
config['encoder'] = ckpt['encoder']
config['encoder_args'] = ckpt['encoder_args']
config['classifier'] = ckpt['classifier']
config['classifier_args'] = ckpt['classifier_args']
model = models.load(ckpt, load_clf=(not inner_args['reset_classifier']))
optimizer, lr_scheduler = optimizers.load(ckpt, model.parameters())
start_epoch = ckpt['training']['epoch'] + 1
max_va = ckpt['training']['max_va']
else:
config['encoder_args'] = config.get('encoder_args') or dict()
config['classifier_args'] = config.get('classifier_args') or dict()
config['encoder_args']['bn_args']['n_episode'] = config['train']['n_episode']
config['classifier_args']['n_way'] = config['train']['n_way']
model = models.make(config['encoder'], config['encoder_args'],
config['classifier'], config['classifier_args'])
optimizer, lr_scheduler = optimizers.make(
config['optimizer'], model.parameters(), **config['optimizer_args'])
start_epoch = 1
max_va = 0.
if args.efficient:
model.go_efficient()
if config.get('_parallel'):
model = nn.DataParallel(model)
utils.log('num params: {}'.format(utils.compute_n_params(model)))
timer_elapsed, timer_epoch = utils.Timer(), utils.Timer()
##### Training and evaluation #####
# 'tl': meta-train loss
# 'ta': meta-train accuracy
# 'vl': meta-val loss
# 'va': meta-val accuracy
aves_keys = ['tl', 'ta', 'vl', 'va']
trlog = dict()
for k in aves_keys:
trlog[k] = []
for epoch in range(start_epoch, config['epoch'] + 1):
timer_epoch.start()
aves = {k: utils.AverageMeter() for k in aves_keys}
# meta-train
model.train()
writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
np.random.seed(epoch)
for data in tqdm(train_loader, desc='meta-train', leave=False):
x_shot, x_query, y_shot, y_query = data
x_shot, y_shot = x_shot.cuda(), y_shot.cuda()
x_query, y_query = x_query.cuda(), y_query.cuda()
if inner_args['reset_classifier']:
if config.get('_parallel'):
model.module.reset_classifier()
else:
model.reset_classifier()
logits = model(x_shot, x_query, y_shot, inner_args, meta_train=True)
logits = logits.flatten(0, 1)
labels = y_query.flatten()
pred = torch.argmax(logits, dim=-1)
acc = utils.compute_acc(pred, labels)
loss = F.cross_entropy(logits, labels)
aves['tl'].update(loss.item(), 1)
aves['ta'].update(acc, 1)
optimizer.zero_grad()
loss.backward()
for param in optimizer.param_groups[0]['params']:
nn.utils.clip_grad_value_(param, 10)
optimizer.step()
# meta-val
if eval_val:
model.eval()
np.random.seed(0)
for data in tqdm(val_loader, desc='meta-val', leave=False):
x_shot, x_query, y_shot, y_query = data
x_shot, y_shot = x_shot.cuda(), y_shot.cuda()
x_query, y_query = x_query.cuda(), y_query.cuda()
if inner_args['reset_classifier']:
if config.get('_parallel'):
model.module.reset_classifier()
else:
model.reset_classifier()
logits = model(x_shot, x_query, y_shot, inner_args, meta_train=False)
logits = logits.flatten(0, 1)
labels = y_query.flatten()
pred = torch.argmax(logits, dim=-1)
acc = utils.compute_acc(pred, labels)
loss = F.cross_entropy(logits, labels)
aves['vl'].update(loss.item(), 1)
aves['va'].update(acc, 1)
if lr_scheduler is not None:
lr_scheduler.step()
for k, avg in aves.items():
aves[k] = avg.item()
trlog[k].append(aves[k])
t_epoch = utils.time_str(timer_epoch.end())
t_elapsed = utils.time_str(timer_elapsed.end())
t_estimate = utils.time_str(timer_elapsed.end() /
(epoch - start_epoch + 1) * (config['epoch'] - start_epoch + 1))
# formats output
log_str = 'epoch {}, meta-train {:.4f}|{:.4f}'.format(
str(epoch), aves['tl'], aves['ta'])
writer.add_scalars('loss', {'meta-train': aves['tl']}, epoch)
writer.add_scalars('acc', {'meta-train': aves['ta']}, epoch)
if eval_val:
log_str += ', meta-val {:.4f}|{:.4f}'.format(aves['vl'], aves['va'])
writer.add_scalars('loss', {'meta-val': aves['vl']}, epoch)
writer.add_scalars('acc', {'meta-val': aves['va']}, epoch)
log_str += ', {} {}/{}'.format(t_epoch, t_elapsed, t_estimate)
utils.log(log_str)
# saves model and meta-data
if config.get('_parallel'):
model_ = model.module
else:
model_ = model
training = {
'epoch': epoch,
'max_va': max(max_va, aves['va']),
'optimizer': config['optimizer'],
'optimizer_args': config['optimizer_args'],
'optimizer_state_dict': optimizer.state_dict(),
'lr_scheduler_state_dict': lr_scheduler.state_dict()
if lr_scheduler is not None else None,
}
ckpt = {
'file': __file__,
'config': config,
'encoder': config['encoder'],
'encoder_args': config['encoder_args'],
'encoder_state_dict': model_.encoder.state_dict(),
'classifier': config['classifier'],
'classifier_args': config['classifier_args'],
'classifier_state_dict': model_.classifier.state_dict(),
'training': training,
}
# 'epoch-last.pth': saved at the latest epoch
# 'max-va.pth': saved when validation accuracy is at its maximum
torch.save(ckpt, os.path.join(ckpt_path, 'epoch-last.pth'))
torch.save(trlog, os.path.join(ckpt_path, 'trlog.pth'))
if aves['va'] > max_va:
max_va = aves['va']
torch.save(ckpt, os.path.join(ckpt_path, 'max-va.pth'))
writer.flush()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config',
help='configuration file')
parser.add_argument('--name',
help='model name',
type=str, default=None)
parser.add_argument('--tag',
help='auxiliary information',
type=str, default=None)
parser.add_argument('--gpu',
help='gpu device number',
type=str, default='0')
parser.add_argument('--efficient',
help='if True, enables gradient checkpointing',
action='store_true')
args = parser.parse_args()
config = yaml.load(open(args.config, 'r'), Loader=yaml.FullLoader)
if len(args.gpu.split(',')) > 1:
config['_parallel'] = True
config['_gpu'] = args.gpu
utils.set_gpu(args.gpu)
main(config)
================================================
FILE: utils/__init__.py
================================================
import os
import shutil
import time
import numpy as np
import scipy.stats as stats
_log_path = None
def set_log_path(path):
global _log_path
_log_path = path
def log(obj, filename='log.txt'):
print(obj)
if _log_path is not None:
with open(os.path.join(_log_path, filename), 'a') as f:
print(obj, file=f)
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0.
self.avg = 0.
self.sum = 0.
self.count = 0.
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def item(self):
return self.avg
class Timer(object):
def __init__(self):
self.start()
def start(self):
self.v = time.time()
def end(self):
return time.time() - self.v
def set_gpu(gpu):
print('set gpu:', gpu)
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
def ensure_path(path, remove=True):
basename = os.path.basename(path.rstrip('/'))
if os.path.exists(path):
if remove and (basename.startswith('_')
or input('{} exists, remove? ([y]/n): '.format(path)) != 'n'):
shutil.rmtree(path)
os.makedirs(path)
else:
os.makedirs(path)
def time_str(t):
if t >= 3600:
return '{:.1f}h'.format(t / 3600)
if t >= 60:
return '{:.1f}m'.format(t / 60)
return '{:.1f}s'.format(t)
def compute_acc(pred, label, reduction='mean'):
result = (pred == label).float()
if reduction == 'none':
return result.detach()
elif reduction == 'mean':
return result.mean().item()
def compute_n_params(model, return_str=True):
n_params = 0
for p in model.parameters():
n_params += p.numel()
if return_str:
if n_params >= 1e6:
return '{:.1f}M'.format(n_params / 1e6)
else:
return '{:.1f}K'.format(n_params / 1e3)
else:
return n_params
def mean_confidence_interval(data, confidence=0.95):
a = 1.0 * np.array(data)
stderr = stats.sem(a)
h = stderr * stats.t.ppf((1 + confidence) / 2., len(a) - 1)
return h
def config_inner_args(inner_args):
if inner_args is None:
inner_args = dict()
inner_args['reset_classifier'] = inner_args.get('reset_classifier') or False
inner_args['n_step'] = inner_args.get('n_step') or 5
inner_args['encoder_lr'] = inner_args.get('encoder_lr') or 0.01
inner_args['classifier_lr'] = inner_args.get('classifier_lr') or 0.01
inner_args['momentum'] = inner_args.get('momentum') or 0.
inner_args['weight_decay'] = inner_args.get('weight_decay') or 0.
inner_args['first_order'] = inner_args.get('first_order') or False
inner_args['frozen'] = inner_args.get('frozen') or []
return inner_args
================================================
FILE: utils/optimizers.py
================================================
from torch.optim import SGD, RMSprop, Adam
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR
def make(name, params, lr, weight_decay=0.,
schedule='step', milestones=None, gamma=0.1):
"""
Prepares an optimizer and its learning-rate scheduler.
Args:
name (str): name of the optimizer. Options: 'sgd', 'rmsprop', 'adam'
params (iterable): parameters to optimize.
lr (float): initial learning rate.
weight_decay (float, optional): weight decay. Default: 0.
schedule (str, optional): type of learning-rate schedule. Default: 'step'
Options: 'step', 'cosine'
(This argument is ignored if milestones=None.)
milestones (int list, optional): a list of epoches when learning rate
is altered. Default: None
gamma (float, optional): multiplicative factor of learning rate decay.
Default: 0.1
"""
if name == 'sgd':
optimizer = SGD(params, lr, momentum=0.9, weight_decay=weight_decay)
elif name == 'rmsprop':
optimizer = RMSprop(params, lr, weight_decay=weight_decay)
elif name == 'adam':
optimizer = Adam(params, lr, weight_decay=weight_decay)
else:
raise ValueError('invalid optimizer')
if milestones is not None:
if schedule == 'step':
lr_scheduler = MultiStepLR(optimizer, milestones, gamma)
elif schedule == 'cosine':
lr_scheduler = CosineAnnealingLR(optimizer, milestones[-1])
else:
lr_scheduler = None
return optimizer, lr_scheduler
def load(ckpt, params):
train = ckpt['training']
optimizer, lr_scheduler = make(
train['optimizer'], params, **train['optimizer_args'])
optimizer.load_state_dict(train['optimizer_state_dict'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(train['lr_scheduler_state_dict'])
return optimizer, lr_scheduler
gitextract_xgy6p1l2/
├── README.md
├── configs/
│ └── convnet4/
│ ├── mini-imagenet/
│ │ ├── 5_way_1_shot/
│ │ │ ├── test_reproduce.yaml
│ │ │ ├── test_template.yaml
│ │ │ ├── train_reproduce.yaml
│ │ │ └── train_template.yaml
│ │ └── 5_way_5_shot/
│ │ ├── test_reproduce.yaml
│ │ ├── test_template.yaml
│ │ ├── train_reproduce.yaml
│ │ └── train_template.yaml
│ └── tiered-imagenet/
│ ├── 5_way_1_shot/
│ │ ├── test_reproduce.yaml
│ │ ├── test_template.yaml
│ │ ├── train_reproduce.yaml
│ │ └── train_template.yaml
│ └── 5_way_5_shot/
│ ├── test_reproduce.yaml
│ ├── test_template.yaml
│ ├── train_reproduce.yaml
│ └── train_template.yaml
├── datasets/
│ ├── __init__.py
│ ├── cifar100.py
│ ├── cub200.py
│ ├── datasets.py
│ ├── inatural.py
│ ├── mini_imagenet.py
│ ├── tiered_imagenet.py
│ └── transforms.py
├── models/
│ ├── __init__.py
│ ├── classifiers/
│ │ ├── __init__.py
│ │ ├── classifiers.py
│ │ └── logistic.py
│ ├── encoders/
│ │ ├── __init__.py
│ │ ├── convnet4.py
│ │ ├── encoders.py
│ │ ├── resnet12.py
│ │ └── resnet18.py
│ ├── maml.py
│ └── modules.py
├── test.py
├── train.py
└── utils/
├── __init__.py
└── optimizers.py
SYMBOL INDEX (146 symbols across 19 files)
FILE: datasets/cifar100.py
class Cifar100 (line 13) | class Cifar100(Dataset):
method __init__ (line 14) | def __init__(self, root_path, split='train', image_size=32,
method __len__ (line 62) | def __len__(self):
method __getitem__ (line 65) | def __getitem__(self, index):
class MetaCifar100 (line 71) | class MetaCifar100(Cifar100):
method __init__ (line 72) | def __init__(self, root_path, split='train', image_size=32,
method __len__ (line 90) | def __len__(self):
method __getitem__ (line 93) | def __getitem__(self, index):
class CifarFS (line 118) | class CifarFS(Cifar100):
method __init__ (line 119) | def __init__(self, *args):
class MetaCifarFS (line 124) | class MetaCifarFS(MetaCifar100):
method __init__ (line 125) | def __init__(self, *args):
class FC100 (line 130) | class FC100(Cifar100):
method __init__ (line 131) | def __init__(self, *args):
class MetaFC100 (line 136) | class MetaFC100(MetaCifar100):
method __init__ (line 137) | def __init__(self, *args):
FILE: datasets/cub200.py
class CUB200 (line 13) | class CUB200(Dataset):
method __init__ (line 14) | def __init__(self, root_path, split='train', image_size=84,
method _load_image (line 59) | def _load_image(self, index):
method __len__ (line 65) | def __len__(self):
method __getitem__ (line 68) | def __getitem__(self, index):
class MetaCUB200 (line 75) | class MetaCUB200(CUB200):
method __init__ (line 76) | def __init__(self, root_path, split='train', image_size=84,
method __len__ (line 94) | def __len__(self):
method __getitem__ (line 97) | def __getitem__(self, index):
FILE: datasets/datasets.py
function register (line 9) | def register(name):
function make (line 16) | def make(name, **kwargs):
function collate_fn (line 23) | def collate_fn(batch):
FILE: datasets/inatural.py
class INat2017 (line 13) | class INat2017(Dataset):
method __init__ (line 14) | def __init__(self, root_path, split='train', image_size=84,
method _load_image (line 59) | def _load_image(self, index):
method __len__ (line 65) | def __len__(self):
method __getitem__ (line 68) | def __getitem__(self, index):
class MetaINat2017 (line 75) | class MetaINat2017(INat2017):
method __init__ (line 76) | def __init__(self, root_path, split='train', image_size=84,
method __len__ (line 94) | def __len__(self):
method __getitem__ (line 97) | def __getitem__(self, index):
FILE: datasets/mini_imagenet.py
class MiniImageNet (line 14) | class MiniImageNet(Dataset):
method __init__ (line 15) | def __init__(self, root_path, split='train', image_size=84,
method __len__ (line 64) | def __len__(self):
method __getitem__ (line 67) | def __getitem__(self, index):
class MetaMiniImageNet (line 74) | class MetaMiniImageNet(MiniImageNet):
method __init__ (line 75) | def __init__(self, root_path, split='train', image_size=84,
method __len__ (line 93) | def __len__(self):
method __getitem__ (line 96) | def __getitem__(self, index):
FILE: datasets/tiered_imagenet.py
class TieredImageNet (line 14) | class TieredImageNet(Dataset):
method __init__ (line 15) | def __init__(self, root_path, split='train', image_size=84,
method __len__ (line 65) | def __len__(self):
method __getitem__ (line 68) | def __getitem__(self, index):
class MetaTieredImageNet (line 75) | class MetaTieredImageNet(TieredImageNet):
method __init__ (line 76) | def __init__(self, root_path, split='train', image_size=84,
method __len__ (line 94) | def __len__(self):
method __getitem__ (line 97) | def __getitem__(self, index):
FILE: datasets/transforms.py
function get_transform (line 7) | def get_transform(name, image_size, norm_params):
FILE: models/classifiers/classifiers.py
function register (line 9) | def register(name):
function make (line 16) | def make(name, **kwargs):
function load (line 25) | def load(ckpt):
FILE: models/classifiers/logistic.py
class LogisticClassifier (line 12) | class LogisticClassifier(Module):
method __init__ (line 13) | def __init__(self, in_dim, n_way, temp=1., learn_temp=False):
method reset_parameters (line 24) | def reset_parameters(self):
method forward (line 28) | def forward(self, x_shot, params=None):
FILE: models/encoders/convnet4.py
class ConvBlock (line 12) | class ConvBlock(Module):
method __init__ (line 13) | def __init__(self, in_channels, out_channels, bn_args):
method forward (line 23) | def forward(self, x, params=None, episode=None):
class ConvNet4 (line 30) | class ConvNet4(Module):
method __init__ (line 31) | def __init__(self, hid_dim, out_dim, bn_args):
method get_out_dim (line 54) | def get_out_dim(self, scale=25):
method forward (line 57) | def forward(self, x, params=None, episode=None):
function convnet4 (line 64) | def convnet4(bn_args=dict()):
function wide_convnet4 (line 69) | def wide_convnet4(bn_args=dict()):
FILE: models/encoders/encoders.py
function register (line 6) | def register(name):
function make (line 13) | def make(name, **kwargs):
function load (line 22) | def load(ckpt):
FILE: models/encoders/resnet12.py
function conv3x3 (line 12) | def conv3x3(in_channels, out_channels):
function conv1x1 (line 16) | def conv1x1(in_channels, out_channels):
class Block (line 20) | class Block(Module):
method __init__ (line 21) | def __init__(self, in_planes, planes, bn_args):
method forward (line 41) | def forward(self, x, params=None, episode=None):
class ResNet12 (line 58) | class ResNet12(Module):
method __init__ (line 59) | def __init__(self, channels, bn_args):
method get_out_dim (line 90) | def get_out_dim(self):
method forward (line 93) | def forward(self, x, params=None, episode=None):
function resnet12 (line 103) | def resnet12(bn_args=dict()):
function wide_resnet12 (line 108) | def wide_resnet12(bn_args=dict()):
FILE: models/encoders/resnet18.py
function conv3x3 (line 12) | def conv3x3(in_channels, out_channels, stride=1):
function conv1x1 (line 16) | def conv1x1(in_channels, out_channels, stride=1):
class Block (line 20) | class Block(Module):
method __init__ (line 21) | def __init__(self, in_planes, planes, stride, bn_args):
method forward (line 40) | def forward(self, x, params=None, episode=None):
class ResNet18 (line 54) | class ResNet18(Module):
method __init__ (line 55) | def __init__(self, channels, bn_args):
method get_out_dim (line 91) | def get_out_dim(self, scale=1):
method forward (line 94) | def forward(self, x, params=None, episode=None):
function resnet18 (line 106) | def resnet18(bn_args=dict()):
function wide_resnet18 (line 111) | def wide_resnet18(bn_args=dict()):
FILE: models/maml.py
function make (line 13) | def make(enc_name, enc_args, clf_name, clf_args):
function load (line 33) | def load(ckpt, load_clf=False, clf_name=None, clf_args=None):
class MAML (line 61) | class MAML(Module):
method __init__ (line 62) | def __init__(self, encoder, classifier):
method reset_classifier (line 67) | def reset_classifier(self):
method _inner_forward (line 70) | def _inner_forward(self, x, params, episode):
method _inner_iter (line 76) | def _inner_iter(self, x, y, params, mom_buffer, episode, inner_args, d...
method _adapt (line 126) | def _adapt(self, x, y, params, episode, inner_args, meta_train):
method forward (line 190) | def forward(self, x_shot, x_query, y_shot, inner_args, meta_train):
FILE: models/modules.py
function get_child_dict (line 13) | def get_child_dict(params, key=None):
class Module (line 38) | class Module(nn.Module):
method __init__ (line 39) | def __init__(self):
method go_efficient (line 44) | def go_efficient(self, mode=True):
method is_first_pass (line 51) | def is_first_pass(self, mode=True):
class Conv2d (line 60) | class Conv2d(nn.Conv2d, Module):
method __init__ (line 61) | def __init__(self, in_channels, out_channels, kernel_size,
method forward (line 66) | def forward(self, x, params=None, episode=None):
class Linear (line 79) | class Linear(nn.Linear, Module):
method __init__ (line 80) | def __init__(self, in_features, out_features, bias=True):
method forward (line 83) | def forward(self, x, params=None, episode=None):
class BatchNorm2d (line 96) | class BatchNorm2d(nn.BatchNorm2d, Module):
method __init__ (line 97) | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
method is_episodic (line 132) | def is_episodic(self):
method _batch_norm (line 135) | def _batch_norm(self, x, mean, var, weight=None, bias=None):
method reset_episodic_running_stats (line 145) | def reset_episodic_running_stats(self, episode):
method forward (line 151) | def forward(self, x, params=None, episode=None):
class Sequential (line 207) | class Sequential(nn.Sequential, Module):
method __init__ (line 208) | def __init__(self, *args):
method forward (line 211) | def forward(self, x, params=None, episode=None):
FILE: test.py
function main (line 16) | def main(config):
FILE: train.py
function main (line 21) | def main(config):
FILE: utils/__init__.py
function set_log_path (line 11) | def set_log_path(path):
function log (line 16) | def log(obj, filename='log.txt'):
class AverageMeter (line 23) | class AverageMeter(object):
method __init__ (line 24) | def __init__(self):
method reset (line 27) | def reset(self):
method update (line 33) | def update(self, val, n=1):
method item (line 39) | def item(self):
class Timer (line 43) | class Timer(object):
method __init__ (line 44) | def __init__(self):
method start (line 47) | def start(self):
method end (line 50) | def end(self):
function set_gpu (line 54) | def set_gpu(gpu):
function ensure_path (line 59) | def ensure_path(path, remove=True):
function time_str (line 70) | def time_str(t):
function compute_acc (line 78) | def compute_acc(pred, label, reduction='mean'):
function compute_n_params (line 86) | def compute_n_params(model, return_str=True):
function mean_confidence_interval (line 99) | def mean_confidence_interval(data, confidence=0.95):
function config_inner_args (line 106) | def config_inner_args(inner_args):
FILE: utils/optimizers.py
function make (line 5) | def make(name, params, lr, weight_decay=0.,
function load (line 43) | def load(ckpt, params):
Condensed preview — 40 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (91K chars).
[
{
"path": "README.md",
"chars": 9571,
"preview": "# MAML in PyTorch - Re-implementation and Beyond\n\nA PyTorch implementation of [Model Agnostic Meta-Learning (MAML)](http"
},
{
"path": "configs/convnet4/mini-imagenet/5_way_1_shot/test_reproduce.yaml",
"chars": 380,
"preview": "dataset: meta-mini-imagenet\ntest:\n split: meta-test\n image_size: 84\n normalization: False\n transform: null\n n_batch"
},
{
"path": "configs/convnet4/mini-imagenet/5_way_1_shot/test_template.yaml",
"chars": 398,
"preview": "dataset: meta-mini-imagenet\ntest:\n split: meta-test\n image_size: 84\n normalization: True\n transform: flip\n n_batch:"
},
{
"path": "configs/convnet4/mini-imagenet/5_way_1_shot/train_reproduce.yaml",
"chars": 612,
"preview": "dataset: meta-mini-imagenet\ntrain:\n split: meta-train\n image_size: 84\n normalization: False\n transform: null\n n_bat"
},
{
"path": "configs/convnet4/mini-imagenet/5_way_1_shot/train_template.yaml",
"chars": 768,
"preview": "dataset: meta-mini-imagenet\ntrain:\n split: meta-train\n image_size: 84\n normalization: True\n transform: flip\n n_batc"
},
{
"path": "configs/convnet4/mini-imagenet/5_way_5_shot/test_reproduce.yaml",
"chars": 354,
"preview": "dataset: meta-mini-imagenet\ntest:\n split: meta-test\n image_size: 84\n normalization: False\n transform: null\n n_batch"
},
{
"path": "configs/convnet4/mini-imagenet/5_way_5_shot/test_template.yaml",
"chars": 398,
"preview": "dataset: meta-mini-imagenet\ntest:\n split: meta-test\n image_size: 84\n normalization: True\n transform: flip\n n_batch:"
},
{
"path": "configs/convnet4/mini-imagenet/5_way_5_shot/train_reproduce.yaml",
"chars": 586,
"preview": "dataset: meta-mini-imagenet\ntrain:\n split: meta-train\n image_size: 84\n normalization: False\n transform: null\n n_bat"
},
{
"path": "configs/convnet4/mini-imagenet/5_way_5_shot/train_template.yaml",
"chars": 768,
"preview": "dataset: meta-mini-imagenet\ntrain:\n split: meta-train\n image_size: 84\n normalization: True\n transform: flip\n n_batc"
},
{
"path": "configs/convnet4/tiered-imagenet/5_way_1_shot/test_reproduce.yaml",
"chars": 389,
"preview": "dataset: meta-tiered-imagenet\ntest:\n split: meta-test\n image_size: 84\n normalization: False\n transform: null\n n_bat"
},
{
"path": "configs/convnet4/tiered-imagenet/5_way_1_shot/test_template.yaml",
"chars": 407,
"preview": "dataset: meta-tiered-imagenet\ntest:\n split: meta-test\n image_size: 84\n normalization: True\n transform: flip\n n_batc"
},
{
"path": "configs/convnet4/tiered-imagenet/5_way_1_shot/train_reproduce.yaml",
"chars": 619,
"preview": "dataset: meta-tiered-imagenet\ntrain:\n split: meta-train\n image_size: 84\n normalization: False\n transform: null\n n_b"
},
{
"path": "configs/convnet4/tiered-imagenet/5_way_1_shot/train_template.yaml",
"chars": 775,
"preview": "dataset: meta-tiered-imagenet\ntrain:\n split: meta-train\n image_size: 84\n normalization: True\n transform: flip\n n_ba"
},
{
"path": "configs/convnet4/tiered-imagenet/5_way_5_shot/test_reproduce.yaml",
"chars": 389,
"preview": "dataset: meta-tiered-imagenet\ntest:\n split: meta-test\n image_size: 84\n normalization: False\n transform: null\n n_bat"
},
{
"path": "configs/convnet4/tiered-imagenet/5_way_5_shot/test_template.yaml",
"chars": 407,
"preview": "dataset: meta-tiered-imagenet\ntest:\n split: meta-test\n image_size: 84\n normalization: True\n transform: flip\n n_batc"
},
{
"path": "configs/convnet4/tiered-imagenet/5_way_5_shot/train_reproduce.yaml",
"chars": 619,
"preview": "dataset: meta-tiered-imagenet\ntrain:\n split: meta-train\n image_size: 84\n normalization: False\n transform: null\n n_b"
},
{
"path": "configs/convnet4/tiered-imagenet/5_way_5_shot/train_template.yaml",
"chars": 775,
"preview": "dataset: meta-tiered-imagenet\ntrain:\n split: meta-train\n image_size: 84\n normalization: True\n transform: flip\n n_ba"
},
{
"path": "datasets/__init__.py",
"chars": 188,
"preview": "from .datasets import make, collate_fn\nfrom . import mini_imagenet\nfrom . import tiered_imagenet\nfrom . import cifar100\n"
},
{
"path": "datasets/cifar100.py",
"chars": 4525,
"preview": "import os\nimport pickle\n\nimport torch\nfrom torch.utils.data import Dataset\nimport numpy as np\nfrom PIL import Image\n\nfro"
},
{
"path": "datasets/cub200.py",
"chars": 4143,
"preview": "import os\n\nimport torch\nfrom torch.utils.data import Dataset\nimport numpy as np\nfrom PIL import Image\n\nfrom .datasets im"
},
{
"path": "datasets/datasets.py",
"chars": 912,
"preview": "import os\n\nimport torch\n\n\nDEFAULT_ROOT = './materials'\ndatasets = {}\n\ndef register(name):\n def decorator(cls):\n data"
},
{
"path": "datasets/inatural.py",
"chars": 4111,
"preview": "import os\n\nimport torch\nfrom torch.utils.data import Dataset\nimport numpy as np\nfrom PIL import Image\n\nfrom .datasets im"
},
{
"path": "datasets/mini_imagenet.py",
"chars": 4202,
"preview": "import os\nimport pickle\n\nimport torch\nfrom torch.utils.data import Dataset\nimport numpy as np\nfrom PIL import Image\n\nfro"
},
{
"path": "datasets/tiered_imagenet.py",
"chars": 4160,
"preview": "import os\nimport pickle\n\nimport torch\nfrom torch.utils.data import Dataset\nimport numpy as np\nfrom PIL import Image\n\nfro"
},
{
"path": "datasets/transforms.py",
"chars": 1630,
"preview": "import torchvision.transforms as transforms\n\n\n__all__ = ['get_transform']\n\n\ndef get_transform(name, image_size, norm_par"
},
{
"path": "models/__init__.py",
"chars": 45,
"preview": "from .maml import make\nfrom .maml import load"
},
{
"path": "models/classifiers/__init__.py",
"chars": 58,
"preview": "from .classifiers import make, load\nfrom . import logistic"
},
{
"path": "models/classifiers/classifiers.py",
"chars": 465,
"preview": "import torch\n\n\n__all__ = ['make', 'load']\n\n\nmodels = {}\n\ndef register(name):\n def decorator(cls):\n models[name] = cl"
},
{
"path": "models/classifiers/logistic.py",
"chars": 811,
"preview": "import torch\nimport torch.nn as nn\n\nfrom .classifiers import register\nfrom ..modules import *\n\n\n__all__ = ['LogisticClas"
},
{
"path": "models/encoders/__init__.py",
"chars": 101,
"preview": "from .encoders import make, load\nfrom . import convnet4\nfrom . import resnet12\nfrom . import resnet18"
},
{
"path": "models/encoders/convnet4.py",
"chars": 2013,
"preview": "from collections import OrderedDict\n\nimport torch.nn as nn\n\nfrom .encoders import register\nfrom ..modules import *\n\n\n__a"
},
{
"path": "models/encoders/encoders.py",
"chars": 453,
"preview": "import torch\n\n\nmodels = {}\n\ndef register(name):\n def decorator(cls):\n models[name] = cls\n return cls\n return dec"
},
{
"path": "models/encoders/resnet12.py",
"chars": 3381,
"preview": "from collections import OrderedDict\n\nimport torch.nn as nn\n\nfrom .encoders import register\nfrom ..modules import *\n\n\n__a"
},
{
"path": "models/encoders/resnet18.py",
"chars": 3509,
"preview": "from collections import OrderedDict\n\nimport torch.nn as nn\n\nfrom .encoders import register\nfrom ..modules import *\n\n\n__a"
},
{
"path": "models/maml.py",
"chars": 8784,
"preview": "from collections import OrderedDict\n\nimport torch\nimport torch.nn.functional as F\nimport torch.autograd as autograd\nimpo"
},
{
"path": "models/modules.py",
"chars": 7816,
"preview": "import re\nfrom collections import OrderedDict\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\n__al"
},
{
"path": "test.py",
"chars": 2732,
"preview": "import argparse\nimport random\n\nimport yaml\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom tqdm import tqdm\nf"
},
{
"path": "train.py",
"chars": 8695,
"preview": "import argparse\nimport os\nimport random\nfrom collections import OrderedDict\n\nimport yaml\nimport torch\nimport torch.nn as"
},
{
"path": "utils/__init__.py",
"chars": 2665,
"preview": "import os\nimport shutil\nimport time\n\nimport numpy as np\nimport scipy.stats as stats\n\n\n_log_path = None\n\ndef set_log_path"
},
{
"path": "utils/optimizers.py",
"chars": 1819,
"preview": "from torch.optim import SGD, RMSprop, Adam\nfrom torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR\n\n\ndef mak"
}
]
About this extraction
This page contains the full source code of the fmu2/PyTorch-MAML GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 40 files (83.4 KB), approximately 24.4k tokens, and a symbol index with 146 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.