master fcc68baad71a cached
8 files
29.6 KB
8.0k tokens
42 symbols
1 requests
Download .txt
Repository: markdtw/meta-learning-lstm-pytorch
Branch: master
Commit: fcc68baad71a
Files: 8
Total size: 29.6 KB

Directory structure:
gitextract_ouj1iose/

├── README.md
├── dataloader.py
├── learner.py
├── main.py
├── metalearner.py
├── scripts/
│   ├── eval_5s_5c.sh
│   └── train_5s_5c.sh
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
# Optimization as a Model for Few-shot Learning
Pytorch implementation of [Optimization as a Model for Few-shot Learning](https://openreview.net/forum?id=rJY0-Kcll) in ICLR 2017 (Oral)

![Model Architecture](https://i.imgur.com/lydKeUc.png)

## Prerequisites
- python 3+
- pytorch 0.4+ (developed on 1.0.1 with cuda 9.0)
- [pillow](https://pillow.readthedocs.io/en/stable/installation.html)
- [tqdm](https://tqdm.github.io/) (a nice progress bar)

## Data
- Mini-Imagenet as described [here](https://github.com/twitter/meta-learning-lstm/tree/master/data/miniImagenet)
  - You can download it from [here](https://drive.google.com/file/d/1rV3aj_hgfNTfCakffpPm7Vhpr1in87CR/view?usp=sharing) (~2.7GB, google drive link)

## Preparation
- Make sure Mini-Imagenet is split properly. For example:
  ```
  - data/
    - miniImagenet/
      - train/
        - n01532829/
          - n0153282900000005.jpg
          - ...
        - n01558993/
        - ...
      - val/
        - n01855672/
        - ...
      - test/
        - ...
  - main.py
  - ...
  ```
  - It'd be set if you download and extract Mini-Imagenet from the link above
- Check out `scripts/train_5s_5c.sh`, make sure `--data-root` is properly set

## Run
For 5-shot, 5-class training, run
```bash
bash scripts/train_5s_5c.sh
```
Hyper-parameters are referred to the [author's repo](https://github.com/twitter/meta-learning-lstm).

For 5-shot, 5-class evaluation, run *(remember to change `--resume` and `--seed` arguments)*
```bash
bash scripts/eval_5s_5c.sh
```

## Notes
- Results (This repo is developed following the [pytorch reproducibility guideline](https://pytorch.org/docs/stable/notes/randomness.html)):

|seed|train episodes|val episodes|val acc mean|val acc std|test episodes|test acc mean|test acc std|
|-|-|-|-|-|-|-|-|
|719|41000|100|59.08|9.9|100|56.59|8.4|
|  -|    -|  -|    -|  -|250|57.85|8.6|
|  -|    -|  -|    -|  -|600|57.76|8.6|
| 53|44000|100|58.04|9.1|100|57.85|7.7|
|  -|    -|  -|    -|  -|250|57.83|8.3|
|  -|    -|  -|    -|  -|600|58.14|8.5|

- The results I get from directly running the author's repo can be found [here](https://i.imgur.com/rtagm2c.png), I have slightly better performance (~5%) but neither results match the number in the paper (60%) *(Discussion and help are welcome!)*.
- Training with the default settings takes ~2.5 hours on a single Titan Xp while occupying ~2GB GPU memory.
- The implementation replicates two learners similar to the author's repo:
  - `learner_w_grad` functions as a regular model, get gradients and loss as inputs to meta learner.
  - `learner_wo_grad` constructs the graph for meta learner:
    - All the parameters in `learner_wo_grad` are replaced by `cI` output by meta learner.
    - `nn.Parameters` in this model are casted to `torch.Tensor` to connect the graph to meta learner.
- Several ways to **copy** a parameters from meta learner to learner depends on the scenario:
  - `copy_flat_params`: we only need the parameter values and keep the original `grad_fn`.
  - `transfer_params`: we want the values as well as the `grad_fn` (from `cI` to `learner_wo_grad`).
    - `.data.copy_` v.s. `clone()` -> the latter retains all the properties of a tensor including `grad_fn`.
    - To maintain the batch statistics, `load_state_dict` is used (from `learner_w_grad` to `learner_wo_grad`).

## References
- [CloserLookFewShot](https://github.com/wyharveychen/CloserLookFewShot) (Data loader)
- [pytorch-meta-optimizer](https://github.com/ikostrikov/pytorch-meta-optimizer) (Casting `nn.Parameters` to `torch.Tensor` inspired from here)
- [meta-learning-lstm](https://github.com/twitter/meta-learning-lstm) (Author's repo in Lua Torch)



================================================
FILE: dataloader.py
================================================
from __future__ import division, print_function, absolute_import

import os
import re
import pdb
import glob
import pickle

import torch
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import PIL.Image as PILI
import numpy as np

from tqdm import tqdm


class EpisodeDataset(data.Dataset):

    def __init__(self, root, phase='train', n_shot=5, n_eval=15, transform=None):
        """Args:
            root (str): path to data
            phase (str): train, val or test
            n_shot (int): how many examples per class for training (k/n_support)
            n_eval (int): how many examples per class for evaluation
                - n_shot + n_eval = batch_size for data.DataLoader of ClassDataset
            transform (torchvision.transforms): data augmentation
        """
        root = os.path.join(root, phase)
        self.labels = sorted(os.listdir(root))
        images = [glob.glob(os.path.join(root, label, '*')) for label in self.labels]

        self.episode_loader = [data.DataLoader(
            ClassDataset(images=images[idx], label=idx, transform=transform),
            batch_size=n_shot+n_eval, shuffle=True, num_workers=0) for idx, _ in enumerate(self.labels)]

    def __getitem__(self, idx):
        return next(iter(self.episode_loader[idx]))

    def __len__(self):
        return len(self.labels)


class ClassDataset(data.Dataset):

    def __init__(self, images, label, transform=None):
        """Args:
            images (list of str): each item is a path to an image of the same label
            label (int): the label of all the images
        """
        self.images = images
        self.label = label
        self.transform = transform

    def __getitem__(self, idx):
        image = PILI.open(self.images[idx]).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        return image, self.label

    def __len__(self):
        return len(self.images)


class EpisodicSampler(data.Sampler):

    def __init__(self, total_classes, n_class, n_episode):
        self.total_classes = total_classes
        self.n_class = n_class
        self.n_episode = n_episode

    def __iter__(self):
        for i in range(self.n_episode):
            yield torch.randperm(self.total_classes)[:self.n_class]

    def __len__(self):
        return self.n_episode


def prepare_data(args):

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
    train_set = EpisodeDataset(args.data_root, 'train', args.n_shot, args.n_eval,
        transform=transforms.Compose([
            transforms.RandomResizedCrop(args.image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(
                brightness=0.4,
                contrast=0.4,
                saturation=0.4,
                hue=0.2),
            transforms.ToTensor(),
            normalize]))

    val_set = EpisodeDataset(args.data_root, 'val', args.n_shot, args.n_eval,
        transform=transforms.Compose([
            transforms.Resize(args.image_size * 8 // 7),
            transforms.CenterCrop(args.image_size),
            transforms.ToTensor(),
            normalize]))

    test_set = EpisodeDataset(args.data_root, 'test', args.n_shot, args.n_eval,
        transform=transforms.Compose([
            transforms.Resize(args.image_size * 8 // 7),
            transforms.CenterCrop(args.image_size),
            transforms.ToTensor(),
            normalize]))

    train_loader = data.DataLoader(train_set, num_workers=args.n_workers, pin_memory=args.pin_mem,
        batch_sampler=EpisodicSampler(len(train_set), args.n_class, args.episode))

    val_loader = data.DataLoader(val_set, num_workers=2, pin_memory=False,
        batch_sampler=EpisodicSampler(len(val_set), args.n_class, args.episode_val))

    test_loader = data.DataLoader(test_set, num_workers=2, pin_memory=False,
        batch_sampler=EpisodicSampler(len(test_set), args.n_class, args.episode_val))

    return train_loader, val_loader, test_loader


================================================
FILE: learner.py
================================================
from __future__ import division, print_function, absolute_import

import pdb
import copy
from collections import OrderedDict

import torch
import torch.nn as nn
import numpy as np

class Learner(nn.Module):

    def __init__(self, image_size, bn_eps, bn_momentum, n_classes):
        super(Learner, self).__init__()
        self.model = nn.ModuleDict({'features': nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(3, 32, 3, padding=1)),
            ('norm1', nn.BatchNorm2d(32, bn_eps, bn_momentum)),
            ('relu1', nn.ReLU(inplace=False)),
            ('pool1', nn.MaxPool2d(2)),

            ('conv2', nn.Conv2d(32, 32, 3, padding=1)),
            ('norm2', nn.BatchNorm2d(32, bn_eps, bn_momentum)),
            ('relu2', nn.ReLU(inplace=False)),
            ('pool2', nn.MaxPool2d(2)),

            ('conv3', nn.Conv2d(32, 32, 3, padding=1)),
            ('norm3', nn.BatchNorm2d(32, bn_eps, bn_momentum)),
            ('relu3', nn.ReLU(inplace=False)),
            ('pool3', nn.MaxPool2d(2)),

            ('conv4', nn.Conv2d(32, 32, 3, padding=1)),
            ('norm4', nn.BatchNorm2d(32, bn_eps, bn_momentum)),
            ('relu4', nn.ReLU(inplace=False)),
            ('pool4', nn.MaxPool2d(2))]))
        })

        clr_in = image_size // 2**4
        self.model.update({'cls': nn.Linear(32 * clr_in * clr_in, n_classes)})
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.model.features(x)
        x = torch.reshape(x, [x.size(0), -1])
        outputs = self.model.cls(x)
        return outputs

    def get_flat_params(self):
        return torch.cat([p.view(-1) for p in self.model.parameters()], 0)

    def copy_flat_params(self, cI):
        idx = 0
        for p in self.model.parameters():
            plen = p.view(-1).size(0)
            p.data.copy_(cI[idx: idx+plen].view_as(p))
            idx += plen

    def transfer_params(self, learner_w_grad, cI):
        # Use load_state_dict only to copy the running mean/var in batchnorm, the values of the parameters
        #  are going to be replaced by cI
        self.load_state_dict(learner_w_grad.state_dict())
        #  replace nn.Parameters with tensors from cI (NOT nn.Parameters anymore).
        idx = 0
        for m in self.model.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear):
                wlen = m._parameters['weight'].view(-1).size(0)
                m._parameters['weight'] = cI[idx: idx+wlen].view_as(m._parameters['weight']).clone()
                idx += wlen
                if m._parameters['bias'] is not None:
                    blen = m._parameters['bias'].view(-1).size(0)
                    m._parameters['bias'] = cI[idx: idx+blen].view_as(m._parameters['bias']).clone()
                    idx += blen

    def reset_batch_stats(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.reset_running_stats()



================================================
FILE: main.py
================================================
from __future__ import division, print_function, absolute_import

import os
import pdb
import copy
import random
import argparse

import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm

from learner import Learner
from metalearner import MetaLearner
from dataloader import prepare_data
from utils import *


FLAGS = argparse.ArgumentParser()
FLAGS.add_argument('--mode', choices=['train', 'test'])
# Hyper-parameters
FLAGS.add_argument('--n-shot', type=int,
                   help="How many examples per class for training (k, n_support)")
FLAGS.add_argument('--n-eval', type=int,
                   help="How many examples per class for evaluation (n_query)")
FLAGS.add_argument('--n-class', type=int,
                   help="How many classes (N, n_way)")
FLAGS.add_argument('--input-size', type=int,
                   help="Input size for the first LSTM")
FLAGS.add_argument('--hidden-size', type=int,
                   help="Hidden size for the first LSTM")
FLAGS.add_argument('--lr', type=float,
                   help="Learning rate")
FLAGS.add_argument('--episode', type=int,
                   help="Episodes to train")
FLAGS.add_argument('--episode-val', type=int,
                   help="Episodes to eval")
FLAGS.add_argument('--epoch', type=int,
                   help="Epoch to train for an episode")
FLAGS.add_argument('--batch-size', type=int,
                   help="Batch size when training an episode")
FLAGS.add_argument('--image-size', type=int,
                   help="Resize image to this size")
FLAGS.add_argument('--grad-clip', type=float,
                   help="Clip gradients larger than this number")
FLAGS.add_argument('--bn-momentum', type=float,
                   help="Momentum parameter in BatchNorm2d")
FLAGS.add_argument('--bn-eps', type=float,
                   help="Eps parameter in BatchNorm2d")

# Paths
FLAGS.add_argument('--data', choices=['miniimagenet'],
                   help="Name of dataset")
FLAGS.add_argument('--data-root', type=str,
                   help="Location of data")
FLAGS.add_argument('--resume', type=str,
                   help="Location to pth.tar")
FLAGS.add_argument('--save', type=str, default='logs',
                   help="Location to logs and ckpts")
# Others
FLAGS.add_argument('--cpu', action='store_true',
                   help="Set this to use CPU, default use CUDA")
FLAGS.add_argument('--n-workers', type=int, default=4,
                   help="How many processes for preprocessing")
FLAGS.add_argument('--pin-mem', type=bool, default=False,
                   help="DataLoader pin_memory")
FLAGS.add_argument('--log-freq', type=int, default=100,
                   help="Logging frequency")
FLAGS.add_argument('--val-freq', type=int, default=1000,
                   help="Validation frequency")
FLAGS.add_argument('--seed', type=int,
                   help="Random seed")


def meta_test(eps, eval_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger):
    for subeps, (episode_x, episode_y) in enumerate(tqdm(eval_loader, ascii=True)):
        train_input = episode_x[:, :args.n_shot].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_shot, :]
        train_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_shot)).to(args.dev) # [n_class * n_shot]
        test_input = episode_x[:, args.n_shot:].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_eval, :]
        test_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_eval)).to(args.dev) # [n_class * n_eval]

        # Train learner with metalearner
        learner_w_grad.reset_batch_stats()
        learner_wo_grad.reset_batch_stats()
        learner_w_grad.train()
        learner_wo_grad.eval()
        cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args)

        learner_wo_grad.transfer_params(learner_w_grad, cI)
        output = learner_wo_grad(test_input)
        loss = learner_wo_grad.criterion(output, test_target)
        acc = accuracy(output, test_target)
 
        logger.batch_info(loss=loss.item(), acc=acc, phase='eval')

    return logger.batch_info(eps=eps, totaleps=args.episode_val, phase='evaldone')


def train_learner(learner_w_grad, metalearner, train_input, train_target, args):
    cI = metalearner.metalstm.cI.data
    hs = [None]
    for _ in range(args.epoch):
        for i in range(0, len(train_input), args.batch_size):
            x = train_input[i:i+args.batch_size]
            y = train_target[i:i+args.batch_size]

            # get the loss/grad
            learner_w_grad.copy_flat_params(cI)
            output = learner_w_grad(x)
            loss = learner_w_grad.criterion(output, y)
            acc = accuracy(output, y)
            learner_w_grad.zero_grad()
            loss.backward()
            grad = torch.cat([p.grad.data.view(-1) / args.batch_size for p in learner_w_grad.parameters()], 0)

            # preprocess grad & loss and metalearner forward
            grad_prep = preprocess_grad_loss(grad)  # [n_learner_params, 2]
            loss_prep = preprocess_grad_loss(loss.data.unsqueeze(0)) # [1, 2]
            metalearner_input = [loss_prep, grad_prep, grad.unsqueeze(1)]
            cI, h = metalearner(metalearner_input, hs[-1])
            hs.append(h)

            #print("training loss: {:8.6f} acc: {:6.3f}, mean grad: {:8.6f}".format(loss, acc, torch.mean(grad)))

    return cI


def main():

    args, unparsed = FLAGS.parse_known_args()
    if len(unparsed) != 0:
        raise NameError("Argument {} not recognized".format(unparsed))

    if args.seed is None:
        args.seed = random.randint(0, 1e3)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cpu:
        args.dev = torch.device('cpu')
    else:
        if not torch.cuda.is_available():
            raise RuntimeError("GPU unavailable.")

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        args.dev = torch.device('cuda')

    logger = GOATLogger(args)

    # Get data
    train_loader, val_loader, test_loader = prepare_data(args)
    
    # Set up learner, meta-learner
    learner_w_grad = Learner(args.image_size, args.bn_eps, args.bn_momentum, args.n_class).to(args.dev)
    learner_wo_grad = copy.deepcopy(learner_w_grad)
    metalearner = MetaLearner(args.input_size, args.hidden_size, learner_w_grad.get_flat_params().size(0)).to(args.dev)
    metalearner.metalstm.init_cI(learner_w_grad.get_flat_params())

    # Set up loss, optimizer, learning rate scheduler
    optim = torch.optim.Adam(metalearner.parameters(), args.lr)

    if args.resume:
        logger.loginfo("Initialized from: {}".format(args.resume))
        last_eps, metalearner, optim = resume_ckpt(metalearner, optim, args.resume, args.dev)

    if args.mode == 'test':
        _ = meta_test(last_eps, test_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger)
        return

    best_acc = 0.0
    logger.loginfo("Start training")
    # Meta-training
    for eps, (episode_x, episode_y) in enumerate(train_loader):
        # episode_x.shape = [n_class, n_shot + n_eval, c, h, w]
        # episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED
        train_input = episode_x[:, :args.n_shot].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_shot, :]
        train_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_shot)).to(args.dev) # [n_class * n_shot]
        test_input = episode_x[:, args.n_shot:].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_eval, :]
        test_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_eval)).to(args.dev) # [n_class * n_eval]

        # Train learner with metalearner
        learner_w_grad.reset_batch_stats()
        learner_wo_grad.reset_batch_stats()
        learner_w_grad.train()
        learner_wo_grad.train()
        cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args)

        # Train meta-learner with validation loss
        learner_wo_grad.transfer_params(learner_w_grad, cI)
        output = learner_wo_grad(test_input)
        loss = learner_wo_grad.criterion(output, test_target)
        acc = accuracy(output, test_target)
        
        optim.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(metalearner.parameters(), args.grad_clip)
        optim.step()

        logger.batch_info(eps=eps, totaleps=args.episode, loss=loss.item(), acc=acc, phase='train')

        # Meta-validation
        if eps % args.val_freq == 0 and eps != 0:
            save_ckpt(eps, metalearner, optim, args.save)
            acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger)
            if acc > best_acc:
                best_acc = acc
                logger.loginfo("* Best accuracy so far *\n")

    logger.loginfo("Done")


