Repository: markdtw/meta-learning-lstm-pytorch
Branch: master
Commit: fcc68baad71a
Files: 8
Total size: 29.6 KB
Directory structure:
gitextract_ouj1iose/
├── README.md
├── dataloader.py
├── learner.py
├── main.py
├── metalearner.py
├── scripts/
│ ├── eval_5s_5c.sh
│ └── train_5s_5c.sh
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
# Optimization as a Model for Few-shot Learning
Pytorch implementation of [Optimization as a Model for Few-shot Learning](https://openreview.net/forum?id=rJY0-Kcll) in ICLR 2017 (Oral)

## Prerequisites
- python 3+
- pytorch 0.4+ (developed on 1.0.1 with cuda 9.0)
- [pillow](https://pillow.readthedocs.io/en/stable/installation.html)
- [tqdm](https://tqdm.github.io/) (a nice progress bar)
## Data
- Mini-Imagenet as described [here](https://github.com/twitter/meta-learning-lstm/tree/master/data/miniImagenet)
- You can download it from [here](https://drive.google.com/file/d/1rV3aj_hgfNTfCakffpPm7Vhpr1in87CR/view?usp=sharing) (~2.7GB, google drive link)
## Preparation
- Make sure Mini-Imagenet is split properly. For example:
```
- data/
- miniImagenet/
- train/
- n01532829/
- n0153282900000005.jpg
- ...
- n01558993/
- ...
- val/
- n01855672/
- ...
- test/
- ...
- main.py
- ...
```
- It'd be set if you download and extract Mini-Imagenet from the link above
- Check out `scripts/train_5s_5c.sh`, make sure `--data-root` is properly set
## Run
For 5-shot, 5-class training, run
```bash
bash scripts/train_5s_5c.sh
```
Hyper-parameters are referred to the [author's repo](https://github.com/twitter/meta-learning-lstm).
For 5-shot, 5-class evaluation, run *(remember to change `--resume` and `--seed` arguments)*
```bash
bash scripts/eval_5s_5c.sh
```
## Notes
- Results (This repo is developed following the [pytorch reproducibility guideline](https://pytorch.org/docs/stable/notes/randomness.html)):
|seed|train episodes|val episodes|val acc mean|val acc std|test episodes|test acc mean|test acc std|
|-|-|-|-|-|-|-|-|
|719|41000|100|59.08|9.9|100|56.59|8.4|
| -| -| -| -| -|250|57.85|8.6|
| -| -| -| -| -|600|57.76|8.6|
| 53|44000|100|58.04|9.1|100|57.85|7.7|
| -| -| -| -| -|250|57.83|8.3|
| -| -| -| -| -|600|58.14|8.5|
- The results I get from directly running the author's repo can be found [here](https://i.imgur.com/rtagm2c.png), I have slightly better performance (~5%) but neither results match the number in the paper (60%) *(Discussion and help are welcome!)*.
- Training with the default settings takes ~2.5 hours on a single Titan Xp while occupying ~2GB GPU memory.
- The implementation replicates two learners similar to the author's repo:
- `learner_w_grad` functions as a regular model, get gradients and loss as inputs to meta learner.
- `learner_wo_grad` constructs the graph for meta learner:
- All the parameters in `learner_wo_grad` are replaced by `cI` output by meta learner.
- `nn.Parameters` in this model are casted to `torch.Tensor` to connect the graph to meta learner.
- Several ways to **copy** a parameters from meta learner to learner depends on the scenario:
- `copy_flat_params`: we only need the parameter values and keep the original `grad_fn`.
- `transfer_params`: we want the values as well as the `grad_fn` (from `cI` to `learner_wo_grad`).
- `.data.copy_` v.s. `clone()` -> the latter retains all the properties of a tensor including `grad_fn`.
- To maintain the batch statistics, `load_state_dict` is used (from `learner_w_grad` to `learner_wo_grad`).
## References
- [CloserLookFewShot](https://github.com/wyharveychen/CloserLookFewShot) (Data loader)
- [pytorch-meta-optimizer](https://github.com/ikostrikov/pytorch-meta-optimizer) (Casting `nn.Parameters` to `torch.Tensor` inspired from here)
- [meta-learning-lstm](https://github.com/twitter/meta-learning-lstm) (Author's repo in Lua Torch)
================================================
FILE: dataloader.py
================================================
from __future__ import division, print_function, absolute_import
import os
import re
import pdb
import glob
import pickle
import torch
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import PIL.Image as PILI
import numpy as np
from tqdm import tqdm
class EpisodeDataset(data.Dataset):
def __init__(self, root, phase='train', n_shot=5, n_eval=15, transform=None):
"""Args:
root (str): path to data
phase (str): train, val or test
n_shot (int): how many examples per class for training (k/n_support)
n_eval (int): how many examples per class for evaluation
- n_shot + n_eval = batch_size for data.DataLoader of ClassDataset
transform (torchvision.transforms): data augmentation
"""
root = os.path.join(root, phase)
self.labels = sorted(os.listdir(root))
images = [glob.glob(os.path.join(root, label, '*')) for label in self.labels]
self.episode_loader = [data.DataLoader(
ClassDataset(images=images[idx], label=idx, transform=transform),
batch_size=n_shot+n_eval, shuffle=True, num_workers=0) for idx, _ in enumerate(self.labels)]
def __getitem__(self, idx):
return next(iter(self.episode_loader[idx]))
def __len__(self):
return len(self.labels)
class ClassDataset(data.Dataset):
def __init__(self, images, label, transform=None):
"""Args:
images (list of str): each item is a path to an image of the same label
label (int): the label of all the images
"""
self.images = images
self.label = label
self.transform = transform
def __getitem__(self, idx):
image = PILI.open(self.images[idx]).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image, self.label
def __len__(self):
return len(self.images)
class EpisodicSampler(data.Sampler):
def __init__(self, total_classes, n_class, n_episode):
self.total_classes = total_classes
self.n_class = n_class
self.n_episode = n_episode
def __iter__(self):
for i in range(self.n_episode):
yield torch.randperm(self.total_classes)[:self.n_class]
def __len__(self):
return self.n_episode
def prepare_data(args):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_set = EpisodeDataset(args.data_root, 'train', args.n_shot, args.n_eval,
transform=transforms.Compose([
transforms.RandomResizedCrop(args.image_size),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2),
transforms.ToTensor(),
normalize]))
val_set = EpisodeDataset(args.data_root, 'val', args.n_shot, args.n_eval,
transform=transforms.Compose([
transforms.Resize(args.image_size * 8 // 7),
transforms.CenterCrop(args.image_size),
transforms.ToTensor(),
normalize]))
test_set = EpisodeDataset(args.data_root, 'test', args.n_shot, args.n_eval,
transform=transforms.Compose([
transforms.Resize(args.image_size * 8 // 7),
transforms.CenterCrop(args.image_size),
transforms.ToTensor(),
normalize]))
train_loader = data.DataLoader(train_set, num_workers=args.n_workers, pin_memory=args.pin_mem,
batch_sampler=EpisodicSampler(len(train_set), args.n_class, args.episode))
val_loader = data.DataLoader(val_set, num_workers=2, pin_memory=False,
batch_sampler=EpisodicSampler(len(val_set), args.n_class, args.episode_val))
test_loader = data.DataLoader(test_set, num_workers=2, pin_memory=False,
batch_sampler=EpisodicSampler(len(test_set), args.n_class, args.episode_val))
return train_loader, val_loader, test_loader
================================================
FILE: learner.py
================================================
from __future__ import division, print_function, absolute_import
import pdb
import copy
from collections import OrderedDict
import torch
import torch.nn as nn
import numpy as np
class Learner(nn.Module):
def __init__(self, image_size, bn_eps, bn_momentum, n_classes):
super(Learner, self).__init__()
self.model = nn.ModuleDict({'features': nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(3, 32, 3, padding=1)),
('norm1', nn.BatchNorm2d(32, bn_eps, bn_momentum)),
('relu1', nn.ReLU(inplace=False)),
('pool1', nn.MaxPool2d(2)),
('conv2', nn.Conv2d(32, 32, 3, padding=1)),
('norm2', nn.BatchNorm2d(32, bn_eps, bn_momentum)),
('relu2', nn.ReLU(inplace=False)),
('pool2', nn.MaxPool2d(2)),
('conv3', nn.Conv2d(32, 32, 3, padding=1)),
('norm3', nn.BatchNorm2d(32, bn_eps, bn_momentum)),
('relu3', nn.ReLU(inplace=False)),
('pool3', nn.MaxPool2d(2)),
('conv4', nn.Conv2d(32, 32, 3, padding=1)),
('norm4', nn.BatchNorm2d(32, bn_eps, bn_momentum)),
('relu4', nn.ReLU(inplace=False)),
('pool4', nn.MaxPool2d(2))]))
})
clr_in = image_size // 2**4
self.model.update({'cls': nn.Linear(32 * clr_in * clr_in, n_classes)})
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
x = self.model.features(x)
x = torch.reshape(x, [x.size(0), -1])
outputs = self.model.cls(x)
return outputs
def get_flat_params(self):
return torch.cat([p.view(-1) for p in self.model.parameters()], 0)
def copy_flat_params(self, cI):
idx = 0
for p in self.model.parameters():
plen = p.view(-1).size(0)
p.data.copy_(cI[idx: idx+plen].view_as(p))
idx += plen
def transfer_params(self, learner_w_grad, cI):
# Use load_state_dict only to copy the running mean/var in batchnorm, the values of the parameters
# are going to be replaced by cI
self.load_state_dict(learner_w_grad.state_dict())
# replace nn.Parameters with tensors from cI (NOT nn.Parameters anymore).
idx = 0
for m in self.model.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear):
wlen = m._parameters['weight'].view(-1).size(0)
m._parameters['weight'] = cI[idx: idx+wlen].view_as(m._parameters['weight']).clone()
idx += wlen
if m._parameters['bias'] is not None:
blen = m._parameters['bias'].view(-1).size(0)
m._parameters['bias'] = cI[idx: idx+blen].view_as(m._parameters['bias']).clone()
idx += blen
def reset_batch_stats(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.reset_running_stats()
================================================
FILE: main.py
================================================
from __future__ import division, print_function, absolute_import
import os
import pdb
import copy
import random
import argparse
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from learner import Learner
from metalearner import MetaLearner
from dataloader import prepare_data
from utils import *
FLAGS = argparse.ArgumentParser()
FLAGS.add_argument('--mode', choices=['train', 'test'])
# Hyper-parameters
FLAGS.add_argument('--n-shot', type=int,
help="How many examples per class for training (k, n_support)")
FLAGS.add_argument('--n-eval', type=int,
help="How many examples per class for evaluation (n_query)")
FLAGS.add_argument('--n-class', type=int,
help="How many classes (N, n_way)")
FLAGS.add_argument('--input-size', type=int,
help="Input size for the first LSTM")
FLAGS.add_argument('--hidden-size', type=int,
help="Hidden size for the first LSTM")
FLAGS.add_argument('--lr', type=float,
help="Learning rate")
FLAGS.add_argument('--episode', type=int,
help="Episodes to train")
FLAGS.add_argument('--episode-val', type=int,
help="Episodes to eval")
FLAGS.add_argument('--epoch', type=int,
help="Epoch to train for an episode")
FLAGS.add_argument('--batch-size', type=int,
help="Batch size when training an episode")
FLAGS.add_argument('--image-size', type=int,
help="Resize image to this size")
FLAGS.add_argument('--grad-clip', type=float,
help="Clip gradients larger than this number")
FLAGS.add_argument('--bn-momentum', type=float,
help="Momentum parameter in BatchNorm2d")
FLAGS.add_argument('--bn-eps', type=float,
help="Eps parameter in BatchNorm2d")
# Paths
FLAGS.add_argument('--data', choices=['miniimagenet'],
help="Name of dataset")
FLAGS.add_argument('--data-root', type=str,
help="Location of data")
FLAGS.add_argument('--resume', type=str,
help="Location to pth.tar")
FLAGS.add_argument('--save', type=str, default='logs',
help="Location to logs and ckpts")
# Others
FLAGS.add_argument('--cpu', action='store_true',
help="Set this to use CPU, default use CUDA")
FLAGS.add_argument('--n-workers', type=int, default=4,
help="How many processes for preprocessing")
FLAGS.add_argument('--pin-mem', type=bool, default=False,
help="DataLoader pin_memory")
FLAGS.add_argument('--log-freq', type=int, default=100,
help="Logging frequency")
FLAGS.add_argument('--val-freq', type=int, default=1000,
help="Validation frequency")
FLAGS.add_argument('--seed', type=int,
help="Random seed")
def meta_test(eps, eval_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger):
for subeps, (episode_x, episode_y) in enumerate(tqdm(eval_loader, ascii=True)):
train_input = episode_x[:, :args.n_shot].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_shot, :]
train_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_shot)).to(args.dev) # [n_class * n_shot]
test_input = episode_x[:, args.n_shot:].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_eval, :]
test_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_eval)).to(args.dev) # [n_class * n_eval]
# Train learner with metalearner
learner_w_grad.reset_batch_stats()
learner_wo_grad.reset_batch_stats()
learner_w_grad.train()
learner_wo_grad.eval()
cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args)
learner_wo_grad.transfer_params(learner_w_grad, cI)
output = learner_wo_grad(test_input)
loss = learner_wo_grad.criterion(output, test_target)
acc = accuracy(output, test_target)
logger.batch_info(loss=loss.item(), acc=acc, phase='eval')
return logger.batch_info(eps=eps, totaleps=args.episode_val, phase='evaldone')
def train_learner(learner_w_grad, metalearner, train_input, train_target, args):
cI = metalearner.metalstm.cI.data
hs = [None]
for _ in range(args.epoch):
for i in range(0, len(train_input), args.batch_size):
x = train_input[i:i+args.batch_size]
y = train_target[i:i+args.batch_size]
# get the loss/grad
learner_w_grad.copy_flat_params(cI)
output = learner_w_grad(x)
loss = learner_w_grad.criterion(output, y)
acc = accuracy(output, y)
learner_w_grad.zero_grad()
loss.backward()
grad = torch.cat([p.grad.data.view(-1) / args.batch_size for p in learner_w_grad.parameters()], 0)
# preprocess grad & loss and metalearner forward
grad_prep = preprocess_grad_loss(grad) # [n_learner_params, 2]
loss_prep = preprocess_grad_loss(loss.data.unsqueeze(0)) # [1, 2]
metalearner_input = [loss_prep, grad_prep, grad.unsqueeze(1)]
cI, h = metalearner(metalearner_input, hs[-1])
hs.append(h)
#print("training loss: {:8.6f} acc: {:6.3f}, mean grad: {:8.6f}".format(loss, acc, torch.mean(grad)))
return cI
def main():
args, unparsed = FLAGS.parse_known_args()
if len(unparsed) != 0:
raise NameError("Argument {} not recognized".format(unparsed))
if args.seed is None:
args.seed = random.randint(0, 1e3)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cpu:
args.dev = torch.device('cpu')
else:
if not torch.cuda.is_available():
raise RuntimeError("GPU unavailable.")
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
args.dev = torch.device('cuda')
logger = GOATLogger(args)
# Get data
train_loader, val_loader, test_loader = prepare_data(args)
# Set up learner, meta-learner
learner_w_grad = Learner(args.image_size, args.bn_eps, args.bn_momentum, args.n_class).to(args.dev)
learner_wo_grad = copy.deepcopy(learner_w_grad)
metalearner = MetaLearner(args.input_size, args.hidden_size, learner_w_grad.get_flat_params().size(0)).to(args.dev)
metalearner.metalstm.init_cI(learner_w_grad.get_flat_params())
# Set up loss, optimizer, learning rate scheduler
optim = torch.optim.Adam(metalearner.parameters(), args.lr)
if args.resume:
logger.loginfo("Initialized from: {}".format(args.resume))
last_eps, metalearner, optim = resume_ckpt(metalearner, optim, args.resume, args.dev)
if args.mode == 'test':
_ = meta_test(last_eps, test_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger)
return
best_acc = 0.0
logger.loginfo("Start training")
# Meta-training
for eps, (episode_x, episode_y) in enumerate(train_loader):
# episode_x.shape = [n_class, n_shot + n_eval, c, h, w]
# episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED
train_input = episode_x[:, :args.n_shot].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_shot, :]
train_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_shot)).to(args.dev) # [n_class * n_shot]
test_input = episode_x[:, args.n_shot:].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_eval, :]
test_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_eval)).to(args.dev) # [n_class * n_eval]
# Train learner with metalearner
learner_w_grad.reset_batch_stats()
learner_wo_grad.reset_batch_stats()
learner_w_grad.train()
learner_wo_grad.train()
cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args)
# Train meta-learner with validation loss
learner_wo_grad.transfer_params(learner_w_grad, cI)
output = learner_wo_grad(test_input)
loss = learner_wo_grad.criterion(output, test_target)
acc = accuracy(output, test_target)
optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(metalearner.parameters(), args.grad_clip)
optim.step()
logger.batch_info(eps=eps, totaleps=args.episode, loss=loss.item(), acc=acc, phase='train')
# Meta-validation
if eps % args.val_freq == 0 and eps != 0:
save_ckpt(eps, metalearner, optim, args.save)
acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger)
if acc > best_acc:
best_acc = acc
logger.loginfo("* Best accuracy so far *\n")
logger.loginfo("Done")
if __name__ == '__main__':
main()
================================================
FILE: metalearner.py
================================================
from __future__ import division, print_function, absolute_import
import pdb
import math
import torch
import torch.nn as nn
class MetaLSTMCell(nn.Module):
"""C_t = f_t * C_{t-1} + i_t * \tilde{C_t}"""
def __init__(self, input_size, hidden_size, n_learner_params):
super(MetaLSTMCell, self).__init__()
"""Args:
input_size (int): cell input size, default = 20
hidden_size (int): should be 1
n_learner_params (int): number of learner's parameters
"""
self.input_size = input_size
self.hidden_size = hidden_size
self.n_learner_params = n_learner_params
self.WF = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))
self.WI = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))
self.cI = nn.Parameter(torch.Tensor(n_learner_params, 1))
self.bI = nn.Parameter(torch.Tensor(1, hidden_size))
self.bF = nn.Parameter(torch.Tensor(1, hidden_size))
self.reset_parameters()
def reset_parameters(self):
for weight in self.parameters():
nn.init.uniform_(weight, -0.01, 0.01)
# want initial forget value to be high and input value to be low so that
# model starts with gradient descent
nn.init.uniform_(self.bF, 4, 6)
nn.init.uniform_(self.bI, -5, -4)
def init_cI(self, flat_params):
self.cI.data.copy_(flat_params.unsqueeze(1))
def forward(self, inputs, hx=None):
"""Args:
inputs = [x_all, grad]:
x_all (torch.Tensor of size [n_learner_params, input_size]): outputs from previous LSTM
grad (torch.Tensor of size [n_learner_params]): gradients from learner
hx = [f_prev, i_prev, c_prev]:
f (torch.Tensor of size [n_learner_params, 1]): forget gate
i (torch.Tensor of size [n_learner_params, 1]): input gate
c (torch.Tensor of size [n_learner_params, 1]): flattened learner parameters
"""
x_all, grad = inputs
batch, _ = x_all.size()
if hx is None:
f_prev = torch.zeros((batch, self.hidden_size)).to(self.WF.device)
i_prev = torch.zeros((batch, self.hidden_size)).to(self.WI.device)
c_prev = self.cI
hx = [f_prev, i_prev, c_prev]
f_prev, i_prev, c_prev = hx
# f_t = sigmoid(W_f * [grad_t, loss_t, theta_{t-1}, f_{t-1}] + b_f)
f_next = torch.mm(torch.cat((x_all, c_prev, f_prev), 1), self.WF) + self.bF.expand_as(f_prev)
# i_t = sigmoid(W_i * [grad_t, loss_t, theta_{t-1}, i_{t-1}] + b_i)
i_next = torch.mm(torch.cat((x_all, c_prev, i_prev), 1), self.WI) + self.bI.expand_as(i_prev)
# next cell/params
c_next = torch.sigmoid(f_next).mul(c_prev) - torch.sigmoid(i_next).mul(grad)
return c_next, [f_next, i_next, c_next]
def extra_repr(self):
s = '{input_size}, {hidden_size}, {n_learner_params}'
return s.format(**self.__dict__)
class MetaLearner(nn.Module):
def __init__(self, input_size, hidden_size, n_learner_params):
super(MetaLearner, self).__init__()
"""Args:
input_size (int): for the first LSTM layer, default = 4
hidden_size (int): for the first LSTM layer, default = 20
n_learner_params (int): number of learner's parameters
"""
self.lstm = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
self.metalstm = MetaLSTMCell(input_size=hidden_size, hidden_size=1, n_learner_params=n_learner_params)
def forward(self, inputs, hs=None):
"""Args:
inputs = [loss, grad_prep, grad]
loss (torch.Tensor of size [1, 2])
grad_prep (torch.Tensor of size [n_learner_params, 2])
grad (torch.Tensor of size [n_learner_params])
hs = [(lstm_hn, lstm_cn), [metalstm_fn, metalstm_in, metalstm_cn]]
"""
loss, grad_prep, grad = inputs
loss = loss.expand_as(grad_prep)
inputs = torch.cat((loss, grad_prep), 1) # [n_learner_params, 4]
if hs is None:
hs = [None, None]
lstmhx, lstmcx = self.lstm(inputs, hs[0])
flat_learner_unsqzd, metalstm_hs = self.metalstm([lstmhx, grad], hs[1])
return flat_learner_unsqzd.squeeze(), [(lstmhx, lstmcx), metalstm_hs]
================================================
FILE: scripts/eval_5s_5c.sh
================================================
#!/bin/bash
#
# For 5-shot, 5-class evaluation, hyper-parameters follow github.com/twitter/meta-learning-lstm
python main.py --mode test \
--resume logs-719/ckpts/meta-learner-42000.pth.tar \
--n-shot 5 \
--n-eval 15 \
--n-class 5 \
--input-size 4 \
--hidden-size 20 \
--lr 1e-3 \
--episode 50000 \
--episode-val 100 \
--epoch 8 \
--batch-size 25 \
--image-size 84 \
--grad-clip 0.25 \
--bn-momentum 0.95 \
--bn-eps 1e-3 \
--data miniimagenet \
--data-root data/miniImagenet/ \
--pin-mem True \
--log-freq 100
================================================
FILE: scripts/train_5s_5c.sh
================================================
#!/bin/bash
#
# For 5-shot, 5-class training
# Hyper-parameters follow https://github.com/twitter/meta-learning-lstm
python main.py --mode train \
--n-shot 5 \
--n-eval 15 \
--n-class 5 \
--input-size 4 \
--hidden-size 20 \
--lr 1e-3 \
--episode 50000 \
--episode-val 100 \
--epoch 8 \
--batch-size 25 \
--image-size 84 \
--grad-clip 0.25 \
--bn-momentum 0.95 \
--bn-eps 1e-3 \
--data miniimagenet \
--data-root data/miniImagenet/ \
--pin-mem True \
--log-freq 50 \
--val-freq 1000
================================================
FILE: utils.py
================================================
from __future__ import division, print_function, absolute_import
import os
import pdb
import logging
import torch
import numpy as np
class GOATLogger:
def __init__(self, args):
args.save = args.save + '-{}'.format(args.seed)
self.mode = args.mode
self.save_root = args.save
self.log_freq = args.log_freq
if self.mode == 'train':
if not os.path.exists(self.save_root):
os.mkdir(self.save_root)
filename = os.path.join(self.save_root, 'console.log')
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s.%(msecs)03d - %(message)s',
datefmt='%b-%d %H:%M:%S',
filename=filename,
filemode='w')
console = logging.StreamHandler()
console.setLevel(logging.INFO)
console.setFormatter(logging.Formatter('%(message)s'))
logging.getLogger('').addHandler(console)
logging.info("Logger created at {}".format(filename))
else:
logging.basicConfig(level=logging.INFO,
format='%(asctime)s.%(msecs)03d - %(message)s',
datefmt='%b-%d %H:%M:%S')
logging.info("Random Seed: {}".format(args.seed))
self.reset_stats()
def reset_stats(self):
if self.mode == 'train':
self.stats = {'train': {'loss': [], 'acc': []},
'eval': {'loss': [], 'acc': []}}
else:
self.stats = {'eval': {'loss': [], 'acc': []}}
def batch_info(self, **kwargs):
if kwargs['phase'] == 'train':
self.stats['train']['loss'].append(kwargs['loss'])
self.stats['train']['acc'].append(kwargs['acc'])
if kwargs['eps'] % self.log_freq == 0 and kwargs['eps'] != 0:
loss_mean = np.mean(self.stats['train']['loss'])
acc_mean = np.mean(self.stats['train']['acc'])
#self.draw_stats()
self.loginfo("[{:5d}/{:5d}] loss: {:6.4f} ({:6.4f}), acc: {:6.3f}% ({:6.3f}%)".format(\
kwargs['eps'], kwargs['totaleps'], kwargs['loss'], loss_mean, kwargs['acc'], acc_mean))
elif kwargs['phase'] == 'eval':
self.stats['eval']['loss'].append(kwargs['loss'])
self.stats['eval']['acc'].append(kwargs['acc'])
elif kwargs['phase'] == 'evaldone':
loss_mean = np.mean(self.stats['eval']['loss'])
loss_std = np.std(self.stats['eval']['loss'])
acc_mean = np.mean(self.stats['eval']['acc'])
acc_std = np.std(self.stats['eval']['acc'])
self.loginfo("[{:5d}] Eval ({:3d} episode) - loss: {:6.4f} +- {:6.4f}, acc: {:6.3f} +- {:5.3f}%".format(\
kwargs['eps'], kwargs['totaleps'], loss_mean, loss_std, acc_mean, acc_std))
self.reset_stats()
return acc_mean
else:
raise ValueError("phase {} not supported".format(kwargs['phase']))
def logdebug(self, strout):
logging.debug(strout)
def loginfo(self, strout):
logging.info(strout)
def accuracy(output, target, topk=(1,)):
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res[0].item() if len(res) == 1 else [r.item() for r in res]
def save_ckpt(episode, metalearner, optim, save):
if not os.path.exists(os.path.join(save, 'ckpts')):
os.mkdir(os.path.join(save, 'ckpts'))
torch.save({
'episode': episode,
'metalearner': metalearner.state_dict(),
'optim': optim.state_dict()
}, os.path.join(save, 'ckpts', 'meta-learner-{}.pth.tar'.format(episode)))
def resume_ckpt(metalearner, optim, resume, device):
ckpt = torch.load(resume, map_location=device)
last_episode = ckpt['episode']
metalearner.load_state_dict(ckpt['metalearner'])
optim.load_state_dict(ckpt['optim'])
return last_episode, metalearner, optim
def preprocess_grad_loss(x):
p = 10
indicator = (x.abs() >= np.exp(-p)).to(torch.float32)
# preproc1
x_proc1 = indicator * torch.log(x.abs() + 1e-8) / p + (1 - indicator) * -1
# preproc2
x_proc2 = indicator * torch.sign(x) + (1 - indicator) * np.exp(p) * x
return torch.stack((x_proc1, x_proc2), 1)
gitextract_ouj1iose/ ├── README.md ├── dataloader.py ├── learner.py ├── main.py ├── metalearner.py ├── scripts/ │ ├── eval_5s_5c.sh │ └── train_5s_5c.sh └── utils.py
SYMBOL INDEX (42 symbols across 5 files)
FILE: dataloader.py
class EpisodeDataset (line 19) | class EpisodeDataset(data.Dataset):
method __init__ (line 21) | def __init__(self, root, phase='train', n_shot=5, n_eval=15, transform...
method __getitem__ (line 38) | def __getitem__(self, idx):
method __len__ (line 41) | def __len__(self):
class ClassDataset (line 45) | class ClassDataset(data.Dataset):
method __init__ (line 47) | def __init__(self, images, label, transform=None):
method __getitem__ (line 56) | def __getitem__(self, idx):
method __len__ (line 63) | def __len__(self):
class EpisodicSampler (line 67) | class EpisodicSampler(data.Sampler):
method __init__ (line 69) | def __init__(self, total_classes, n_class, n_episode):
method __iter__ (line 74) | def __iter__(self):
method __len__ (line 78) | def __len__(self):
function prepare_data (line 82) | def prepare_data(args):
FILE: learner.py
class Learner (line 11) | class Learner(nn.Module):
method __init__ (line 13) | def __init__(self, image_size, bn_eps, bn_momentum, n_classes):
method forward (line 41) | def forward(self, x):
method get_flat_params (line 47) | def get_flat_params(self):
method copy_flat_params (line 50) | def copy_flat_params(self, cI):
method transfer_params (line 57) | def transfer_params(self, learner_w_grad, cI):
method reset_batch_stats (line 73) | def reset_batch_stats(self):
FILE: main.py
function meta_test (line 76) | def meta_test(eps, eval_loader, learner_w_grad, learner_wo_grad, metalea...
function train_learner (line 100) | def train_learner(learner_w_grad, metalearner, train_input, train_target...
function main (line 129) | def main():
FILE: metalearner.py
class MetaLSTMCell (line 9) | class MetaLSTMCell(nn.Module):
method __init__ (line 11) | def __init__(self, input_size, hidden_size, n_learner_params):
method reset_parameters (line 29) | def reset_parameters(self):
method init_cI (line 38) | def init_cI(self, flat_params):
method forward (line 41) | def forward(self, inputs, hx=None):
method extra_repr (line 71) | def extra_repr(self):
class MetaLearner (line 76) | class MetaLearner(nn.Module):
method __init__ (line 78) | def __init__(self, input_size, hidden_size, n_learner_params):
method forward (line 88) | def forward(self, inputs, hs=None):
FILE: utils.py
class GOATLogger (line 11) | class GOATLogger:
method __init__ (line 13) | def __init__(self, args):
method reset_stats (line 43) | def reset_stats(self):
method batch_info (line 50) | def batch_info(self, **kwargs):
method logdebug (line 80) | def logdebug(self, strout):
method loginfo (line 82) | def loginfo(self, strout):
function accuracy (line 86) | def accuracy(output, target, topk=(1,)):
function save_ckpt (line 102) | def save_ckpt(episode, metalearner, optim, save):
function resume_ckpt (line 113) | def resume_ckpt(metalearner, optim, resume, device):
function preprocess_grad_loss (line 121) | def preprocess_grad_loss(x):
Condensed preview — 8 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (32K chars).
[
{
"path": "README.md",
"chars": 3675,
"preview": "# Optimization as a Model for Few-shot Learning\nPytorch implementation of [Optimization as a Model for Few-shot Learning"
},
{
"path": "dataloader.py",
"chars": 4105,
"preview": "from __future__ import division, print_function, absolute_import\n\nimport os\nimport re\nimport pdb\nimport glob\nimport pick"
},
{
"path": "learner.py",
"chars": 2983,
"preview": "from __future__ import division, print_function, absolute_import\n\nimport pdb\nimport copy\nfrom collections import Ordered"
},
{
"path": "main.py",
"chars": 8961,
"preview": "from __future__ import division, print_function, absolute_import\n\nimport os\nimport pdb\nimport copy\nimport random\nimport "
},
{
"path": "metalearner.py",
"chars": 4397,
"preview": "from __future__ import division, print_function, absolute_import\n\nimport pdb\nimport math\nimport torch\nimport torch.nn as"
},
{
"path": "scripts/eval_5s_5c.sh",
"chars": 796,
"preview": "#!/bin/bash\n#\n# For 5-shot, 5-class evaluation, hyper-parameters follow github.com/twitter/meta-learning-lstm\n\npython ma"
},
{
"path": "scripts/train_5s_5c.sh",
"chars": 768,
"preview": "#!/bin/bash\n#\n# For 5-shot, 5-class training\n# Hyper-parameters follow https://github.com/twitter/meta-learning-lstm\n\npy"
},
{
"path": "utils.py",
"chars": 4618,
"preview": "from __future__ import division, print_function, absolute_import\n\nimport os\nimport pdb\nimport logging\n\nimport torch\nimpo"
}
]
About this extraction
This page contains the full source code of the markdtw/meta-learning-lstm-pytorch GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 8 files (29.6 KB), approximately 8.0k tokens, and a symbol index with 42 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.