[
  {
    "path": "README.md",
    "content": "# Optimization as a Model for Few-shot Learning\nPytorch implementation of [Optimization as a Model for Few-shot Learning](https://openreview.net/forum?id=rJY0-Kcll) in ICLR 2017 (Oral)\n\n![Model Architecture](https://i.imgur.com/lydKeUc.png)\n\n## Prerequisites\n- python 3+\n- pytorch 0.4+ (developed on 1.0.1 with cuda 9.0)\n- [pillow](https://pillow.readthedocs.io/en/stable/installation.html)\n- [tqdm](https://tqdm.github.io/) (a nice progress bar)\n\n## Data\n- Mini-Imagenet as described [here](https://github.com/twitter/meta-learning-lstm/tree/master/data/miniImagenet)\n  - You can download it from [here](https://drive.google.com/file/d/1rV3aj_hgfNTfCakffpPm7Vhpr1in87CR/view?usp=sharing) (~2.7GB, google drive link)\n\n## Preparation\n- Make sure Mini-Imagenet is split properly. For example:\n  ```\n  - data/\n    - miniImagenet/\n      - train/\n        - n01532829/\n          - n0153282900000005.jpg\n          - ...\n        - n01558993/\n        - ...\n      - val/\n        - n01855672/\n        - ...\n      - test/\n        - ...\n  - main.py\n  - ...\n  ```\n  - It'd be set if you download and extract Mini-Imagenet from the link above\n- Check out `scripts/train_5s_5c.sh`, make sure `--data-root` is properly set\n\n## Run\nFor 5-shot, 5-class training, run\n```bash\nbash scripts/train_5s_5c.sh\n```\nHyper-parameters are referred to the [author's repo](https://github.com/twitter/meta-learning-lstm).\n\nFor 5-shot, 5-class evaluation, run *(remember to change `--resume` and `--seed` arguments)*\n```bash\nbash scripts/eval_5s_5c.sh\n```\n\n## Notes\n- Results (This repo is developed following the [pytorch reproducibility guideline](https://pytorch.org/docs/stable/notes/randomness.html)):\n\n|seed|train episodes|val episodes|val acc mean|val acc std|test episodes|test acc mean|test acc std|\n|-|-|-|-|-|-|-|-|\n|719|41000|100|59.08|9.9|100|56.59|8.4|\n|  -|    -|  -|    -|  -|250|57.85|8.6|\n|  -|    -|  -|    -|  -|600|57.76|8.6|\n| 53|44000|100|58.04|9.1|100|57.85|7.7|\n|  -|    -|  -|    -|  -|250|57.83|8.3|\n|  -|    -|  -|    -|  -|600|58.14|8.5|\n\n- The results I get from directly running the author's repo can be found [here](https://i.imgur.com/rtagm2c.png), I have slightly better performance (~5%) but neither results match the number in the paper (60%) *(Discussion and help are welcome!)*.\n- Training with the default settings takes ~2.5 hours on a single Titan Xp while occupying ~2GB GPU memory.\n- The implementation replicates two learners similar to the author's repo:\n  - `learner_w_grad` functions as a regular model, get gradients and loss as inputs to meta learner.\n  - `learner_wo_grad` constructs the graph for meta learner:\n    - All the parameters in `learner_wo_grad` are replaced by `cI` output by meta learner.\n    - `nn.Parameters` in this model are casted to `torch.Tensor` to connect the graph to meta learner.\n- Several ways to **copy** a parameters from meta learner to learner depends on the scenario:\n  - `copy_flat_params`: we only need the parameter values and keep the original `grad_fn`.\n  - `transfer_params`: we want the values as well as the `grad_fn` (from `cI` to `learner_wo_grad`).\n    - `.data.copy_` v.s. `clone()` -> the latter retains all the properties of a tensor including `grad_fn`.\n    - To maintain the batch statistics, `load_state_dict` is used (from `learner_w_grad` to `learner_wo_grad`).\n\n## References\n- [CloserLookFewShot](https://github.com/wyharveychen/CloserLookFewShot) (Data loader)\n- [pytorch-meta-optimizer](https://github.com/ikostrikov/pytorch-meta-optimizer) (Casting `nn.Parameters` to `torch.Tensor` inspired from here)\n- [meta-learning-lstm](https://github.com/twitter/meta-learning-lstm) (Author's repo in Lua Torch)\n\n"
  },
  {
    "path": "dataloader.py",
    "content": "from __future__ import division, print_function, absolute_import\n\nimport os\nimport re\nimport pdb\nimport glob\nimport pickle\n\nimport torch\nimport torch.utils.data as data\nimport torchvision.datasets as datasets\nimport torchvision.transforms as transforms\nimport PIL.Image as PILI\nimport numpy as np\n\nfrom tqdm import tqdm\n\n\nclass EpisodeDataset(data.Dataset):\n\n    def __init__(self, root, phase='train', n_shot=5, n_eval=15, transform=None):\n        \"\"\"Args:\n            root (str): path to data\n            phase (str): train, val or test\n            n_shot (int): how many examples per class for training (k/n_support)\n            n_eval (int): how many examples per class for evaluation\n                - n_shot + n_eval = batch_size for data.DataLoader of ClassDataset\n            transform (torchvision.transforms): data augmentation\n        \"\"\"\n        root = os.path.join(root, phase)\n        self.labels = sorted(os.listdir(root))\n        images = [glob.glob(os.path.join(root, label, '*')) for label in self.labels]\n\n        self.episode_loader = [data.DataLoader(\n            ClassDataset(images=images[idx], label=idx, transform=transform),\n            batch_size=n_shot+n_eval, shuffle=True, num_workers=0) for idx, _ in enumerate(self.labels)]\n\n    def __getitem__(self, idx):\n        return next(iter(self.episode_loader[idx]))\n\n    def __len__(self):\n        return len(self.labels)\n\n\nclass ClassDataset(data.Dataset):\n\n    def __init__(self, images, label, transform=None):\n        \"\"\"Args:\n            images (list of str): each item is a path to an image of the same label\n            label (int): the label of all the images\n        \"\"\"\n        self.images = images\n        self.label = label\n        self.transform = transform\n\n    def __getitem__(self, idx):\n        image = PILI.open(self.images[idx]).convert('RGB')\n        if self.transform is not None:\n            image = self.transform(image)\n\n        return image, self.label\n\n    def __len__(self):\n        return len(self.images)\n\n\nclass EpisodicSampler(data.Sampler):\n\n    def __init__(self, total_classes, n_class, n_episode):\n        self.total_classes = total_classes\n        self.n_class = n_class\n        self.n_episode = n_episode\n\n    def __iter__(self):\n        for i in range(self.n_episode):\n            yield torch.randperm(self.total_classes)[:self.n_class]\n\n    def __len__(self):\n        return self.n_episode\n\n\ndef prepare_data(args):\n\n    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    \n    train_set = EpisodeDataset(args.data_root, 'train', args.n_shot, args.n_eval,\n        transform=transforms.Compose([\n            transforms.RandomResizedCrop(args.image_size),\n            transforms.RandomHorizontalFlip(),\n            transforms.ColorJitter(\n                brightness=0.4,\n                contrast=0.4,\n                saturation=0.4,\n                hue=0.2),\n            transforms.ToTensor(),\n            normalize]))\n\n    val_set = EpisodeDataset(args.data_root, 'val', args.n_shot, args.n_eval,\n        transform=transforms.Compose([\n            transforms.Resize(args.image_size * 8 // 7),\n            transforms.CenterCrop(args.image_size),\n            transforms.ToTensor(),\n            normalize]))\n\n    test_set = EpisodeDataset(args.data_root, 'test', args.n_shot, args.n_eval,\n        transform=transforms.Compose([\n            transforms.Resize(args.image_size * 8 // 7),\n            transforms.CenterCrop(args.image_size),\n            transforms.ToTensor(),\n            normalize]))\n\n    train_loader = data.DataLoader(train_set, num_workers=args.n_workers, pin_memory=args.pin_mem,\n        batch_sampler=EpisodicSampler(len(train_set), args.n_class, args.episode))\n\n    val_loader = data.DataLoader(val_set, num_workers=2, pin_memory=False,\n        batch_sampler=EpisodicSampler(len(val_set), args.n_class, args.episode_val))\n\n    test_loader = data.DataLoader(test_set, num_workers=2, pin_memory=False,\n        batch_sampler=EpisodicSampler(len(test_set), args.n_class, args.episode_val))\n\n    return train_loader, val_loader, test_loader\n"
  },
  {
    "path": "learner.py",
    "content": "from __future__ import division, print_function, absolute_import\n\nimport pdb\nimport copy\nfrom collections import OrderedDict\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\n\nclass Learner(nn.Module):\n\n    def __init__(self, image_size, bn_eps, bn_momentum, n_classes):\n        super(Learner, self).__init__()\n        self.model = nn.ModuleDict({'features': nn.Sequential(OrderedDict([\n            ('conv1', nn.Conv2d(3, 32, 3, padding=1)),\n            ('norm1', nn.BatchNorm2d(32, bn_eps, bn_momentum)),\n            ('relu1', nn.ReLU(inplace=False)),\n            ('pool1', nn.MaxPool2d(2)),\n\n            ('conv2', nn.Conv2d(32, 32, 3, padding=1)),\n            ('norm2', nn.BatchNorm2d(32, bn_eps, bn_momentum)),\n            ('relu2', nn.ReLU(inplace=False)),\n            ('pool2', nn.MaxPool2d(2)),\n\n            ('conv3', nn.Conv2d(32, 32, 3, padding=1)),\n            ('norm3', nn.BatchNorm2d(32, bn_eps, bn_momentum)),\n            ('relu3', nn.ReLU(inplace=False)),\n            ('pool3', nn.MaxPool2d(2)),\n\n            ('conv4', nn.Conv2d(32, 32, 3, padding=1)),\n            ('norm4', nn.BatchNorm2d(32, bn_eps, bn_momentum)),\n            ('relu4', nn.ReLU(inplace=False)),\n            ('pool4', nn.MaxPool2d(2))]))\n        })\n\n        clr_in = image_size // 2**4\n        self.model.update({'cls': nn.Linear(32 * clr_in * clr_in, n_classes)})\n        self.criterion = nn.CrossEntropyLoss()\n\n    def forward(self, x):\n        x = self.model.features(x)\n        x = torch.reshape(x, [x.size(0), -1])\n        outputs = self.model.cls(x)\n        return outputs\n\n    def get_flat_params(self):\n        return torch.cat([p.view(-1) for p in self.model.parameters()], 0)\n\n    def copy_flat_params(self, cI):\n        idx = 0\n        for p in self.model.parameters():\n            plen = p.view(-1).size(0)\n            p.data.copy_(cI[idx: idx+plen].view_as(p))\n            idx += plen\n\n    def transfer_params(self, learner_w_grad, cI):\n        # Use load_state_dict only to copy the running mean/var in batchnorm, the values of the parameters\n        #  are going to be replaced by cI\n        self.load_state_dict(learner_w_grad.state_dict())\n        #  replace nn.Parameters with tensors from cI (NOT nn.Parameters anymore).\n        idx = 0\n        for m in self.model.modules():\n            if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear):\n                wlen = m._parameters['weight'].view(-1).size(0)\n                m._parameters['weight'] = cI[idx: idx+wlen].view_as(m._parameters['weight']).clone()\n                idx += wlen\n                if m._parameters['bias'] is not None:\n                    blen = m._parameters['bias'].view(-1).size(0)\n                    m._parameters['bias'] = cI[idx: idx+blen].view_as(m._parameters['bias']).clone()\n                    idx += blen\n\n    def reset_batch_stats(self):\n        for m in self.modules():\n            if isinstance(m, nn.BatchNorm2d):\n                m.reset_running_stats()\n\n"
  },
  {
    "path": "main.py",
    "content": "from __future__ import division, print_function, absolute_import\n\nimport os\nimport pdb\nimport copy\nimport random\nimport argparse\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom learner import Learner\nfrom metalearner import MetaLearner\nfrom dataloader import prepare_data\nfrom utils import *\n\n\nFLAGS = argparse.ArgumentParser()\nFLAGS.add_argument('--mode', choices=['train', 'test'])\n# Hyper-parameters\nFLAGS.add_argument('--n-shot', type=int,\n                   help=\"How many examples per class for training (k, n_support)\")\nFLAGS.add_argument('--n-eval', type=int,\n                   help=\"How many examples per class for evaluation (n_query)\")\nFLAGS.add_argument('--n-class', type=int,\n                   help=\"How many classes (N, n_way)\")\nFLAGS.add_argument('--input-size', type=int,\n                   help=\"Input size for the first LSTM\")\nFLAGS.add_argument('--hidden-size', type=int,\n                   help=\"Hidden size for the first LSTM\")\nFLAGS.add_argument('--lr', type=float,\n                   help=\"Learning rate\")\nFLAGS.add_argument('--episode', type=int,\n                   help=\"Episodes to train\")\nFLAGS.add_argument('--episode-val', type=int,\n                   help=\"Episodes to eval\")\nFLAGS.add_argument('--epoch', type=int,\n                   help=\"Epoch to train for an episode\")\nFLAGS.add_argument('--batch-size', type=int,\n                   help=\"Batch size when training an episode\")\nFLAGS.add_argument('--image-size', type=int,\n                   help=\"Resize image to this size\")\nFLAGS.add_argument('--grad-clip', type=float,\n                   help=\"Clip gradients larger than this number\")\nFLAGS.add_argument('--bn-momentum', type=float,\n                   help=\"Momentum parameter in BatchNorm2d\")\nFLAGS.add_argument('--bn-eps', type=float,\n                   help=\"Eps parameter in BatchNorm2d\")\n\n# Paths\nFLAGS.add_argument('--data', choices=['miniimagenet'],\n                   help=\"Name of dataset\")\nFLAGS.add_argument('--data-root', type=str,\n                   help=\"Location of data\")\nFLAGS.add_argument('--resume', type=str,\n                   help=\"Location to pth.tar\")\nFLAGS.add_argument('--save', type=str, default='logs',\n                   help=\"Location to logs and ckpts\")\n# Others\nFLAGS.add_argument('--cpu', action='store_true',\n                   help=\"Set this to use CPU, default use CUDA\")\nFLAGS.add_argument('--n-workers', type=int, default=4,\n                   help=\"How many processes for preprocessing\")\nFLAGS.add_argument('--pin-mem', type=bool, default=False,\n                   help=\"DataLoader pin_memory\")\nFLAGS.add_argument('--log-freq', type=int, default=100,\n                   help=\"Logging frequency\")\nFLAGS.add_argument('--val-freq', type=int, default=1000,\n                   help=\"Validation frequency\")\nFLAGS.add_argument('--seed', type=int,\n                   help=\"Random seed\")\n\n\ndef meta_test(eps, eval_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger):\n    for subeps, (episode_x, episode_y) in enumerate(tqdm(eval_loader, ascii=True)):\n        train_input = episode_x[:, :args.n_shot].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_shot, :]\n        train_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_shot)).to(args.dev) # [n_class * n_shot]\n        test_input = episode_x[:, args.n_shot:].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_eval, :]\n        test_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_eval)).to(args.dev) # [n_class * n_eval]\n\n        # Train learner with metalearner\n        learner_w_grad.reset_batch_stats()\n        learner_wo_grad.reset_batch_stats()\n        learner_w_grad.train()\n        learner_wo_grad.eval()\n        cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args)\n\n        learner_wo_grad.transfer_params(learner_w_grad, cI)\n        output = learner_wo_grad(test_input)\n        loss = learner_wo_grad.criterion(output, test_target)\n        acc = accuracy(output, test_target)\n \n        logger.batch_info(loss=loss.item(), acc=acc, phase='eval')\n\n    return logger.batch_info(eps=eps, totaleps=args.episode_val, phase='evaldone')\n\n\ndef train_learner(learner_w_grad, metalearner, train_input, train_target, args):\n    cI = metalearner.metalstm.cI.data\n    hs = [None]\n    for _ in range(args.epoch):\n        for i in range(0, len(train_input), args.batch_size):\n            x = train_input[i:i+args.batch_size]\n            y = train_target[i:i+args.batch_size]\n\n            # get the loss/grad\n            learner_w_grad.copy_flat_params(cI)\n            output = learner_w_grad(x)\n            loss = learner_w_grad.criterion(output, y)\n            acc = accuracy(output, y)\n            learner_w_grad.zero_grad()\n            loss.backward()\n            grad = torch.cat([p.grad.data.view(-1) / args.batch_size for p in learner_w_grad.parameters()], 0)\n\n            # preprocess grad & loss and metalearner forward\n            grad_prep = preprocess_grad_loss(grad)  # [n_learner_params, 2]\n            loss_prep = preprocess_grad_loss(loss.data.unsqueeze(0)) # [1, 2]\n            metalearner_input = [loss_prep, grad_prep, grad.unsqueeze(1)]\n            cI, h = metalearner(metalearner_input, hs[-1])\n            hs.append(h)\n\n            #print(\"training loss: {:8.6f} acc: {:6.3f}, mean grad: {:8.6f}\".format(loss, acc, torch.mean(grad)))\n\n    return cI\n\n\ndef main():\n\n    args, unparsed = FLAGS.parse_known_args()\n    if len(unparsed) != 0:\n        raise NameError(\"Argument {} not recognized\".format(unparsed))\n\n    if args.seed is None:\n        args.seed = random.randint(0, 1e3)\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    torch.manual_seed(args.seed)\n\n    if args.cpu:\n        args.dev = torch.device('cpu')\n    else:\n        if not torch.cuda.is_available():\n            raise RuntimeError(\"GPU unavailable.\")\n\n        torch.backends.cudnn.deterministic = True\n        torch.backends.cudnn.benchmark = False\n        args.dev = torch.device('cuda')\n\n    logger = GOATLogger(args)\n\n    # Get data\n    train_loader, val_loader, test_loader = prepare_data(args)\n    \n    # Set up learner, meta-learner\n    learner_w_grad = Learner(args.image_size, args.bn_eps, args.bn_momentum, args.n_class).to(args.dev)\n    learner_wo_grad = copy.deepcopy(learner_w_grad)\n    metalearner = MetaLearner(args.input_size, args.hidden_size, learner_w_grad.get_flat_params().size(0)).to(args.dev)\n    metalearner.metalstm.init_cI(learner_w_grad.get_flat_params())\n\n    # Set up loss, optimizer, learning rate scheduler\n    optim = torch.optim.Adam(metalearner.parameters(), args.lr)\n\n    if args.resume:\n        logger.loginfo(\"Initialized from: {}\".format(args.resume))\n        last_eps, metalearner, optim = resume_ckpt(metalearner, optim, args.resume, args.dev)\n\n    if args.mode == 'test':\n        _ = meta_test(last_eps, test_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger)\n        return\n\n    best_acc = 0.0\n    logger.loginfo(\"Start training\")\n    # Meta-training\n    for eps, (episode_x, episode_y) in enumerate(train_loader):\n        # episode_x.shape = [n_class, n_shot + n_eval, c, h, w]\n        # episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED\n        train_input = episode_x[:, :args.n_shot].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_shot, :]\n        train_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_shot)).to(args.dev) # [n_class * n_shot]\n        test_input = episode_x[:, args.n_shot:].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_eval, :]\n        test_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_eval)).to(args.dev) # [n_class * n_eval]\n\n        # Train learner with metalearner\n        learner_w_grad.reset_batch_stats()\n        learner_wo_grad.reset_batch_stats()\n        learner_w_grad.train()\n        learner_wo_grad.train()\n        cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args)\n\n        # Train meta-learner with validation loss\n        learner_wo_grad.transfer_params(learner_w_grad, cI)\n        output = learner_wo_grad(test_input)\n        loss = learner_wo_grad.criterion(output, test_target)\n        acc = accuracy(output, test_target)\n        \n        optim.zero_grad()\n        loss.backward()\n        nn.utils.clip_grad_norm_(metalearner.parameters(), args.grad_clip)\n        optim.step()\n\n        logger.batch_info(eps=eps, totaleps=args.episode, loss=loss.item(), acc=acc, phase='train')\n\n        # Meta-validation\n        if eps % args.val_freq == 0 and eps != 0:\n            save_ckpt(eps, metalearner, optim, args.save)\n            acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger)\n            if acc > best_acc:\n                best_acc = acc\n                logger.loginfo(\"* Best accuracy so far *\\n\")\n\n    logger.loginfo(\"Done\")\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "metalearner.py",
    "content": "from __future__ import division, print_function, absolute_import\n\nimport pdb\nimport math\nimport torch\nimport torch.nn as nn\n\n\nclass MetaLSTMCell(nn.Module):\n    \"\"\"C_t = f_t * C_{t-1} + i_t * \\tilde{C_t}\"\"\"\n    def __init__(self, input_size, hidden_size, n_learner_params):\n        super(MetaLSTMCell, self).__init__()\n        \"\"\"Args:\n            input_size (int): cell input size, default = 20\n            hidden_size (int): should be 1\n            n_learner_params (int): number of learner's parameters\n        \"\"\"\n        self.input_size = input_size\n        self.hidden_size = hidden_size\n        self.n_learner_params = n_learner_params\n        self.WF = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))\n        self.WI = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))\n        self.cI = nn.Parameter(torch.Tensor(n_learner_params, 1))\n        self.bI = nn.Parameter(torch.Tensor(1, hidden_size))\n        self.bF = nn.Parameter(torch.Tensor(1, hidden_size))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for weight in self.parameters():\n            nn.init.uniform_(weight, -0.01, 0.01)\n\n        # want initial forget value to be high and input value to be low so that \n        #  model starts with gradient descent\n        nn.init.uniform_(self.bF, 4, 6)\n        nn.init.uniform_(self.bI, -5, -4)\n\n    def init_cI(self, flat_params):\n        self.cI.data.copy_(flat_params.unsqueeze(1))\n\n    def forward(self, inputs, hx=None):\n        \"\"\"Args:\n            inputs = [x_all, grad]:\n                x_all (torch.Tensor of size [n_learner_params, input_size]): outputs from previous LSTM\n                grad (torch.Tensor of size [n_learner_params]): gradients from learner\n            hx = [f_prev, i_prev, c_prev]:\n                f (torch.Tensor of size [n_learner_params, 1]): forget gate\n                i (torch.Tensor of size [n_learner_params, 1]): input gate\n                c (torch.Tensor of size [n_learner_params, 1]): flattened learner parameters\n        \"\"\"\n        x_all, grad = inputs\n        batch, _ = x_all.size()\n\n        if hx is None:\n            f_prev = torch.zeros((batch, self.hidden_size)).to(self.WF.device)\n            i_prev = torch.zeros((batch, self.hidden_size)).to(self.WI.device)\n            c_prev = self.cI\n            hx = [f_prev, i_prev, c_prev]\n\n        f_prev, i_prev, c_prev = hx\n        \n        # f_t = sigmoid(W_f * [grad_t, loss_t, theta_{t-1}, f_{t-1}] + b_f)\n        f_next = torch.mm(torch.cat((x_all, c_prev, f_prev), 1), self.WF) + self.bF.expand_as(f_prev)\n        # i_t = sigmoid(W_i * [grad_t, loss_t, theta_{t-1}, i_{t-1}] + b_i)\n        i_next = torch.mm(torch.cat((x_all, c_prev, i_prev), 1), self.WI) + self.bI.expand_as(i_prev)\n        # next cell/params\n        c_next = torch.sigmoid(f_next).mul(c_prev) - torch.sigmoid(i_next).mul(grad)\n\n        return c_next, [f_next, i_next, c_next]\n\n    def extra_repr(self):\n        s = '{input_size}, {hidden_size}, {n_learner_params}'\n        return s.format(**self.__dict__)\n\n\nclass MetaLearner(nn.Module):\n\n    def __init__(self, input_size, hidden_size, n_learner_params):\n        super(MetaLearner, self).__init__()\n        \"\"\"Args:\n            input_size (int): for the first LSTM layer, default = 4\n            hidden_size (int): for the first LSTM layer, default = 20\n            n_learner_params (int): number of learner's parameters\n        \"\"\"\n        self.lstm = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)\n        self.metalstm = MetaLSTMCell(input_size=hidden_size, hidden_size=1, n_learner_params=n_learner_params)\n\n    def forward(self, inputs, hs=None):\n        \"\"\"Args:\n            inputs = [loss, grad_prep, grad]\n                loss (torch.Tensor of size [1, 2])\n                grad_prep (torch.Tensor of size [n_learner_params, 2])\n                grad (torch.Tensor of size [n_learner_params])\n\n            hs = [(lstm_hn, lstm_cn), [metalstm_fn, metalstm_in, metalstm_cn]]\n        \"\"\"\n        loss, grad_prep, grad = inputs\n        loss = loss.expand_as(grad_prep)\n        inputs = torch.cat((loss, grad_prep), 1)   # [n_learner_params, 4]\n\n        if hs is None:\n            hs = [None, None]\n\n        lstmhx, lstmcx = self.lstm(inputs, hs[0])\n        flat_learner_unsqzd, metalstm_hs = self.metalstm([lstmhx, grad], hs[1])\n\n        return flat_learner_unsqzd.squeeze(), [(lstmhx, lstmcx), metalstm_hs]\n\n"
  },
  {
    "path": "scripts/eval_5s_5c.sh",
    "content": "#!/bin/bash\n#\n# For 5-shot, 5-class evaluation, hyper-parameters follow github.com/twitter/meta-learning-lstm\n\npython main.py --mode test \\\n               --resume logs-719/ckpts/meta-learner-42000.pth.tar \\\n               --n-shot 5 \\\n               --n-eval 15 \\\n               --n-class 5 \\\n               --input-size 4 \\\n               --hidden-size 20 \\\n               --lr 1e-3 \\\n               --episode 50000 \\\n               --episode-val 100 \\\n               --epoch 8 \\\n               --batch-size 25 \\\n               --image-size 84 \\\n               --grad-clip 0.25 \\\n               --bn-momentum 0.95 \\\n               --bn-eps 1e-3 \\\n               --data miniimagenet \\\n               --data-root data/miniImagenet/ \\\n               --pin-mem True \\\n               --log-freq 100\n"
  },
  {
    "path": "scripts/train_5s_5c.sh",
    "content": "#!/bin/bash\n#\n# For 5-shot, 5-class training\n# Hyper-parameters follow https://github.com/twitter/meta-learning-lstm\n\npython main.py --mode train \\\n               --n-shot 5 \\\n               --n-eval 15 \\\n               --n-class 5 \\\n               --input-size 4 \\\n               --hidden-size 20 \\\n               --lr 1e-3 \\\n               --episode 50000 \\\n               --episode-val 100 \\\n               --epoch 8 \\\n               --batch-size 25 \\\n               --image-size 84 \\\n               --grad-clip 0.25 \\\n               --bn-momentum 0.95 \\\n               --bn-eps 1e-3 \\\n               --data miniimagenet \\\n               --data-root data/miniImagenet/ \\\n               --pin-mem True \\\n               --log-freq 50 \\\n               --val-freq 1000\n"
  },
  {
    "path": "utils.py",
    "content": "from __future__ import division, print_function, absolute_import\n\nimport os\nimport pdb\nimport logging\n\nimport torch\nimport numpy as np\n\n\nclass GOATLogger:\n\n    def __init__(self, args):\n        args.save = args.save + '-{}'.format(args.seed)\n\n        self.mode = args.mode\n        self.save_root = args.save\n        self.log_freq = args.log_freq\n\n        if self.mode == 'train':\n            if not os.path.exists(self.save_root):\n                os.mkdir(self.save_root)\n            filename = os.path.join(self.save_root, 'console.log')\n            logging.basicConfig(level=logging.DEBUG,\n                format='%(asctime)s.%(msecs)03d - %(message)s',\n                datefmt='%b-%d %H:%M:%S',\n                filename=filename,\n                filemode='w')\n            console = logging.StreamHandler()\n            console.setLevel(logging.INFO)\n            console.setFormatter(logging.Formatter('%(message)s'))\n            logging.getLogger('').addHandler(console)\n\n            logging.info(\"Logger created at {}\".format(filename))\n        else:\n            logging.basicConfig(level=logging.INFO,\n                format='%(asctime)s.%(msecs)03d - %(message)s',\n                datefmt='%b-%d %H:%M:%S')\n\n        logging.info(\"Random Seed: {}\".format(args.seed))\n        self.reset_stats()\n\n    def reset_stats(self):\n        if self.mode == 'train':\n           self.stats = {'train': {'loss': [], 'acc': []},\n                          'eval': {'loss': [], 'acc': []}}\n        else:\n            self.stats = {'eval': {'loss': [], 'acc': []}}\n\n    def batch_info(self, **kwargs):\n        if kwargs['phase'] == 'train':\n            self.stats['train']['loss'].append(kwargs['loss'])\n            self.stats['train']['acc'].append(kwargs['acc'])\n\n            if kwargs['eps'] % self.log_freq == 0 and kwargs['eps'] != 0:\n                loss_mean = np.mean(self.stats['train']['loss'])\n                acc_mean = np.mean(self.stats['train']['acc'])\n                #self.draw_stats()\n                self.loginfo(\"[{:5d}/{:5d}] loss: {:6.4f} ({:6.4f}), acc: {:6.3f}% ({:6.3f}%)\".format(\\\n                    kwargs['eps'], kwargs['totaleps'], kwargs['loss'], loss_mean, kwargs['acc'], acc_mean))\n\n        elif kwargs['phase'] == 'eval':\n            self.stats['eval']['loss'].append(kwargs['loss'])\n            self.stats['eval']['acc'].append(kwargs['acc'])\n\n        elif kwargs['phase'] == 'evaldone':\n            loss_mean = np.mean(self.stats['eval']['loss'])\n            loss_std = np.std(self.stats['eval']['loss'])\n            acc_mean = np.mean(self.stats['eval']['acc'])\n            acc_std = np.std(self.stats['eval']['acc'])\n            self.loginfo(\"[{:5d}] Eval ({:3d} episode) - loss: {:6.4f} +- {:6.4f}, acc: {:6.3f} +- {:5.3f}%\".format(\\\n                kwargs['eps'], kwargs['totaleps'], loss_mean, loss_std, acc_mean, acc_std))\n\n            self.reset_stats()\n            return acc_mean\n\n        else:\n            raise ValueError(\"phase {} not supported\".format(kwargs['phase']))\n\n    def logdebug(self, strout):\n        logging.debug(strout)\n    def loginfo(self, strout):\n        logging.info(strout)\n\n\ndef accuracy(output, target, topk=(1,)):\n    with torch.no_grad():\n        maxk = max(topk)\n        batch_size = target.size(0)\n\n        _, pred = output.topk(maxk, 1, True, True)\n        pred = pred.t()\n        correct = pred.eq(target.view(1, -1).expand_as(pred))\n\n        res = []\n        for k in topk:\n            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)\n            res.append(correct_k.mul_(100.0 / batch_size))\n        return res[0].item() if len(res) == 1 else [r.item() for r in res]\n\n\ndef save_ckpt(episode, metalearner, optim, save):\n    if not os.path.exists(os.path.join(save, 'ckpts')):\n        os.mkdir(os.path.join(save, 'ckpts'))\n\n    torch.save({\n        'episode': episode,\n        'metalearner': metalearner.state_dict(),\n        'optim': optim.state_dict()\n    }, os.path.join(save, 'ckpts', 'meta-learner-{}.pth.tar'.format(episode)))\n\n\ndef resume_ckpt(metalearner, optim, resume, device):\n    ckpt = torch.load(resume, map_location=device)\n    last_episode = ckpt['episode']\n    metalearner.load_state_dict(ckpt['metalearner'])\n    optim.load_state_dict(ckpt['optim'])\n    return last_episode, metalearner, optim\n\n\ndef preprocess_grad_loss(x):\n    p = 10\n    indicator = (x.abs() >= np.exp(-p)).to(torch.float32)\n\n    # preproc1\n    x_proc1 = indicator * torch.log(x.abs() + 1e-8) / p + (1 - indicator) * -1\n    # preproc2\n    x_proc2 = indicator * torch.sign(x) + (1 - indicator) * np.exp(p) * x\n    return torch.stack((x_proc1, x_proc2), 1)\n\n"
  }
]