if __name__ == '__main__':
    main()


================================================
FILE: metalearner.py
================================================
from __future__ import division, print_function, absolute_import

import pdb
import math
import torch
import torch.nn as nn


class MetaLSTMCell(nn.Module):
    """C_t = f_t * C_{t-1} + i_t * \tilde{C_t}"""
    def __init__(self, input_size, hidden_size, n_learner_params):
        super(MetaLSTMCell, self).__init__()
        """Args:
            input_size (int): cell input size, default = 20
            hidden_size (int): should be 1
            n_learner_params (int): number of learner's parameters
        """
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_learner_params = n_learner_params
        self.WF = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))
        self.WI = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))
        self.cI = nn.Parameter(torch.Tensor(n_learner_params, 1))
        self.bI = nn.Parameter(torch.Tensor(1, hidden_size))
        self.bF = nn.Parameter(torch.Tensor(1, hidden_size))

        self.reset_parameters()

    def reset_parameters(self):
        for weight in self.parameters():
            nn.init.uniform_(weight, -0.01, 0.01)

        # want initial forget value to be high and input value to be low so that 
        #  model starts with gradient descent
        nn.init.uniform_(self.bF, 4, 6)
        nn.init.uniform_(self.bI, -5, -4)

    def init_cI(self, flat_params):
        self.cI.data.copy_(flat_params.unsqueeze(1))

    def forward(self, inputs, hx=None):
        """Args:
            inputs = [x_all, grad]:
                x_all (torch.Tensor of size [n_learner_params, input_size]): outputs from previous LSTM
                grad (torch.Tensor of size [n_learner_params]): gradients from learner
            hx = [f_prev, i_prev, c_prev]:
                f (torch.Tensor of size [n_learner_params, 1]): forget gate
                i (torch.Tensor of size [n_learner_params, 1]): input gate
                c (torch.Tensor of size [n_learner_params, 1]): flattened learner parameters
        """
        x_all, grad = inputs
        batch, _ = x_all.size()

        if hx is None:
            f_prev = torch.zeros((batch, self.hidden_size)).to(self.WF.device)
            i_prev = torch.zeros((batch, self.hidden_size)).to(self.WI.device)
            c_prev = self.cI
            hx = [f_prev, i_prev, c_prev]

        f_prev, i_prev, c_prev = hx
        
        # f_t = sigmoid(W_f * [grad_t, loss_t, theta_{t-1}, f_{t-1}] + b_f)
        f_next = torch.mm(torch.cat((x_all, c_prev, f_prev), 1), self.WF) + self.bF.expand_as(f_prev)
        # i_t = sigmoid(W_i * [grad_t, loss_t, theta_{t-1}, i_{t-1}] + b_i)
        i_next = torch.mm(torch.cat((x_all, c_prev, i_prev), 1), self.WI) + self.bI.expand_as(i_prev)
        # next cell/params
        c_next = torch.sigmoid(f_next).mul(c_prev) - torch.sigmoid(i_next).mul(grad)

        return c_next, [f_next, i_next, c_next]

    def extra_repr(self):
        s = '{input_size}, {hidden_size}, {n_learner_params}'
        return s.format(**self.__dict__)


