Repository: markdtw/meta-learning-lstm-pytorch Branch: master Commit: fcc68baad71a Files: 8 Total size: 29.6 KB Directory structure: gitextract_ouj1iose/ ├── README.md ├── dataloader.py ├── learner.py ├── main.py ├── metalearner.py ├── scripts/ │ ├── eval_5s_5c.sh │ └── train_5s_5c.sh └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # Optimization as a Model for Few-shot Learning Pytorch implementation of [Optimization as a Model for Few-shot Learning](https://openreview.net/forum?id=rJY0-Kcll) in ICLR 2017 (Oral) ![Model Architecture](https://i.imgur.com/lydKeUc.png) ## Prerequisites - python 3+ - pytorch 0.4+ (developed on 1.0.1 with cuda 9.0) - [pillow](https://pillow.readthedocs.io/en/stable/installation.html) - [tqdm](https://tqdm.github.io/) (a nice progress bar) ## Data - Mini-Imagenet as described [here](https://github.com/twitter/meta-learning-lstm/tree/master/data/miniImagenet) - You can download it from [here](https://drive.google.com/file/d/1rV3aj_hgfNTfCakffpPm7Vhpr1in87CR/view?usp=sharing) (~2.7GB, google drive link) ## Preparation - Make sure Mini-Imagenet is split properly. For example: ``` - data/ - miniImagenet/ - train/ - n01532829/ - n0153282900000005.jpg - ... - n01558993/ - ... - val/ - n01855672/ - ... - test/ - ... - main.py - ... ``` - It'd be set if you download and extract Mini-Imagenet from the link above - Check out `scripts/train_5s_5c.sh`, make sure `--data-root` is properly set ## Run For 5-shot, 5-class training, run ```bash bash scripts/train_5s_5c.sh ``` Hyper-parameters are referred to the [author's repo](https://github.com/twitter/meta-learning-lstm). For 5-shot, 5-class evaluation, run *(remember to change `--resume` and `--seed` arguments)* ```bash bash scripts/eval_5s_5c.sh ``` ## Notes - Results (This repo is developed following the [pytorch reproducibility guideline](https://pytorch.org/docs/stable/notes/randomness.html)): |seed|train episodes|val episodes|val acc mean|val acc std|test episodes|test acc mean|test acc std| |-|-|-|-|-|-|-|-| |719|41000|100|59.08|9.9|100|56.59|8.4| | -| -| -| -| -|250|57.85|8.6| | -| -| -| -| -|600|57.76|8.6| | 53|44000|100|58.04|9.1|100|57.85|7.7| | -| -| -| -| -|250|57.83|8.3| | -| -| -| -| -|600|58.14|8.5| - The results I get from directly running the author's repo can be found [here](https://i.imgur.com/rtagm2c.png), I have slightly better performance (~5%) but neither results match the number in the paper (60%) *(Discussion and help are welcome!)*. - Training with the default settings takes ~2.5 hours on a single Titan Xp while occupying ~2GB GPU memory. - The implementation replicates two learners similar to the author's repo: - `learner_w_grad` functions as a regular model, get gradients and loss as inputs to meta learner. - `learner_wo_grad` constructs the graph for meta learner: - All the parameters in `learner_wo_grad` are replaced by `cI` output by meta learner. - `nn.Parameters` in this model are casted to `torch.Tensor` to connect the graph to meta learner. - Several ways to **copy** a parameters from meta learner to learner depends on the scenario: - `copy_flat_params`: we only need the parameter values and keep the original `grad_fn`. - `transfer_params`: we want the values as well as the `grad_fn` (from `cI` to `learner_wo_grad`). - `.data.copy_` v.s. `clone()` -> the latter retains all the properties of a tensor including `grad_fn`. - To maintain the batch statistics, `load_state_dict` is used (from `learner_w_grad` to `learner_wo_grad`). ## References - [CloserLookFewShot](https://github.com/wyharveychen/CloserLookFewShot) (Data loader) - [pytorch-meta-optimizer](https://github.com/ikostrikov/pytorch-meta-optimizer) (Casting `nn.Parameters` to `torch.Tensor` inspired from here) - [meta-learning-lstm](https://github.com/twitter/meta-learning-lstm) (Author's repo in Lua Torch) ================================================ FILE: dataloader.py ================================================ from __future__ import division, print_function, absolute_import import os import re import pdb import glob import pickle import torch import torch.utils.data as data import torchvision.datasets as datasets import torchvision.transforms as transforms import PIL.Image as PILI import numpy as np from tqdm import tqdm class EpisodeDataset(data.Dataset): def __init__(self, root, phase='train', n_shot=5, n_eval=15, transform=None): """Args: root (str): path to data phase (str): train, val or test n_shot (int): how many examples per class for training (k/n_support) n_eval (int): how many examples per class for evaluation - n_shot + n_eval = batch_size for data.DataLoader of ClassDataset transform (torchvision.transforms): data augmentation """ root = os.path.join(root, phase) self.labels = sorted(os.listdir(root)) images = [glob.glob(os.path.join(root, label, '*')) for label in self.labels] self.episode_loader = [data.DataLoader( ClassDataset(images=images[idx], label=idx, transform=transform), batch_size=n_shot+n_eval, shuffle=True, num_workers=0) for idx, _ in enumerate(self.labels)] def __getitem__(self, idx): return next(iter(self.episode_loader[idx])) def __len__(self): return len(self.labels) class ClassDataset(data.Dataset): def __init__(self, images, label, transform=None): """Args: images (list of str): each item is a path to an image of the same label label (int): the label of all the images """ self.images = images self.label = label self.transform = transform def __getitem__(self, idx): image = PILI.open(self.images[idx]).convert('RGB') if self.transform is not None: image = self.transform(image) return image, self.label def __len__(self): return len(self.images) class EpisodicSampler(data.Sampler): def __init__(self, total_classes, n_class, n_episode): self.total_classes = total_classes self.n_class = n_class self.n_episode = n_episode def __iter__(self): for i in range(self.n_episode): yield torch.randperm(self.total_classes)[:self.n_class] def __len__(self): return self.n_episode def prepare_data(args): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_set = EpisodeDataset(args.data_root, 'train', args.n_shot, args.n_eval, transform=transforms.Compose([ transforms.RandomResizedCrop(args.image_size), transforms.RandomHorizontalFlip(), transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2), transforms.ToTensor(), normalize])) val_set = EpisodeDataset(args.data_root, 'val', args.n_shot, args.n_eval, transform=transforms.Compose([ transforms.Resize(args.image_size * 8 // 7), transforms.CenterCrop(args.image_size), transforms.ToTensor(), normalize])) test_set = EpisodeDataset(args.data_root, 'test', args.n_shot, args.n_eval, transform=transforms.Compose([ transforms.Resize(args.image_size * 8 // 7), transforms.CenterCrop(args.image_size), transforms.ToTensor(), normalize])) train_loader = data.DataLoader(train_set, num_workers=args.n_workers, pin_memory=args.pin_mem, batch_sampler=EpisodicSampler(len(train_set), args.n_class, args.episode)) val_loader = data.DataLoader(val_set, num_workers=2, pin_memory=False, batch_sampler=EpisodicSampler(len(val_set), args.n_class, args.episode_val)) test_loader = data.DataLoader(test_set, num_workers=2, pin_memory=False, batch_sampler=EpisodicSampler(len(test_set), args.n_class, args.episode_val)) return train_loader, val_loader, test_loader ================================================ FILE: learner.py ================================================ from __future__ import division, print_function, absolute_import import pdb import copy from collections import OrderedDict import torch import torch.nn as nn import numpy as np class Learner(nn.Module): def __init__(self, image_size, bn_eps, bn_momentum, n_classes): super(Learner, self).__init__() self.model = nn.ModuleDict({'features': nn.Sequential(OrderedDict([ ('conv1', nn.Conv2d(3, 32, 3, padding=1)), ('norm1', nn.BatchNorm2d(32, bn_eps, bn_momentum)), ('relu1', nn.ReLU(inplace=False)), ('pool1', nn.MaxPool2d(2)), ('conv2', nn.Conv2d(32, 32, 3, padding=1)), ('norm2', nn.BatchNorm2d(32, bn_eps, bn_momentum)), ('relu2', nn.ReLU(inplace=False)), ('pool2', nn.MaxPool2d(2)), ('conv3', nn.Conv2d(32, 32, 3, padding=1)), ('norm3', nn.BatchNorm2d(32, bn_eps, bn_momentum)), ('relu3', nn.ReLU(inplace=False)), ('pool3', nn.MaxPool2d(2)), ('conv4', nn.Conv2d(32, 32, 3, padding=1)), ('norm4', nn.BatchNorm2d(32, bn_eps, bn_momentum)), ('relu4', nn.ReLU(inplace=False)), ('pool4', nn.MaxPool2d(2))])) }) clr_in = image_size // 2**4 self.model.update({'cls': nn.Linear(32 * clr_in * clr_in, n_classes)}) self.criterion = nn.CrossEntropyLoss() def forward(self, x): x = self.model.features(x) x = torch.reshape(x, [x.size(0), -1]) outputs = self.model.cls(x) return outputs def get_flat_params(self): return torch.cat([p.view(-1) for p in self.model.parameters()], 0) def copy_flat_params(self, cI): idx = 0 for p in self.model.parameters(): plen = p.view(-1).size(0) p.data.copy_(cI[idx: idx+plen].view_as(p)) idx += plen def transfer_params(self, learner_w_grad, cI): # Use load_state_dict only to copy the running mean/var in batchnorm, the values of the parameters # are going to be replaced by cI self.load_state_dict(learner_w_grad.state_dict()) # replace nn.Parameters with tensors from cI (NOT nn.Parameters anymore). idx = 0 for m in self.model.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear): wlen = m._parameters['weight'].view(-1).size(0) m._parameters['weight'] = cI[idx: idx+wlen].view_as(m._parameters['weight']).clone() idx += wlen if m._parameters['bias'] is not None: blen = m._parameters['bias'].view(-1).size(0) m._parameters['bias'] = cI[idx: idx+blen].view_as(m._parameters['bias']).clone() idx += blen def reset_batch_stats(self): for m in self.modules(): if isinstance(m, nn.BatchNorm2d): m.reset_running_stats() ================================================ FILE: main.py ================================================ from __future__ import division, print_function, absolute_import import os import pdb import copy import random import argparse import torch import torch.nn as nn import numpy as np from tqdm import tqdm from learner import Learner from metalearner import MetaLearner from dataloader import prepare_data from utils import * FLAGS = argparse.ArgumentParser() FLAGS.add_argument('--mode', choices=['train', 'test']) # Hyper-parameters FLAGS.add_argument('--n-shot', type=int, help="How many examples per class for training (k, n_support)") FLAGS.add_argument('--n-eval', type=int, help="How many examples per class for evaluation (n_query)") FLAGS.add_argument('--n-class', type=int, help="How many classes (N, n_way)") FLAGS.add_argument('--input-size', type=int, help="Input size for the first LSTM") FLAGS.add_argument('--hidden-size', type=int, help="Hidden size for the first LSTM") FLAGS.add_argument('--lr', type=float, help="Learning rate") FLAGS.add_argument('--episode', type=int, help="Episodes to train") FLAGS.add_argument('--episode-val', type=int, help="Episodes to eval") FLAGS.add_argument('--epoch', type=int, help="Epoch to train for an episode") FLAGS.add_argument('--batch-size', type=int, help="Batch size when training an episode") FLAGS.add_argument('--image-size', type=int, help="Resize image to this size") FLAGS.add_argument('--grad-clip', type=float, help="Clip gradients larger than this number") FLAGS.add_argument('--bn-momentum', type=float, help="Momentum parameter in BatchNorm2d") FLAGS.add_argument('--bn-eps', type=float, help="Eps parameter in BatchNorm2d") # Paths FLAGS.add_argument('--data', choices=['miniimagenet'], help="Name of dataset") FLAGS.add_argument('--data-root', type=str, help="Location of data") FLAGS.add_argument('--resume', type=str, help="Location to pth.tar") FLAGS.add_argument('--save', type=str, default='logs', help="Location to logs and ckpts") # Others FLAGS.add_argument('--cpu', action='store_true', help="Set this to use CPU, default use CUDA") FLAGS.add_argument('--n-workers', type=int, default=4, help="How many processes for preprocessing") FLAGS.add_argument('--pin-mem', type=bool, default=False, help="DataLoader pin_memory") FLAGS.add_argument('--log-freq', type=int, default=100, help="Logging frequency") FLAGS.add_argument('--val-freq', type=int, default=1000, help="Validation frequency") FLAGS.add_argument('--seed', type=int, help="Random seed") def meta_test(eps, eval_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger): for subeps, (episode_x, episode_y) in enumerate(tqdm(eval_loader, ascii=True)): train_input = episode_x[:, :args.n_shot].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_shot, :] train_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_shot)).to(args.dev) # [n_class * n_shot] test_input = episode_x[:, args.n_shot:].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_eval, :] test_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_eval)).to(args.dev) # [n_class * n_eval] # Train learner with metalearner learner_w_grad.reset_batch_stats() learner_wo_grad.reset_batch_stats() learner_w_grad.train() learner_wo_grad.eval() cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args) learner_wo_grad.transfer_params(learner_w_grad, cI) output = learner_wo_grad(test_input) loss = learner_wo_grad.criterion(output, test_target) acc = accuracy(output, test_target) logger.batch_info(loss=loss.item(), acc=acc, phase='eval') return logger.batch_info(eps=eps, totaleps=args.episode_val, phase='evaldone') def train_learner(learner_w_grad, metalearner, train_input, train_target, args): cI = metalearner.metalstm.cI.data hs = [None] for _ in range(args.epoch): for i in range(0, len(train_input), args.batch_size): x = train_input[i:i+args.batch_size] y = train_target[i:i+args.batch_size] # get the loss/grad learner_w_grad.copy_flat_params(cI) output = learner_w_grad(x) loss = learner_w_grad.criterion(output, y) acc = accuracy(output, y) learner_w_grad.zero_grad() loss.backward() grad = torch.cat([p.grad.data.view(-1) / args.batch_size for p in learner_w_grad.parameters()], 0) # preprocess grad & loss and metalearner forward grad_prep = preprocess_grad_loss(grad) # [n_learner_params, 2] loss_prep = preprocess_grad_loss(loss.data.unsqueeze(0)) # [1, 2] metalearner_input = [loss_prep, grad_prep, grad.unsqueeze(1)] cI, h = metalearner(metalearner_input, hs[-1]) hs.append(h) #print("training loss: {:8.6f} acc: {:6.3f}, mean grad: {:8.6f}".format(loss, acc, torch.mean(grad))) return cI def main(): args, unparsed = FLAGS.parse_known_args() if len(unparsed) != 0: raise NameError("Argument {} not recognized".format(unparsed)) if args.seed is None: args.seed = random.randint(0, 1e3) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cpu: args.dev = torch.device('cpu') else: if not torch.cuda.is_available(): raise RuntimeError("GPU unavailable.") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False args.dev = torch.device('cuda') logger = GOATLogger(args) # Get data train_loader, val_loader, test_loader = prepare_data(args) # Set up learner, meta-learner learner_w_grad = Learner(args.image_size, args.bn_eps, args.bn_momentum, args.n_class).to(args.dev) learner_wo_grad = copy.deepcopy(learner_w_grad) metalearner = MetaLearner(args.input_size, args.hidden_size, learner_w_grad.get_flat_params().size(0)).to(args.dev) metalearner.metalstm.init_cI(learner_w_grad.get_flat_params()) # Set up loss, optimizer, learning rate scheduler optim = torch.optim.Adam(metalearner.parameters(), args.lr) if args.resume: logger.loginfo("Initialized from: {}".format(args.resume)) last_eps, metalearner, optim = resume_ckpt(metalearner, optim, args.resume, args.dev) if args.mode == 'test': _ = meta_test(last_eps, test_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger) return best_acc = 0.0 logger.loginfo("Start training") # Meta-training for eps, (episode_x, episode_y) in enumerate(train_loader): # episode_x.shape = [n_class, n_shot + n_eval, c, h, w] # episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED train_input = episode_x[:, :args.n_shot].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_shot, :] train_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_shot)).to(args.dev) # [n_class * n_shot] test_input = episode_x[:, args.n_shot:].reshape(-1, *episode_x.shape[-3:]).to(args.dev) # [n_class * n_eval, :] test_target = torch.LongTensor(np.repeat(range(args.n_class), args.n_eval)).to(args.dev) # [n_class * n_eval] # Train learner with metalearner learner_w_grad.reset_batch_stats() learner_wo_grad.reset_batch_stats() learner_w_grad.train() learner_wo_grad.train() cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args) # Train meta-learner with validation loss learner_wo_grad.transfer_params(learner_w_grad, cI) output = learner_wo_grad(test_input) loss = learner_wo_grad.criterion(output, test_target) acc = accuracy(output, test_target) optim.zero_grad() loss.backward() nn.utils.clip_grad_norm_(metalearner.parameters(), args.grad_clip) optim.step() logger.batch_info(eps=eps, totaleps=args.episode, loss=loss.item(), acc=acc, phase='train') # Meta-validation if eps % args.val_freq == 0 and eps != 0: save_ckpt(eps, metalearner, optim, args.save) acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad, metalearner, args, logger) if acc > best_acc: best_acc = acc logger.loginfo("* Best accuracy so far *\n") logger.loginfo("Done") if __name__ == '__main__': main() ================================================ FILE: metalearner.py ================================================ from __future__ import division, print_function, absolute_import import pdb import math import torch import torch.nn as nn class MetaLSTMCell(nn.Module): """C_t = f_t * C_{t-1} + i_t * \tilde{C_t}""" def __init__(self, input_size, hidden_size, n_learner_params): super(MetaLSTMCell, self).__init__() """Args: input_size (int): cell input size, default = 20 hidden_size (int): should be 1 n_learner_params (int): number of learner's parameters """ self.input_size = input_size self.hidden_size = hidden_size self.n_learner_params = n_learner_params self.WF = nn.Parameter(torch.Tensor(input_size + 2, hidden_size)) self.WI = nn.Parameter(torch.Tensor(input_size + 2, hidden_size)) self.cI = nn.Parameter(torch.Tensor(n_learner_params, 1)) self.bI = nn.Parameter(torch.Tensor(1, hidden_size)) self.bF = nn.Parameter(torch.Tensor(1, hidden_size)) self.reset_parameters() def reset_parameters(self): for weight in self.parameters(): nn.init.uniform_(weight, -0.01, 0.01) # want initial forget value to be high and input value to be low so that # model starts with gradient descent nn.init.uniform_(self.bF, 4, 6) nn.init.uniform_(self.bI, -5, -4) def init_cI(self, flat_params): self.cI.data.copy_(flat_params.unsqueeze(1)) def forward(self, inputs, hx=None): """Args: inputs = [x_all, grad]: x_all (torch.Tensor of size [n_learner_params, input_size]): outputs from previous LSTM grad (torch.Tensor of size [n_learner_params]): gradients from learner hx = [f_prev, i_prev, c_prev]: f (torch.Tensor of size [n_learner_params, 1]): forget gate i (torch.Tensor of size [n_learner_params, 1]): input gate c (torch.Tensor of size [n_learner_params, 1]): flattened learner parameters """ x_all, grad = inputs batch, _ = x_all.size() if hx is None: f_prev = torch.zeros((batch, self.hidden_size)).to(self.WF.device) i_prev = torch.zeros((batch, self.hidden_size)).to(self.WI.device) c_prev = self.cI hx = [f_prev, i_prev, c_prev] f_prev, i_prev, c_prev = hx # f_t = sigmoid(W_f * [grad_t, loss_t, theta_{t-1}, f_{t-1}] + b_f) f_next = torch.mm(torch.cat((x_all, c_prev, f_prev), 1), self.WF) + self.bF.expand_as(f_prev) # i_t = sigmoid(W_i * [grad_t, loss_t, theta_{t-1}, i_{t-1}] + b_i) i_next = torch.mm(torch.cat((x_all, c_prev, i_prev), 1), self.WI) + self.bI.expand_as(i_prev) # next cell/params c_next = torch.sigmoid(f_next).mul(c_prev) - torch.sigmoid(i_next).mul(grad) return c_next, [f_next, i_next, c_next] def extra_repr(self): s = '{input_size}, {hidden_size}, {n_learner_params}' return s.format(**self.__dict__) class MetaLearner(nn.Module): def __init__(self, input_size, hidden_size, n_learner_params): super(MetaLearner, self).__init__() """Args: input_size (int): for the first LSTM layer, default = 4 hidden_size (int): for the first LSTM layer, default = 20 n_learner_params (int): number of learner's parameters """ self.lstm = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size) self.metalstm = MetaLSTMCell(input_size=hidden_size, hidden_size=1, n_learner_params=n_learner_params) def forward(self, inputs, hs=None): """Args: inputs = [loss, grad_prep, grad] loss (torch.Tensor of size [1, 2]) grad_prep (torch.Tensor of size [n_learner_params, 2]) grad (torch.Tensor of size [n_learner_params]) hs = [(lstm_hn, lstm_cn), [metalstm_fn, metalstm_in, metalstm_cn]] """ loss, grad_prep, grad = inputs loss = loss.expand_as(grad_prep) inputs = torch.cat((loss, grad_prep), 1) # [n_learner_params, 4] if hs is None: hs = [None, None] lstmhx, lstmcx = self.lstm(inputs, hs[0]) flat_learner_unsqzd, metalstm_hs = self.metalstm([lstmhx, grad], hs[1]) return flat_learner_unsqzd.squeeze(), [(lstmhx, lstmcx), metalstm_hs] ================================================ FILE: scripts/eval_5s_5c.sh ================================================ #!/bin/bash # # For 5-shot, 5-class evaluation, hyper-parameters follow github.com/twitter/meta-learning-lstm python main.py --mode test \ --resume logs-719/ckpts/meta-learner-42000.pth.tar \ --n-shot 5 \ --n-eval 15 \ --n-class 5 \ --input-size 4 \ --hidden-size 20 \ --lr 1e-3 \ --episode 50000 \ --episode-val 100 \ --epoch 8 \ --batch-size 25 \ --image-size 84 \ --grad-clip 0.25 \ --bn-momentum 0.95 \ --bn-eps 1e-3 \ --data miniimagenet \ --data-root data/miniImagenet/ \ --pin-mem True \ --log-freq 100 ================================================ FILE: scripts/train_5s_5c.sh ================================================ #!/bin/bash # # For 5-shot, 5-class training # Hyper-parameters follow https://github.com/twitter/meta-learning-lstm python main.py --mode train \ --n-shot 5 \ --n-eval 15 \ --n-class 5 \ --input-size 4 \ --hidden-size 20 \ --lr 1e-3 \ --episode 50000 \ --episode-val 100 \ --epoch 8 \ --batch-size 25 \ --image-size 84 \ --grad-clip 0.25 \ --bn-momentum 0.95 \ --bn-eps 1e-3 \ --data miniimagenet \ --data-root data/miniImagenet/ \ --pin-mem True \ --log-freq 50 \ --val-freq 1000 ================================================ FILE: utils.py ================================================ from __future__ import division, print_function, absolute_import import os import pdb import logging import torch import numpy as np class GOATLogger: def __init__(self, args): args.save = args.save + '-{}'.format(args.seed) self.mode = args.mode self.save_root = args.save self.log_freq = args.log_freq if self.mode == 'train': if not os.path.exists(self.save_root): os.mkdir(self.save_root) filename = os.path.join(self.save_root, 'console.log') logging.basicConfig(level=logging.DEBUG, format='%(asctime)s.%(msecs)03d - %(message)s', datefmt='%b-%d %H:%M:%S', filename=filename, filemode='w') console = logging.StreamHandler() console.setLevel(logging.INFO) console.setFormatter(logging.Formatter('%(message)s')) logging.getLogger('').addHandler(console) logging.info("Logger created at {}".format(filename)) else: logging.basicConfig(level=logging.INFO, format='%(asctime)s.%(msecs)03d - %(message)s', datefmt='%b-%d %H:%M:%S') logging.info("Random Seed: {}".format(args.seed)) self.reset_stats() def reset_stats(self): if self.mode == 'train': self.stats = {'train': {'loss': [], 'acc': []}, 'eval': {'loss': [], 'acc': []}} else: self.stats = {'eval': {'loss': [], 'acc': []}} def batch_info(self, **kwargs): if kwargs['phase'] == 'train': self.stats['train']['loss'].append(kwargs['loss']) self.stats['train']['acc'].append(kwargs['acc']) if kwargs['eps'] % self.log_freq == 0 and kwargs['eps'] != 0: loss_mean = np.mean(self.stats['train']['loss']) acc_mean = np.mean(self.stats['train']['acc']) #self.draw_stats() self.loginfo("[{:5d}/{:5d}] loss: {:6.4f} ({:6.4f}), acc: {:6.3f}% ({:6.3f}%)".format(\ kwargs['eps'], kwargs['totaleps'], kwargs['loss'], loss_mean, kwargs['acc'], acc_mean)) elif kwargs['phase'] == 'eval': self.stats['eval']['loss'].append(kwargs['loss']) self.stats['eval']['acc'].append(kwargs['acc']) elif kwargs['phase'] == 'evaldone': loss_mean = np.mean(self.stats['eval']['loss']) loss_std = np.std(self.stats['eval']['loss']) acc_mean = np.mean(self.stats['eval']['acc']) acc_std = np.std(self.stats['eval']['acc']) self.loginfo("[{:5d}] Eval ({:3d} episode) - loss: {:6.4f} +- {:6.4f}, acc: {:6.3f} +- {:5.3f}%".format(\ kwargs['eps'], kwargs['totaleps'], loss_mean, loss_std, acc_mean, acc_std)) self.reset_stats() return acc_mean else: raise ValueError("phase {} not supported".format(kwargs['phase'])) def logdebug(self, strout): logging.debug(strout) def loginfo(self, strout): logging.info(strout) def accuracy(output, target, topk=(1,)): with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res[0].item() if len(res) == 1 else [r.item() for r in res] def save_ckpt(episode, metalearner, optim, save): if not os.path.exists(os.path.join(save, 'ckpts')): os.mkdir(os.path.join(save, 'ckpts')) torch.save({ 'episode': episode, 'metalearner': metalearner.state_dict(), 'optim': optim.state_dict() }, os.path.join(save, 'ckpts', 'meta-learner-{}.pth.tar'.format(episode))) def resume_ckpt(metalearner, optim, resume, device): ckpt = torch.load(resume, map_location=device) last_episode = ckpt['episode'] metalearner.load_state_dict(ckpt['metalearner']) optim.load_state_dict(ckpt['optim']) return last_episode, metalearner, optim def preprocess_grad_loss(x): p = 10 indicator = (x.abs() >= np.exp(-p)).to(torch.float32) # preproc1 x_proc1 = indicator * torch.log(x.abs() + 1e-8) / p + (1 - indicator) * -1 # preproc2 x_proc2 = indicator * torch.sign(x) + (1 - indicator) * np.exp(p) * x return torch.stack((x_proc1, x_proc2), 1)