Repository: BayesWatch/pytorch-prunes
Branch: master
Commit: bc85a5c52865
Files: 8
Total size: 56.7 KB
Directory structure:
gitextract_3q8rylo5/
├── LICENSE
├── README.md
├── funcs.py
├── models/
│ ├── __init__.py
│ ├── densenet.py
│ └── wideresnet.py
├── prune.py
└── train.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2019 BayesWatch
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# A Closer Look at Structured Pruning for Neural Network Compression
Code used to reproduce experiments in https://arxiv.org/abs/1810.04622.
To prune, we fill our networks with custom `MaskBlocks`, which are manipulated using `Pruner` in funcs.py. There will certainly be a better way to do this, but we leave this as an exercise to someone who can code much better than we can.
## Setup
This is best done in a clean conda environment:
```
conda create -n prunes python=3.6
conda activate prunes
conda install pytorch torchvision -c pytorch
```
## Repository layout
-`train.py`: contains all of the code for training large models from scratch and for training pruned models from scratch
-`prune.py`: contains the code for pruning trained models
-`funcs.py`: contains useful pruning functions and any functions we used commonly
## CIFAR Experiments
First, you will need some initial models.
To train a WRN-40-2:
```
python train.py --net='res' --depth=40 --width=2.0 --data_loc=<path-to-data> --save_file='res'
```
The default arguments of train.py are suitable for training WRNs. The following trains a DenseNet-BC-100 (k=12) with its default hyperparameters:
```
python train.py --net='dense' --depth=100 --data_loc=<path-to-data> --save_file='dense' --no_epochs 300 -b 64 --epoch_step '[150,225]' --weight_decay 0.0001 --lr_decay_ratio 0.1
```
These will automatically save checkpoints to the `checkpoints` folder.
### Pruning
Once training is finished, we can prune our networks using prune.py (defaults are set to WRN pruning, so extra arguments are needed for DenseNets)
```
python prune.py --net='res' --data_loc=<path-to-data> --base_model='res' --save_file='res_fisher'
python prune.py --net='res' --data_loc=<path-to-data> --l1_prune=True --base_model='res' --save_file='res_l1'
python prune.py --net='dense' --depth 100 --data_loc=<path-to-data> --base_model='dense' --save_file='dense_fisher' --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64 --no_epochs 2600
python prune.py --net='dense' --depth 100 --data_loc=<path-to-data> --l1_prune=True --base_model='dense' --save_file='dense_l1' --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64 --no_epochs 2600
```
Note that the default is to perform Fisher pruning, so you don't need to pass a flag to use it.
Once finished, we can train the pruned models from scratch, e.g.:
```
python train.py --data_loc=<path-to-data> --net='res' --base_file='res_fisher_<N>_prunes' --deploy --mask=1 --save_file='res_fisher_<N>_prunes_scratch'
```
Each model can then be evaluated using:
```
python train.py --deploy --eval --data_loc=<path-to-data> --net='res' --mask=1 --base_file='res_fisher_<N>_prunes'
```
### Training Reduced models
This can be done by varying the input arguments to train.py. To reduce depth or width of a WRN, change the corresponding option:
```
python train.py --net='res' --depth=<REDUCED DEPTH> --width=<REDUCE WIDTH> --data_loc=<path-to-data> --save_file='res_reduced'
```
To add bottlenecks, use the following:
```
python train.py --net='res' --depth=40 --width=2.0 --data_loc=<path-to-data> --save_file='res_bottle' --bottle --bottle_mult <Z>
```
With DenseNets you can modify the `depth` or `growth`, or use `--bottle --bottle_mult <Z>` as above.
### Acknowledgements
[Jack Turner][jack] wrote the L1 stuff, and some other stuff for that matter.
Code has been liberally borrowed from many a repo, including, but not limited to:
```
https://github.com/xternalz/WideResNet-pytorch
https://github.com/bamos/densenet.pytorch
https://github.com/kuangliu/pytorch-cifar
https://github.com/ShichenLiu/CondenseNet
```
### Citing this work
If you would like to cite this work, please use the following bibtex entry:
```
@article{crowley2018pruning,
title={A Closer Look at Structured Pruning for Neural Network Compression},
author={Crowley, Elliot J and Turner, Jack and Storkey, Amos and O'Boyle, Michael},
journal={arXiv preprint arXiv:1810.04622},
year={2018},
}
```
[jack]: https://github.com/jack-willturner
================================================
FILE: funcs.py
================================================
import random
import numpy as np
import torchvision.transforms as transforms
import torchvision
import time
from functools import reduce
from models import *
import random
import time
import operator
import torchvision
import torchvision.transforms as transforms
from models import *
class Pruner:
def __init__(self, module_name='MaskBlock'):
# First get vector of masks
self.module_name = module_name
self.masks = []
self.prune_history = []
def fisher_prune(self, model, prune_every):
self._get_fisher(model)
tot_loss = self.fisher.div(prune_every) + 1e6 * (1 - self.masks) # dummy value for off masks
print(len(tot_loss))
min, argmin = torch.min(tot_loss, 0)
self.prune(model, argmin.item())
self.prune_history.append(argmin.item())
def fixed_prune(self, model, ID):
self.prune(model, ID)
self.prune_history.append(ID)
def random_prune(self, model):
self._get_fisher(model)
# Do this to update costs.
masks = []
for m in model.modules():
if m._get_name() == self.module_name:
masks.append(m.mask.detach())
masks = self.concat(masks)
masks_on = [i for i, v in enumerate(masks) if v == 1]
random_pick = random.choice(masks_on)
self.prune(model, random_pick)
self.prune_history.append(random_pick)
def l1_prune(self, model, prune_every):
masks = []
l1_norms = []
for m in model.modules():
if m._get_name() == 'MaskBlock':
l1_norm = torch.sum(m.conv1.weight, (1, 2, 3)).detach().cpu().numpy()
masks.append(m.mask.detach())
l1_norms.append(l1_norm)
masks = self.concat(masks)
self.masks = masks
l1_norms = np.concatenate(l1_norms)
l1_norms_on = []
for m, l in zip(masks, l1_norms):
if m == 1:
l1_norms_on.append(l)
else:
l1_norms_on.append(9999.) # dummy value
smallest_norm = min(l1_norms_on)
pick = np.where(l1_norms == smallest_norm)[0][0]
self.prune(model, pick)
self.prune_history.append(pick)
def prune(self, model, feat_index):
print('Pruned %d out of %d channels so far' % (len(self.prune_history), len(self.masks)))
if len(self.prune_history) > len(self.masks):
raise Exception('Time to stop')
"""feat_index refers to the index of a feature map. This function modifies the mask to turn it off."""
safe = 0
running_index = 0
for m in model.modules():
if m._get_name() == self.module_name:
mask_indices = range(running_index, running_index + len(m.mask))
if feat_index in mask_indices:
print('Pruning channel %d' % feat_index)
local_index = mask_indices.index(feat_index)
m.mask[local_index] = 0
safe = 1
break
else:
running_index += len(m.mask)
# print(running_index)
if not safe:
raise Exception('The provided index doesn''t correspond to any feature maps. This is bad.')
def compress(self, model):
for m in model.modules():
if m._get_name() == 'MaskBlock':
m.compress_weights()
def _get_fisher(self, model):
masks = []
fisher = []
self._update_cost(model)
for m in model.modules():
if m._get_name() == self.module_name:
masks.append(m.mask.detach())
fisher.append(m.running_fisher.detach())
# Now clear the fisher cache
m.reset_fisher()
self.masks = self.concat(masks)
self.fisher = self.concat(fisher)
def _get_masks(self, model):
masks = []
for m in model.modules():
if m._get_name() == self.module_name:
masks.append(m.mask.detach())
self.masks = self.concat(masks)
def _update_cost(self, model):
for m in model.modules():
if m._get_name() == self.module_name:
m.cost()
def get_cost(self, model):
params = 0
for m in model.modules():
if m._get_name() == self.module_name:
m.cost()
params += m.params
return params
@staticmethod
def concat(input):
return torch.cat([item for item in input])
def find(input):
# Find as in MATLAB to find indices in a binary vector
return [i for i, j in enumerate(input) if j]
def concat(input):
return torch.cat([item for item in input])
def save_checkpoint(state, filename='checkpoint.pth.tar'):
torch.save(state, filename)
def get_error(output, target, topk=(1,)):
"""Computes the error@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(100.0 - correct_k.mul_(100.0 / batch_size))
return res
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get_inf_params(net, verbose=True, sd=False):
if sd:
params = net
else:
params = net.state_dict()
tot = 0
conv_tot = 0
for p in params:
no = params[p].view(-1).__len__()
if ('num_batches_tracked' not in p) and ('running' not in p) and ('mask' not in p):
tot += no
if verbose:
print('%s has %d params' % (p, no))
if 'conv' in p:
conv_tot += no
if verbose:
print('Net has %d conv params' % conv_tot)
print('Net has %d params in total' % tot)
return tot
count_ops = 0
count_params = 0
def get_num_gen(gen):
return sum(1 for x in gen)
def is_pruned(layer):
try:
layer.mask
return True
except AttributeError:
return False
def is_leaf(model):
return get_num_gen(model.children()) == 0
def get_layer_info(layer):
layer_str = str(layer)
type_name = layer_str[:layer_str.find('(')].strip()
return type_name
def get_layer_param(model):
return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()])
### The input batch size should be 1 to call this function
def measure_layer(layer, x):
global count_ops, count_params
delta_ops = 0
delta_params = 0
multi_add = 1
type_name = get_layer_info(layer)
### ops_conv
if type_name in ['Conv2d']:
out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) /
layer.stride[0] + 1)
out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) /
layer.stride[1] + 1)
delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \
layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add
delta_params = get_layer_param(layer)
### ops_learned_conv
elif type_name in ['LearnedGroupConv']:
measure_layer(layer.relu, x)
measure_layer(layer.norm, x)
conv = layer.conv
out_h = int((x.size()[2] + 2 * conv.padding[0] - conv.kernel_size[0]) /
conv.stride[0] + 1)
out_w = int((x.size()[3] + 2 * conv.padding[1] - conv.kernel_size[1]) /
conv.stride[1] + 1)
delta_ops = conv.in_channels * conv.out_channels * conv.kernel_size[0] * \
conv.kernel_size[1] * out_h * out_w / layer.condense_factor * multi_add
delta_params = get_layer_param(conv) / layer.condense_factor
### ops_nonlinearity
elif type_name in ['ReLU']:
delta_ops = x.numel()
delta_params = get_layer_param(layer)
### ops_pooling
elif type_name in ['AvgPool2d']:
in_w = x.size()[2]
kernel_ops = layer.kernel_size * layer.kernel_size
out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
delta_ops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops
print(delta_ops)
delta_params = get_layer_param(layer)
elif type_name in ['AdaptiveAvgPool2d']:
delta_ops = x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3]
delta_params = get_layer_param(layer)
### ops_linear
elif type_name in ['Linear']:
weight_ops = layer.weight.numel() * multi_add
bias_ops = layer.bias.numel()
delta_ops = x.size()[0] * (weight_ops + bias_ops)
delta_params = get_layer_param(layer)
### ops_nothing
elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout']:
delta_params = get_layer_param(layer)
### unknown layer type
else:
None
# raise TypeError('unknown layer type: %s' % type_name)
count_ops += delta_ops
count_params += delta_params
return
def measure_model(model, H, W):
global count_ops, count_params
count_ops = 0
count_params = 0
data = Variable(torch.zeros(1, 3, H, W))
def should_measure(x):
return is_leaf(x) or is_pruned(x)
def modify_forward(model):
for child in model.children():
if should_measure(child):
def new_forward(m):
def lambda_forward(x):
measure_layer(m, x)
return m.old_forward(x)
return lambda_forward
child.old_forward = child.forward
child.forward = new_forward(child)
else:
modify_forward(child)
def restore_forward(model):
for child in model.children():
# leaf node
if is_leaf(child) and hasattr(child, 'old_forward'):
child.forward = child.old_forward
child.old_forward = None
else:
restore_forward(child)
modify_forward(model)
model.forward(data)
restore_forward(model)
return count_ops, count_params
================================================
FILE: models/__init__.py
================================================
from .wideresnet import *
from .densenet import *
================================================
FILE: models/densenet.py
================================================
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.models as models
import sys
import math
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self):
super(Zero, self).__init__()
def forward(self, x):
return x * 0
class ZeroMake(nn.Module):
def __init__(self, channels, spatial):
super(ZeroMake, self).__init__()
self.spatial = spatial
self.channels = channels
def forward(self, x):
return torch.zeros([x.size()[0], self.channels, x.size()[2] // self.spatial, x.size()[3] // self.spatial],
dtype=x.dtype, layout=x.layout, device=x.device)
class MaskBlock(nn.Module):
def __init__(self, nChannels, growthRate):
super(MaskBlock, self).__init__()
interChannels = 4 * growthRate
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1,
bias=False)
self.bn2 = nn.BatchNorm2d(interChannels)
self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3,
padding=1, bias=False)
self.activation = Identity()
self.activation.register_backward_hook(self._fisher)
self.register_buffer('mask', None)
self.input_shape = None
self.output_shape = None
self.flops = None
self.params = None
self.in_channels = nChannels
self.out_channels = growthRate
self.stride = 1
# Fisher method is called on backward passes
self.running_fisher = 0
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = F.relu(self.bn2(out))
if self.mask is not None:
out = out * self.mask[None, :, None, None]
else:
self._create_mask(x, out)
out = self.activation(out)
self.act = out
out = self.conv2(out)
out = torch.cat([x, out], 1)
return out
def _create_mask(self, x, out):
"""This takes an activation to generate the exact mask required. It also records input and output shapes
for posterity."""
self.mask = x.new_ones(out.shape[1])
self.input_shape = x.size()
self.output_shape = out.size()
def _fisher(self, _, __, grad_output):
act = self.act.detach()
grad = grad_output[0].detach()
g_nk = (act * grad).sum(-1).sum(-1)
del_k = g_nk.pow(2).mean(0).mul(0.5)
self.running_fisher += del_k
def reset_fisher(self):
self.running_fisher = 0 * self.running_fisher
def update(self, previous_mask):
# This is only required for non-modular nets.
return None
def cost(self):
in_channels = self.in_channels
out_channels = self.out_channels
middle_channels = int(self.mask.sum().item())
conv1_size = self.conv1.weight.size()
conv2_size = self.conv2.weight.size()
self.params = in_channels * middle_channels * conv1_size[2] * conv1_size[3] + middle_channels * out_channels * \
conv2_size[2] * conv2_size[3]
self.params += 2 * in_channels + 2 * middle_channels
def compress_weights(self):
middle_dim = int(self.mask.sum().item())
if middle_dim is not 0:
conv1 = nn.Conv2d(self.in_channels, middle_dim, kernel_size=3, stride=1, bias=False)
conv1.weight = nn.Parameter(self.conv1.weight[self.mask == 1, :, :, :])
# Batch norm 2 changes
bn2 = nn.BatchNorm2d(middle_dim)
bn2.weight = nn.Parameter(self.bn2.weight[self.mask == 1])
bn2.bias = nn.Parameter(self.bn2.bias[self.mask == 1])
bn2.running_mean = self.bn2.running_mean[self.mask == 1]
bn2.running_var = self.bn2.running_var[self.mask == 1]
conv2 = nn.Conv2d(middle_dim, self.out_channels, kernel_size=3, stride=1, padding=1, bias=False)
conv2.weight = nn.Parameter(self.conv2.weight[:, self.mask == 1, :, :])
if middle_dim is 0:
conv1 = Zero()
bn2 = Zero()
conv2 = ZeroMake(channels=self.out_channels, spatial=self.stride)
self.conv1 = conv1
self.conv2 = conv2
self.bn2 = bn2
if middle_dim is not 0:
self.mask = torch.ones(middle_dim)
else:
self.mask = torch.ones(1)
class Bottleneck(nn.Module):
def __init__(self, nChannels, growthRate, width=1):
super(Bottleneck, self).__init__()
interChannels = int(4 * growthRate * width)
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1,
bias=False)
self.bn2 = nn.BatchNorm2d(interChannels)
self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3,
padding=1, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
out = torch.cat((x, out), 1)
return out
class SingleLayer(nn.Module):
def __init__(self, nChannels, growthRate):
super(SingleLayer, self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3,
padding=1, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = torch.cat((x, out), 1)
return out
class Transition(nn.Module):
def __init__(self, nChannels, nOutChannels):
super(Transition, self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1,
bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = F.avg_pool2d(out, 2)
return out
class DenseNet(nn.Module):
def __init__(self, growthRate, depth, reduction, nClasses, bottleneck, mask=False, width=1.):
super(DenseNet, self).__init__()
nDenseBlocks = (depth - 4) // 3
if bottleneck:
nDenseBlocks //= 2
nChannels = 2 * growthRate
self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1,
bias=False)
self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, mask, width)
nChannels += nDenseBlocks * growthRate
nOutChannels = int(math.floor(nChannels * reduction))
self.trans1 = Transition(nChannels, nOutChannels)
nChannels = nOutChannels
self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, mask, width)
nChannels += nDenseBlocks * growthRate
nOutChannels = int(math.floor(nChannels * reduction))
self.trans2 = Transition(nChannels, nOutChannels)
nChannels = nOutChannels
self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, mask, width)
nChannels += nDenseBlocks * growthRate
self.bn1 = nn.BatchNorm2d(nChannels)
self.fc = nn.Linear(nChannels, nClasses)
# Count params that don't exist in blocks (conv1, bn1, fc, trans1, trans2, trans3)
self.fixed_params = len(self.conv1.weight.view(-1)) + len(self.bn1.weight) + len(self.bn1.bias) + \
len(self.fc.weight.view(-1)) + len(self.fc.bias)
self.fixed_params += len(self.trans1.conv1.weight.view(-1)) + 2 * len(self.trans1.bn1.weight)
self.fixed_params += len(self.trans2.conv1.weight.view(-1)) + 2 * len(self.trans2.bn1.weight)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck, mask=False, width=1):
layers = []
for i in range(int(nDenseBlocks)):
if bottleneck and mask:
layers.append(MaskBlock(nChannels, growthRate))
elif bottleneck:
layers.append(Bottleneck(nChannels, growthRate, width))
else:
layers.append(SingleLayer(nChannels, growthRate))
nChannels += growthRate
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.trans1(self.dense1(out))
out = self.trans2(self.dense2(out))
out = self.dense3(out)
out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8))
out = self.fc(out)
return out
================================================
FILE: models/wideresnet.py
================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self):
super(Zero, self).__init__()
def forward(self, x):
return x * 0
class ZeroMake(nn.Module):
def __init__(self, channels, spatial):
super(ZeroMake, self).__init__()
self.spatial = spatial
self.channels = channels
def forward(self, x):
return torch.zeros([x.size()[0], self.channels, x.size()[2] // self.spatial, x.size()[3] // self.spatial],
dtype=x.dtype, layout=x.layout, device=x.device)
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, dropRate=0.0):
super(BasicBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu2 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1,
padding=1, bias=False)
self.droprate = dropRate
self.equalInOut = (in_channels == out_channels)
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
padding=0, bias=False) or None
def forward(self, x):
if not self.equalInOut:
x = self.relu1(self.bn1(x))
else:
out = self.relu1(self.bn1(x))
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
if self.droprate > 0:
out = F.dropout(out, p=self.droprate, training=self.training)
out = self.conv2(out)
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
class BottleBlock(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels, stride, dropRate=0.0):
super(BottleBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(mid_channels)
self.relu2 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1,
padding=1, bias=False)
self.droprate = dropRate
self.equalInOut = (in_channels == out_channels)
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
padding=0, bias=False) or None
def forward(self, x):
if not self.equalInOut:
x = self.relu1(self.bn1(x))
else:
out = self.relu1(self.bn1(x))
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
if self.droprate > 0:
out = F.dropout(out, p=self.droprate, training=self.training)
out = self.conv2(out)
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
class MaskBlock(nn.Module):
expansion = 1
def __init__(self, in_channels, out_channels, stride=1, dropRate=0.0):
super(MaskBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu2 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.droprate = dropRate
self.equalInOut = (in_channels == out_channels)
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
padding=0, bias=False) or None
self.activation = Identity()
self.activation.register_backward_hook(self._fisher)
self.register_buffer('mask', None)
self.input_shape = None
self.output_shape = None
self.flops = None
self.params = None
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.got_shapes = False
# Fisher method is called on backward passes
self.running_fisher = 0
def forward(self, x):
if not self.equalInOut:
x = self.relu1(self.bn1(x))
else:
out = self.relu1(self.bn1(x))
out = self.conv1(out if self.equalInOut else x)
out = self.relu2(self.bn2(out))
if self.mask is not None:
out = out * self.mask[None, :, None, None]
else:
self._create_mask(x, out)
out = self.activation(out)
self.act = out
if self.droprate > 0:
out = F.dropout(out, p=self.droprate, training=self.training)
out = self.conv2(out)
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
def _create_mask(self, x, out):
self.mask = x.new_ones(out.shape[1])
self.input_shape = x.size()
self.output_shape = out.size()
def _fisher(self, notused1, notused2, grad_output):
act = self.act.detach()
grad = grad_output[0].detach()
g_nk = (act * grad).sum(-1).sum(-1)
del_k = g_nk.pow(2).mean(0).mul(0.5)
self.running_fisher += del_k
def reset_fisher(self):
self.running_fisher = 0 * self.running_fisher
def cost(self):
in_channels = self.in_channels
out_channels = self.out_channels
middle_channels = int(self.mask.sum().item())
conv1_size = self.conv1.weight.size()
conv2_size = self.conv2.weight.size()
# convs
self.params = in_channels * middle_channels * conv1_size[2] * conv1_size[3] + middle_channels * out_channels * \
conv2_size[2] * conv2_size[3]
# batchnorms, assuming running stats are absorbed
self.params += 2 * in_channels + 2 * middle_channels
# skip
if not self.equalInOut:
self.params += in_channels * out_channels
else:
self.params += 0
def compress_weights(self):
middle_dim = int(self.mask.sum().item())
print(middle_dim)
if middle_dim is not 0:
conv1 = nn.Conv2d(self.in_channels, middle_dim, kernel_size=3, stride=self.stride, padding=1, bias=False)
conv1.weight = nn.Parameter(self.conv1.weight[self.mask == 1, :, :, :])
# Batch norm 2 changes
bn2 = nn.BatchNorm2d(middle_dim)
bn2.weight = nn.Parameter(self.bn2.weight[self.mask == 1])
bn2.bias = nn.Parameter(self.bn2.bias[self.mask == 1])
bn2.running_mean = self.bn2.running_mean[self.mask == 1]
bn2.running_var = self.bn2.running_var[self.mask == 1]
conv2 = nn.Conv2d(middle_dim, self.out_channels, kernel_size=3, stride=1, padding=1, bias=False)
conv2.weight = nn.Parameter(self.conv2.weight[:, self.mask == 1, :, :])
if middle_dim is 0:
conv1 = Zero()
bn2 = Zero()
conv2 = ZeroMake(channels=self.out_channels, spatial=self.stride)
self.conv1 = conv1
self.conv2 = conv2
self.bn2 = bn2
if middle_dim is not 0:
self.mask = torch.ones(middle_dim)
else:
self.mask = torch.ones(1)
class NetworkBlock(nn.Module):
def __init__(self, nb_layers, in_channels, out_channels, block, stride, dropRate=0.0):
super(NetworkBlock, self).__init__()
self.layer = self._make_layer(block, in_channels, out_channels, nb_layers, stride, dropRate)
def _make_layer(self, block, in_channels, out_channels, nb_layers, stride, dropRate):
layers = []
for i in range(int(nb_layers)):
layers.append(block(i == 0 and in_channels or out_channels, out_channels, i == 0 and stride or 1, dropRate))
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)
class NetworkBlockBottle(nn.Module):
def __init__(self, nb_layers, in_channels, out_channels, mid_channels, block, stride, dropRate=0.0):
super(NetworkBlockBottle, self).__init__()
self.layer = self._make_layer(block, in_channels, out_channels, mid_channels, nb_layers, stride, dropRate)
def _make_layer(self, block, in_channels, out_channels, mid_channels, nb_layers, stride, dropRate):
layers = []
for i in range(int(nb_layers)):
layers.append(
block(i == 0 and in_channels or out_channels, out_channels, mid_channels, i == 0 and stride or 1,
dropRate))
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)
class WideResNet(nn.Module):
def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0, mask=False):
super(WideResNet, self).__init__()
nChannels = [16, int(16 * widen_factor), int(32 * widen_factor), int(64 * widen_factor)]
assert ((depth - 4) % 6 == 0)
n = (depth - 4) / 6
if mask == 1:
block = MaskBlock
else:
block = BasicBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
padding=1, bias=False)
# 1st block
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
# 2nd block
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
# 3rd block
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(nChannels[3])
self.relu = nn.ReLU(inplace=True)
self.fc = nn.Linear(nChannels[3], num_classes)
self.nChannels = nChannels[3]
# Count params that don't exist in blocks (conv1, bn1, fc)
self.fixed_params = len(self.conv1.weight.view(-1)) + len(self.bn1.weight) + len(self.bn1.bias) + \
len(self.fc.weight.view(-1)) + len(self.fc.bias)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
out = self.conv1(x)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.relu(self.bn1(out))
out = F.avg_pool2d(out, 8)
out = out.view(-1, self.nChannels)
return self.fc(out)
class WideResNetBottle(nn.Module):
def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0, bottle_mult=0.5):
super(WideResNetBottle, self).__init__()
nChannels = [16, int(16 * widen_factor), int(32 * widen_factor), int(64 * widen_factor)]
assert ((depth - 4) % 6 == 0)
n = (depth - 4) / 6
block = BottleBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
padding=1, bias=False)
# 1st block
self.block1 = NetworkBlockBottle(n, nChannels[0], nChannels[1], int(nChannels[1] * bottle_mult), block, 1,
dropRate)
# 2nd block
self.block2 = NetworkBlockBottle(n, nChannels[1], nChannels[2], int(nChannels[2] * bottle_mult), block, 2,
dropRate)
# 3rd block
self.block3 = NetworkBlockBottle(n, nChannels[2], nChannels[3], int(nChannels[3] * bottle_mult), block, 2,
dropRate)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(nChannels[3])
self.relu = nn.ReLU(inplace=True)
self.fc = nn.Linear(nChannels[3], num_classes)
self.nChannels = nChannels[3]
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
out = self.conv1(x)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.relu(self.bn1(out))
out = F.avg_pool2d(out, 8)
out = out.view(-1, self.nChannels)
return self.fc(out)
================================================
FILE: prune.py
================================================
"""Pruning script"""
import argparse
import os
import torch.utils.model_zoo as model_zoo
from funcs import *
from models import *
parser = argparse.ArgumentParser(description='Pruning')
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', help='number of data loading workers')
parser.add_argument('--GPU', default='0', type=str, help='GPU to use')
parser.add_argument('--save_file', default='wrn16_2_p', type=str, help='save file for checkpoints')
parser.add_argument('--print_freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
parser.add_argument('--resume_ckpt', default='checkpoint', type=str,
help='save file for resumed checkpoint')
parser.add_argument('--data_loc', default='/disk/scratch/datasets/cifar', type=str, help='where is the dataset')
# Learning specific arguments
parser.add_argument('--optimizer', choices=['sgd', 'adam'], default='sgd', type=str, help='optimizer')
parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('-lr', '--learning_rate', default=8e-4, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('-epochs', '--no_epochs', default=1300, type=int, metavar='epochs', help='no. epochs')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--weight_decay', '--wd', default=0.0005, type=float, metavar='W', help='weight decay')
parser.add_argument('--prune_every', default=100, type=int, help='prune every X steps')
parser.add_argument('--save_every', default=100, type=int, help='save model every X EPOCHS')
parser.add_argument('--random', default=False, type=bool, help='Prune at random')
parser.add_argument('--base_model', default='base_model', type=str, help='basemodel')
parser.add_argument('--val_every', default=1, type=int, help='val model every X EPOCHS')
parser.add_argument('--mask', default=1, type=int, help='Mask type')
parser.add_argument('--l1_prune', default=False, type=bool, help='Prune via l1 norm')
parser.add_argument('--net', default='dense', type=str, help='dense, res')
parser.add_argument('--width', default=2.0, type=float, metavar='D')
parser.add_argument('--depth', default=40, type=int, metavar='W')
parser.add_argument('--growth', default=12, type=int, help='growth rate of densenet')
parser.add_argument('--transition_rate', default=0.5, type=float, help='transition rate of densenet')
args = parser.parse_args()
print(args)
os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU
device = torch.device("cuda:%s" % '0' if torch.cuda.is_available() else "cpu")
if args.net == 'res':
model = WideResNet(args.depth, args.width, mask=args.mask)
elif args.net =='dense':
model = DenseNet(args.growth, args.depth, args.transition_rate, 10, True, mask=args.mask)
model.load_state_dict(torch.load('checkpoints/%s.t7' % args.base_model, map_location='cpu')['state_dict'], strict=True)
if args.resume:
state = torch.load('checkpoints/%s.t7' % args.resume_ckpt, map_location='cpu')
model = resume_from(state, model_type=args.net)
error_history = state['error_history']
prune_history = state['prune_history']
flop_history = state['flop_history']
param_history = state['param_history']
start_epoch = state['epoch']
else:
error_history = []
prune_history = []
param_history = []
start_epoch = 0
model.to(device)
normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]
normTransform = transforms.Normalize(normMean, normStd)
print('==> Preparing data..')
num_classes = 10
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normTransform
])
transform_val = transforms.Compose([
transforms.ToTensor(),
normTransform
])
trainset = torchvision.datasets.CIFAR10(root=args.data_loc,
train=True, download=True, transform=transform_train)
valset = torchvision.datasets.CIFAR10(root=args.data_loc,
train=False, download=True, transform=transform_val)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers,
pin_memory=False)
valloader = torch.utils.data.DataLoader(valset, batch_size=50, shuffle=False,
num_workers=args.workers,
pin_memory=False)
prune_count = 0
pruner = Pruner()
pruner.prune_history = prune_history
NO_STEPS = args.prune_every
def finetune():
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to train mode
model.train()
end = time.time()
dataiter = iter(trainloader)
for i in range(0, NO_STEPS):
try:
input, target = dataiter.next()
except StopIteration:
dataiter = iter(trainloader)
input, target = dataiter.next()
# measure data loading time
data_time.update(time.time() - end)
input, target = input.to(device), target.to(device)
# compute output
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
err1, err5 = get_error(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(err1.item(), input.size(0))
top5.update(err5.item(), input.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Prunepoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Error@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, NO_STEPS, batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
def prune():
print('Pruning')
if args.random is False:
if args.l1_prune is False:
print('fisher pruning')
pruner.fisher_prune(model, prune_every=args.prune_every)
else:
print('l1 pruning')
pruner.l1_prune(model, prune_every=args.prune_every)
else:
print('random pruning')
pruner.random_prune(model, )
def validate():
global error_history
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
for i, (input, target) in enumerate(valloader):
# measure data loading time
data_time.update(time.time() - end)
input, target = input.to(device), target.to(device)
# compute output
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
err1, err5 = get_error(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(err1.item(), input.size(0))
top5.update(err5.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Error@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(valloader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
print(' * Error@1 {top1.avg:.3f} Error@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
# Record Top 1 for CIFAR
error_history.append(top1.avg)
if __name__ == '__main__':
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD([v for v in model.parameters() if v.requires_grad],
lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
for epoch in range(start_epoch, args.no_epochs):
print('Epoch %d:' % epoch)
print('Learning rate is %s' % [v['lr'] for v in optimizer.param_groups][0])
# finetune for one epoch
finetune()
# # evaluate on validation set
if epoch != 0 and ((epoch % args.val_every == 0) or (epoch + 1 == args.no_epochs)): # Save at last epoch!
validate()
# Error history is recorded in validate(). Record params here
no_params = pruner.get_cost(model) + model.fixed_params
param_history.append(no_params)
# Save before pruning
if epoch != 0 and ((epoch % args.save_every == 0) or (epoch + 1 == args.no_epochs)): #
filename = 'checkpoints/%s_%d_prunes.t7' % (args.save_file, epoch)
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'error_history': error_history,
'param_history': param_history,
'prune_history': pruner.prune_history,
}, filename=filename)
## Prune
prune()
================================================
FILE: train.py
================================================
"""This script just trains models from scratch, to later be pruned"""
import argparse
import json
import os
import time
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.model_zoo as model_zoo
from models import *
from funcs import *
parser = argparse.ArgumentParser(description='Pruning')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers')
parser.add_argument('--GPU', default='0', type=str, help='GPU to use')
parser.add_argument('--save_file', default='saveto', type=str, help='save file for checkpoints')
parser.add_argument('--base_file', default='bbb', type=str, help='base file for checkpoints')
parser.add_argument('--print_freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)')
parser.add_argument('--data_loc', default='/disk/scratch/datasets/cifar')
# Learning specific arguments
parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('-lr', '--learning_rate', default=.1, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('-epochs', '--no_epochs', default=200, type=int, metavar='epochs', help='no. epochs')
parser.add_argument('--epoch_step', default='[60,120,160]', type=str, help='json list with epochs to drop lr on')
parser.add_argument('--lr_decay_ratio', default=0.2, type=float, help='learning rate decay factor')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--weight_decay', '--wd', default=0.0005, type=float, metavar='W', help='weight decay')
parser.add_argument('--eval', '-e', action='store_true', help='resume from checkpoint')
parser.add_argument('--mask', '-m', type=int, help='mask mode', default=0)
parser.add_argument('--deploy', '-de', action='store_true', help='prune and deploy model')
parser.add_argument('--params_left', '-pl', default=0, type=int, help='prune til...')
parser.add_argument('--net', choices=['res', 'dense'], default='res')
# Net specific
parser.add_argument('--depth', '-d', default=40, type=int, metavar='D', help='depth of wideresnet/densenet')
parser.add_argument('--width', '-w', default=2.0, type=float, metavar='W', help='width of wideresnet')
parser.add_argument('--growth', default=12, type=int, help='growth rate of densenet')
parser.add_argument('--transition_rate', default=0.5, type=float, help='transition rate of densenet')
# Uniform bottlenecks
parser.add_argument('--bottle', action='store_true', help='Linearly scale bottlenecks')
parser.add_argument('--bottle_mult', default=0.5, type=float, help='bottleneck multiplier')
if not os.path.exists('checkpoints/'):
os.makedirs('checkpoints/')
args = parser.parse_args()
print(args)
os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if args.net == 'res':
if not args.bottle:
model = WideResNet(args.depth, args.width, mask=args.mask)
else:
model = WideResNetBottle(args.depth, args.width, bottle_mult=args.bottle_mult)
elif args.net == 'dense':
if not args.bottle:
model = DenseNet(args.growth, args.depth, args.transition_rate, 10, True, mask=args.mask)
else:
model = DenseNet(args.growth, args.depth, args.transition_rate, 10, True, width=args.bottle_mult)
else:
raise ValueError('pick a valid net')
pruner = Pruner()
if args.deploy:
# Feed example to activate masks
model(torch.rand(1, 3, 32, 32))
SD = torch.load('checkpoints/%s.t7' % args.base_file)
if not args.eval:
pruner = Pruner()
pruner._get_masks(model)
for ii in SD['prune_history']:
pruner.fixed_prune(model, ii)
else:
model.load_state_dict(SD['state_dict'])
pruner.compress(model)
get_inf_params(model)
time.sleep(1)
model.to(device)
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
print('==> Preparing data..')
num_classes = 10
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
transform_val = transforms.Compose([
transforms.ToTensor(),
normalize,
])
trainset = torchvision.datasets.CIFAR10(root=args.data_loc,
train=True, download=True, transform=transform_train)
valset = torchvision.datasets.CIFAR10(root=args.data_loc,
train=False, download=True, transform=transform_val)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers,
pin_memory=False)
valloader = torch.utils.data.DataLoader(valset, batch_size=50, shuffle=False,
num_workers=args.workers,
pin_memory=False)
error_history = []
epoch_step = json.loads(args.epoch_step)
def train():
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to train mode
model.train()
end = time.time()
for i, (input, target) in enumerate(trainloader):
# measure data loading time
data_time.update(time.time() - end)
input, target = input.to(device), target.to(device)
# compute output
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
err1, err5 = get_error(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(err1.item(), input.size(0))
top5.update(err5.item(), input.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Error@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(trainloader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
def validate():
global error_history
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
for i, (input, target) in enumerate(valloader):
# measure data loading time
data_time.update(time.time() - end)
input, target = input.to(device), target.to(device)
# compute output
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
err1, err5 = get_error(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(err1.item(), input.size(0))
top5.update(err5.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Error@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(valloader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
print(' * Error@1 {top1.avg:.3f} Error@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
# Record Top 1 for CIFAR
error_history.append(top1.avg)
if __name__ == '__main__':
filename = 'checkpoints/%s.t7' % args.save_file
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD([v for v in model.parameters() if v.requires_grad],
lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=epoch_step, gamma=args.lr_decay_ratio)
if not args.eval:
for epoch in range(args.no_epochs):
print('Epoch %d:' % epoch)
print('Learning rate is %s' % [v['lr'] for v in optimizer.param_groups][0])
# train for one epoch
train()
scheduler.step()
# # evaluate on validation set
validate()
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'error_history': error_history,
}, filename=filename)
else:
if not args.deploy:
model.load_state_dict(torch.load(filename)['state_dict'])
epoch = 0
validate()
gitextract_3q8rylo5/ ├── LICENSE ├── README.md ├── funcs.py ├── models/ │ ├── __init__.py │ ├── densenet.py │ └── wideresnet.py ├── prune.py └── train.py
SYMBOL INDEX (102 symbols across 5 files)
FILE: funcs.py
class Pruner (line 17) | class Pruner:
method __init__ (line 18) | def __init__(self, module_name='MaskBlock'):
method fisher_prune (line 24) | def fisher_prune(self, model, prune_every):
method fixed_prune (line 33) | def fixed_prune(self, model, ID):
method random_prune (line 37) | def random_prune(self, model):
method l1_prune (line 52) | def l1_prune(self, model, prune_every):
method prune (line 79) | def prune(self, model, feat_index):
method compress (line 102) | def compress(self, model):
method _get_fisher (line 107) | def _get_fisher(self, model):
method _get_masks (line 124) | def _get_masks(self, model):
method _update_cost (line 133) | def _update_cost(self, model):
method get_cost (line 138) | def get_cost(self, model):
method concat (line 147) | def concat(input):
function find (line 151) | def find(input):
function concat (line 156) | def concat(input):
function save_checkpoint (line 160) | def save_checkpoint(state, filename='checkpoint.pth.tar'):
function get_error (line 164) | def get_error(output, target, topk=(1,)):
class AverageMeter (line 180) | class AverageMeter(object):
method __init__ (line 183) | def __init__(self):
method reset (line 186) | def reset(self):
method update (line 192) | def update(self, val, n=1):
function get_inf_params (line 199) | def get_inf_params(net, verbose=True, sd=False):
function get_num_gen (line 228) | def get_num_gen(gen):
function is_pruned (line 232) | def is_pruned(layer):
function is_leaf (line 240) | def is_leaf(model):
function get_layer_info (line 244) | def get_layer_info(layer):
function get_layer_param (line 250) | def get_layer_param(model):
function measure_layer (line 255) | def measure_layer(layer, x):
function measure_model (line 325) | def measure_model(model, H, W):
FILE: models/densenet.py
class Identity (line 19) | class Identity(nn.Module):
method __init__ (line 20) | def __init__(self):
method forward (line 23) | def forward(self, x):
class Zero (line 27) | class Zero(nn.Module):
method __init__ (line 28) | def __init__(self):
method forward (line 31) | def forward(self, x):
class ZeroMake (line 35) | class ZeroMake(nn.Module):
method __init__ (line 36) | def __init__(self, channels, spatial):
method forward (line 41) | def forward(self, x):
class MaskBlock (line 46) | class MaskBlock(nn.Module):
method __init__ (line 47) | def __init__(self, nChannels, growthRate):
method forward (line 72) | def forward(self, x):
method _create_mask (line 86) | def _create_mask(self, x, out):
method _fisher (line 93) | def _fisher(self, _, __, grad_output):
method reset_fisher (line 101) | def reset_fisher(self):
method update (line 104) | def update(self, previous_mask):
method cost (line 108) | def cost(self):
method compress_weights (line 123) | def compress_weights(self):
class Bottleneck (line 155) | class Bottleneck(nn.Module):
method __init__ (line 156) | def __init__(self, nChannels, growthRate, width=1):
method forward (line 166) | def forward(self, x):
class SingleLayer (line 173) | class SingleLayer(nn.Module):
method __init__ (line 174) | def __init__(self, nChannels, growthRate):
method forward (line 180) | def forward(self, x):
class Transition (line 186) | class Transition(nn.Module):
method __init__ (line 187) | def __init__(self, nChannels, nOutChannels):
method forward (line 193) | def forward(self, x):
class DenseNet (line 199) | class DenseNet(nn.Module):
method __init__ (line 200) | def __init__(self, growthRate, depth, reduction, nClasses, bottleneck,...
method _make_dense (line 245) | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck,...
method forward (line 257) | def forward(self, x):
FILE: models/wideresnet.py
class Identity (line 7) | class Identity(nn.Module):
method __init__ (line 8) | def __init__(self):
method forward (line 11) | def forward(self, x):
class Zero (line 15) | class Zero(nn.Module):
method __init__ (line 16) | def __init__(self):
method forward (line 19) | def forward(self, x):
class ZeroMake (line 23) | class ZeroMake(nn.Module):
method __init__ (line 24) | def __init__(self, channels, spatial):
method forward (line 29) | def forward(self, x):
class BasicBlock (line 34) | class BasicBlock(nn.Module):
method __init__ (line 35) | def __init__(self, in_channels, out_channels, stride, dropRate=0.0):
method forward (line 50) | def forward(self, x):
class BottleBlock (line 63) | class BottleBlock(nn.Module):
method __init__ (line 64) | def __init__(self, in_channels, out_channels, mid_channels, stride, dr...
method forward (line 79) | def forward(self, x):
class MaskBlock (line 92) | class MaskBlock(nn.Module):
method __init__ (line 95) | def __init__(self, in_channels, out_channels, stride=1, dropRate=0.0):
method forward (line 126) | def forward(self, x):
method _create_mask (line 153) | def _create_mask(self, x, out):
method _fisher (line 159) | def _fisher(self, notused1, notused2, grad_output):
method reset_fisher (line 167) | def reset_fisher(self):
method cost (line 170) | def cost(self):
method compress_weights (line 192) | def compress_weights(self):
class NetworkBlock (line 226) | class NetworkBlock(nn.Module):
method __init__ (line 227) | def __init__(self, nb_layers, in_channels, out_channels, block, stride...
method _make_layer (line 231) | def _make_layer(self, block, in_channels, out_channels, nb_layers, str...
method forward (line 237) | def forward(self, x):
class NetworkBlockBottle (line 241) | class NetworkBlockBottle(nn.Module):
method __init__ (line 242) | def __init__(self, nb_layers, in_channels, out_channels, mid_channels,...
method _make_layer (line 246) | def _make_layer(self, block, in_channels, out_channels, mid_channels, ...
method forward (line 254) | def forward(self, x):
class WideResNet (line 258) | class WideResNet(nn.Module):
method __init__ (line 259) | def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0, ...
method forward (line 302) | def forward(self, x):
class WideResNetBottle (line 313) | class WideResNetBottle(nn.Module):
method __init__ (line 314) | def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0, ...
method forward (line 352) | def forward(self, x):
FILE: prune.py
function finetune (line 114) | def finetune():
function prune (line 175) | def prune():
function validate (line 189) | def validate():
FILE: train.py
function train (line 134) | def train():
function validate (line 187) | def validate():
Condensed preview — 8 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (60K chars).
[
{
"path": "LICENSE",
"chars": 1067,
"preview": "MIT License\n\nCopyright (c) 2019 BayesWatch\n\nPermission is hereby granted, free of charge, to any person obtaining a copy"
},
{
"path": "README.md",
"chars": 4058,
"preview": "# A Closer Look at Structured Pruning for Neural Network Compression\n\nCode used to reproduce experiments in https://arxi"
},
{
"path": "funcs.py",
"chars": 10715,
"preview": "import random\nimport numpy as np\nimport torchvision.transforms as transforms\nimport torchvision\nimport time\nfrom functoo"
},
{
"path": "models/__init__.py",
"chars": 51,
"preview": "from .wideresnet import *\nfrom .densenet import *\n\n"
},
{
"path": "models/densenet.py",
"chars": 9235,
"preview": "import torch\n\nimport torch.nn as nn\nimport torch.optim as optim\n\nimport torch.nn.functional as F\nfrom torch.autograd imp"
},
{
"path": "models/wideresnet.py",
"chars": 13490,
"preview": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Identity(nn.Module):\n def __in"
},
{
"path": "prune.py",
"chars": 9903,
"preview": "\"\"\"Pruning script\"\"\"\n\nimport argparse\nimport os\n\nimport torch.utils.model_zoo as model_zoo\n\nfrom funcs import *\nfrom mod"
},
{
"path": "train.py",
"chars": 9533,
"preview": "\"\"\"This script just trains models from scratch, to later be pruned\"\"\"\n\nimport argparse\nimport json\nimport os\nimport time"
}
]
About this extraction
This page contains the full source code of the BayesWatch/pytorch-prunes GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 8 files (56.7 KB), approximately 14.6k tokens, and a symbol index with 102 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.