class MetaLearner(nn.Module):

    def __init__(self, input_size, hidden_size, n_learner_params):
        super(MetaLearner, self).__init__()
        """Args:
            input_size (int): for the first LSTM layer, default = 4
            hidden_size (int): for the first LSTM layer, default = 20
            n_learner_params (int): number of learner's parameters
        """
        self.lstm = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
        self.metalstm = MetaLSTMCell(input_size=hidden_size, hidden_size=1, n_learner_params=n_learner_params)

    def forward(self, inputs, hs=None):
        """Args:
            inputs = [loss, grad_prep, grad]
                loss (torch.Tensor of size [1, 2])
                grad_prep (torch.Tensor of size [n_learner_params, 2])
                grad (torch.Tensor of size [n_learner_params])

            hs = [(lstm_hn, lstm_cn), [metalstm_fn, metalstm_in, metalstm_cn]]
        """
        loss, grad_prep, grad = inputs
        loss = loss.expand_as(grad_prep)
        inputs = torch.cat((loss, grad_prep), 1)   # [n_learner_params, 4]

        if hs is None:
            hs = [None, None]

        lstmhx, lstmcx = self.lstm(inputs, hs[0])
        flat_learner_unsqzd, metalstm_hs = self.metalstm([lstmhx, grad], hs[1])

        return flat_learner_unsqzd.squeeze(), [(lstmhx, lstmcx), metalstm_hs]



