[
  {
    "path": "README.md",
    "content": "# MAML in PyTorch - Re-implementation and Beyond\n\nA PyTorch implementation of [Model Agnostic Meta-Learning (MAML)](https://arxiv.org/abs/1703.03400). We faithfully reproduce [the official Tensorflow implementation](https://github.com/cbfinn/maml) while incorporating a number of additional features that may ease further study of this very high-profile meta-learning framework. \n\n## Overview\n\nThis repository contains code for training and evaluating MAML on the mini-ImageNet and tiered-ImageNet datasets most commonly used for few-shot image classification. To the best of our knowledge, this is the only PyTorch implementation of MAML to date that **fully reproduces the results in the original paper** without applying tricks such as data augmentation, evaluation on multiple crops, and ensemble of multiple models. Other existing PyTorch implementations typically see a ~3% gap in accuracy for the 5-way-1-shot and 5-way-5-shot classification tasks on mini-ImageNet.\n\nBeyond reproducing the results, our implementation comes with a few extra bits that we believe can be helpful for further development of the framework. We highlight the improvements we have built into our code, and discuss our observations that warrent some attention.\n\n## Implementation Highlights\n\n- **Batch normalization with per-episode running statistics.** Our implementation provides flexibility of tracking global and/or per-episode running statistics, hence supporting both transductive and inductive inference.\n\n- **Better data pre-processing.** The official implementation does not normalize and augment data. We support data normalization and a variety of data augmentation techniques. We also implement data batching and support/query-set splitting more efficiently.\n\n- **More datasets.** We support mini-ImageNet, tiered-ImageNet and more.\n\n- **More options for outer-loop optimization.** We support mutiple optimizers and learning-rate schedulers for the outer-loop optimization.\n\n- **More powerful inner-loop optimization.** The official implementation uses vanilla gradient descent in the inner loop. We support momentum and weight decay.\n\n- **More options for encoder architecture.** We support the standard four-layer ConvNet as well as ResNet-12 and ResNet-18 as the encoder.\n\n- **Easy layer freezing.** We provide an interface for layer freezing experiments. One may freeze an arbitrary set of layers or blocks during inner-loop adaptation.\n\n- **Meta-learning with zero-initialized classifier head.** The official implementation learns a meta-initialization for both the encoder and the classifier head. This prevents one from varying the number of categories at training or test time. With our implementation, one may opt to learn a meta-initialization for the encoder while initializing the classifier head at zero.\n\n- **Distributed training and gradient checkpointing.** MAML is very memory-intensive because it buffers all tensors generated throughout the inner-loop adaptation steps. Gradient checkpointing trades compute for memory, effectively bringing the memory cost from O(N) down to O(1), where N is the number of inner-loop steps. In our experiments, gradient checkpointing saved up to 80% of GPU memory at the cost of running the forward pass more than once (a moderate 20% increase in running time).\n\n## Transductive or Inductive?\n\nThe official implementation assumes transductive learning. The batch normalization layers do not track running statistics at training time, and they use mini-batch statistics at test time. The implicit assumption here is that test data come in mini-batches and are perhaps balanced across categories. This is a very restrictive assumption and does not land MAML directly comparable with the vast majority of meta-learning and few-shot learning methods. Unfortunately, this is not immediately obvious from the paper, and our findings suggest that the performance of MAML is hugely overestimated.\n\n- **Accuracy is very sensitive to the size of query set in the transductive setting.** For example, the result for 5-way-1-shot classification on miniImageNet from the paper (48.70%) was obtained on five queries, one per category. We found that the accuracy dropped by ~1.5% given five queries per category, and by ~2.5% given 15 queries per category.\n\n- The paper reports mean accuracy over 600 independently sampled tasks, or trials. We found that **600 trials, again in the transductive setting, are insufficient for an unbiased estimate of model performance**. The mean accuracy from 6,000 trials is more stable, and is always ~2% lower than that from the first 600 trials. We conjecture that the distribution of per-trial accuracy is highly skewed towards the high end.\n\n- We found that **MAML performs a lot worse in the inductive setting**. Given the same model configuration, inductive accuracy is always much lower (~4%) than the *corrected* transductive accuracy, which is already a few percentage points behind the reported number.\n\nHence, one should be extremely cautious when comparing MAML with its competitors as is evident from the discussion above.\n\n## FOMAML and layer freezing\n\nUnfortunately, some insights discussed in the original paper and its follow-up works do not appear to hold in the inductive setting. \n\n- FOMAML (i.e. the first-order approximation of MAML) performs as well as MAML in transductive learning, but fails completely in the inductive setting. \n\n- Completely freezing the encoder during inner-loop adaption as was done in [this work](https://arxiv.org/abs/1909.09157) results in dramatic decrease in accuracy.\n\n## BatchNorm and TaskNorm\n\n[A recent work](https://arxiv.org/abs/2003.03284) proposes TaskNorm, a test-time enhancement of batch normalization, noting that the small batch sizes during training may leave batch normalization less effective. We did not have much success with this method. We observed marginal improvement most of the time, and found that it hurts performance occationally. That said, we do believe that batch normalization is hard to deal with in MAML. TaskNorm attempts to attack the problem of small batch sizes, which we conjecture is just one among the three main causes (i.e., extremely scarse training data, extremely small batch sizes, and extremely small number of inner-loop updates) of the ineffectiveness of batch normalization in MAML.\n\n## Quick Start\n\n### 0. Preliminaries\n\n**Environment**\n\n- Python 3.6.8 (or any Python 3 distribution)\n- PyTorch 1.3.1 (or any PyTorch > 1.0)\n- tensorboardX\n\n**Datasets**\n\nPlease follow the download links [here](https://github.com/cyvius96/few-shot-meta-baseline). Please modify the file names accordingly so that they can be recognized by the data loaders.\n\n**Configurations**\n\nTemplate configuration files as well as those for reproducing the results in the original paper can be found in `configs/`. The hyperparameters are self-explanatory.\n\n### 1. Training MAML\nHere is the command for single-GPU training of MAML with ConvNet4 backbone for 5-way-1-shot classification on mini-ImageNet to reproduce the result in the original paper.\n```\npython train.py --config=configs/convnet4/mini-imagenet/train_reproduce.yaml\n```\n\nUse `-gpu` to specify available GPUs for multi-GPU training. For example,\n```\npython train.py --config=configs/convnet4/mini-imagenet/train_reproduce.yaml --gpu=0,1\n```\n\nAdd `-efficient` to enable gradient checkpointing. This aggressively saves GPU memory while slightly increases running time.\n```\npython train.py --config=configs/convnet4/mini-imagenet/train_reproduce.yaml --efficient\n```\n\nUse `-tag` to customize the name of the directory where the checkpoints and log files are saved.\n\n### 2. Testing MAML\nHere is how one would test MAML for 5-way-1-shot classification on mini-ImageNet to reproduce the result in the original paper. Please confirm the loading path first.\n```\npython test.py --config=configs/convnet4/mini-imagenet/test_reproduce.yaml\n```\n\nThe `-gpu` and `-efficient` tags function similarly as in training.\n\n## Contact\n[Xinchan Zhu](https://www.linkedin.com/in/xinchan-zhu-66673b106) (zhuxinchan@gmail.com)\n\n## Cite our Repository\n```\n@misc{pytorch_maml,\n  title={maml in pytorch - re-implementation and beyond},\n  author={Zhu, Xinchan},\n  howpublished={\\url{https://github.com/shirleyzhu233/PyTorch-MAML}},\n  year={2020}\n}\n```\n\n## Related Code Repositories\n\nOur implementation is inspired by the following repositories.\n* maml (the official implementation) <https://github.com/cbfinn/maml>\n* MAML-Pytorch <https://github.com/dragen1860/MAML-Pytorch>\n* HowToTrainYourMAMLPytorch <https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch>\n* memory-efficient-maml <https://github.com/dbaranchuk/memory-efficient-maml>\n\n## References\n```\n@inproceedings{finn2017model,\n  title={Model-agnostic meta-learning for fast adaptation of deep networks},\n  author={Finn, Chelsea and Abbeel, Pieter and Levine, Sergey},\n  booktitle={International Conference on Machine Learning (ICML)},\n  year={2017}\n}\n\n@inproceedings{raghu2019rapid,\n  title={Rapid learning or feature reuse? towards understanding the effectiveness of maml},\n  author={Raghu, Aniruddh and Raghu, Maithra and Bengio, Samy and Vinyals, Oriol},\n  booktitle={International Conference on Learning Representations (ICLR)},\n  year={2019}\n}\n\n@article{Bronskill2020tasknorm,\n  title={Tasknorm: rethinking batch normalization for meta-learning},\n  author={Bronskill, John and Gordon, Jonathan and Requeima, James and Nowozin, Sebastian and Turner, Richard E.},\n  journal={arXiv preprint arXiv:2003.03284},\n  year={2020}\n}\n```\n"
  },
  {
    "path": "configs/convnet4/mini-imagenet/5_way_1_shot/test_reproduce.yaml",
    "content": "dataset: meta-mini-imagenet\ntest:\n  split: meta-test\n  image_size: 84\n  normalization: False\n  transform: null\n  n_batch: 150\n  n_episode: 4\n  n_way: 5\n  n_shot: 1\n  n_query: 1\n\nload: ./save/convnet4_mini-imagenet_5_way_1_shot/max-va.pth\n\ninner_args:\n  n_step: 10\n  encoder_lr: 0.01\n  classifier_lr: 0.01\n  first_order: False  # set to True for FOMAML\n  frozen:\n    - bn\n\nepoch: 1"
  },
  {
    "path": "configs/convnet4/mini-imagenet/5_way_1_shot/test_template.yaml",
    "content": "dataset: meta-mini-imagenet\ntest:\n  split: meta-test\n  image_size: 84\n  normalization: True\n  transform: flip\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 1\n  n_query: 15\n\nload: ./save/convnet4_mini-imagenet_5_way_1_shot/max-va.pth\n\ninner_args:\n  reset_classifier: True\n  n_step: 5\n  encoder_lr: 0.01\n  classifier_lr: 0.01\n  momentum: 0.9\n  weight_decay: 5.e-4\n  first_order: False\n\nepoch: 10"
  },
  {
    "path": "configs/convnet4/mini-imagenet/5_way_1_shot/train_reproduce.yaml",
    "content": "dataset: meta-mini-imagenet\ntrain:\n  split: meta-train\n  image_size: 84\n  normalization: False\n  transform: null\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 1\n  n_query: 15\nval:\n  split: meta-val\n  image_size: 84\n  normalization: False\n  transform: null\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 1\n  n_query: 15\n\nencoder: convnet4\nencoder_args:\n  bn_args:\n    track_running_stats: False\nclassifier: logistic\n\ninner_args:\n  n_step: 5\n  encoder_lr: 0.01\n  classifier_lr: 0.01\n  first_order: False  # set to True for FOMAML\n  frozen:\n    - bn\n\noptimizer: adam\noptimizer_args:\n  lr: 0.001\n\nepoch: 300"
  },
  {
    "path": "configs/convnet4/mini-imagenet/5_way_1_shot/train_template.yaml",
    "content": "dataset: meta-mini-imagenet\ntrain:\n  split: meta-train\n  image_size: 84\n  normalization: True\n  transform: flip\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 1\n  n_query: 15\nval:\n  split: meta-val\n  image_size: 84\n  normalization: True\n  transform: flip\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 1\n  n_query: 15\n\nencoder: convnet4\nencoder_args:\n  bn_args:\n    track_running_stats: True\n    episodic:\n      - conv1\n      - conv2\n      - conv3\n      - conv4\nclassifier: logistic\n\ninner_args:\n  reset_classifier: True\n  n_step: 5\n  encoder_lr: 0.01\n  classifier_lr: 0.01\n  momentum: 0.9\n  weight_decay: 5.e-4\n  first_order: False\n\noptimizer: sgd\noptimizer_args:\n  lr: 0.01\n  weight_decay: 5.e-4\n  schedule: step\n  milestones:\n    - 120\n    - 140\n\nepoch: 150"
  },
  {
    "path": "configs/convnet4/mini-imagenet/5_way_5_shot/test_reproduce.yaml",
    "content": "dataset: meta-mini-imagenet\ntest:\n  split: meta-test\n  image_size: 84\n  normalization: False\n  transform: null\n  n_batch: 150\n  n_episode: 4\n  n_way: 5\n  n_shot: 5\n  n_query: 5\n\nload: ./save/convnet4_mini-imagenet_5_way_5_shot/max-va.pth\n\ninner_args:\n  n_step: 10\n  encoder_lr: 0.01\n  classifier_lr: 0.01\n  first_order: False\n  frozen:\n    - bn\n\nepoch: 1"
  },
  {
    "path": "configs/convnet4/mini-imagenet/5_way_5_shot/test_template.yaml",
    "content": "dataset: meta-mini-imagenet\ntest:\n  split: meta-test\n  image_size: 84\n  normalization: True\n  transform: flip\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 5\n  n_query: 15\n\nload: ./save/convnet4_mini-imagenet_5_way_5_shot/max-va.pth\n\ninner_args:\n  reset_classifier: True\n  n_step: 5\n  encoder_lr: 0.01\n  classifier_lr: 0.01\n  momentum: 0.9\n  weight_decay: 5.e-4\n  first_order: False\n\nepoch: 10"
  },
  {
    "path": "configs/convnet4/mini-imagenet/5_way_5_shot/train_reproduce.yaml",
    "content": "dataset: meta-mini-imagenet\ntrain:\n  split: meta-train\n  image_size: 84\n  normalization: False\n  transform: null\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 5\n  n_query: 15\nval:\n  split: meta-val\n  image_size: 84\n  normalization: False\n  transform: null\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 5\n  n_query: 15\n\nencoder: convnet4\nencoder_args:\n  bn_args:\n    track_running_stats: False\nclassifier: logistic\n\ninner_args:\n  n_step: 5\n  encoder_lr: 0.01\n  classifier_lr: 0.01\n  first_order: False\n  frozen:\n    - bn\n\noptimizer: adam\noptimizer_args:\n  lr: 0.001\n\nepoch: 300"
  },
  {
    "path": "configs/convnet4/mini-imagenet/5_way_5_shot/train_template.yaml",
    "content": "dataset: meta-mini-imagenet\ntrain:\n  split: meta-train\n  image_size: 84\n  normalization: True\n  transform: flip\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 5\n  n_query: 15\nval:\n  split: meta-val\n  image_size: 84\n  normalization: True\n  transform: flip\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 5\n  n_query: 15\n\nencoder: convnet4\nencoder_args:\n  bn_args:\n    track_running_stats: True\n    episodic:\n      - conv1\n      - conv2\n      - conv3\n      - conv4\nclassifier: logistic\n\ninner_args:\n  reset_classifier: True\n  n_step: 5\n  encoder_lr: 0.01\n  classifier_lr: 0.01\n  momentum: 0.9\n  weight_decay: 5.e-4\n  first_order: False\n\noptimizer: sgd\noptimizer_args:\n  lr: 0.01\n  weight_decay: 5.e-4\n  schedule: step\n  milestones:\n    - 120\n    - 140\n\nepoch: 150"
  },
  {
    "path": "configs/convnet4/tiered-imagenet/5_way_1_shot/test_reproduce.yaml",
    "content": "dataset: meta-tiered-imagenet\ntest:\n  split: meta-test\n  image_size: 84\n  normalization: False\n  transform: null\n  n_batch: 150\n  n_episode: 4\n  n_way: 5\n  n_shot: 1\n  n_query: 1\n\nload: ./save/wide-convnet4_tiered-imagenet_5_way_1_shot/max-va.pth\n\ninner_args:\n  n_step: 10\n  encoder_lr: 0.01\n  classifier_lr: 0.01\n  first_order: False  # set to True for FOMAML\n  frozen:\n    - bn\n\nepoch: 1"
  },
  {
    "path": "configs/convnet4/tiered-imagenet/5_way_1_shot/test_template.yaml",
    "content": "dataset: meta-tiered-imagenet\ntest:\n  split: meta-test\n  image_size: 84\n  normalization: True\n  transform: flip\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 1\n  n_query: 15\n\nload: ./save/wide-convnet4_tiered-imagenet_5_way_1_shot/max-va.pth\n\ninner_args:\n  reset_classifier: True\n  n_step: 5\n  encoder_lr: 0.01\n  classifier_lr: 0.01\n  momentum: 0.9\n  weight_decay: 5.e-4\n  first_order: False\n\nepoch: 10"
  },
  {
    "path": "configs/convnet4/tiered-imagenet/5_way_1_shot/train_reproduce.yaml",
    "content": "dataset: meta-tiered-imagenet\ntrain:\n  split: meta-train\n  image_size: 84\n  normalization: False\n  transform: null\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 1\n  n_query: 15\nval:\n  split: meta-val\n  image_size: 84\n  normalization: False\n  transform: null\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 1\n  n_query: 15\n\nencoder: wide-convnet4\nencoder_args:\n  bn_args:\n    track_running_stats: False\nclassifier: logistic\n\ninner_args:\n  n_step: 5\n  encoder_lr: 0.01\n  classifier_lr: 0.01\n  first_order: False  # set to True for FOMAML\n  frozen:\n    - bn\n\noptimizer: adam\noptimizer_args:\n  lr: 0.001\n\nepoch: 300"
  },
  {
    "path": "configs/convnet4/tiered-imagenet/5_way_1_shot/train_template.yaml",
    "content": "dataset: meta-tiered-imagenet\ntrain:\n  split: meta-train\n  image_size: 84\n  normalization: True\n  transform: flip\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 1\n  n_query: 15\nval:\n  split: meta-val\n  image_size: 84\n  normalization: True\n  transform: flip\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 1\n  n_query: 15\n\nencoder: wide-convnet4\nencoder_args:\n  bn_args:\n    track_running_stats: True\n    episodic:\n      - conv1\n      - conv2\n      - conv3\n      - conv4\nclassifier: logistic\n\ninner_args:\n  reset_classifier: True\n  n_step: 5\n  encoder_lr: 0.01\n  classifier_lr: 0.01\n  momentum: 0.9\n  weight_decay: 5.e-4\n  first_order: False\n\noptimizer: sgd\noptimizer_args:\n  lr: 0.01\n  weight_decay: 5.e-4\n  schedule: step\n  milestones:\n    - 120\n    - 140\n\nepoch: 150"
  },
  {
    "path": "configs/convnet4/tiered-imagenet/5_way_5_shot/test_reproduce.yaml",
    "content": "dataset: meta-tiered-imagenet\ntest:\n  split: meta-test\n  image_size: 84\n  normalization: False\n  transform: null\n  n_batch: 150\n  n_episode: 4\n  n_way: 5\n  n_shot: 5\n  n_query: 5\n\nload: ./save/wide-convnet4_tiered-imagenet_5_way_5_shot/max-va.pth\n\ninner_args:\n  n_step: 10\n  encoder_lr: 0.01\n  classifier_lr: 0.01\n  first_order: False  # set to True for FOMAML\n  frozen:\n    - bn\n\nepoch: 1"
  },
  {
    "path": "configs/convnet4/tiered-imagenet/5_way_5_shot/test_template.yaml",
    "content": "dataset: meta-tiered-imagenet\ntest:\n  split: meta-test\n  image_size: 84\n  normalization: True\n  transform: flip\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 5\n  n_query: 15\n\nload: ./save/wide-convnet4_tiered-imagenet_5_way_5_shot/max-va.pth\n\ninner_args:\n  reset_classifier: True\n  n_step: 5\n  encoder_lr: 0.01\n  classifier_lr: 0.01\n  momentum: 0.9\n  weight_decay: 5.e-4\n  first_order: False\n\nepoch: 10"
  },
  {
    "path": "configs/convnet4/tiered-imagenet/5_way_5_shot/train_reproduce.yaml",
    "content": "dataset: meta-tiered-imagenet\ntrain:\n  split: meta-train\n  image_size: 84\n  normalization: False\n  transform: null\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 5\n  n_query: 15\nval:\n  split: meta-val\n  image_size: 84\n  normalization: False\n  transform: null\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 5\n  n_query: 15\n\nencoder: wide-convnet4\nencoder_args:\n  bn_args:\n    track_running_stats: False\nclassifier: logistic\n\ninner_args:\n  n_step: 5\n  encoder_lr: 0.01\n  classifier_lr: 0.01\n  first_order: False  # set to True for FOMAML\n  frozen:\n    - bn\n\noptimizer: adam\noptimizer_args:\n  lr: 0.001\n\nepoch: 300"
  },
  {
    "path": "configs/convnet4/tiered-imagenet/5_way_5_shot/train_template.yaml",
    "content": "dataset: meta-tiered-imagenet\ntrain:\n  split: meta-train\n  image_size: 84\n  normalization: True\n  transform: flip\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 5\n  n_query: 15\nval:\n  split: meta-val\n  image_size: 84\n  normalization: True\n  transform: flip\n  n_batch: 200\n  n_episode: 4\n  n_way: 5\n  n_shot: 5\n  n_query: 15\n\nencoder: wide-convnet4\nencoder_args:\n  bn_args:\n    track_running_stats: True\n    episodic:\n      - conv1\n      - conv2\n      - conv3\n      - conv4\nclassifier: logistic\n\ninner_args:\n  reset_classifier: True\n  n_step: 5\n  encoder_lr: 0.01\n  classifier_lr: 0.01\n  momentum: 0.9\n  weight_decay: 5.e-4\n  first_order: False\n\noptimizer: sgd\noptimizer_args:\n  lr: 0.01\n  weight_decay: 5.e-4\n  schedule: step\n  milestones:\n    - 120\n    - 140\n\nepoch: 150"
  },
  {
    "path": "datasets/__init__.py",
    "content": "from .datasets import make, collate_fn\nfrom . import mini_imagenet\nfrom . import tiered_imagenet\nfrom . import cifar100\nfrom . import cub200\nfrom . import inatural\nfrom . import transforms"
  },
  {
    "path": "datasets/cifar100.py",
    "content": "import os\nimport pickle\n\nimport torch\nfrom torch.utils.data import Dataset\nimport numpy as np\nfrom PIL import Image\n\nfrom .datasets import register\nfrom .transforms import get_transform\n\n\nclass Cifar100(Dataset):\n  def __init__(self, root_path, split='train', image_size=32, \n               normalization=True, transform=None):\n    super(Cifar100, self).__init__()\n    split_dict = {'train': 'train',             # standard train\n                  'trainval': 'trainval',       # standard train + val\n                  'meta-train': 'train',        # meta-train\n                  'meta-val': 'val',            # meta-val\n                  'meta-trainval': 'trainval',  # meta-train + meta-val\n                  'meta-test': 'test',          # meta-test\n                 }\n    split_tag = split_dict[split]\n\n    split_file = os.path.join(root_path, split_tag + '.pickle')\n    assert os.path.isfile(split_file)\n    with open(split_file, 'rb') as f:\n      pack = pickle.load(f, encoding='latin1')\n    data, label = pack['data'], pack['labels']\n\n    data = [Image.fromarray(x) for x in data]\n    label = np.array(label)\n    label_key = sorted(np.unique(label))\n    label_map = dict(zip(label_key, range(len(label_key))))\n    new_label = np.array([label_map[x] for x in label])\n    \n    self.root_path = root_path\n    self.split_tag = split_tag\n    self.image_size = image_size\n    \n    self.data = data\n    self.label = new_label\n    self.n_classes = len(label_key)\n\n    if normalization:\n      self.norm_params = {'mean': [0.5071, 0.4867, 0.4408],\n                          'std':  [0.2675, 0.2565, 0.2761]}\n    else:\n      self.norm_params = {'mean': [0., 0., 0.],\n                          'std':  [1., 1., 1.]}\n\n    self.transform = get_transform(transform, image_size, self.norm_params)\n    \n    def convert_raw(x):\n      mean = torch.tensor(self.norm_params['mean']).view(3, 1, 1).type_as(x)\n      std = torch.tensor(self.norm_params['std']).view(3, 1, 1).type_as(x)\n      return x * std + mean\n      \n    self.convert_raw = convert_raw\n\n  def __len__(self):\n    return len(self.data)\n\n  def __getitem__(self, index):\n    image = self.transform(self.data[index])\n    label = self.label[index]\n    return image, label\n\n\nclass MetaCifar100(Cifar100):\n  def __init__(self, root_path, split='train', image_size=32, \n               normalization=True, transform=None, val_transform=None,\n               n_batch=200, n_episode=4, n_way=5, n_shot=1, n_query=15):\n    super(MetaCifar100, self).__init__(root_path, split, image_size, \n                                       normalization, transform)\n    self.n_batch = n_batch\n    self.n_episode = n_episode\n    self.n_way = n_way\n    self.n_shot = n_shot\n    self.n_query = n_query\n\n    self.catlocs = tuple()\n    for cat in range(self.n_classes):\n      self.catlocs += (np.argwhere(self.label == cat).reshape(-1),)\n\n    self.val_transform = get_transform(\n      val_transform, image_size, self.norm_params)\n\n  def __len__(self):\n    return self.n_batch * self.n_episode\n\n  def __getitem__(self, index):\n    shot, query = [], []\n    cats = np.random.choice(self.n_classes, self.n_way, replace=False)\n    for c in cats:\n      c_shot, c_query = [], []\n      idx_list = np.random.choice(\n        self.catlocs[c], self.n_shot + self.n_query, replace=False)\n      shot_idx, query_idx = idx_list[:self.n_shot], idx_list[-self.n_query:]\n      for idx in shot_idx:\n        c_shot.append(self.transform(self.data[idx]))\n      for idx in query_idx:\n        c_query.append(self.val_transform(self.data[idx]))\n      shot.append(torch.stack(c_shot))\n      query.append(torch.stack(c_query))\n    \n    shot = torch.cat(shot, dim=0)             # [n_way * n_shot, C, H, W]\n    query = torch.cat(query, dim=0)           # [n_way * n_query, C, H, W]\n    cls = torch.arange(self.n_way)[:, None]\n    shot_labels = cls.repeat(1, self.n_shot).flatten()    # [n_way * n_shot]\n    query_labels = cls.repeat(1, self.n_query).flatten()  # [n_way * n_query]\n    \n    return shot, query, shot_labels, query_labels\n\n\n@register('cifar-fs')\nclass CifarFS(Cifar100):\n  def __init__(self, *args):\n    super(CifarFS, self).__init__(*args)\n\n\n@register('meta-cifar-fs')\nclass MetaCifarFS(MetaCifar100):\n  def __init__(self, *args):\n    super(MetaCifarFS, self).__init__(*args)\n\n\n@register('fc100')\nclass FC100(Cifar100):\n  def __init__(self, *args):\n    super(FC100, self).__init__(*args)\n\n\n@register('meta-fc100')\nclass MetaFC100(MetaCifar100):\n  def __init__(self, *args):\n    super(MetaFC100, self).__init__(*args)"
  },
  {
    "path": "datasets/cub200.py",
    "content": "import os\n\nimport torch\nfrom torch.utils.data import Dataset\nimport numpy as np\nfrom PIL import Image\n\nfrom .datasets import register\nfrom .transforms import get_transform\n\n\n@register('cub200')\nclass CUB200(Dataset):\n  def __init__(self, root_path, split='train', image_size=84, \n               normalization=True, transform=None):\n    super(CUB200, self).__init__()\n    split_dict = {'train': 'train',      # standard train\n                  'meta-train': 'train', # meta-train\n                  'meta-val': 'val',     # meta-val\n                  'meta-test': 'test',   # meta-test\n                 }\n    split_tag = split_dict[split]\n\n    split_file = os.path.join(root_path, 'fs-splits', split_tag + '.csv')\n    assert os.path.isfile(split_file)\n    with open(split_file, 'r') as f:\n      pairs = [x.strip().split(',') \n                for x in f.readlines() if x.strip() != '']\n\n    data, label = [x[0] for x in pairs], [int(x[1]) for x in pairs]\n    label = np.array(label)\n    label_key = sorted(np.unique(label))\n    label_map = dict(zip(label_key, range(len(label_key))))\n    new_label = np.array([label_map[x] for x in label])\n\n    self.root_path = root_path\n    self.split_tag = split_tag\n    self.image_size = image_size\n\n    self.data = data\n    self.label = new_label\n    self.n_classes = len(label_key)\n\n    if normalization:\n      self.norm_params = {'mean': [0.485, 0.456, 0.406],\n                          'std':  [0.229, 0.224, 0.225]}   # ImageNet statistics\n    else:\n      self.norm_params = {'mean': [0., 0., 0.],\n                          'std':  [1., 1., 1.]}\n\n    self.transform = get_transform(transform, image_size, self.norm_params)\n\n    def convert_raw(x):\n      mean = torch.tensor(self.norm_params['mean']).view(3, 1, 1).type_as(x)\n      std = torch.tensor(self.norm_params['std']).view(3, 1, 1).type_as(x)\n      return x * std + mean\n    self.convert_raw = convert_raw\n\n  def _load_image(self, index):\n    image_path = os.path.join(self.root_path, 'images', self.data[index])\n    assert os.path.isfile(image_path)\n    image = Image.open(image_path).convert('RGB')\n    return image\n\n  def __len__(self):\n    return len(self.label)\n\n  def __getitem__(self, index):\n    image = self.transform(self._load_image(index))\n    label = self.label[index]\n    return image, label\n\n\n@register('meta-cub200')\nclass MetaCUB200(CUB200):\n  def __init__(self, root_path, split='train', image_size=84, \n               normalization=True, transform=None, val_transform=None,\n               n_batch=200, n_episode=4, n_way=5, n_shot=1, n_query=15):\n    super(MetaCUB200, self).__init__(root_path, split, image_size, \n                                     normalization, transform)\n    self.n_batch = n_batch\n    self.n_episode = n_episode\n    self.n_way = n_way\n    self.n_shot = n_shot\n    self.n_query = n_query\n\n    self.catlocs = tuple()\n    for cat in range(self.n_classes):\n      self.catlocs += (np.argwhere(self.label == cat).reshape(-1),)\n\n    self.val_transform = get_transform(\n      val_transform, image_size, self.norm_params)\n\n  def __len__(self):\n    return self.n_batch * self.n_episode\n\n  def __getitem__(self, index):\n    shot, query = [], []\n    cats = np.random.choice(self.n_classes, self.n_way, replace=False)\n    for c in cats:\n      c_shot, c_query = [], []\n      idx_list = np.random.choice(\n        self.catlocs[c], self.n_shot + self.n_query, replace=False)\n      shot_idx, query_idx = idx_list[:self.n_shot], idx_list[-self.n_query:]\n      for idx in shot_idx:\n        c_shot.append(self.transform(self._load_image(idx)))\n      for idx in query_idx:\n        c_query.append(self.val_transform(self._load_image(idx)))\n      shot.append(torch.stack(c_shot))\n      query.append(torch.stack(c_query))\n    \n    shot = torch.cat(shot, dim=0)             # [n_way * n_shot, C, H, W]\n    query = torch.cat(query, dim=0)           # [n_way * n_query, C, H, W]\n    cls = torch.arange(self.n_way)[:, None]\n    shot_labels = cls.repeat(1, self.n_shot).flatten()    # [n_way * n_shot]\n    query_labels = cls.repeat(1, self.n_query).flatten()  # [n_way * n_query]\n    \n    return shot, query, shot_labels, query_labels"
  },
  {
    "path": "datasets/datasets.py",
    "content": "import os\n\nimport torch\n\n\nDEFAULT_ROOT = './materials'\ndatasets = {}\n\ndef register(name):\n  def decorator(cls):\n    datasets[name] = cls\n    return cls\n  return decorator\n\n\ndef make(name, **kwargs):\n  if kwargs.get('root_path') is None:\n    kwargs['root_path'] = os.path.join(DEFAULT_ROOT, name.replace('meta-', ''))\n  dataset = datasets[name](**kwargs)\n  return dataset\n\n\ndef collate_fn(batch):\n  shot, query, shot_label, query_label = [], [], [], []\n  for s, q, sl, ql in batch:\n    shot.append(s)\n    query.append(q)\n    shot_label.append(sl)\n    query_label.append(ql)\n  \n  shot = torch.stack(shot)                # [n_ep, n_way * n_shot, C, H, W]\n  query = torch.stack(query)              # [n_ep, n_way * n_query, C, H, W]\n  shot_label = torch.stack(shot_label)    # [n_ep, n_way * n_shot]\n  query_label = torch.stack(query_label)  # [n_ep, n_way * n_query]\n  \n  return shot, query, shot_label, query_label"
  },
  {
    "path": "datasets/inatural.py",
    "content": "import os\n\nimport torch\nfrom torch.utils.data import Dataset\nimport numpy as np\nfrom PIL import Image\n\nfrom .datasets import register\nfrom .transforms import get_transform\n\n\n@register('inatural')\nclass INat2017(Dataset):\n  def __init__(self, root_path, split='train', image_size=84, \n               normalization=True, transform=None):\n    super(INat2017, self).__init__()\n    split_dict = {'train': 'train',      # standard train\n                  'meta-train': 'train', # meta-train\n                  'meta-test': 'test',   # meta-test\n                 }\n    split_tag = split_dict[split]\n\n    split_file = os.path.join(root_path, 'fs-splits', split_tag + '.csv')\n    assert os.path.isfile(split_file)\n    with open(split_file, 'r') as f:\n      pairs = [x.strip().split(',') \n                for x in f.readlines() if x.strip() != '']\n\n    data, label = [x[0] for x in pairs], [int(x[1]) for x in pairs]\n    label = np.array(label)\n    label_key = sorted(np.unique(label))\n    label_map = dict(zip(label_key, range(len(label_key))))\n    new_label = np.array([label_map[x] for x in label])\n\n    self.root_path = root_path\n    self.split_tag = split_tag\n    self.image_size = image_size\n\n    self.data = data\n    self.label = new_label\n    self.n_classes = len(label_key)\n\n    if normalization:\n      self.norm_params = {'mean': [0.4905, 0.4961, 0.4330],\n                          'std':  [0.1737, 0.1713, 0.1779]}\n    else:\n      self.norm_params = {'mean': [0., 0., 0.],\n                          'std':  [1., 1., 1.]}\n                     \n    self.transform = get_transform(transform, image_size, self.norm_params)\n\n    def convert_raw(x):\n      mean = torch.tensor(self.norm_params['mean']).view(3, 1, 1).type_as(x)\n      std = torch.tensor(self.norm_params['std']).view(3, 1, 1).type_as(x)\n      return x * std + mean\n\n    self.convert_raw = convert_raw\n\n  def _load_image(self, index):\n    image_path = os.path.join(self.root_path, 'images', self.data[index])\n    assert os.path.isfile(image_path)\n    image = Image.open(image_path).convert('RGB')\n    return image\n\n  def __len__(self):\n    return len(self.label)\n\n  def __getitem__(self, index):\n    image = self.transform(self._load_image(index))\n    label = self.label[index]\n    return image, label\n\n\n@register('meta-inatural')\nclass MetaINat2017(INat2017):\n  def __init__(self, root_path, split='train', image_size=84, \n               normalization=True, transform=None, val_transform=None,\n               n_batch=200, n_episode=4, n_way=5, n_shot=1, n_query=15):\n    super(MetaINat2017, self).__init__(root_path, split, image_size, \n                                       normalization, transform)\n    self.n_batch = n_batch\n    self.n_episode = n_episode\n    self.n_way = n_way\n    self.n_shot = n_shot\n    self.n_query = n_query\n\n    self.catlocs = tuple()\n    for cat in range(self.n_classes):\n      self.catlocs += (np.argwhere(self.label == cat).reshape(-1),)\n\n    self.val_transform = get_transform(\n      val_transform, image_size, self.norm_params)\n\n  def __len__(self):\n    return self.n_batch * self.n_episode\n\n  def __getitem__(self, index):\n    shot, query = [], []\n    cats = np.random.choice(self.n_classes, self.n_way, replace=False)\n    for c in cats:\n      c_shot, c_query = [], []\n      idx_list = np.random.choice(\n        self.catlocs[c], self.n_shot + self.n_query, replace=False)\n      shot_idx, query_idx = idx_list[:self.n_shot], idx_list[-self.n_query:]\n      for idx in shot_idx:\n        c_shot.append(self.transform(self._load_image(idx)))\n      for idx in query_idx:\n        c_query.append(self.val_transform(self._load_image(idx)))\n      shot.append(torch.stack(c_shot))\n      query.append(torch.stack(c_query))\n    \n    shot = torch.cat(shot, dim=0)             # [n_way * n_shot, C, H, W]\n    query = torch.cat(query, dim=0)           # [n_way * n_query, C, H, W]\n    cls = torch.arange(self.n_way)[:, None]\n    shot_labels = cls.repeat(1, self.n_shot).flatten()    # [n_way * n_shot]\n    query_labels = cls.repeat(1, self.n_query).flatten()  # [n_way * n_query]\n    \n    return shot, query, shot_labels, query_labels"
  },
  {
    "path": "datasets/mini_imagenet.py",
    "content": "import os\nimport pickle\n\nimport torch\nfrom torch.utils.data import Dataset\nimport numpy as np\nfrom PIL import Image\n\nfrom .datasets import register\nfrom .transforms import get_transform\n\n\n@register('mini-imagenet')\nclass MiniImageNet(Dataset):\n  def __init__(self, root_path, split='train', image_size=84, \n               normalization=True, transform=None):\n    super(MiniImageNet, self).__init__()\n    split_dict = {'train': 'train_phase_train',        # standard train\n                  'val': 'train_phase_val',            # standard val\n                  'trainval': 'train_phase_trainval',  # standard train and val\n                  'test': 'train_phase_test',          # standard test\n                  'meta-train': 'train_phase_train',   # meta-train\n                  'meta-val': 'val',                   # meta-val\n                  'meta-test': 'test',                 # meta-test\n                 }\n    split_tag = split_dict[split]\n\n    split_file = os.path.join(root_path, split_tag + '.pickle')\n    assert os.path.isfile(split_file)\n    with open(split_file, 'rb') as f:\n      pack = pickle.load(f, encoding='latin1')\n    data, label = pack['data'], pack['labels']\n\n    data = [Image.fromarray(x) for x in data]\n    label = np.array(label)\n    label_key = sorted(np.unique(label))\n    label_map = dict(zip(label_key, range(len(label_key))))\n    new_label = np.array([label_map[x] for x in label])\n    \n    self.root_path = root_path\n    self.split_tag = split_tag\n    self.image_size = image_size\n\n    self.data = data\n    self.label = new_label\n    self.n_classes = len(label_key)\n\n    if normalization:\n      self.norm_params = {'mean': [0.471, 0.450, 0.403],\n                          'std':  [0.278, 0.268, 0.284]}\n    else:\n      self.norm_params = {'mean': [0., 0., 0.],\n                          'std':  [1., 1., 1.]}\n\n    self.transform = get_transform(transform, image_size, self.norm_params)\n\n    def convert_raw(x):\n      mean = torch.tensor(self.norm_params['mean']).view(3, 1, 1).type_as(x)\n      std = torch.tensor(self.norm_params['std']).view(3, 1, 1).type_as(x)\n      return x * std + mean\n    \n    self.convert_raw = convert_raw\n\n  def __len__(self):\n    return len(self.data)\n\n  def __getitem__(self, index):\n    image = self.transform(self.data[index])\n    label = self.label[index]\n    return image, label\n\n\n@register('meta-mini-imagenet')\nclass MetaMiniImageNet(MiniImageNet):\n  def __init__(self, root_path, split='train', image_size=84, \n               normalization=True, transform=None, val_transform=None,\n               n_batch=200, n_episode=4, n_way=5, n_shot=1, n_query=15):\n    super(MetaMiniImageNet, self).__init__(root_path, split, image_size, \n                                           normalization, transform)\n    self.n_batch = n_batch\n    self.n_episode = n_episode\n    self.n_way = n_way\n    self.n_shot = n_shot\n    self.n_query = n_query\n\n    self.catlocs = tuple()\n    for cat in range(self.n_classes):\n      self.catlocs += (np.argwhere(self.label == cat).reshape(-1),)\n\n    self.val_transform = get_transform(\n      val_transform, image_size, self.norm_params)\n\n  def __len__(self):\n    return self.n_batch * self.n_episode\n\n  def __getitem__(self, index):\n    shot, query = [], []\n    cats = np.random.choice(self.n_classes, self.n_way, replace=False)\n    for c in cats:\n      c_shot, c_query = [], []\n      idx_list = np.random.choice(\n        self.catlocs[c], self.n_shot + self.n_query, replace=False)\n      shot_idx, query_idx = idx_list[:self.n_shot], idx_list[-self.n_query:]\n      for idx in shot_idx:\n        c_shot.append(self.transform(self.data[idx]))\n      for idx in query_idx:\n        c_query.append(self.val_transform(self.data[idx]))\n      shot.append(torch.stack(c_shot))\n      query.append(torch.stack(c_query))\n    \n    shot = torch.cat(shot, dim=0)             # [n_way * n_shot, C, H, W]\n    query = torch.cat(query, dim=0)           # [n_way * n_query, C, H, W]\n    cls = torch.arange(self.n_way)[:, None]\n    shot_labels = cls.repeat(1, self.n_shot).flatten()    # [n_way * n_shot]\n    query_labels = cls.repeat(1, self.n_query).flatten()  # [n_way * n_query]\n    \n    return shot, query, shot_labels, query_labels"
  },
  {
    "path": "datasets/tiered_imagenet.py",
    "content": "import os\nimport pickle\n\nimport torch\nfrom torch.utils.data import Dataset\nimport numpy as np\nfrom PIL import Image\n\nfrom .datasets import register\nfrom .transforms import get_transform\n\n\n@register('tiered-imagenet')\nclass TieredImageNet(Dataset):\n  def __init__(self, root_path, split='train', image_size=84, \n               normalization=True, transform=None):\n    super(TieredImageNet, self).__init__()\n    split_dict = {'train': 'train',         # standard train\n                  'val': 'train_phase_val', # standard val\n                  'meta-train': 'train',    # meta-train\n                  'meta-val': 'val',        # meta-val\n                  'meta-test': 'test',      # meta-test\n                 }\n    split_tag = split_dict[split]\n\n    split_file = os.path.join(root_path, split_tag + '_images.npz')\n    label_file = os.path.join(root_path, split_tag + '_labels.pkl')\n    assert os.path.isfile(split_file)\n    assert os.path.isfile(label_file)\n    data = np.load(split_file, allow_pickle=True)['images']\n    data = data[:, :, :, ::-1]\n    with open(label_file, 'rb') as f:\n      label = pickle.load(f)['labels']\n\n    data = [Image.fromarray(x) for x in data]\n    label = np.array(label)\n    label_key = sorted(np.unique(label))\n    label_map = dict(zip(label_key, range(len(label_key))))\n    new_label = np.array([label_map[x] for x in label])\n\n    self.root_path = root_path\n    self.split_tag = split_tag\n    self.image_size = image_size\n    \n    self.data = data\n    self.label = new_label\n    self.n_classes = len(label_key)\n\n    if normalization:\n      self.norm_params = {'mean': [0.478, 0.456, 0.410],\n                          'std':  [0.279, 0.274, 0.286]}\n    else:\n      self.norm_params = {'mean': [0., 0., 0.],\n                          'std':  [1., 1., 1.]}\n\n    self.transform = get_transform(transform, image_size, self.norm_params)\n\n    def convert_raw(x):\n      mean = torch.tensor(self.norm_params['mean']).view(3, 1, 1).type_as(x)\n      std = torch.tensor(self.norm_params['std']).view(3, 1, 1).type_as(x)\n      return x * std + mean\n      \n    self.convert_raw = convert_raw\n\n  def __len__(self):\n    return len(self.data)\n\n  def __getitem__(self, index):\n    image = self.transform(self.data[index])\n    label = self.label[index]\n    return image, label\n\n\n@register('meta-tiered-imagenet')\nclass MetaTieredImageNet(TieredImageNet):\n  def __init__(self, root_path, split='train', image_size=84, \n               normalization=True, transform=None, val_transform=None,\n               n_batch=200, n_episode=4, n_way=5, n_shot=1, n_query=15):\n    super(MetaTieredImageNet, self).__init__(root_path, split, image_size, \n                                             normalization, transform)\n    self.n_batch = n_batch\n    self.n_episode = n_episode\n    self.n_way = n_way\n    self.n_shot = n_shot\n    self.n_query = n_query\n\n    self.catlocs = tuple()\n    for cat in range(self.n_classes):\n      self.catlocs += (np.argwhere(self.label == cat).reshape(-1),)\n\n    self.val_transform = get_transform(\n      val_transform, image_size, self.norm_params)\n\n  def __len__(self):\n    return self.n_batch * self.n_episode\n\n  def __getitem__(self, index):\n    shot, query = [], []\n    cats = np.random.choice(self.n_classes, self.n_way, replace=False)\n    for c in cats:\n      c_shot, c_query = [], []\n      idx_list = np.random.choice(\n        self.catlocs[c], self.n_shot + self.n_query, replace=False)\n      shot_idx, query_idx = idx_list[:self.n_shot], idx_list[-self.n_query:]\n      for idx in shot_idx:\n        c_shot.append(self.transform(self.data[idx]))\n      for idx in query_idx:\n        c_query.append(self.val_transform(self.data[idx]))\n      shot.append(torch.stack(c_shot))\n      query.append(torch.stack(c_query))\n    \n    shot = torch.cat(shot, dim=0)             # [n_way * n_shot, C, H, W]\n    query = torch.cat(query, dim=0)           # [n_way * n_query, C, H, W]\n    cls = torch.arange(self.n_way)[:, None]\n    shot_labels = cls.repeat(1, self.n_shot).flatten()    # [n_way * n_shot]\n    query_labels = cls.repeat(1, self.n_query).flatten()  # [n_way * n_query]\n    \n    return shot, query, shot_labels, query_labels"
  },
  {
    "path": "datasets/transforms.py",
    "content": "import torchvision.transforms as transforms\n\n\n__all__ = ['get_transform']\n\n\ndef get_transform(name, image_size, norm_params):\n  if name == 'resize':\n    return transforms.Compose([\n      transforms.RandomResizedCrop(image_size),\n      transforms.RandomHorizontalFlip(),\n      transforms.ToTensor(),\n      transforms.Normalize(**norm_params),\n    ])\n  elif name == 'crop':\n    return transforms.Compose([\n      transforms.Resize(image_size),\n      transforms.RandomCrop(image_size, padding=8),\n      transforms.RandomHorizontalFlip(),\n      transforms.ToTensor(),\n      transforms.Normalize(**norm_params),\n    ])\n  elif name == 'color':\n    return transforms.Compose([\n      transforms.Resize(image_size),\n      transforms.RandomCrop(image_size, padding=8),\n      transforms.ColorJitter(\n        brightness=0.4, contrast=0.4, saturation=0.4),\n      transforms.RandomHorizontalFlip(),\n      transforms.ToTensor(),\n      transforms.Normalize(**norm_params),\n    ])\n  elif name == 'flip':\n    return transforms.Compose([\n      transforms.Resize(image_size),\n      transforms.RandomHorizontalFlip(),\n      transforms.ToTensor(),\n      transforms.Normalize(**norm_params),\n    ])\n  elif name == 'enlarge':\n    return transforms.Compose([\n      transforms.Resize(int(image_size * 256 / 224)),\n      transforms.CenterCrop(image_size),\n      transforms.ToTensor(),\n      transforms.Normalize(**norm_params),\n    ])\n  elif name is None:\n    return transforms.Compose([\n      transforms.Resize(image_size),\n      transforms.ToTensor(),\n      transforms.Normalize(**norm_params),\n    ])\n  else:\n    raise ValueError('invalid transformation')"
  },
  {
    "path": "models/__init__.py",
    "content": "from .maml import make\nfrom .maml import load"
  },
  {
    "path": "models/classifiers/__init__.py",
    "content": "from .classifiers import make, load\nfrom . import logistic"
  },
  {
    "path": "models/classifiers/classifiers.py",
    "content": "import torch\n\n\n__all__ = ['make', 'load']\n\n\nmodels = {}\n\ndef register(name):\n  def decorator(cls):\n    models[name] = cls\n    return cls\n  return decorator\n\n\ndef make(name, **kwargs):\n  if name is None:\n    return None\n  model = models[name](**kwargs)\n  if torch.cuda.is_available():\n    model.cuda()\n  return model\n\n\ndef load(ckpt):\n  model = make(ckpt['classifier'], **ckpt['classifier_args'])\n  model.load_state_dict(ckpt['classifier_state_dict'])\n  return model"
  },
  {
    "path": "models/classifiers/logistic.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom .classifiers import register\nfrom ..modules import *\n\n\n__all__ = ['LogisticClassifier']\n\n\n@register('logistic')\nclass LogisticClassifier(Module):\n  def __init__(self, in_dim, n_way, temp=1., learn_temp=False):\n    super(LogisticClassifier, self).__init__()\n    self.in_dim = in_dim\n    self.n_way = n_way\n    self.temp = temp\n    self.learn_temp = learn_temp\n\n    self.linear = Linear(in_dim, n_way)\n    if self.learn_temp:\n      self.temp = nn.Parameter(torch.tensor(temp))\n\n  def reset_parameters(self):\n    nn.init.zeros_(self.linear.weight)\n    nn.init.zeros_(self.linear.bias)\n\n  def forward(self, x_shot, params=None):\n    assert x_shot.dim() == 2\n    logits = self.linear(x_shot, get_child_dict(params, 'linear'))\n    logits = logits * self.temp\n    return logits"
  },
  {
    "path": "models/encoders/__init__.py",
    "content": "from .encoders import make, load\nfrom . import convnet4\nfrom . import resnet12\nfrom . import resnet18"
  },
  {
    "path": "models/encoders/convnet4.py",
    "content": "from collections import OrderedDict\n\nimport torch.nn as nn\n\nfrom .encoders import register\nfrom ..modules import *\n\n\n__all__ = ['convnet4', 'wide_convnet4']\n\n\nclass ConvBlock(Module):\n  def __init__(self, in_channels, out_channels, bn_args):\n    super(ConvBlock, self).__init__()\n    self.in_channels = in_channels\n    self.out_channels = out_channels\n\n    self.conv = Conv2d(in_channels, out_channels, 3, 1, padding=1)\n    self.bn = BatchNorm2d(out_channels, **bn_args)\n    self.relu = nn.ReLU(inplace=True)\n    self.pool = nn.MaxPool2d(2)\n\n  def forward(self, x, params=None, episode=None):\n    out = self.conv(x, get_child_dict(params, 'conv'))\n    out = self.bn(out, get_child_dict(params, 'bn'), episode)\n    out = self.pool(self.relu(out))\n    return out\n\n\nclass ConvNet4(Module):\n  def __init__(self, hid_dim, out_dim, bn_args):\n    super(ConvNet4, self).__init__()\n    self.hid_dim = hid_dim\n    self.out_dim = out_dim\n\n    episodic = bn_args.get('episodic') or []\n    bn_args_ep, bn_args_no_ep = bn_args.copy(), bn_args.copy()\n    bn_args_ep['episodic'] = True\n    bn_args_no_ep['episodic'] = False\n    bn_args_dict = dict()\n    for i in [1, 2, 3, 4]:\n      if 'conv%d' % i in episodic:\n        bn_args_dict[i] = bn_args_ep\n      else:\n        bn_args_dict[i] = bn_args_no_ep\n\n    self.encoder = Sequential(OrderedDict([\n      ('conv1', ConvBlock(3, hid_dim, bn_args_dict[1])),\n      ('conv2', ConvBlock(hid_dim, hid_dim, bn_args_dict[2])),\n      ('conv3', ConvBlock(hid_dim, hid_dim, bn_args_dict[3])),\n      ('conv4', ConvBlock(hid_dim, out_dim, bn_args_dict[4])),\n    ]))\n\n  def get_out_dim(self, scale=25):\n    return self.out_dim * scale\n\n  def forward(self, x, params=None, episode=None):\n    out = self.encoder(x, get_child_dict(params, 'encoder'), episode)\n    out = out.view(out.shape[0], -1)\n    return out\n\n\n@register('convnet4')\ndef convnet4(bn_args=dict()):\n  return ConvNet4(32, 32, bn_args)\n\n\n@register('wide-convnet4')\ndef wide_convnet4(bn_args=dict()):\n  return ConvNet4(64, 64, bn_args)"
  },
  {
    "path": "models/encoders/encoders.py",
    "content": "import torch\n\n\nmodels = {}\n\ndef register(name):\n  def decorator(cls):\n    models[name] = cls\n    return cls\n  return decorator\n\n\ndef make(name, **kwargs):\n  if name is None:\n    return None\n  model = models[name](**kwargs)\n  if torch.cuda.is_available():\n    model.cuda()\n  return model\n\n\ndef load(ckpt):\n  model = make(ckpt['encoder'], **ckpt['encoder_args'])\n  if model is not None:\n    model.load_state_dict(ckpt['encoder_state_dict'])\n  return model"
  },
  {
    "path": "models/encoders/resnet12.py",
    "content": "from collections import OrderedDict\n\nimport torch.nn as nn\n\nfrom .encoders import register\nfrom ..modules import *\n\n\n__all__ = ['resnet12', 'wide_resnet12']\n\n\ndef conv3x3(in_channels, out_channels):\n  return Conv2d(in_channels, out_channels, 3, 1, padding=1, bias=False)\n\n\ndef conv1x1(in_channels, out_channels):\n  return Conv2d(in_channels, out_channels, 1, 1, padding=0, bias=False)\n\n\nclass Block(Module):\n  def __init__(self, in_planes, planes, bn_args):\n    super(Block, self).__init__()\n    self.in_planes = in_planes\n    self.planes = planes\n\n    self.conv1 = conv3x3(in_planes, planes)\n    self.bn1 = BatchNorm2d(planes, **bn_args)\n    self.conv2 = conv3x3(planes, planes)\n    self.bn2 = BatchNorm2d(planes, **bn_args)\n    self.conv3 = conv3x3(planes, planes)\n    self.bn3 = BatchNorm2d(planes, **bn_args)\n\n    self.res_conv = Sequential(OrderedDict([\n      ('conv', conv1x1(in_planes, planes)),\n      ('bn', BatchNorm2d(planes, **bn_args)),\n    ]))\n\n    self.relu = nn.LeakyReLU(0.1, inplace=True)\n    self.pool = nn.MaxPool2d(2)\n\n  def forward(self, x, params=None, episode=None):\n    out = self.conv1(x, get_child_dict(params, 'conv1'))\n    out = self.bn1(out, get_child_dict(params, 'bn1'), episode)\n    out = self.relu(out)\n\n    out = self.conv2(out, get_child_dict(params, 'conv2'))\n    out = self.bn2(out, get_child_dict(params, 'bn2'), episode)\n    out = self.relu(out)\n\n    out = self.conv3(out, get_child_dict(params, 'conv3'))\n    out = self.bn3(out, get_child_dict(params, 'bn3'), episode)\n\n    x = self.res_conv(x, get_child_dict(params, 'res_conv'), episode)\n    out = self.pool(self.relu(out + x))\n    return out\n\n\nclass ResNet12(Module):\n  def __init__(self, channels, bn_args):\n    super(ResNet12, self).__init__()\n    self.channels = channels\n\n    episodic = bn_args.get('episodic') or []\n    bn_args_ep, bn_args_no_ep = bn_args.copy(), bn_args.copy()\n    bn_args_ep['episodic'] = True\n    bn_args_no_ep['episodic'] = False\n    bn_args_dict = dict()\n    for i in [1, 2, 3, 4]:\n      if 'layer%d' % i in episodic:\n        bn_args_dict[i] = bn_args_ep\n      else:\n        bn_args_dict[i] = bn_args_no_ep\n\n    self.layer1 = Block(3, channels[0], bn_args_dict[1])\n    self.layer2 = Block(channels[0], channels[1], bn_args_dict[2])\n    self.layer3 = Block(channels[1], channels[2], bn_args_dict[3])\n    self.layer4 = Block(channels[2], channels[3], bn_args_dict[4])\n    \n    self.pool = nn.AdaptiveAvgPool2d(1)\n    self.out_dim = channels[3]\n\n    for m in self.modules():\n      if isinstance(m, Conv2d):\n        nn.init.kaiming_normal_(\n          m.weight, mode='fan_out', nonlinearity='leaky_relu')\n      elif isinstance(m, BatchNorm2d):\n        nn.init.constant_(m.weight, 1.)\n        nn.init.constant_(m.bias, 0.)\n\n  def get_out_dim(self):\n    return self.out_dim\n\n  def forward(self, x, params=None, episode=None):\n    out = self.layer1(x, get_child_dict(params, 'layer1'), episode)\n    out = self.layer2(out, get_child_dict(params, 'layer2'), episode)\n    out = self.layer3(out, get_child_dict(params, 'layer3'), episode)\n    out = self.layer4(out, get_child_dict(params, 'layer4'), episode)\n    out = self.pool(out).flatten(1)\n    return out\n\n\n@register('resnet12')\ndef resnet12(bn_args=dict()):\n  return ResNet12([64, 128, 256, 512], bn_args)\n\n\n@register('wide-resnet12')\ndef wide_resnet12(bn_args=dict()):\n  return ResNet12([64, 160, 320, 640], bn_args)"
  },
  {
    "path": "models/encoders/resnet18.py",
    "content": "from collections import OrderedDict\n\nimport torch.nn as nn\n\nfrom .encoders import register\nfrom ..modules import *\n\n\n__all__ = ['resnet18', 'wide_resnet18']\n\n\ndef conv3x3(in_channels, out_channels, stride=1):\n  return Conv2d(in_channels, out_channels, 3, stride, padding=1, bias=False)\n\n\ndef conv1x1(in_channels, out_channels, stride=1):\n  return Conv2d(in_channels, out_channels, 1, stride, padding=0, bias=False)\n\n\nclass Block(Module):\n  def __init__(self, in_planes, planes, stride, bn_args):\n    super(Block, self).__init__()\n    self.in_planes = in_planes\n    self.planes = planes\n    self.stride = stride\n\n    self.conv1 = conv3x3(in_planes, planes, stride)\n    self.bn1 = BatchNorm2d(planes, **bn_args)\n    self.conv2 = conv3x3(planes, planes)\n    self.bn2 = BatchNorm2d(planes, **bn_args)\n\n    if stride > 1:\n      self.res_conv = Sequential(OrderedDict([\n        ('conv', conv1x1(in_planes, planes)),\n        ('bn', BatchNorm2d(planes, **bn_args)),\n      ]))\n\n    self.relu = nn.ReLU(inplace=True)\n\n  def forward(self, x, params=None, episode=None):\n    out = self.conv1(x, get_child_dict(params, 'conv1'))\n    out = self.bn1(out, get_child_dict(params, 'bn1'), episode)\n    out = self.relu(out)\n\n    out = self.conv2(out, get_child_dict(params, 'conv2'))\n    out = self.bn2(out, get_child_dict(params, 'bn2'), episode)\n\n    if self.stride > 1:\n      x = self.res_conv(x, get_child_dict(params, 'res_conv'), episode)\n    out = self.relu(out + x)\n    return out\n\n\nclass ResNet18(Module):\n  def __init__(self, channels, bn_args):\n    super(ResNet18, self).__init__()\n    self.channels = channels\n\n    episodic = bn_args.get('episodic') or []\n    bn_args_ep, bn_args_no_ep = bn_args.copy(), bn_args.copy()\n    bn_args_ep['episodic'] = True\n    bn_args_no_ep['episodic'] = False\n    bn_args_dict = dict()\n    for i in [0, 1, 2, 3, 4]:\n      if 'layer%d' % i in episodic:\n        bn_args_dict[i] = bn_args_ep\n      else:\n        bn_args_dict[i] = bn_args_no_ep\n\n    self.layer0 = Sequential(OrderedDict([\n      ('conv', conv3x3(3, 64)),\n      ('bn', BatchNorm2d(64, **bn_args_dict[0])),\n    ]))\n    self.relu = nn.ReLU(inplace=True)\n    self.layer1 = Block(64, channels[0], 1, bn_args_dict[1])\n    self.layer2 = Block(channels[0], channels[1], 2, bn_args_dict[2])\n    self.layer3 = Block(channels[1], channels[2], 2, bn_args_dict[3])\n    self.layer4 = Block(channels[2], channels[3], 2, bn_args_dict[4])\n    \n    self.pool = nn.AdaptiveAvgPool2d(1)\n    self.out_dim = channels[3]\n\n    for m in self.modules():\n      if isinstance(m, Conv2d):\n        nn.init.kaiming_normal_(\n          m.weight, mode='fan_out', nonlinearity='relu')\n      elif isinstance(m, BatchNorm2d):\n        nn.init.constant_(m.weight, 1.)\n        nn.init.constant_(m.bias, 0.)\n\n  def get_out_dim(self, scale=1):\n    return self.out_dim * scale\n\n  def forward(self, x, params=None, episode=None):\n    out = self.layer0(x, get_child_dict(params, 'layer0'), episode)\n    out = self.relu(out)\n    out = self.layer1(out, get_child_dict(params, 'layer1'), episode)\n    out = self.layer2(out, get_child_dict(params, 'layer2'), episode)\n    out = self.layer3(out, get_child_dict(params, 'layer3'), episode)\n    out = self.layer4(out, get_child_dict(params, 'layer4'), episode)\n    out = self.pool(out).flatten(1)\n    return out\n\n\n@register('resnet18')\ndef resnet18(bn_args=dict()):\n  return ResNet18([64, 128, 256, 512], bn_args)\n\n\n@register('wide-resnet18')\ndef wide_resnet18(bn_args=dict()):\n  return ResNet18([64, 160, 320, 640], bn_args)"
  },
  {
    "path": "models/maml.py",
    "content": "from collections import OrderedDict\n\nimport torch\nimport torch.nn.functional as F\nimport torch.autograd as autograd\nimport torch.utils.checkpoint as cp\n\nfrom . import encoders\nfrom . import classifiers\nfrom .modules import get_child_dict, Module, BatchNorm2d\n\n\ndef make(enc_name, enc_args, clf_name, clf_args):\n  \"\"\"\n  Initializes a random meta model.\n\n  Args:\n    enc_name (str): name of the encoder (e.g., 'resnet12').\n    enc_args (dict): arguments for the encoder.\n    clf_name (str): name of the classifier (e.g., 'meta-nn').\n    clf_args (dict): arguments for the classifier.\n\n  Returns:\n    model (MAML): a meta classifier with a random encoder.\n  \"\"\"\n  enc = encoders.make(enc_name, **enc_args)\n  clf_args['in_dim'] = enc.get_out_dim()\n  clf = classifiers.make(clf_name, **clf_args)\n  model = MAML(enc, clf)\n  return model\n\n\ndef load(ckpt, load_clf=False, clf_name=None, clf_args=None):\n  \"\"\"\n  Initializes a meta model with a pre-trained encoder.\n\n  Args:\n    ckpt (dict): a checkpoint from which a pre-trained encoder is restored.\n    load_clf (bool, optional): if True, loads a pre-trained classifier.\n      Default: False (in which case the classifier is randomly initialized)\n    clf_name (str, optional): name of the classifier (e.g., 'meta-nn')\n    clf_args (dict, optional): arguments for the classifier.\n    (The last two arguments are ignored if load_clf=True.)\n\n  Returns:\n    model (MAML): a meta model with a pre-trained encoder.\n  \"\"\"\n  enc = encoders.load(ckpt)\n  if load_clf:\n    clf = classifiers.load(ckpt)\n  else:\n    if clf_name is None and clf_args is None:\n      clf = classifiers.make(ckpt['classifier'], **ckpt['classifier_args'])\n    else:\n      clf_args['in_dim'] = enc.get_out_dim()\n      clf = classifiers.make(clf_name, **clf_args)\n  model = MAML(enc, clf)\n  return model\n\n\nclass MAML(Module):\n  def __init__(self, encoder, classifier):\n    super(MAML, self).__init__()\n    self.encoder = encoder\n    self.classifier = classifier\n\n  def reset_classifier(self):\n    self.classifier.reset_parameters()\n\n  def _inner_forward(self, x, params, episode):\n    \"\"\" Forward pass for the inner loop. \"\"\"\n    feat = self.encoder(x, get_child_dict(params, 'encoder'), episode)\n    logits = self.classifier(feat, get_child_dict(params, 'classifier'))\n    return logits\n\n  def _inner_iter(self, x, y, params, mom_buffer, episode, inner_args, detach):\n    \"\"\" \n    Performs one inner-loop iteration of MAML including the forward and \n    backward passes and the parameter update.\n\n    Args:\n      x (float tensor, [n_way * n_shot, C, H, W]): per-episode support set.\n      y (int tensor, [n_way * n_shot]): per-episode support set labels.\n      params (dict): the model parameters BEFORE the update.\n      mom_buffer (dict): the momentum buffer BEFORE the update.\n      episode (int): the current episode index.\n      inner_args (dict): inner-loop optimization hyperparameters.\n      detach (bool): if True, detachs the graph for the current iteration.\n\n    Returns:\n      updated_params (dict): the model parameters AFTER the update.\n      mom_buffer (dict): the momentum buffer AFTER the update.\n    \"\"\"\n    with torch.enable_grad():\n      # forward pass\n      logits = self._inner_forward(x, params, episode)\n      loss = F.cross_entropy(logits, y)\n      # backward pass\n      grads = autograd.grad(loss, params.values(), \n        create_graph=(not detach and not inner_args['first_order']),\n        only_inputs=True, allow_unused=True)\n      # parameter update\n      updated_params = OrderedDict()\n      for (name, param), grad in zip(params.items(), grads):\n        if grad is None:\n          updated_param = param\n        else:\n          if inner_args['weight_decay'] > 0:\n            grad = grad + inner_args['weight_decay'] * param\n          if inner_args['momentum'] > 0:\n            grad = grad + inner_args['momentum'] * mom_buffer[name]\n            mom_buffer[name] = grad\n          if 'encoder' in name:\n            lr = inner_args['encoder_lr']\n          elif 'classifier' in name:\n            lr = inner_args['classifier_lr']\n          else:\n            raise ValueError('invalid parameter name')\n          updated_param = param - lr * grad\n        if detach:\n          updated_param = updated_param.detach().requires_grad_(True)\n        updated_params[name] = updated_param\n\n    return updated_params, mom_buffer\n\n  def _adapt(self, x, y, params, episode, inner_args, meta_train):\n    \"\"\"\n    Performs inner-loop adaptation in MAML.\n\n    Args:\n      x (float tensor, [n_way * n_shot, C, H, W]): per-episode support set.\n        (T: transforms, C: channels, H: height, W: width)\n      y (int tensor, [n_way * n_shot]): per-episode support set labels.\n      params (dict): a dictionary of parameters at meta-initialization.\n      episode (int): the current episode index.\n      inner_args (dict): inner-loop optimization hyperparameters.\n      meta_train (bool): if True, the model is in meta-training.\n      \n    Returns:\n      params (dict): model paramters AFTER inner-loop adaptation.\n    \"\"\"\n    assert x.dim() == 4 and y.dim() == 1\n    assert x.size(0) == y.size(0)\n\n    # Initializes a dictionary of momentum buffer for gradient descent in the \n    # inner loop. It has the same set of keys as the parameter dictionary.\n    mom_buffer = OrderedDict()\n    if inner_args['momentum'] > 0:\n      for name, param in params.items():\n        mom_buffer[name] = torch.zeros_like(param)\n    params_keys = tuple(params.keys())\n    mom_buffer_keys = tuple(mom_buffer.keys())\n\n    for m in self.modules():\n      if isinstance(m, BatchNorm2d) and m.is_episodic():\n        m.reset_episodic_running_stats(episode)\n\n    def _inner_iter_cp(episode, *state):\n      \"\"\" \n      Performs one inner-loop iteration when checkpointing is enabled. \n      The code is executed twice:\n        - 1st time with torch.no_grad() for creating checkpoints.\n        - 2nd time with torch.enable_grad() for computing gradients.\n      \"\"\"\n      params = OrderedDict(zip(params_keys, state[:len(params_keys)]))\n      mom_buffer = OrderedDict(\n        zip(mom_buffer_keys, state[-len(mom_buffer_keys):]))\n\n      detach = not torch.is_grad_enabled()  # detach graph in the first pass\n      self.is_first_pass(detach)\n      params, mom_buffer = self._inner_iter(\n        x, y, params, mom_buffer, int(episode), inner_args, detach)\n      state = tuple(t if t.requires_grad else t.clone().requires_grad_(True)\n        for t in tuple(params.values()) + tuple(mom_buffer.values()))\n      return state\n\n    for step in range(inner_args['n_step']):\n      if self.efficient:  # checkpointing\n        state = tuple(params.values()) + tuple(mom_buffer.values())\n        state = cp.checkpoint(_inner_iter_cp, torch.as_tensor(episode), *state)\n        params = OrderedDict(zip(params_keys, state[:len(params_keys)]))\n        mom_buffer = OrderedDict(\n          zip(mom_buffer_keys, state[-len(mom_buffer_keys):]))\n      else:\n        params, mom_buffer = self._inner_iter(\n          x, y, params, mom_buffer, episode, inner_args, not meta_train)\n        \n    return params\n\n  def forward(self, x_shot, x_query, y_shot, inner_args, meta_train):\n    \"\"\"\n    Args:\n      x_shot (float tensor, [n_episode, n_way * n_shot, C, H, W]): support sets.\n      x_query (float tensor, [n_episode, n_way * n_query, C, H, W]): query sets.\n        (T: transforms, C: channels, H: height, W: width)\n      y_shot (int tensor, [n_episode, n_way * n_shot]): support set labels.\n      inner_args (dict, optional): inner-loop hyperparameters.\n      meta_train (bool): if True, the model is in meta-training.\n      \n    Returns:\n      logits (float tensor, [n_episode, n_way * n_shot, n_way]): predicted logits.\n    \"\"\"\n    assert self.encoder is not None\n    assert self.classifier is not None\n    assert x_shot.dim() == 5 and x_query.dim() == 5\n    assert x_shot.size(0) == x_query.size(0)\n\n    # a dictionary of parameters that will be updated in the inner loop\n    params = OrderedDict(self.named_parameters())\n    for name in list(params.keys()):\n      if not params[name].requires_grad or \\\n        any(s in name for s in inner_args['frozen'] + ['temp']):\n        params.pop(name)\n\n    logits = []\n    for ep in range(x_shot.size(0)):\n      # inner-loop training\n      self.train()\n      if not meta_train:\n        for m in self.modules():\n          if isinstance(m, BatchNorm2d) and not m.is_episodic():\n            m.eval()\n      updated_params = self._adapt(\n        x_shot[ep], y_shot[ep], params, ep, inner_args, meta_train)\n      # inner-loop validation\n      with torch.set_grad_enabled(meta_train):\n        self.eval()\n        logits_ep = self._inner_forward(x_query[ep], updated_params, ep)\n      logits.append(logits_ep)\n\n    self.train(meta_train)\n    logits = torch.stack(logits)\n    return logits"
  },
  {
    "path": "models/modules.py",
    "content": "import re\nfrom collections import OrderedDict\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\n__all__ = ['Module', 'Conv2d', 'Linear', 'BatchNorm2d', 'Sequential', \n           'get_child_dict']\n\n\ndef get_child_dict(params, key=None):\n  \"\"\"\n  Constructs parameter dictionary for a network module.\n\n  Args:\n    params (dict): a parent dictionary of named parameters.\n    key (str, optional): a key that specifies the root of the child dictionary.\n\n  Returns:\n    child_dict (dict): a child dictionary of model parameters.\n  \"\"\"\n  if params is None:\n    return None\n  if key is None or (isinstance(key, str) and key == ''):\n    return params\n\n  key_re = re.compile(r'^{0}\\.(.+)'.format(re.escape(key)))\n  if not any(filter(key_re.match, params.keys())):  # handles nn.DataParallel\n    key_re = re.compile(r'^module\\.{0}\\.(.+)'.format(re.escape(key)))\n  child_dict = OrderedDict(\n    (key_re.sub(r'\\1', k), value) for (k, value)\n      in params.items() if key_re.match(k) is not None)\n  return child_dict\n\n\nclass Module(nn.Module):\n  def __init__(self):\n    super(Module, self).__init__()\n    self.efficient = False\n    self.first_pass = True\n\n  def go_efficient(self, mode=True):\n    \"\"\" Switches on / off gradient checkpointing. \"\"\"\n    self.efficient = mode\n    for m in self.children():\n      if isinstance(m, Module):\n        m.go_efficient(mode)\n\n  def is_first_pass(self, mode=True):\n    \"\"\" Tracks the progress of forward and backward pass when gradient \n    checkpointing is enabled. \"\"\"\n    self.first_pass = mode\n    for m in self.children():\n      if isinstance(m, Module):\n        m.is_first_pass(mode)\n\n\nclass Conv2d(nn.Conv2d, Module):\n  def __init__(self, in_channels, out_channels, kernel_size, \n               stride=1, padding=0, bias=True):\n    super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, \n                                 stride, padding, bias=bias)\n\n  def forward(self, x, params=None, episode=None):\n    if params is None:\n      x = super(Conv2d, self).forward(x)\n    else:\n      weight, bias = params.get('weight'), params.get('bias')\n      if weight is None:\n        weight = self.weight\n      if bias is None:\n        bias = self.bias\n      x = F.conv2d(x, weight, bias, self.stride, self.padding)\n    return x\n\n\nclass Linear(nn.Linear, Module):\n  def __init__(self, in_features, out_features, bias=True):\n    super(Linear, self).__init__(in_features, out_features, bias=bias)\n\n  def forward(self, x, params=None, episode=None):\n    if params is None:\n      x = super(Linear, self).forward(x)\n    else:\n      weight, bias = params.get('weight'), params.get('bias')\n      if weight is None:\n        weight = self.weight\n      if bias is None:\n        bias = self.bias\n      x = F.linear(x, weight, bias)\n    return x\n\n\nclass BatchNorm2d(nn.BatchNorm2d, Module):\n  def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, \n               track_running_stats=True, episodic=False, n_episode=4,\n               alpha=False):\n    \"\"\"\n    Args:\n      episodic (bool, optional): if True, maintains running statistics for \n        each episode separately. It is ignored if track_running_stats=False. \n        Default: True\n      n_episode (int, optional): number of episodes per mini-batch. It is \n        ignored if episodic=False.\n      alpha (bool, optional): if True, learns to interpolate between batch \n        statistics computed over the support set and instance statistics from \n        a query at validation time. Default: True\n        (It is ignored if track_running_stats=False or meta_learn=False)\n    \"\"\"\n    super(BatchNorm2d, self).__init__(num_features, eps, momentum, affine, \n                                      track_running_stats)\n    self.episodic = episodic\n    self.n_episode = n_episode\n    self.alpha = alpha\n\n    if self.track_running_stats:\n      if self.episodic:\n        for ep in range(n_episode):\n          self.register_buffer(\n            'running_mean_%d' % ep, torch.zeros(num_features))\n          self.register_buffer(\n            'running_var_%d' % ep, torch.ones(num_features))\n          self.register_buffer(\n            'num_batches_tracked_%d' % ep, torch.tensor(0, dtype=torch.int))\n      if self.alpha:\n        self.register_buffer('batch_size', torch.tensor(0, dtype=torch.int))\n        self.alpha_scale = nn.Parameter(torch.tensor(0.))\n        self.alpha_offset = nn.Parameter(torch.tensor(0.))\n        \n  def is_episodic(self):\n    return self.episodic\n\n  def _batch_norm(self, x, mean, var, weight=None, bias=None):\n    if self.affine:\n      assert weight is not None and bias is not None\n      weight = weight.view(1, -1, 1, 1)\n      bias = bias.view(1, -1, 1, 1)\n      x = weight * (x - mean) / (var + self.eps) ** .5 + bias\n    else:\n      x = (x - mean) / (var + self.eps) ** .5\n    return x\n\n  def reset_episodic_running_stats(self, episode):\n    if self.episodic:\n      getattr(self, 'running_mean_%d' % episode).zero_()\n      getattr(self, 'running_var_%d' % episode).fill_(1.)\n      getattr(self, 'num_batches_tracked_%d' % episode).zero_()\n\n  def forward(self, x, params=None, episode=None):\n    self._check_input_dim(x)\n    if params is not None:\n      weight, bias = params.get('weight'), params.get('bias')\n      if weight is None:\n        weight = self.weight\n      if bias is None:\n        bias = self.bias\n    else:\n      weight, bias = self.weight, self.bias\n\n    if self.track_running_stats:\n      if self.episodic:\n        assert episode is not None and episode < self.n_episode\n        running_mean = getattr(self, 'running_mean_%d' % episode)\n        running_var = getattr(self, 'running_var_%d' % episode)\n        num_batches_tracked = getattr(self, 'num_batches_tracked_%d' % episode)\n      else:\n        running_mean, running_var = self.running_mean, self.running_var\n        num_batches_tracked = self.num_batches_tracked\n\n      if self.training:\n        exp_avg_factor = 0.\n        if self.first_pass: # only updates statistics in the first pass\n          if self.alpha:\n            self.batch_size = x.size(0)\n          num_batches_tracked += 1\n          if self.momentum is None:\n            exp_avg_factor = 1. / float(num_batches_tracked)\n          else:\n            exp_avg_factor = self.momentum\n        return F.batch_norm(x, running_mean, running_var, weight, bias,\n                            True, exp_avg_factor, self.eps)\n      else:\n        if self.alpha:\n          assert self.batch_size > 0\n          alpha = torch.sigmoid(\n            self.alpha_scale * self.batch_size + self.alpha_offset)\n          # exponentially moving-averaged training statistics\n          running_mean = running_mean.view(1, -1, 1, 1)\n          running_var = running_var.view(1, -1, 1, 1)\n          # per-sample statistics\n          sample_mean = torch.mean(x, dim=(2, 3), keepdim=True)\n          sample_var = torch.var(x, dim=(2, 3), unbiased=False, keepdim=True)\n          # interpolated statistics\n          mean = alpha * running_mean + (1 - alpha) * sample_mean\n          var = alpha * running_var + (1 - alpha) * sample_var + \\\n                alpha * (1 - alpha) * (sample_mean - running_mean) ** 2\n          return self._batch_norm(x, mean, var, weight, bias)\n        else:\n          return F.batch_norm(x, running_mean, running_var, weight, bias,\n                              False, 0., self.eps)\n    else:\n      return F.batch_norm(x, None, None, weight, bias, True, 0., self.eps)\n\n\nclass Sequential(nn.Sequential, Module):\n  def __init__(self, *args):\n    super(Sequential, self).__init__(*args)\n\n  def forward(self, x, params=None, episode=None):\n    if params is None:\n      for module in self:\n        x = module(x, None, episode)\n    else:\n      for name, module in self._modules.items():\n        x = module(x, get_child_dict(params, name), episode)\n    return x"
  },
  {
    "path": "test.py",
    "content": "import argparse\nimport random\n\nimport yaml\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom tqdm import tqdm\nfrom torch.utils.data import DataLoader\n\nimport datasets\nimport models\nimport utils\n\n\ndef main(config):\n  random.seed(0)\n  np.random.seed(0)\n  torch.manual_seed(0)\n  torch.cuda.manual_seed(0)\n  # torch.backends.cudnn.deterministic = True\n  # torch.backends.cudnn.benchmark = False\n\n  ##### Dataset #####\n\n  dataset = datasets.make(config['dataset'], **config['test'])\n  utils.log('meta-test set: {} (x{}), {}'.format(\n    dataset[0][0].shape, len(dataset), dataset.n_classes))\n  loader = DataLoader(dataset, config['test']['n_episode'],\n    collate_fn=datasets.collate_fn, num_workers=1, pin_memory=True)\n\n  ##### Model #####\n\n  ckpt = torch.load(config['load'])\n  inner_args = utils.config_inner_args(config.get('inner_args'))\n  model = models.load(ckpt, load_clf=(not inner_args['reset_classifier']))\n\n  if args.efficient:\n    model.go_efficient()\n\n  if config.get('_parallel'):\n    model = nn.DataParallel(model)\n\n  utils.log('num params: {}'.format(utils.compute_n_params(model)))\n\n  ##### Evaluation #####\n\n  model.eval()\n  aves_va = utils.AverageMeter()\n  va_lst = []\n\n  for epoch in range(1, config['epoch'] + 1):\n    for data in tqdm(loader, leave=False):\n      x_shot, x_query, y_shot, y_query = data\n      x_shot, y_shot = x_shot.cuda(), y_shot.cuda()\n      x_query, y_query = x_query.cuda(), y_query.cuda()\n\n      if inner_args['reset_classifier']:\n        if config.get('_parallel'):\n          model.module.reset_classifier()\n        else:\n          model.reset_classifier()\n\n      logits = model(x_shot, x_query, y_shot, inner_args, meta_train=False)\n      logits = logits.view(-1, config['test']['n_way'])\n      labels = y_query.view(-1)\n      \n      pred = torch.argmax(logits, dim=1)\n      acc = utils.compute_acc(pred, labels)\n      aves_va.update(acc, 1)\n      va_lst.append(acc)\n\n    print('test epoch {}: acc={:.2f} +- {:.2f} (%)'.format(\n      epoch, aves_va.item() * 100, \n      utils.mean_confidence_interval(va_lst) * 100))\n\n\nif __name__ == '__main__':\n  parser = argparse.ArgumentParser()\n  parser.add_argument('--config', \n                      help='configuration file')\n  parser.add_argument('--gpu', \n                      help='gpu device number', \n                      type=str, default='0')\n  parser.add_argument('--efficient', \n                      help='if True, enables gradient checkpointing',\n                      action='store_true')\n  args = parser.parse_args()\n  config = yaml.load(open(args.config, 'r'), Loader=yaml.FullLoader)\n  \n  if len(args.gpu.split(',')) > 1:\n    config['_parallel'] = True\n    config['_gpu'] = args.gpu\n\n  utils.set_gpu(args.gpu)\n  main(config)"
  },
  {
    "path": "train.py",
    "content": "import argparse\nimport os\nimport random\nfrom collections import OrderedDict\n\nimport yaml\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom tqdm import tqdm\nfrom torch.utils.data import DataLoader\nfrom tensorboardX import SummaryWriter\n\nimport datasets\nimport models\nimport utils\nimport utils.optimizers as optimizers\n\n\ndef main(config):\n  random.seed(0)\n  np.random.seed(0)\n  torch.manual_seed(0)\n  torch.cuda.manual_seed(0)\n  # torch.backends.cudnn.deterministic = True\n  # torch.backends.cudnn.benchmark = False\n\n  ckpt_name = args.name\n  if ckpt_name is None:\n    ckpt_name = config['encoder']\n    ckpt_name += '_' + config['dataset'].replace('meta-', '')\n    ckpt_name += '_{}_way_{}_shot'.format(\n      config['train']['n_way'], config['train']['n_shot'])\n  if args.tag is not None:\n    ckpt_name += '_' + args.tag\n\n  ckpt_path = os.path.join('./save', ckpt_name)\n  utils.ensure_path(ckpt_path)\n  utils.set_log_path(ckpt_path)\n  writer = SummaryWriter(os.path.join(ckpt_path, 'tensorboard'))\n  yaml.dump(config, open(os.path.join(ckpt_path, 'config.yaml'), 'w'))\n\n  ##### Dataset #####\n\n  # meta-train\n  train_set = datasets.make(config['dataset'], **config['train'])\n  utils.log('meta-train set: {} (x{}), {}'.format(\n    train_set[0][0].shape, len(train_set), train_set.n_classes))\n  train_loader = DataLoader(\n    train_set, config['train']['n_episode'],\n    collate_fn=datasets.collate_fn, num_workers=1, pin_memory=True)\n\n  # meta-val\n  eval_val = False\n  if config.get('val'):\n    eval_val = True\n    val_set = datasets.make(config['dataset'], **config['val'])\n    utils.log('meta-val set: {} (x{}), {}'.format(\n      val_set[0][0].shape, len(val_set), val_set.n_classes))\n    val_loader = DataLoader(\n      val_set, config['val']['n_episode'],\n      collate_fn=datasets.collate_fn, num_workers=1, pin_memory=True)\n  \n  ##### Model and Optimizer #####\n\n  inner_args = utils.config_inner_args(config.get('inner_args'))\n  if config.get('load'):\n    ckpt = torch.load(config['load'])\n    config['encoder'] = ckpt['encoder']\n    config['encoder_args'] = ckpt['encoder_args']\n    config['classifier'] = ckpt['classifier']\n    config['classifier_args'] = ckpt['classifier_args']\n    model = models.load(ckpt, load_clf=(not inner_args['reset_classifier']))\n    optimizer, lr_scheduler = optimizers.load(ckpt, model.parameters())\n    start_epoch = ckpt['training']['epoch'] + 1\n    max_va = ckpt['training']['max_va']\n  else:\n    config['encoder_args'] = config.get('encoder_args') or dict()\n    config['classifier_args'] = config.get('classifier_args') or dict()\n    config['encoder_args']['bn_args']['n_episode'] = config['train']['n_episode']\n    config['classifier_args']['n_way'] = config['train']['n_way']\n    model = models.make(config['encoder'], config['encoder_args'],\n                        config['classifier'], config['classifier_args'])\n    optimizer, lr_scheduler = optimizers.make(\n      config['optimizer'], model.parameters(), **config['optimizer_args'])\n    start_epoch = 1\n    max_va = 0.\n\n  if args.efficient:\n    model.go_efficient()\n\n  if config.get('_parallel'):\n    model = nn.DataParallel(model)\n\n  utils.log('num params: {}'.format(utils.compute_n_params(model)))\n  timer_elapsed, timer_epoch = utils.Timer(), utils.Timer()\n\n  ##### Training and evaluation #####\n    \n  # 'tl': meta-train loss\n  # 'ta': meta-train accuracy\n  # 'vl': meta-val loss\n  # 'va': meta-val accuracy\n  aves_keys = ['tl', 'ta', 'vl', 'va']\n  trlog = dict()\n  for k in aves_keys:\n    trlog[k] = []\n\n  for epoch in range(start_epoch, config['epoch'] + 1):\n    timer_epoch.start()\n    aves = {k: utils.AverageMeter() for k in aves_keys}\n\n    # meta-train\n    model.train()\n    writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)\n    np.random.seed(epoch)\n\n    for data in tqdm(train_loader, desc='meta-train', leave=False):\n      x_shot, x_query, y_shot, y_query = data\n      x_shot, y_shot = x_shot.cuda(), y_shot.cuda()\n      x_query, y_query = x_query.cuda(), y_query.cuda()\n\n      if inner_args['reset_classifier']:\n        if config.get('_parallel'):\n          model.module.reset_classifier()\n        else:\n          model.reset_classifier()\n\n      logits = model(x_shot, x_query, y_shot, inner_args, meta_train=True)\n      logits = logits.flatten(0, 1)\n      labels = y_query.flatten()\n      \n      pred = torch.argmax(logits, dim=-1)\n      acc = utils.compute_acc(pred, labels)\n      loss = F.cross_entropy(logits, labels)\n      aves['tl'].update(loss.item(), 1)\n      aves['ta'].update(acc, 1)\n      \n      optimizer.zero_grad()\n      loss.backward()\n      for param in optimizer.param_groups[0]['params']:\n        nn.utils.clip_grad_value_(param, 10)\n      optimizer.step()\n\n    # meta-val\n    if eval_val:\n      model.eval()\n      np.random.seed(0)\n\n      for data in tqdm(val_loader, desc='meta-val', leave=False):\n        x_shot, x_query, y_shot, y_query = data\n        x_shot, y_shot = x_shot.cuda(), y_shot.cuda()\n        x_query, y_query = x_query.cuda(), y_query.cuda()\n\n        if inner_args['reset_classifier']:\n          if config.get('_parallel'):\n            model.module.reset_classifier()\n          else:\n            model.reset_classifier()\n\n        logits = model(x_shot, x_query, y_shot, inner_args, meta_train=False)\n        logits = logits.flatten(0, 1)\n        labels = y_query.flatten()\n        \n        pred = torch.argmax(logits, dim=-1)\n        acc = utils.compute_acc(pred, labels)\n        loss = F.cross_entropy(logits, labels)\n        aves['vl'].update(loss.item(), 1)\n        aves['va'].update(acc, 1)\n\n    if lr_scheduler is not None:\n      lr_scheduler.step()\n\n    for k, avg in aves.items():\n      aves[k] = avg.item()\n      trlog[k].append(aves[k])\n\n    t_epoch = utils.time_str(timer_epoch.end())\n    t_elapsed = utils.time_str(timer_elapsed.end())\n    t_estimate = utils.time_str(timer_elapsed.end() / \n      (epoch - start_epoch + 1) * (config['epoch'] - start_epoch + 1))\n\n    # formats output\n    log_str = 'epoch {}, meta-train {:.4f}|{:.4f}'.format(\n      str(epoch), aves['tl'], aves['ta'])\n    writer.add_scalars('loss', {'meta-train': aves['tl']}, epoch)\n    writer.add_scalars('acc', {'meta-train': aves['ta']}, epoch)\n\n    if eval_val:\n      log_str += ', meta-val {:.4f}|{:.4f}'.format(aves['vl'], aves['va'])\n      writer.add_scalars('loss', {'meta-val': aves['vl']}, epoch)\n      writer.add_scalars('acc', {'meta-val': aves['va']}, epoch)\n\n    log_str += ', {} {}/{}'.format(t_epoch, t_elapsed, t_estimate)\n    utils.log(log_str)\n\n    # saves model and meta-data\n    if config.get('_parallel'):\n      model_ = model.module\n    else:\n      model_ = model\n\n    training = {\n      'epoch': epoch,\n      'max_va': max(max_va, aves['va']),\n\n      'optimizer': config['optimizer'],\n      'optimizer_args': config['optimizer_args'],\n      'optimizer_state_dict': optimizer.state_dict(),\n      'lr_scheduler_state_dict': lr_scheduler.state_dict() \n        if lr_scheduler is not None else None,\n    }\n    ckpt = {\n      'file': __file__,\n      'config': config,\n\n      'encoder': config['encoder'],\n      'encoder_args': config['encoder_args'],\n      'encoder_state_dict': model_.encoder.state_dict(),\n\n      'classifier': config['classifier'],\n      'classifier_args': config['classifier_args'],\n      'classifier_state_dict': model_.classifier.state_dict(),\n\n      'training': training,\n    }\n\n    # 'epoch-last.pth': saved at the latest epoch\n    # 'max-va.pth': saved when validation accuracy is at its maximum\n    torch.save(ckpt, os.path.join(ckpt_path, 'epoch-last.pth'))\n    torch.save(trlog, os.path.join(ckpt_path, 'trlog.pth'))\n\n    if aves['va'] > max_va:\n      max_va = aves['va']\n      torch.save(ckpt, os.path.join(ckpt_path, 'max-va.pth'))\n\n    writer.flush()\n\n\nif __name__ == '__main__':\n  parser = argparse.ArgumentParser()\n  parser.add_argument('--config', \n                      help='configuration file')\n  parser.add_argument('--name', \n                      help='model name', \n                      type=str, default=None)\n  parser.add_argument('--tag', \n                      help='auxiliary information', \n                      type=str, default=None)\n  parser.add_argument('--gpu', \n                      help='gpu device number', \n                      type=str, default='0')\n  parser.add_argument('--efficient', \n                      help='if True, enables gradient checkpointing',\n                      action='store_true')\n  args = parser.parse_args()\n  config = yaml.load(open(args.config, 'r'), Loader=yaml.FullLoader)\n\n  if len(args.gpu.split(',')) > 1:\n    config['_parallel'] = True\n    config['_gpu'] = args.gpu\n\n  utils.set_gpu(args.gpu)\n  main(config)"
  },
  {
    "path": "utils/__init__.py",
    "content": "import os\nimport shutil\nimport time\n\nimport numpy as np\nimport scipy.stats as stats\n\n\n_log_path = None\n\ndef set_log_path(path):\n  global _log_path\n  _log_path = path\n\n\ndef log(obj, filename='log.txt'):\n  print(obj)\n  if _log_path is not None:\n    with open(os.path.join(_log_path, filename), 'a') as f:\n      print(obj, file=f)\n\n\nclass AverageMeter(object):\n  def __init__(self):\n    self.reset()\n\n  def reset(self):\n    self.val = 0.\n    self.avg = 0.\n    self.sum = 0.\n    self.count = 0.\n\n  def update(self, val, n=1):\n    self.val = val\n    self.sum += val * n\n    self.count += n\n    self.avg = self.sum / self.count\n\n  def item(self):\n    return self.avg\n\n\nclass Timer(object):\n  def __init__(self):\n    self.start()\n\n  def start(self):\n    self.v = time.time()\n\n  def end(self):\n    return time.time() - self.v\n\n\ndef set_gpu(gpu):\n  print('set gpu:', gpu)\n  os.environ['CUDA_VISIBLE_DEVICES'] = gpu\n\n\ndef ensure_path(path, remove=True):\n  basename = os.path.basename(path.rstrip('/'))\n  if os.path.exists(path):\n    if remove and (basename.startswith('_')\n      or input('{} exists, remove? ([y]/n): '.format(path)) != 'n'):\n      shutil.rmtree(path)\n      os.makedirs(path)\n  else:\n    os.makedirs(path)\n\n\ndef time_str(t):\n  if t >= 3600:\n    return '{:.1f}h'.format(t / 3600)\n  if t >= 60:\n    return '{:.1f}m'.format(t / 60)\n  return '{:.1f}s'.format(t)\n\n\ndef compute_acc(pred, label, reduction='mean'):\n  result = (pred == label).float()\n  if reduction == 'none':\n    return result.detach()\n  elif reduction == 'mean':\n    return result.mean().item()\n\n\ndef compute_n_params(model, return_str=True):\n  n_params = 0\n  for p in model.parameters():\n    n_params += p.numel()\n  if return_str:\n    if n_params >= 1e6:\n      return '{:.1f}M'.format(n_params / 1e6)\n    else:\n      return '{:.1f}K'.format(n_params / 1e3)\n  else:\n    return n_params\n\n\ndef mean_confidence_interval(data, confidence=0.95):\n  a = 1.0 * np.array(data)\n  stderr = stats.sem(a)\n  h = stderr * stats.t.ppf((1 + confidence) / 2., len(a) - 1)\n  return h\n\n\ndef config_inner_args(inner_args):\n  if inner_args is None: \n    inner_args = dict()\n\n  inner_args['reset_classifier'] = inner_args.get('reset_classifier') or False\n  inner_args['n_step'] = inner_args.get('n_step') or 5\n  inner_args['encoder_lr'] = inner_args.get('encoder_lr') or 0.01\n  inner_args['classifier_lr'] = inner_args.get('classifier_lr') or 0.01\n  inner_args['momentum'] = inner_args.get('momentum') or 0.\n  inner_args['weight_decay'] = inner_args.get('weight_decay') or 0.\n  inner_args['first_order'] = inner_args.get('first_order') or False\n  inner_args['frozen'] = inner_args.get('frozen') or []\n\n  return inner_args"
  },
  {
    "path": "utils/optimizers.py",
    "content": "from torch.optim import SGD, RMSprop, Adam\nfrom torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR\n\n\ndef make(name, params, lr, weight_decay=0., \n         schedule='step', milestones=None, gamma=0.1):\n  \"\"\"\n  Prepares an optimizer and its learning-rate scheduler.\n\n  Args:\n    name (str): name of the optimizer. Options: 'sgd', 'rmsprop', 'adam'\n    params (iterable): parameters to optimize.\n    lr (float): initial learning rate.\n    weight_decay (float, optional): weight decay. Default: 0.\n    schedule (str, optional): type of learning-rate schedule. Default: 'step'\n      Options: 'step', 'cosine'\n      (This argument is ignored if milestones=None.)\n    milestones (int list, optional): a list of epoches when learning rate \n      is altered. Default: None\n    gamma (float, optional): multiplicative factor of learning rate decay.\n      Default: 0.1\n  \"\"\"\n  if name == 'sgd':\n    optimizer = SGD(params, lr, momentum=0.9, weight_decay=weight_decay)\n  elif name == 'rmsprop':\n    optimizer = RMSprop(params, lr, weight_decay=weight_decay)\n  elif name == 'adam':\n    optimizer = Adam(params, lr, weight_decay=weight_decay)\n  else:\n    raise ValueError('invalid optimizer')\n  \n  if milestones is not None:\n    if schedule == 'step':\n      lr_scheduler = MultiStepLR(optimizer, milestones, gamma)\n    elif schedule == 'cosine':\n      lr_scheduler = CosineAnnealingLR(optimizer, milestones[-1])\n  else:\n    lr_scheduler = None\n  \n  return optimizer, lr_scheduler\n\n\ndef load(ckpt, params):\n  train = ckpt['training']\n  optimizer, lr_scheduler = make(\n    train['optimizer'], params, **train['optimizer_args'])\n  optimizer.load_state_dict(train['optimizer_state_dict'])\n  \n  if lr_scheduler is not None:\n    lr_scheduler.load_state_dict(train['lr_scheduler_state_dict'])\n  \n  return optimizer, lr_scheduler"
  }
]