Repository: BayesWatch/pytorch-prunes Branch: master Commit: bc85a5c52865 Files: 8 Total size: 56.7 KB Directory structure: gitextract_3q8rylo5/ ├── LICENSE ├── README.md ├── funcs.py ├── models/ │ ├── __init__.py │ ├── densenet.py │ └── wideresnet.py ├── prune.py └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2019 BayesWatch Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # A Closer Look at Structured Pruning for Neural Network Compression Code used to reproduce experiments in https://arxiv.org/abs/1810.04622. To prune, we fill our networks with custom `MaskBlocks`, which are manipulated using `Pruner` in funcs.py. There will certainly be a better way to do this, but we leave this as an exercise to someone who can code much better than we can. ## Setup This is best done in a clean conda environment: ``` conda create -n prunes python=3.6 conda activate prunes conda install pytorch torchvision -c pytorch ``` ## Repository layout -`train.py`: contains all of the code for training large models from scratch and for training pruned models from scratch -`prune.py`: contains the code for pruning trained models -`funcs.py`: contains useful pruning functions and any functions we used commonly ## CIFAR Experiments First, you will need some initial models. To train a WRN-40-2: ``` python train.py --net='res' --depth=40 --width=2.0 --data_loc= --save_file='res' ``` The default arguments of train.py are suitable for training WRNs. The following trains a DenseNet-BC-100 (k=12) with its default hyperparameters: ``` python train.py --net='dense' --depth=100 --data_loc= --save_file='dense' --no_epochs 300 -b 64 --epoch_step '[150,225]' --weight_decay 0.0001 --lr_decay_ratio 0.1 ``` These will automatically save checkpoints to the `checkpoints` folder. ### Pruning Once training is finished, we can prune our networks using prune.py (defaults are set to WRN pruning, so extra arguments are needed for DenseNets) ``` python prune.py --net='res' --data_loc= --base_model='res' --save_file='res_fisher' python prune.py --net='res' --data_loc= --l1_prune=True --base_model='res' --save_file='res_l1' python prune.py --net='dense' --depth 100 --data_loc= --base_model='dense' --save_file='dense_fisher' --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64 --no_epochs 2600 python prune.py --net='dense' --depth 100 --data_loc= --l1_prune=True --base_model='dense' --save_file='dense_l1' --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64 --no_epochs 2600 ``` Note that the default is to perform Fisher pruning, so you don't need to pass a flag to use it. Once finished, we can train the pruned models from scratch, e.g.: ``` python train.py --data_loc= --net='res' --base_file='res_fisher__prunes' --deploy --mask=1 --save_file='res_fisher__prunes_scratch' ``` Each model can then be evaluated using: ``` python train.py --deploy --eval --data_loc= --net='res' --mask=1 --base_file='res_fisher__prunes' ``` ### Training Reduced models This can be done by varying the input arguments to train.py. To reduce depth or width of a WRN, change the corresponding option: ``` python train.py --net='res' --depth= --width= --data_loc= --save_file='res_reduced' ``` To add bottlenecks, use the following: ``` python train.py --net='res' --depth=40 --width=2.0 --data_loc= --save_file='res_bottle' --bottle --bottle_mult ``` With DenseNets you can modify the `depth` or `growth`, or use `--bottle --bottle_mult ` as above. ### Acknowledgements [Jack Turner][jack] wrote the L1 stuff, and some other stuff for that matter. Code has been liberally borrowed from many a repo, including, but not limited to: ``` https://github.com/xternalz/WideResNet-pytorch https://github.com/bamos/densenet.pytorch https://github.com/kuangliu/pytorch-cifar https://github.com/ShichenLiu/CondenseNet ``` ### Citing this work If you would like to cite this work, please use the following bibtex entry: ``` @article{crowley2018pruning, title={A Closer Look at Structured Pruning for Neural Network Compression}, author={Crowley, Elliot J and Turner, Jack and Storkey, Amos and O'Boyle, Michael}, journal={arXiv preprint arXiv:1810.04622}, year={2018}, } ``` [jack]: https://github.com/jack-willturner ================================================ FILE: funcs.py ================================================ import random import numpy as np import torchvision.transforms as transforms import torchvision import time from functools import reduce from models import * import random import time import operator import torchvision import torchvision.transforms as transforms from models import * class Pruner: def __init__(self, module_name='MaskBlock'): # First get vector of masks self.module_name = module_name self.masks = [] self.prune_history = [] def fisher_prune(self, model, prune_every): self._get_fisher(model) tot_loss = self.fisher.div(prune_every) + 1e6 * (1 - self.masks) # dummy value for off masks print(len(tot_loss)) min, argmin = torch.min(tot_loss, 0) self.prune(model, argmin.item()) self.prune_history.append(argmin.item()) def fixed_prune(self, model, ID): self.prune(model, ID) self.prune_history.append(ID) def random_prune(self, model): self._get_fisher(model) # Do this to update costs. masks = [] for m in model.modules(): if m._get_name() == self.module_name: masks.append(m.mask.detach()) masks = self.concat(masks) masks_on = [i for i, v in enumerate(masks) if v == 1] random_pick = random.choice(masks_on) self.prune(model, random_pick) self.prune_history.append(random_pick) def l1_prune(self, model, prune_every): masks = [] l1_norms = [] for m in model.modules(): if m._get_name() == 'MaskBlock': l1_norm = torch.sum(m.conv1.weight, (1, 2, 3)).detach().cpu().numpy() masks.append(m.mask.detach()) l1_norms.append(l1_norm) masks = self.concat(masks) self.masks = masks l1_norms = np.concatenate(l1_norms) l1_norms_on = [] for m, l in zip(masks, l1_norms): if m == 1: l1_norms_on.append(l) else: l1_norms_on.append(9999.) # dummy value smallest_norm = min(l1_norms_on) pick = np.where(l1_norms == smallest_norm)[0][0] self.prune(model, pick) self.prune_history.append(pick) def prune(self, model, feat_index): print('Pruned %d out of %d channels so far' % (len(self.prune_history), len(self.masks))) if len(self.prune_history) > len(self.masks): raise Exception('Time to stop') """feat_index refers to the index of a feature map. This function modifies the mask to turn it off.""" safe = 0 running_index = 0 for m in model.modules(): if m._get_name() == self.module_name: mask_indices = range(running_index, running_index + len(m.mask)) if feat_index in mask_indices: print('Pruning channel %d' % feat_index) local_index = mask_indices.index(feat_index) m.mask[local_index] = 0 safe = 1 break else: running_index += len(m.mask) # print(running_index) if not safe: raise Exception('The provided index doesn''t correspond to any feature maps. This is bad.') def compress(self, model): for m in model.modules(): if m._get_name() == 'MaskBlock': m.compress_weights() def _get_fisher(self, model): masks = [] fisher = [] self._update_cost(model) for m in model.modules(): if m._get_name() == self.module_name: masks.append(m.mask.detach()) fisher.append(m.running_fisher.detach()) # Now clear the fisher cache m.reset_fisher() self.masks = self.concat(masks) self.fisher = self.concat(fisher) def _get_masks(self, model): masks = [] for m in model.modules(): if m._get_name() == self.module_name: masks.append(m.mask.detach()) self.masks = self.concat(masks) def _update_cost(self, model): for m in model.modules(): if m._get_name() == self.module_name: m.cost() def get_cost(self, model): params = 0 for m in model.modules(): if m._get_name() == self.module_name: m.cost() params += m.params return params @staticmethod def concat(input): return torch.cat([item for item in input]) def find(input): # Find as in MATLAB to find indices in a binary vector return [i for i, j in enumerate(input) if j] def concat(input): return torch.cat([item for item in input]) def save_checkpoint(state, filename='checkpoint.pth.tar'): torch.save(state, filename) def get_error(output, target, topk=(1,)): """Computes the error@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(100.0 - correct_k.mul_(100.0 / batch_size)) return res class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def get_inf_params(net, verbose=True, sd=False): if sd: params = net else: params = net.state_dict() tot = 0 conv_tot = 0 for p in params: no = params[p].view(-1).__len__() if ('num_batches_tracked' not in p) and ('running' not in p) and ('mask' not in p): tot += no if verbose: print('%s has %d params' % (p, no)) if 'conv' in p: conv_tot += no if verbose: print('Net has %d conv params' % conv_tot) print('Net has %d params in total' % tot) return tot count_ops = 0 count_params = 0 def get_num_gen(gen): return sum(1 for x in gen) def is_pruned(layer): try: layer.mask return True except AttributeError: return False def is_leaf(model): return get_num_gen(model.children()) == 0 def get_layer_info(layer): layer_str = str(layer) type_name = layer_str[:layer_str.find('(')].strip() return type_name def get_layer_param(model): return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()]) ### The input batch size should be 1 to call this function def measure_layer(layer, x): global count_ops, count_params delta_ops = 0 delta_params = 0 multi_add = 1 type_name = get_layer_info(layer) ### ops_conv if type_name in ['Conv2d']: out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) / layer.stride[0] + 1) out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) / layer.stride[1] + 1) delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \ layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add delta_params = get_layer_param(layer) ### ops_learned_conv elif type_name in ['LearnedGroupConv']: measure_layer(layer.relu, x) measure_layer(layer.norm, x) conv = layer.conv out_h = int((x.size()[2] + 2 * conv.padding[0] - conv.kernel_size[0]) / conv.stride[0] + 1) out_w = int((x.size()[3] + 2 * conv.padding[1] - conv.kernel_size[1]) / conv.stride[1] + 1) delta_ops = conv.in_channels * conv.out_channels * conv.kernel_size[0] * \ conv.kernel_size[1] * out_h * out_w / layer.condense_factor * multi_add delta_params = get_layer_param(conv) / layer.condense_factor ### ops_nonlinearity elif type_name in ['ReLU']: delta_ops = x.numel() delta_params = get_layer_param(layer) ### ops_pooling elif type_name in ['AvgPool2d']: in_w = x.size()[2] kernel_ops = layer.kernel_size * layer.kernel_size out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1) out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1) delta_ops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops print(delta_ops) delta_params = get_layer_param(layer) elif type_name in ['AdaptiveAvgPool2d']: delta_ops = x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3] delta_params = get_layer_param(layer) ### ops_linear elif type_name in ['Linear']: weight_ops = layer.weight.numel() * multi_add bias_ops = layer.bias.numel() delta_ops = x.size()[0] * (weight_ops + bias_ops) delta_params = get_layer_param(layer) ### ops_nothing elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout']: delta_params = get_layer_param(layer) ### unknown layer type else: None # raise TypeError('unknown layer type: %s' % type_name) count_ops += delta_ops count_params += delta_params return def measure_model(model, H, W): global count_ops, count_params count_ops = 0 count_params = 0 data = Variable(torch.zeros(1, 3, H, W)) def should_measure(x): return is_leaf(x) or is_pruned(x) def modify_forward(model): for child in model.children(): if should_measure(child): def new_forward(m): def lambda_forward(x): measure_layer(m, x) return m.old_forward(x) return lambda_forward child.old_forward = child.forward child.forward = new_forward(child) else: modify_forward(child) def restore_forward(model): for child in model.children(): # leaf node if is_leaf(child) and hasattr(child, 'old_forward'): child.forward = child.old_forward child.old_forward = None else: restore_forward(child) modify_forward(model) model.forward(data) restore_forward(model) return count_ops, count_params ================================================ FILE: models/__init__.py ================================================ from .wideresnet import * from .densenet import * ================================================ FILE: models/densenet.py ================================================ import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.autograd import Variable import torchvision.datasets as dset import torchvision.transforms as transforms from torch.utils.data import DataLoader import torchvision.models as models import sys import math class Identity(nn.Module): def __init__(self): super(Identity, self).__init__() def forward(self, x): return x class Zero(nn.Module): def __init__(self): super(Zero, self).__init__() def forward(self, x): return x * 0 class ZeroMake(nn.Module): def __init__(self, channels, spatial): super(ZeroMake, self).__init__() self.spatial = spatial self.channels = channels def forward(self, x): return torch.zeros([x.size()[0], self.channels, x.size()[2] // self.spatial, x.size()[3] // self.spatial], dtype=x.dtype, layout=x.layout, device=x.device) class MaskBlock(nn.Module): def __init__(self, nChannels, growthRate): super(MaskBlock, self).__init__() interChannels = 4 * growthRate self.bn1 = nn.BatchNorm2d(nChannels) self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) self.bn2 = nn.BatchNorm2d(interChannels) self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False) self.activation = Identity() self.activation.register_backward_hook(self._fisher) self.register_buffer('mask', None) self.input_shape = None self.output_shape = None self.flops = None self.params = None self.in_channels = nChannels self.out_channels = growthRate self.stride = 1 # Fisher method is called on backward passes self.running_fisher = 0 def forward(self, x): out = self.conv1(F.relu(self.bn1(x))) out = F.relu(self.bn2(out)) if self.mask is not None: out = out * self.mask[None, :, None, None] else: self._create_mask(x, out) out = self.activation(out) self.act = out out = self.conv2(out) out = torch.cat([x, out], 1) return out def _create_mask(self, x, out): """This takes an activation to generate the exact mask required. It also records input and output shapes for posterity.""" self.mask = x.new_ones(out.shape[1]) self.input_shape = x.size() self.output_shape = out.size() def _fisher(self, _, __, grad_output): act = self.act.detach() grad = grad_output[0].detach() g_nk = (act * grad).sum(-1).sum(-1) del_k = g_nk.pow(2).mean(0).mul(0.5) self.running_fisher += del_k def reset_fisher(self): self.running_fisher = 0 * self.running_fisher def update(self, previous_mask): # This is only required for non-modular nets. return None def cost(self): in_channels = self.in_channels out_channels = self.out_channels middle_channels = int(self.mask.sum().item()) conv1_size = self.conv1.weight.size() conv2_size = self.conv2.weight.size() self.params = in_channels * middle_channels * conv1_size[2] * conv1_size[3] + middle_channels * out_channels * \ conv2_size[2] * conv2_size[3] self.params += 2 * in_channels + 2 * middle_channels def compress_weights(self): middle_dim = int(self.mask.sum().item()) if middle_dim is not 0: conv1 = nn.Conv2d(self.in_channels, middle_dim, kernel_size=3, stride=1, bias=False) conv1.weight = nn.Parameter(self.conv1.weight[self.mask == 1, :, :, :]) # Batch norm 2 changes bn2 = nn.BatchNorm2d(middle_dim) bn2.weight = nn.Parameter(self.bn2.weight[self.mask == 1]) bn2.bias = nn.Parameter(self.bn2.bias[self.mask == 1]) bn2.running_mean = self.bn2.running_mean[self.mask == 1] bn2.running_var = self.bn2.running_var[self.mask == 1] conv2 = nn.Conv2d(middle_dim, self.out_channels, kernel_size=3, stride=1, padding=1, bias=False) conv2.weight = nn.Parameter(self.conv2.weight[:, self.mask == 1, :, :]) if middle_dim is 0: conv1 = Zero() bn2 = Zero() conv2 = ZeroMake(channels=self.out_channels, spatial=self.stride) self.conv1 = conv1 self.conv2 = conv2 self.bn2 = bn2 if middle_dim is not 0: self.mask = torch.ones(middle_dim) else: self.mask = torch.ones(1) class Bottleneck(nn.Module): def __init__(self, nChannels, growthRate, width=1): super(Bottleneck, self).__init__() interChannels = int(4 * growthRate * width) self.bn1 = nn.BatchNorm2d(nChannels) self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) self.bn2 = nn.BatchNorm2d(interChannels) self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False) def forward(self, x): out = self.conv1(F.relu(self.bn1(x))) out = self.conv2(F.relu(self.bn2(out))) out = torch.cat((x, out), 1) return out class SingleLayer(nn.Module): def __init__(self, nChannels, growthRate): super(SingleLayer, self).__init__() self.bn1 = nn.BatchNorm2d(nChannels) self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False) def forward(self, x): out = self.conv1(F.relu(self.bn1(x))) out = torch.cat((x, out), 1) return out class Transition(nn.Module): def __init__(self, nChannels, nOutChannels): super(Transition, self).__init__() self.bn1 = nn.BatchNorm2d(nChannels) self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) def forward(self, x): out = self.conv1(F.relu(self.bn1(x))) out = F.avg_pool2d(out, 2) return out class DenseNet(nn.Module): def __init__(self, growthRate, depth, reduction, nClasses, bottleneck, mask=False, width=1.): super(DenseNet, self).__init__() nDenseBlocks = (depth - 4) // 3 if bottleneck: nDenseBlocks //= 2 nChannels = 2 * growthRate self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, mask, width) nChannels += nDenseBlocks * growthRate nOutChannels = int(math.floor(nChannels * reduction)) self.trans1 = Transition(nChannels, nOutChannels) nChannels = nOutChannels self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, mask, width) nChannels += nDenseBlocks * growthRate nOutChannels = int(math.floor(nChannels * reduction)) self.trans2 = Transition(nChannels, nOutChannels) nChannels = nOutChannels self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, mask, width) nChannels += nDenseBlocks * growthRate self.bn1 = nn.BatchNorm2d(nChannels) self.fc = nn.Linear(nChannels, nClasses) # Count params that don't exist in blocks (conv1, bn1, fc, trans1, trans2, trans3) self.fixed_params = len(self.conv1.weight.view(-1)) + len(self.bn1.weight) + len(self.bn1.bias) + \ len(self.fc.weight.view(-1)) + len(self.fc.bias) self.fixed_params += len(self.trans1.conv1.weight.view(-1)) + 2 * len(self.trans1.bn1.weight) self.fixed_params += len(self.trans2.conv1.weight.view(-1)) + 2 * len(self.trans2.bn1.weight) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.bias.data.zero_() def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck, mask=False, width=1): layers = [] for i in range(int(nDenseBlocks)): if bottleneck and mask: layers.append(MaskBlock(nChannels, growthRate)) elif bottleneck: layers.append(Bottleneck(nChannels, growthRate, width)) else: layers.append(SingleLayer(nChannels, growthRate)) nChannels += growthRate return nn.Sequential(*layers) def forward(self, x): out = self.conv1(x) out = self.trans1(self.dense1(out)) out = self.trans2(self.dense2(out)) out = self.dense3(out) out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8)) out = self.fc(out) return out ================================================ FILE: models/wideresnet.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F class Identity(nn.Module): def __init__(self): super(Identity, self).__init__() def forward(self, x): return x class Zero(nn.Module): def __init__(self): super(Zero, self).__init__() def forward(self, x): return x * 0 class ZeroMake(nn.Module): def __init__(self, channels, spatial): super(ZeroMake, self).__init__() self.spatial = spatial self.channels = channels def forward(self, x): return torch.zeros([x.size()[0], self.channels, x.size()[2] // self.spatial, x.size()[3] // self.spatial], dtype=x.dtype, layout=x.layout, device=x.device) class BasicBlock(nn.Module): def __init__(self, in_channels, out_channels, stride, dropRate=0.0): super(BasicBlock, self).__init__() self.bn1 = nn.BatchNorm2d(in_channels) self.relu1 = nn.ReLU(inplace=True) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.relu2 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.droprate = dropRate self.equalInOut = (in_channels == out_channels) self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False) or None def forward(self, x): if not self.equalInOut: x = self.relu1(self.bn1(x)) else: out = self.relu1(self.bn1(x)) out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) if self.droprate > 0: out = F.dropout(out, p=self.droprate, training=self.training) out = self.conv2(out) return torch.add(x if self.equalInOut else self.convShortcut(x), out) class BottleBlock(nn.Module): def __init__(self, in_channels, out_channels, mid_channels, stride, dropRate=0.0): super(BottleBlock, self).__init__() self.bn1 = nn.BatchNorm2d(in_channels) self.relu1 = nn.ReLU(inplace=True) self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(mid_channels) self.relu2 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.droprate = dropRate self.equalInOut = (in_channels == out_channels) self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False) or None def forward(self, x): if not self.equalInOut: x = self.relu1(self.bn1(x)) else: out = self.relu1(self.bn1(x)) out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) if self.droprate > 0: out = F.dropout(out, p=self.droprate, training=self.training) out = self.conv2(out) return torch.add(x if self.equalInOut else self.convShortcut(x), out) class MaskBlock(nn.Module): expansion = 1 def __init__(self, in_channels, out_channels, stride=1, dropRate=0.0): super(MaskBlock, self).__init__() self.bn1 = nn.BatchNorm2d(in_channels) self.relu1 = nn.ReLU(inplace=True) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.relu2 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.droprate = dropRate self.equalInOut = (in_channels == out_channels) self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False) or None self.activation = Identity() self.activation.register_backward_hook(self._fisher) self.register_buffer('mask', None) self.input_shape = None self.output_shape = None self.flops = None self.params = None self.in_channels = in_channels self.out_channels = out_channels self.stride = stride self.got_shapes = False # Fisher method is called on backward passes self.running_fisher = 0 def forward(self, x): if not self.equalInOut: x = self.relu1(self.bn1(x)) else: out = self.relu1(self.bn1(x)) out = self.conv1(out if self.equalInOut else x) out = self.relu2(self.bn2(out)) if self.mask is not None: out = out * self.mask[None, :, None, None] else: self._create_mask(x, out) out = self.activation(out) self.act = out if self.droprate > 0: out = F.dropout(out, p=self.droprate, training=self.training) out = self.conv2(out) return torch.add(x if self.equalInOut else self.convShortcut(x), out) def _create_mask(self, x, out): self.mask = x.new_ones(out.shape[1]) self.input_shape = x.size() self.output_shape = out.size() def _fisher(self, notused1, notused2, grad_output): act = self.act.detach() grad = grad_output[0].detach() g_nk = (act * grad).sum(-1).sum(-1) del_k = g_nk.pow(2).mean(0).mul(0.5) self.running_fisher += del_k def reset_fisher(self): self.running_fisher = 0 * self.running_fisher def cost(self): in_channels = self.in_channels out_channels = self.out_channels middle_channels = int(self.mask.sum().item()) conv1_size = self.conv1.weight.size() conv2_size = self.conv2.weight.size() # convs self.params = in_channels * middle_channels * conv1_size[2] * conv1_size[3] + middle_channels * out_channels * \ conv2_size[2] * conv2_size[3] # batchnorms, assuming running stats are absorbed self.params += 2 * in_channels + 2 * middle_channels # skip if not self.equalInOut: self.params += in_channels * out_channels else: self.params += 0 def compress_weights(self): middle_dim = int(self.mask.sum().item()) print(middle_dim) if middle_dim is not 0: conv1 = nn.Conv2d(self.in_channels, middle_dim, kernel_size=3, stride=self.stride, padding=1, bias=False) conv1.weight = nn.Parameter(self.conv1.weight[self.mask == 1, :, :, :]) # Batch norm 2 changes bn2 = nn.BatchNorm2d(middle_dim) bn2.weight = nn.Parameter(self.bn2.weight[self.mask == 1]) bn2.bias = nn.Parameter(self.bn2.bias[self.mask == 1]) bn2.running_mean = self.bn2.running_mean[self.mask == 1] bn2.running_var = self.bn2.running_var[self.mask == 1] conv2 = nn.Conv2d(middle_dim, self.out_channels, kernel_size=3, stride=1, padding=1, bias=False) conv2.weight = nn.Parameter(self.conv2.weight[:, self.mask == 1, :, :]) if middle_dim is 0: conv1 = Zero() bn2 = Zero() conv2 = ZeroMake(channels=self.out_channels, spatial=self.stride) self.conv1 = conv1 self.conv2 = conv2 self.bn2 = bn2 if middle_dim is not 0: self.mask = torch.ones(middle_dim) else: self.mask = torch.ones(1) class NetworkBlock(nn.Module): def __init__(self, nb_layers, in_channels, out_channels, block, stride, dropRate=0.0): super(NetworkBlock, self).__init__() self.layer = self._make_layer(block, in_channels, out_channels, nb_layers, stride, dropRate) def _make_layer(self, block, in_channels, out_channels, nb_layers, stride, dropRate): layers = [] for i in range(int(nb_layers)): layers.append(block(i == 0 and in_channels or out_channels, out_channels, i == 0 and stride or 1, dropRate)) return nn.Sequential(*layers) def forward(self, x): return self.layer(x) class NetworkBlockBottle(nn.Module): def __init__(self, nb_layers, in_channels, out_channels, mid_channels, block, stride, dropRate=0.0): super(NetworkBlockBottle, self).__init__() self.layer = self._make_layer(block, in_channels, out_channels, mid_channels, nb_layers, stride, dropRate) def _make_layer(self, block, in_channels, out_channels, mid_channels, nb_layers, stride, dropRate): layers = [] for i in range(int(nb_layers)): layers.append( block(i == 0 and in_channels or out_channels, out_channels, mid_channels, i == 0 and stride or 1, dropRate)) return nn.Sequential(*layers) def forward(self, x): return self.layer(x) class WideResNet(nn.Module): def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0, mask=False): super(WideResNet, self).__init__() nChannels = [16, int(16 * widen_factor), int(32 * widen_factor), int(64 * widen_factor)] assert ((depth - 4) % 6 == 0) n = (depth - 4) / 6 if mask == 1: block = MaskBlock else: block = BasicBlock # 1st conv before any network block self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False) # 1st block self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) # 2nd block self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) # 3rd block self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) # global average pooling and classifier self.bn1 = nn.BatchNorm2d(nChannels[3]) self.relu = nn.ReLU(inplace=True) self.fc = nn.Linear(nChannels[3], num_classes) self.nChannels = nChannels[3] # Count params that don't exist in blocks (conv1, bn1, fc) self.fixed_params = len(self.conv1.weight.view(-1)) + len(self.bn1.weight) + len(self.bn1.bias) + \ len(self.fc.weight.view(-1)) + len(self.fc.bias) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.bias.data.zero_() def forward(self, x): out = self.conv1(x) out = self.block1(out) out = self.block2(out) out = self.block3(out) out = self.relu(self.bn1(out)) out = F.avg_pool2d(out, 8) out = out.view(-1, self.nChannels) return self.fc(out) class WideResNetBottle(nn.Module): def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0, bottle_mult=0.5): super(WideResNetBottle, self).__init__() nChannels = [16, int(16 * widen_factor), int(32 * widen_factor), int(64 * widen_factor)] assert ((depth - 4) % 6 == 0) n = (depth - 4) / 6 block = BottleBlock # 1st conv before any network block self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False) # 1st block self.block1 = NetworkBlockBottle(n, nChannels[0], nChannels[1], int(nChannels[1] * bottle_mult), block, 1, dropRate) # 2nd block self.block2 = NetworkBlockBottle(n, nChannels[1], nChannels[2], int(nChannels[2] * bottle_mult), block, 2, dropRate) # 3rd block self.block3 = NetworkBlockBottle(n, nChannels[2], nChannels[3], int(nChannels[3] * bottle_mult), block, 2, dropRate) # global average pooling and classifier self.bn1 = nn.BatchNorm2d(nChannels[3]) self.relu = nn.ReLU(inplace=True) self.fc = nn.Linear(nChannels[3], num_classes) self.nChannels = nChannels[3] for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.bias.data.zero_() def forward(self, x): out = self.conv1(x) out = self.block1(out) out = self.block2(out) out = self.block3(out) out = self.relu(self.bn1(out)) out = F.avg_pool2d(out, 8) out = out.view(-1, self.nChannels) return self.fc(out) ================================================ FILE: prune.py ================================================ """Pruning script""" import argparse import os import torch.utils.model_zoo as model_zoo from funcs import * from models import * parser = argparse.ArgumentParser(description='Pruning') parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', help='number of data loading workers') parser.add_argument('--GPU', default='0', type=str, help='GPU to use') parser.add_argument('--save_file', default='wrn16_2_p', type=str, help='save file for checkpoints') parser.add_argument('--print_freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)') parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') parser.add_argument('--resume_ckpt', default='checkpoint', type=str, help='save file for resumed checkpoint') parser.add_argument('--data_loc', default='/disk/scratch/datasets/cifar', type=str, help='where is the dataset') # Learning specific arguments parser.add_argument('--optimizer', choices=['sgd', 'adam'], default='sgd', type=str, help='optimizer') parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='N', help='mini-batch size (default: 256)') parser.add_argument('-lr', '--learning_rate', default=8e-4, type=float, metavar='LR', help='initial learning rate') parser.add_argument('-epochs', '--no_epochs', default=1300, type=int, metavar='epochs', help='no. epochs') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--weight_decay', '--wd', default=0.0005, type=float, metavar='W', help='weight decay') parser.add_argument('--prune_every', default=100, type=int, help='prune every X steps') parser.add_argument('--save_every', default=100, type=int, help='save model every X EPOCHS') parser.add_argument('--random', default=False, type=bool, help='Prune at random') parser.add_argument('--base_model', default='base_model', type=str, help='basemodel') parser.add_argument('--val_every', default=1, type=int, help='val model every X EPOCHS') parser.add_argument('--mask', default=1, type=int, help='Mask type') parser.add_argument('--l1_prune', default=False, type=bool, help='Prune via l1 norm') parser.add_argument('--net', default='dense', type=str, help='dense, res') parser.add_argument('--width', default=2.0, type=float, metavar='D') parser.add_argument('--depth', default=40, type=int, metavar='W') parser.add_argument('--growth', default=12, type=int, help='growth rate of densenet') parser.add_argument('--transition_rate', default=0.5, type=float, help='transition rate of densenet') args = parser.parse_args() print(args) os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU device = torch.device("cuda:%s" % '0' if torch.cuda.is_available() else "cpu") if args.net == 'res': model = WideResNet(args.depth, args.width, mask=args.mask) elif args.net =='dense': model = DenseNet(args.growth, args.depth, args.transition_rate, 10, True, mask=args.mask) model.load_state_dict(torch.load('checkpoints/%s.t7' % args.base_model, map_location='cpu')['state_dict'], strict=True) if args.resume: state = torch.load('checkpoints/%s.t7' % args.resume_ckpt, map_location='cpu') model = resume_from(state, model_type=args.net) error_history = state['error_history'] prune_history = state['prune_history'] flop_history = state['flop_history'] param_history = state['param_history'] start_epoch = state['epoch'] else: error_history = [] prune_history = [] param_history = [] start_epoch = 0 model.to(device) normMean = [0.49139968, 0.48215827, 0.44653124] normStd = [0.24703233, 0.24348505, 0.26158768] normTransform = transforms.Normalize(normMean, normStd) print('==> Preparing data..') num_classes = 10 transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normTransform ]) transform_val = transforms.Compose([ transforms.ToTensor(), normTransform ]) trainset = torchvision.datasets.CIFAR10(root=args.data_loc, train=True, download=True, transform=transform_train) valset = torchvision.datasets.CIFAR10(root=args.data_loc, train=False, download=True, transform=transform_val) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=False) valloader = torch.utils.data.DataLoader(valset, batch_size=50, shuffle=False, num_workers=args.workers, pin_memory=False) prune_count = 0 pruner = Pruner() pruner.prune_history = prune_history NO_STEPS = args.prune_every def finetune(): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to train mode model.train() end = time.time() dataiter = iter(trainloader) for i in range(0, NO_STEPS): try: input, target = dataiter.next() except StopIteration: dataiter = iter(trainloader) input, target = dataiter.next() # measure data loading time data_time.update(time.time() - end) input, target = input.to(device), target.to(device) # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss err1, err5 = get_error(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(err1.item(), input.size(0)) top5.update(err5.item(), input.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Prunepoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Error@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, NO_STEPS, batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) def prune(): print('Pruning') if args.random is False: if args.l1_prune is False: print('fisher pruning') pruner.fisher_prune(model, prune_every=args.prune_every) else: print('l1 pruning') pruner.l1_prune(model, prune_every=args.prune_every) else: print('random pruning') pruner.random_prune(model, ) def validate(): global error_history batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() end = time.time() for i, (input, target) in enumerate(valloader): # measure data loading time data_time.update(time.time() - end) input, target = input.to(device), target.to(device) # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss err1, err5 = get_error(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(err1.item(), input.size(0)) top5.update(err5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Error@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format( i, len(valloader), batch_time=batch_time, loss=losses, top1=top1, top5=top5)) print(' * Error@1 {top1.avg:.3f} Error@5 {top5.avg:.3f}' .format(top1=top1, top5=top5)) # Record Top 1 for CIFAR error_history.append(top1.avg) if __name__ == '__main__': criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD([v for v in model.parameters() if v.requires_grad], lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) for epoch in range(start_epoch, args.no_epochs): print('Epoch %d:' % epoch) print('Learning rate is %s' % [v['lr'] for v in optimizer.param_groups][0]) # finetune for one epoch finetune() # # evaluate on validation set if epoch != 0 and ((epoch % args.val_every == 0) or (epoch + 1 == args.no_epochs)): # Save at last epoch! validate() # Error history is recorded in validate(). Record params here no_params = pruner.get_cost(model) + model.fixed_params param_history.append(no_params) # Save before pruning if epoch != 0 and ((epoch % args.save_every == 0) or (epoch + 1 == args.no_epochs)): # filename = 'checkpoints/%s_%d_prunes.t7' % (args.save_file, epoch) save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'error_history': error_history, 'param_history': param_history, 'prune_history': pruner.prune_history, }, filename=filename) ## Prune prune() ================================================ FILE: train.py ================================================ """This script just trains models from scratch, to later be pruned""" import argparse import json import os import time import torch.optim.lr_scheduler as lr_scheduler import torch.utils.model_zoo as model_zoo from models import * from funcs import * parser = argparse.ArgumentParser(description='Pruning') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers') parser.add_argument('--GPU', default='0', type=str, help='GPU to use') parser.add_argument('--save_file', default='saveto', type=str, help='save file for checkpoints') parser.add_argument('--base_file', default='bbb', type=str, help='base file for checkpoints') parser.add_argument('--print_freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)') parser.add_argument('--data_loc', default='/disk/scratch/datasets/cifar') # Learning specific arguments parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='N', help='mini-batch size (default: 256)') parser.add_argument('-lr', '--learning_rate', default=.1, type=float, metavar='LR', help='initial learning rate') parser.add_argument('-epochs', '--no_epochs', default=200, type=int, metavar='epochs', help='no. epochs') parser.add_argument('--epoch_step', default='[60,120,160]', type=str, help='json list with epochs to drop lr on') parser.add_argument('--lr_decay_ratio', default=0.2, type=float, help='learning rate decay factor') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--weight_decay', '--wd', default=0.0005, type=float, metavar='W', help='weight decay') parser.add_argument('--eval', '-e', action='store_true', help='resume from checkpoint') parser.add_argument('--mask', '-m', type=int, help='mask mode', default=0) parser.add_argument('--deploy', '-de', action='store_true', help='prune and deploy model') parser.add_argument('--params_left', '-pl', default=0, type=int, help='prune til...') parser.add_argument('--net', choices=['res', 'dense'], default='res') # Net specific parser.add_argument('--depth', '-d', default=40, type=int, metavar='D', help='depth of wideresnet/densenet') parser.add_argument('--width', '-w', default=2.0, type=float, metavar='W', help='width of wideresnet') parser.add_argument('--growth', default=12, type=int, help='growth rate of densenet') parser.add_argument('--transition_rate', default=0.5, type=float, help='transition rate of densenet') # Uniform bottlenecks parser.add_argument('--bottle', action='store_true', help='Linearly scale bottlenecks') parser.add_argument('--bottle_mult', default=0.5, type=float, help='bottleneck multiplier') if not os.path.exists('checkpoints/'): os.makedirs('checkpoints/') args = parser.parse_args() print(args) os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if args.net == 'res': if not args.bottle: model = WideResNet(args.depth, args.width, mask=args.mask) else: model = WideResNetBottle(args.depth, args.width, bottle_mult=args.bottle_mult) elif args.net == 'dense': if not args.bottle: model = DenseNet(args.growth, args.depth, args.transition_rate, 10, True, mask=args.mask) else: model = DenseNet(args.growth, args.depth, args.transition_rate, 10, True, width=args.bottle_mult) else: raise ValueError('pick a valid net') pruner = Pruner() if args.deploy: # Feed example to activate masks model(torch.rand(1, 3, 32, 32)) SD = torch.load('checkpoints/%s.t7' % args.base_file) if not args.eval: pruner = Pruner() pruner._get_masks(model) for ii in SD['prune_history']: pruner.fixed_prune(model, ii) else: model.load_state_dict(SD['state_dict']) pruner.compress(model) get_inf_params(model) time.sleep(1) model.to(device) normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) print('==> Preparing data..') num_classes = 10 transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4), mode='reflect').squeeze()), transforms.ToPILImage(), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) transform_val = transforms.Compose([ transforms.ToTensor(), normalize, ]) trainset = torchvision.datasets.CIFAR10(root=args.data_loc, train=True, download=True, transform=transform_train) valset = torchvision.datasets.CIFAR10(root=args.data_loc, train=False, download=True, transform=transform_val) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=False) valloader = torch.utils.data.DataLoader(valset, batch_size=50, shuffle=False, num_workers=args.workers, pin_memory=False) error_history = [] epoch_step = json.loads(args.epoch_step) def train(): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to train mode model.train() end = time.time() for i, (input, target) in enumerate(trainloader): # measure data loading time data_time.update(time.time() - end) input, target = input.to(device), target.to(device) # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss err1, err5 = get_error(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(err1.item(), input.size(0)) top5.update(err5.item(), input.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Error@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(trainloader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) def validate(): global error_history batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() end = time.time() for i, (input, target) in enumerate(valloader): # measure data loading time data_time.update(time.time() - end) input, target = input.to(device), target.to(device) # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss err1, err5 = get_error(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(err1.item(), input.size(0)) top5.update(err5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Error@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format( i, len(valloader), batch_time=batch_time, loss=losses, top1=top1, top5=top5)) print(' * Error@1 {top1.avg:.3f} Error@5 {top5.avg:.3f}' .format(top1=top1, top5=top5)) # Record Top 1 for CIFAR error_history.append(top1.avg) if __name__ == '__main__': filename = 'checkpoints/%s.t7' % args.save_file criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD([v for v in model.parameters() if v.requires_grad], lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=epoch_step, gamma=args.lr_decay_ratio) if not args.eval: for epoch in range(args.no_epochs): print('Epoch %d:' % epoch) print('Learning rate is %s' % [v['lr'] for v in optimizer.param_groups][0]) # train for one epoch train() scheduler.step() # # evaluate on validation set validate() save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'error_history': error_history, }, filename=filename) else: if not args.deploy: model.load_state_dict(torch.load(filename)['state_dict']) epoch = 0 validate()