================================================
FILE: scripts/eval_5s_5c.sh
================================================
#!/bin/bash
#
# For 5-shot, 5-class evaluation, hyper-parameters follow github.com/twitter/meta-learning-lstm

python main.py --mode test \
               --resume logs-719/ckpts/meta-learner-42000.pth.tar \
               --n-shot 5 \
               --n-eval 15 \
               --n-class 5 \
               --input-size 4 \
               --hidden-size 20 \
               --lr 1e-3 \
               --episode 50000 \
               --episode-val 100 \
               --epoch 8 \
               --batch-size 25 \
               --image-size 84 \
               --grad-clip 0.25 \
               --bn-momentum 0.95 \
               --bn-eps 1e-3 \
               --data miniimagenet \
               --data-root data/miniImagenet/ \
               --pin-mem True \
               --log-freq 100


================================================
FILE: scripts/train_5s_5c.sh
================================================
#!/bin/bash
#
# For 5-shot, 5-class training
# Hyper-parameters follow https://github.com/twitter/meta-learning-lstm

python main.py --mode train \
               --n-shot 5 \
               --n-eval 15 \
               --n-class 5 \
               --input-size 4 \
               --hidden-size 20 \
               --lr 1e-3 \
               --episode 50000 \
               --episode-val 100 \
               --epoch 8 \
               --batch-size 25 \
               --image-size 84 \
               --grad-clip 0.25 \
               --bn-momentum 0.95 \
               --bn-eps 1e-3 \
               --data miniimagenet \
               --data-root data/miniImagenet/ \
               --pin-mem True \
               --log-freq 50 \
               --val-freq 1000


================================================
FILE: utils.py
================================================
from __future__ import division, print_function, absolute_import

import os
import pdb
import logging

import torch
import numpy as np


class GOATLogger:

    def __init__(self, args):
        args.save = args.save + '-{}'.format(args.seed)

        self.mode = args.mode
        self.save_root = args.save
        self.log_freq = args.log_freq

        if self.mode == 'train':
            if not os.path.exists(self.save_root):
                os.mkdir(self.save_root)
            filename = os.path.join(self.save_root, 'console.log')
            logging.basicConfig(level=logging.DEBUG,
                format='%(asctime)s.%(msecs)03d - %(message)s',
                datefmt='%b-%d %H:%M:%S',
                filename=filename,
                filemode='w')
            console = logging.StreamHandler()
            console.setLevel(logging.INFO)
            console.setFormatter(logging.Formatter('%(message)s'))
            logging.getLogger('').addHandler(console)

            logging.info("Logger created at {}".format(filename))
        else:
            logging.basicConfig(level=logging.INFO,
                format='%(asctime)s.%(msecs)03d - %(message)s',
                datefmt='%b-%d %H:%M:%S')

        logging.info("Random Seed: {}".format(args.seed))
        self.reset_stats()

    def reset_stats(self):
        if self.mode == 'train':
           self.stats = {'train': {'loss': [], 'acc': []},
                          'eval': {'loss': [], 'acc': []}}
        else:
            self.stats = {'eval': {'loss': [], 'acc': []}}

    def batch_info(self, **kwargs):
        if kwargs['phase'] == 'train':
            self.stats['train']['loss'].append(kwargs['loss'])
            self.stats['train']['acc'].append(kwargs['acc'])

            if kwargs['eps'] % self.log_freq == 0 and kwargs['eps'] != 0:
                loss_mean = np.mean(self.stats['train']['loss'])
                acc_mean = np.mean(self.stats['train']['acc'])
                #self.draw_stats()
                self.loginfo("[{:5d}/{:5d}] loss: {:6.4f} ({:6.4f}), acc: {:6.3f}% ({:6.3f}%)".format(\
                    kwargs['eps'], kwargs['totaleps'], kwargs['loss'], loss_mean, kwargs['acc'], acc_mean))

        elif kwargs['phase'] == 'eval':
            self.stats['eval']['loss'].append(kwargs['loss'])
            self.stats['eval']['acc'].append(kwargs['acc'])

        elif kwargs['phase'] == 'evaldone':
            loss_mean = np.mean(self.stats['eval']['loss'])
            loss_std = np.std(self.stats['eval']['loss'])
            acc_mean = np.mean(self.stats['eval']['acc'])
            acc_std = np.std(self.stats['eval']['acc'])
            self.loginfo("[{:5d}] Eval ({:3d} episode) - loss: {:6.4f} +- {:6.4f}, acc: {:6.3f} +- {:5.3f}%".format(\
                kwargs['eps'], kwargs['totaleps'], loss_mean, loss_std, acc_mean, acc_std))

            self.reset_stats()
            return acc_mean

        else:
            raise ValueError("phase {} not supported".format(kwargs['phase']))

    def logdebug(self, strout):
        logging.debug(strout)
    def loginfo(self, strout):
        logging.info(strout)


def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res[0].item() if len(res) == 1 else [r.item() for r in res]


def save_ckpt(episode, metalearner, optim, save):
    if not os.path.exists(os.path.join(save, 'ckpts')):
        os.mkdir(os.path.join(save, 'ckpts'))

    torch.save({
        'episode': episode,
        'metalearner': metalearner.state_dict(),
        'optim': optim.state_dict()
    }, os.path.join(save, 'ckpts', 'meta-learner-{}.pth.tar'.format(episode)))


def resume_ckpt(metalearner, optim, resume, device):
    ckpt = torch.load(resume, map_location=device)
    last_episode = ckpt['episode']
    metalearner.load_state_dict(ckpt['metalearner'])
    optim.load_state_dict(ckpt['optim'])
    return last_episode, metalearner, optim


def preprocess_grad_loss(x):
    p = 10
    indicator = (x.abs() >= np.exp(-p)).to(torch.float32)

    # preproc1
    x_proc1 = indicator * torch.log(x.abs() + 1e-8) / p + (1 - indicator) * -1
    # preproc2
    x_proc2 = indicator * torch.sign(x) + (1 - indicator) * np.exp(p) * x
    return torch.stack((x_proc1, x_proc2), 1)

Download .txt
gitextract_ouj1iose/

├── README.md
├── dataloader.py
├── learner.py
├── main.py
├── metalearner.py
├── scripts/
│   ├── eval_5s_5c.sh
│   └── train_5s_5c.sh
└── utils.py
Download .txt
SYMBOL INDEX (42 symbols across 5 files)

FILE: dataloader.py
  class EpisodeDataset (line 19) | class EpisodeDataset(data.Dataset):
    method __init__ (line 21) | def __init__(self, root, phase='train', n_shot=5, n_eval=15, transform...
    method __getitem__ (line 38) | def __getitem__(self, idx):
    method __len__ (line 41) | def __len__(self):
  class ClassDataset (line 45) | class ClassDataset(data.Dataset):
    method __init__ (line 47) | def __init__(self, images, label, transform=None):
    method __getitem__ (line 56) | def __getitem__(self, idx):
    method __len__ (line 63) | def __len__(self):
  class EpisodicSampler (line 67) | class EpisodicSampler(data.Sampler):
    method __init__ (line 69) | def __init__(self, total_classes, n_class, n_episode):
    method __iter__ (line 74) | def __iter__(self):
    method __len__ (line 78) | def __len__(self):
  function prepare_data (line 82) | def prepare_data(args):

FILE: learner.py
  class Learner (line 11) | class Learner(nn.Module):
    method __init__ (line 13) | def __init__(self, image_size, bn_eps, bn_momentum, n_classes):
    method forward (line 41) | def forward(self, x):
    method get_flat_params (line 47) | def get_flat_params(self):
    method copy_flat_params (line 50) | def copy_flat_params(self, cI):
    method transfer_params (line 57) | def transfer_params(self, learner_w_grad, cI):
    method reset_batch_stats (line 73) | def reset_batch_stats(self):

FILE: main.py
  function meta_test (line 76) | def meta_test(eps, eval_loader, learner_w_grad, learner_wo_grad, metalea...
  function train_learner (line 100) | def train_learner(learner_w_grad, metalearner, train_input, train_target...
  function main (line 129) | def main():

FILE: metalearner.py
  class MetaLSTMCell (line 9) | class MetaLSTMCell(nn.Module):
    method __init__ (line 11) | def __init__(self, input_size, hidden_size, n_learner_params):
    method reset_parameters (line 29) | def reset_parameters(self):
    method init_cI (line 38) | def init_cI(self, flat_params):
    method forward (line 41) | def forward(self, inputs, hx=None):
    method extra_repr (line 71) | def extra_repr(self):
  class MetaLearner (line 76) | class MetaLearner(nn.Module):
    method __init__ (line 78) | def __init__(self, input_size, hidden_size, n_learner_params):
    method forward (line 88) | def forward(self, inputs, hs=None):

FILE: utils.py
  class GOATLogger (line 11) | class GOATLogger:
    method __init__ (line 13) | def __init__(self, args):
    method reset_stats (line 43) | def reset_stats(self):
    method batch_info (line 50) | def batch_info(self, **kwargs):
    method logdebug (line 80) | def logdebug(self, strout):
    method loginfo (line 82) | def loginfo(self, strout):
  function accuracy (line 86) | def accuracy(output, target, topk=(1,)):
  function save_ckpt (line 102) | def save_ckpt(episode, metalearner, optim, save):
  function resume_ckpt (line 113) | def resume_ckpt(metalearner, optim, resume, device):
  function preprocess_grad_loss (line 121) | def preprocess_grad_loss(x):
Condensed preview — 8 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (32K chars).
[
  {
    "path": "README.md",
    "chars": 3675,
    "preview": "# Optimization as a Model for Few-shot Learning\nPytorch implementation of [Optimization as a Model for Few-shot Learning"
  },
  {
    "path": "dataloader.py",
    "chars": 4105,
    "preview": "from __future__ import division, print_function, absolute_import\n\nimport os\nimport re\nimport pdb\nimport glob\nimport pick"
  },
  {
    "path": "learner.py",
    "chars": 2983,
    "preview": "from __future__ import division, print_function, absolute_import\n\nimport pdb\nimport copy\nfrom collections import Ordered"
  },
  {
    "path": "main.py",
    "chars": 8961,
    "preview": "from __future__ import division, print_function, absolute_import\n\nimport os\nimport pdb\nimport copy\nimport random\nimport "
  },
  {
    "path": "metalearner.py",
    "chars": 4397,
    "preview": "from __future__ import division, print_function, absolute_import\n\nimport pdb\nimport math\nimport torch\nimport torch.nn as"
  },
  {
    "path": "scripts/eval_5s_5c.sh",
    "chars": 796,
    "preview": "#!/bin/bash\n#\n# For 5-shot, 5-class evaluation, hyper-parameters follow github.com/twitter/meta-learning-lstm\n\npython ma"
  },
  {
    "path": "scripts/train_5s_5c.sh",
    "chars": 768,
    "preview": "#!/bin/bash\n#\n# For 5-shot, 5-class training\n# Hyper-parameters follow https://github.com/twitter/meta-learning-lstm\n\npy"
  },
  {
    "path": "utils.py",
    "chars": 4618,
    "preview": "from __future__ import division, print_function, absolute_import\n\nimport os\nimport pdb\nimport logging\n\nimport torch\nimpo"
  }
]

About this extraction

This page contains the full source code of the markdtw/meta-learning-lstm-pytorch GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 8 files (29.6 KB), approximately 8.0k tokens, and a symbol index with 42 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!