Full Code of BayesWatch/pytorch-prunes for AI

master bc85a5c52865 cached
8 files
56.7 KB
14.6k tokens
102 symbols
1 requests
Download .txt
Repository: BayesWatch/pytorch-prunes
Branch: master
Commit: bc85a5c52865
Files: 8
Total size: 56.7 KB

Directory structure:
gitextract_3q8rylo5/

├── LICENSE
├── README.md
├── funcs.py
├── models/
│   ├── __init__.py
│   ├── densenet.py
│   └── wideresnet.py
├── prune.py
└── train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2019 BayesWatch

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# A Closer Look at Structured Pruning for Neural Network Compression

Code used to reproduce experiments in https://arxiv.org/abs/1810.04622.

To prune, we fill our networks with custom `MaskBlocks`, which are manipulated using `Pruner` in funcs.py. There will certainly be a better way to do this, but we leave this as an exercise to someone who can code much better than we can.
## Setup
This is best done in a clean conda environment:

```
conda create -n prunes python=3.6
conda activate prunes
conda install pytorch torchvision -c pytorch
```

## Repository layout
-`train.py`: contains all of the code for training large models from scratch and for training pruned models from scratch  
-`prune.py`: contains the code for pruning trained models   
-`funcs.py`: contains useful pruning functions and any functions we used commonly   

## CIFAR Experiments
First, you will need some initial models. 

To train a WRN-40-2:
```
python train.py --net='res' --depth=40 --width=2.0 --data_loc=<path-to-data> --save_file='res'
```

The default arguments of train.py are suitable for training WRNs. The following trains a DenseNet-BC-100 (k=12) with its default hyperparameters:

```
python train.py --net='dense' --depth=100 --data_loc=<path-to-data> --save_file='dense' --no_epochs 300 -b 64 --epoch_step '[150,225]' --weight_decay 0.0001 --lr_decay_ratio 0.1
```

These will automatically save checkpoints to the `checkpoints` folder.
 
 
 
### Pruning 
Once training is finished, we can prune our networks using prune.py (defaults are set to WRN pruning, so extra arguments are needed for DenseNets)  
```
python prune.py --net='res'   --data_loc=<path-to-data> --base_model='res' --save_file='res_fisher'
python prune.py --net='res'   --data_loc=<path-to-data> --l1_prune=True --base_model='res' --save_file='res_l1'

python prune.py --net='dense' --depth 100 --data_loc=<path-to-data> --base_model='dense' --save_file='dense_fisher' --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64 --no_epochs 2600
python prune.py --net='dense' --depth 100 --data_loc=<path-to-data> --l1_prune=True --base_model='dense' --save_file='dense_l1'  --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64  --no_epochs 2600

```
Note that the default is to perform Fisher pruning, so you don't need to pass a flag to use it.  
Once finished, we can train the pruned models from scratch, e.g.:  
```
python train.py --data_loc=<path-to-data> --net='res' --base_file='res_fisher_<N>_prunes' --deploy --mask=1 --save_file='res_fisher_<N>_prunes_scratch'
```

Each model can then be evaluated using:
```
python train.py --deploy --eval --data_loc=<path-to-data> --net='res' --mask=1 --base_file='res_fisher_<N>_prunes'
```


### Training Reduced models

This can be done by varying the input arguments to train.py. To reduce depth or width of a WRN, change the corresponding option:
```
python train.py --net='res' --depth=<REDUCED DEPTH> --width=<REDUCE WIDTH> --data_loc=<path-to-data> --save_file='res_reduced'
```

To add bottlenecks, use the following:

```
python train.py --net='res' --depth=40 --width=2.0 --data_loc=<path-to-data> --save_file='res_bottle' --bottle --bottle_mult <Z>
```

With DenseNets you can modify the `depth` or `growth`, or use `--bottle --bottle_mult <Z>` as above.


### Acknowledgements

[Jack Turner][jack] wrote the L1 stuff, and some other stuff for that matter.

Code has been liberally borrowed from many a repo, including, but not limited to:

```
https://github.com/xternalz/WideResNet-pytorch
https://github.com/bamos/densenet.pytorch
https://github.com/kuangliu/pytorch-cifar
https://github.com/ShichenLiu/CondenseNet
```
### Citing this work

If you would like to cite this work, please use the following bibtex entry:

```
@article{crowley2018pruning,
  title={A Closer Look at Structured Pruning for Neural Network Compression},
  author={Crowley, Elliot J and Turner, Jack and Storkey, Amos and O'Boyle, Michael},
  journal={arXiv preprint arXiv:1810.04622},
  year={2018},
  }
```
[jack]: https://github.com/jack-willturner


================================================
FILE: funcs.py
================================================
import random
import numpy as np
import torchvision.transforms as transforms
import torchvision
import time
from functools import reduce
from models import *
import random
import time
import operator
import torchvision
import torchvision.transforms as transforms

from models import *


class Pruner:
    def __init__(self, module_name='MaskBlock'):
        # First get vector of masks
        self.module_name = module_name
        self.masks = []
        self.prune_history = []

    def fisher_prune(self, model, prune_every):

        self._get_fisher(model)
        tot_loss = self.fisher.div(prune_every) + 1e6 * (1 - self.masks)  # dummy value for off masks
        print(len(tot_loss))
        min, argmin = torch.min(tot_loss, 0)
        self.prune(model, argmin.item())
        self.prune_history.append(argmin.item())

    def fixed_prune(self, model, ID):
        self.prune(model, ID)
        self.prune_history.append(ID)

    def random_prune(self, model):

        self._get_fisher(model)
        # Do this to update costs.
        masks = []
        for m in model.modules():
            if m._get_name() == self.module_name:
                masks.append(m.mask.detach())

        masks = self.concat(masks)
        masks_on = [i for i, v in enumerate(masks) if v == 1]
        random_pick = random.choice(masks_on)
        self.prune(model, random_pick)
        self.prune_history.append(random_pick)

    def l1_prune(self, model, prune_every):
        masks = []
        l1_norms = []

        for m in model.modules():
            if m._get_name() == 'MaskBlock':
                l1_norm = torch.sum(m.conv1.weight, (1, 2, 3)).detach().cpu().numpy()
                masks.append(m.mask.detach())
                l1_norms.append(l1_norm)

        masks = self.concat(masks)
        self.masks = masks
        l1_norms = np.concatenate(l1_norms)

        l1_norms_on = []
        for m, l in zip(masks, l1_norms):
            if m == 1:
                l1_norms_on.append(l)
            else:
                l1_norms_on.append(9999.)  # dummy value

        smallest_norm = min(l1_norms_on)
        pick = np.where(l1_norms == smallest_norm)[0][0]

        self.prune(model, pick)
        self.prune_history.append(pick)

    def prune(self, model, feat_index):
        print('Pruned %d out of %d channels so far' % (len(self.prune_history), len(self.masks)))
        if len(self.prune_history) > len(self.masks):
            raise Exception('Time to stop')
        """feat_index refers to the index of a feature map. This function modifies the mask to turn it off."""
        safe = 0
        running_index = 0
        for m in model.modules():
            if m._get_name() == self.module_name:
                mask_indices = range(running_index, running_index + len(m.mask))

                if feat_index in mask_indices:
                    print('Pruning channel %d' % feat_index)
                    local_index = mask_indices.index(feat_index)
                    m.mask[local_index] = 0
                    safe = 1
                    break
                else:
                    running_index += len(m.mask)
                    # print(running_index)
        if not safe:
            raise Exception('The provided index doesn''t correspond to any feature maps. This is bad.')

    def compress(self, model):
        for m in model.modules():
            if m._get_name() == 'MaskBlock':
                m.compress_weights()

    def _get_fisher(self, model):
        masks = []
        fisher = []

        self._update_cost(model)

        for m in model.modules():
            if m._get_name() == self.module_name:
                masks.append(m.mask.detach())
                fisher.append(m.running_fisher.detach())

                # Now clear the fisher cache
                m.reset_fisher()

        self.masks = self.concat(masks)
        self.fisher = self.concat(fisher)

    def _get_masks(self, model):
        masks = []

        for m in model.modules():
            if m._get_name() == self.module_name:
                masks.append(m.mask.detach())

        self.masks = self.concat(masks)

    def _update_cost(self, model):
        for m in model.modules():
            if m._get_name() == self.module_name:
                m.cost()

    def get_cost(self, model):
        params = 0
        for m in model.modules():
            if m._get_name() == self.module_name:
                m.cost()
                params += m.params
        return params

    @staticmethod
    def concat(input):
        return torch.cat([item for item in input])


def find(input):
    # Find as in MATLAB to find indices in a binary vector
    return [i for i, j in enumerate(input) if j]


def concat(input):
    return torch.cat([item for item in input])


def save_checkpoint(state, filename='checkpoint.pth.tar'):
    torch.save(state, filename)


def get_error(output, target, topk=(1,)):
    """Computes the error@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(100.0 - correct_k.mul_(100.0 / batch_size))
    return res


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def get_inf_params(net, verbose=True, sd=False):
    if sd:
        params = net
    else:
        params = net.state_dict()
    tot = 0
    conv_tot = 0
    for p in params:
        no = params[p].view(-1).__len__()

        if ('num_batches_tracked' not in p) and ('running' not in p) and ('mask' not in p):
            tot += no

            if verbose:
                print('%s has %d params' % (p, no))
        if 'conv' in p:
            conv_tot += no

    if verbose:
        print('Net has %d conv params' % conv_tot)
        print('Net has %d params in total' % tot)

    return tot


count_ops = 0
count_params = 0


def get_num_gen(gen):
    return sum(1 for x in gen)


def is_pruned(layer):
    try:
        layer.mask
        return True
    except AttributeError:
        return False


def is_leaf(model):
    return get_num_gen(model.children()) == 0


def get_layer_info(layer):
    layer_str = str(layer)
    type_name = layer_str[:layer_str.find('(')].strip()
    return type_name


def get_layer_param(model):
    return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()])


### The input batch size should be 1 to call this function
def measure_layer(layer, x):
    global count_ops, count_params
    delta_ops = 0
    delta_params = 0
    multi_add = 1
    type_name = get_layer_info(layer)

    ### ops_conv
    if type_name in ['Conv2d']:
        out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) /
                    layer.stride[0] + 1)
        out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) /
                    layer.stride[1] + 1)
        delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \
                    layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add
        delta_params = get_layer_param(layer)

    ### ops_learned_conv
    elif type_name in ['LearnedGroupConv']:
        measure_layer(layer.relu, x)
        measure_layer(layer.norm, x)
        conv = layer.conv
        out_h = int((x.size()[2] + 2 * conv.padding[0] - conv.kernel_size[0]) /
                    conv.stride[0] + 1)
        out_w = int((x.size()[3] + 2 * conv.padding[1] - conv.kernel_size[1]) /
                    conv.stride[1] + 1)
        delta_ops = conv.in_channels * conv.out_channels * conv.kernel_size[0] * \
                    conv.kernel_size[1] * out_h * out_w / layer.condense_factor * multi_add
        delta_params = get_layer_param(conv) / layer.condense_factor

    ### ops_nonlinearity
    elif type_name in ['ReLU']:
        delta_ops = x.numel()
        delta_params = get_layer_param(layer)

    ### ops_pooling
    elif type_name in ['AvgPool2d']:
        in_w = x.size()[2]
        kernel_ops = layer.kernel_size * layer.kernel_size
        out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
        out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
        delta_ops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops
        print(delta_ops)
        delta_params = get_layer_param(layer)

    elif type_name in ['AdaptiveAvgPool2d']:
        delta_ops = x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3]
        delta_params = get_layer_param(layer)

    ### ops_linear
    elif type_name in ['Linear']:
        weight_ops = layer.weight.numel() * multi_add
        bias_ops = layer.bias.numel()
        delta_ops = x.size()[0] * (weight_ops + bias_ops)
        delta_params = get_layer_param(layer)

    ### ops_nothing
    elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout']:
        delta_params = get_layer_param(layer)

    ### unknown layer type
    else:
        None
        # raise TypeError('unknown layer type: %s' % type_name)

    count_ops += delta_ops
    count_params += delta_params
    return


def measure_model(model, H, W):
    global count_ops, count_params
    count_ops = 0
    count_params = 0
    data = Variable(torch.zeros(1, 3, H, W))

    def should_measure(x):
        return is_leaf(x) or is_pruned(x)

    def modify_forward(model):
        for child in model.children():
            if should_measure(child):
                def new_forward(m):
                    def lambda_forward(x):
                        measure_layer(m, x)
                        return m.old_forward(x)

                    return lambda_forward

                child.old_forward = child.forward
                child.forward = new_forward(child)
            else:
                modify_forward(child)

    def restore_forward(model):
        for child in model.children():
            # leaf node
            if is_leaf(child) and hasattr(child, 'old_forward'):
                child.forward = child.old_forward
                child.old_forward = None
            else:
                restore_forward(child)

    modify_forward(model)
    model.forward(data)
    restore_forward(model)

    return count_ops, count_params


================================================
FILE: models/__init__.py
================================================
from .wideresnet import *
from .densenet import *



================================================
FILE: models/densenet.py
================================================
import torch

import torch.nn as nn
import torch.optim as optim

import torch.nn.functional as F
from torch.autograd import Variable

import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import torchvision.models as models

import sys
import math


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x


class Zero(nn.Module):
    def __init__(self):
        super(Zero, self).__init__()

    def forward(self, x):
        return x * 0


class ZeroMake(nn.Module):
    def __init__(self, channels, spatial):
        super(ZeroMake, self).__init__()
        self.spatial = spatial
        self.channels = channels

    def forward(self, x):
        return torch.zeros([x.size()[0], self.channels, x.size()[2] // self.spatial, x.size()[3] // self.spatial],
                           dtype=x.dtype, layout=x.layout, device=x.device)


class MaskBlock(nn.Module):
    def __init__(self, nChannels, growthRate):
        super(MaskBlock, self).__init__()
        interChannels = 4 * growthRate
        self.bn1 = nn.BatchNorm2d(nChannels)
        self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(interChannels)
        self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3,
                               padding=1, bias=False)

        self.activation = Identity()
        self.activation.register_backward_hook(self._fisher)
        self.register_buffer('mask', None)

        self.input_shape = None
        self.output_shape = None
        self.flops = None
        self.params = None
        self.in_channels = nChannels
        self.out_channels = growthRate
        self.stride = 1

        # Fisher method is called on backward passes
        self.running_fisher = 0

    def forward(self, x):
        out = self.conv1(F.relu(self.bn1(x)))
        out = F.relu(self.bn2(out))
        if self.mask is not None:
            out = out * self.mask[None, :, None, None]
        else:
            self._create_mask(x, out)
        out = self.activation(out)
        self.act = out

        out = self.conv2(out)
        out = torch.cat([x, out], 1)
        return out

    def _create_mask(self, x, out):
        """This takes an activation to generate the exact mask required. It also records input and output shapes
        for posterity."""
        self.mask = x.new_ones(out.shape[1])
        self.input_shape = x.size()
        self.output_shape = out.size()

    def _fisher(self, _, __, grad_output):
        act = self.act.detach()
        grad = grad_output[0].detach()

        g_nk = (act * grad).sum(-1).sum(-1)
        del_k = g_nk.pow(2).mean(0).mul(0.5)
        self.running_fisher += del_k

    def reset_fisher(self):
        self.running_fisher = 0 * self.running_fisher

    def update(self, previous_mask):
        # This is only required for non-modular nets.
        return None

    def cost(self):

        in_channels = self.in_channels
        out_channels = self.out_channels
        middle_channels = int(self.mask.sum().item())

        conv1_size = self.conv1.weight.size()
        conv2_size = self.conv2.weight.size()

        self.params = in_channels * middle_channels * conv1_size[2] * conv1_size[3] + middle_channels * out_channels * \
                      conv2_size[2] * conv2_size[3]

        self.params += 2 * in_channels + 2 * middle_channels


    def compress_weights(self):
        middle_dim = int(self.mask.sum().item())

        if middle_dim is not 0:
            conv1 = nn.Conv2d(self.in_channels, middle_dim, kernel_size=3, stride=1, bias=False)
            conv1.weight = nn.Parameter(self.conv1.weight[self.mask == 1, :, :, :])

            # Batch norm 2 changes
            bn2 = nn.BatchNorm2d(middle_dim)
            bn2.weight = nn.Parameter(self.bn2.weight[self.mask == 1])
            bn2.bias = nn.Parameter(self.bn2.bias[self.mask == 1])
            bn2.running_mean = self.bn2.running_mean[self.mask == 1]
            bn2.running_var = self.bn2.running_var[self.mask == 1]

            conv2 = nn.Conv2d(middle_dim, self.out_channels, kernel_size=3, stride=1, padding=1, bias=False)
            conv2.weight = nn.Parameter(self.conv2.weight[:, self.mask == 1, :, :])

        if middle_dim is 0:
            conv1 = Zero()
            bn2 = Zero()
            conv2 = ZeroMake(channels=self.out_channels, spatial=self.stride)

        self.conv1 = conv1
        self.conv2 = conv2
        self.bn2 = bn2

        if middle_dim is not 0:
            self.mask = torch.ones(middle_dim)
        else:
            self.mask = torch.ones(1)


class Bottleneck(nn.Module):
    def __init__(self, nChannels, growthRate, width=1):
        super(Bottleneck, self).__init__()
        interChannels = int(4 * growthRate * width)
        self.bn1 = nn.BatchNorm2d(nChannels)
        self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(interChannels)
        self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3,
                               padding=1, bias=False)

    def forward(self, x):
        out = self.conv1(F.relu(self.bn1(x)))
        out = self.conv2(F.relu(self.bn2(out)))
        out = torch.cat((x, out), 1)
        return out


class SingleLayer(nn.Module):
    def __init__(self, nChannels, growthRate):
        super(SingleLayer, self).__init__()
        self.bn1 = nn.BatchNorm2d(nChannels)
        self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3,
                               padding=1, bias=False)

    def forward(self, x):
        out = self.conv1(F.relu(self.bn1(x)))
        out = torch.cat((x, out), 1)
        return out


class Transition(nn.Module):
    def __init__(self, nChannels, nOutChannels):
        super(Transition, self).__init__()
        self.bn1 = nn.BatchNorm2d(nChannels)
        self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1,
                               bias=False)

    def forward(self, x):
        out = self.conv1(F.relu(self.bn1(x)))
        out = F.avg_pool2d(out, 2)
        return out


class DenseNet(nn.Module):
    def __init__(self, growthRate, depth, reduction, nClasses, bottleneck, mask=False, width=1.):
        super(DenseNet, self).__init__()

        nDenseBlocks = (depth - 4) // 3
        if bottleneck:
            nDenseBlocks //= 2

        nChannels = 2 * growthRate
        self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1,
                               bias=False)

        self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, mask, width)
        nChannels += nDenseBlocks * growthRate
        nOutChannels = int(math.floor(nChannels * reduction))
        self.trans1 = Transition(nChannels, nOutChannels)

        nChannels = nOutChannels
        self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, mask, width)
        nChannels += nDenseBlocks * growthRate
        nOutChannels = int(math.floor(nChannels * reduction))
        self.trans2 = Transition(nChannels, nOutChannels)

        nChannels = nOutChannels
        self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, mask, width)
        nChannels += nDenseBlocks * growthRate

        self.bn1 = nn.BatchNorm2d(nChannels)
        self.fc = nn.Linear(nChannels, nClasses)

        # Count params that don't exist in blocks (conv1, bn1, fc, trans1, trans2, trans3)
        self.fixed_params = len(self.conv1.weight.view(-1)) + len(self.bn1.weight) + len(self.bn1.bias) + \
                            len(self.fc.weight.view(-1)) + len(self.fc.bias)
        self.fixed_params += len(self.trans1.conv1.weight.view(-1)) + 2 * len(self.trans1.bn1.weight)
        self.fixed_params += len(self.trans2.conv1.weight.view(-1)) + 2 * len(self.trans2.bn1.weight)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck, mask=False, width=1):
        layers = []
        for i in range(int(nDenseBlocks)):
            if bottleneck and mask:
                layers.append(MaskBlock(nChannels, growthRate))
            elif bottleneck:
                layers.append(Bottleneck(nChannels, growthRate, width))
            else:
                layers.append(SingleLayer(nChannels, growthRate))
            nChannels += growthRate
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.trans1(self.dense1(out))
        out = self.trans2(self.dense2(out))
        out = self.dense3(out)
        out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8))
        out = self.fc(out)
        return out


================================================
FILE: models/wideresnet.py
================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x


class Zero(nn.Module):
    def __init__(self):
        super(Zero, self).__init__()

    def forward(self, x):
        return x * 0


class ZeroMake(nn.Module):
    def __init__(self, channels, spatial):
        super(ZeroMake, self).__init__()
        self.spatial = spatial
        self.channels = channels

    def forward(self, x):
        return torch.zeros([x.size()[0], self.channels, x.size()[2] // self.spatial, x.size()[3] // self.spatial],
                           dtype=x.dtype, layout=x.layout, device=x.device)


class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, dropRate=0.0):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_channels == out_channels)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
                                                                padding=0, bias=False) or None

    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)

        return torch.add(x if self.equalInOut else self.convShortcut(x), out)


class BottleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels, stride, dropRate=0.0):
        super(BottleBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(mid_channels)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_channels == out_channels)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
                                                                padding=0, bias=False) or None

    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)

        return torch.add(x if self.equalInOut else self.convShortcut(x), out)


class MaskBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, dropRate=0.0):

        super(MaskBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)

        self.droprate = dropRate
        self.equalInOut = (in_channels == out_channels)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
                                                                padding=0, bias=False) or None

        self.activation = Identity()
        self.activation.register_backward_hook(self._fisher)
        self.register_buffer('mask', None)

        self.input_shape = None
        self.output_shape = None
        self.flops = None
        self.params = None
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.got_shapes = False

        # Fisher method is called on backward passes
        self.running_fisher = 0

    def forward(self, x):

        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))

        out = self.conv1(out if self.equalInOut else x)

        out = self.relu2(self.bn2(out))

        if self.mask is not None:
            out = out * self.mask[None, :, None, None]

        else:
            self._create_mask(x, out)

        out = self.activation(out)
        self.act = out

        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)

        out = self.conv2(out)

        return torch.add(x if self.equalInOut else self.convShortcut(x), out)

    def _create_mask(self, x, out):

        self.mask = x.new_ones(out.shape[1])
        self.input_shape = x.size()
        self.output_shape = out.size()

    def _fisher(self, notused1, notused2, grad_output):
        act = self.act.detach()
        grad = grad_output[0].detach()

        g_nk = (act * grad).sum(-1).sum(-1)
        del_k = g_nk.pow(2).mean(0).mul(0.5)
        self.running_fisher += del_k

    def reset_fisher(self):
        self.running_fisher = 0 * self.running_fisher

    def cost(self):

        in_channels = self.in_channels
        out_channels = self.out_channels
        middle_channels = int(self.mask.sum().item())

        conv1_size = self.conv1.weight.size()
        conv2_size = self.conv2.weight.size()

        # convs
        self.params = in_channels * middle_channels * conv1_size[2] * conv1_size[3] + middle_channels * out_channels * \
                      conv2_size[2] * conv2_size[3]

        # batchnorms, assuming running stats are absorbed
        self.params += 2 * in_channels + 2 * middle_channels

        # skip
        if not self.equalInOut:
            self.params += in_channels * out_channels
        else:
            self.params += 0

    def compress_weights(self):

        middle_dim = int(self.mask.sum().item())
        print(middle_dim)

        if middle_dim is not 0:
            conv1 = nn.Conv2d(self.in_channels, middle_dim, kernel_size=3, stride=self.stride, padding=1, bias=False)
            conv1.weight = nn.Parameter(self.conv1.weight[self.mask == 1, :, :, :])

            # Batch norm 2 changes
            bn2 = nn.BatchNorm2d(middle_dim)
            bn2.weight = nn.Parameter(self.bn2.weight[self.mask == 1])
            bn2.bias = nn.Parameter(self.bn2.bias[self.mask == 1])
            bn2.running_mean = self.bn2.running_mean[self.mask == 1]
            bn2.running_var = self.bn2.running_var[self.mask == 1]

            conv2 = nn.Conv2d(middle_dim, self.out_channels, kernel_size=3, stride=1, padding=1, bias=False)
            conv2.weight = nn.Parameter(self.conv2.weight[:, self.mask == 1, :, :])

        if middle_dim is 0:
            conv1 = Zero()
            bn2 = Zero()
            conv2 = ZeroMake(channels=self.out_channels, spatial=self.stride)

        self.conv1 = conv1
        self.conv2 = conv2
        self.bn2 = bn2

        if middle_dim is not 0:
            self.mask = torch.ones(middle_dim)
        else:
            self.mask = torch.ones(1)


class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_channels, out_channels, block, stride, dropRate=0.0):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_channels, out_channels, nb_layers, stride, dropRate)

    def _make_layer(self, block, in_channels, out_channels, nb_layers, stride, dropRate):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_channels or out_channels, out_channels, i == 0 and stride or 1, dropRate))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)


class NetworkBlockBottle(nn.Module):
    def __init__(self, nb_layers, in_channels, out_channels, mid_channels, block, stride, dropRate=0.0):
        super(NetworkBlockBottle, self).__init__()
        self.layer = self._make_layer(block, in_channels, out_channels, mid_channels, nb_layers, stride, dropRate)

    def _make_layer(self, block, in_channels, out_channels, mid_channels, nb_layers, stride, dropRate):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(
                block(i == 0 and in_channels or out_channels, out_channels, mid_channels, i == 0 and stride or 1,
                      dropRate))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)


class WideResNet(nn.Module):
    def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0, mask=False):
        super(WideResNet, self).__init__()

        nChannels = [16, int(16 * widen_factor), int(32 * widen_factor), int(64 * widen_factor)]

        assert ((depth - 4) % 6 == 0)
        n = (depth - 4) / 6

        if mask == 1:
            block = MaskBlock
        else:
            block = BasicBlock

        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

        # Count params that don't exist in blocks (conv1, bn1, fc)
        self.fixed_params = len(self.conv1.weight.view(-1)) + len(self.bn1.weight) + len(self.bn1.bias) + \
                            len(self.fc.weight.view(-1)) + len(self.fc.bias)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        return self.fc(out)


class WideResNetBottle(nn.Module):
    def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0, bottle_mult=0.5):
        super(WideResNetBottle, self).__init__()

        nChannels = [16, int(16 * widen_factor), int(32 * widen_factor), int(64 * widen_factor)]

        assert ((depth - 4) % 6 == 0)
        n = (depth - 4) / 6

        block = BottleBlock

        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlockBottle(n, nChannels[0], nChannels[1], int(nChannels[1] * bottle_mult), block, 1,
                                         dropRate)
        # 2nd block
        self.block2 = NetworkBlockBottle(n, nChannels[1], nChannels[2], int(nChannels[2] * bottle_mult), block, 2,
                                         dropRate)
        # 3rd block
        self.block3 = NetworkBlockBottle(n, nChannels[2], nChannels[3], int(nChannels[3] * bottle_mult), block, 2,
                                         dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        return self.fc(out)


================================================
FILE: prune.py
================================================
"""Pruning script"""

import argparse
import os

import torch.utils.model_zoo as model_zoo

from funcs import *
from models import *


parser = argparse.ArgumentParser(description='Pruning')
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', help='number of data loading workers')
parser.add_argument('--GPU', default='0', type=str, help='GPU to use')
parser.add_argument('--save_file', default='wrn16_2_p', type=str, help='save file for checkpoints')
parser.add_argument('--print_freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
parser.add_argument('--resume_ckpt', default='checkpoint', type=str,
                    help='save file for resumed checkpoint')
parser.add_argument('--data_loc', default='/disk/scratch/datasets/cifar', type=str, help='where is the dataset')

# Learning specific arguments
parser.add_argument('--optimizer', choices=['sgd', 'adam'], default='sgd', type=str, help='optimizer')
parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('-lr', '--learning_rate', default=8e-4, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('-epochs', '--no_epochs', default=1300, type=int, metavar='epochs', help='no. epochs')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--weight_decay', '--wd', default=0.0005, type=float, metavar='W', help='weight decay')
parser.add_argument('--prune_every', default=100, type=int, help='prune every X steps')
parser.add_argument('--save_every', default=100, type=int, help='save model every X EPOCHS')
parser.add_argument('--random', default=False, type=bool, help='Prune at random')
parser.add_argument('--base_model', default='base_model', type=str, help='basemodel')
parser.add_argument('--val_every', default=1, type=int, help='val model every X EPOCHS')
parser.add_argument('--mask', default=1, type=int, help='Mask type')
parser.add_argument('--l1_prune', default=False, type=bool, help='Prune via l1 norm')
parser.add_argument('--net', default='dense', type=str, help='dense, res')
parser.add_argument('--width', default=2.0, type=float, metavar='D')
parser.add_argument('--depth', default=40, type=int, metavar='W')
parser.add_argument('--growth', default=12, type=int, help='growth rate of densenet')
parser.add_argument('--transition_rate', default=0.5, type=float, help='transition rate of densenet')

args = parser.parse_args()
print(args)
os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU

device = torch.device("cuda:%s" % '0' if torch.cuda.is_available() else "cpu")


if args.net == 'res':
    model = WideResNet(args.depth, args.width, mask=args.mask)
elif args.net =='dense':
    model = DenseNet(args.growth, args.depth, args.transition_rate, 10, True, mask=args.mask)

model.load_state_dict(torch.load('checkpoints/%s.t7' % args.base_model, map_location='cpu')['state_dict'], strict=True)

if args.resume:
    state = torch.load('checkpoints/%s.t7' % args.resume_ckpt, map_location='cpu')

    model = resume_from(state, model_type=args.net)
    error_history = state['error_history']
    prune_history = state['prune_history']
    flop_history = state['flop_history']
    param_history = state['param_history']
    start_epoch = state['epoch']

else:

    error_history = []
    prune_history = []
    param_history = []
    start_epoch = 0

model.to(device)

normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]
normTransform = transforms.Normalize(normMean, normStd)

print('==> Preparing data..')
num_classes = 10

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normTransform
])

transform_val = transforms.Compose([
    transforms.ToTensor(),
    normTransform

])

trainset = torchvision.datasets.CIFAR10(root=args.data_loc,
                                        train=True, download=True, transform=transform_train)
valset = torchvision.datasets.CIFAR10(root=args.data_loc,
                                      train=False, download=True, transform=transform_val)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True,
                                          num_workers=args.workers,
                                          pin_memory=False)
valloader = torch.utils.data.DataLoader(valset, batch_size=50, shuffle=False,
                                        num_workers=args.workers,
                                        pin_memory=False)

prune_count = 0
pruner = Pruner()
pruner.prune_history = prune_history

NO_STEPS = args.prune_every


def finetune():
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()

    dataiter = iter(trainloader)

    for i in range(0, NO_STEPS):

        try:
            input, target = dataiter.next()
        except StopIteration:
            dataiter = iter(trainloader)
            input, target = dataiter.next()

        # measure data loading time
        data_time.update(time.time() - end)

        input, target = input.to(device), target.to(device)

        # compute output
        output = model(input)

        loss = criterion(output, target)

        # measure accuracy and record loss
        err1, err5 = get_error(output.detach(), target, topk=(1, 5))

        losses.update(loss.item(), input.size(0))
        top1.update(err1.item(), input.size(0))
        top5.update(err5.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Prunepoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Error@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                epoch, i, NO_STEPS, batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5))




def prune():
    print('Pruning')
    if args.random is False:
        if args.l1_prune is False:
            print('fisher pruning')
            pruner.fisher_prune(model, prune_every=args.prune_every)
        else:
            print('l1 pruning')
            pruner.l1_prune(model, prune_every=args.prune_every)
    else:
        print('random pruning')
        pruner.random_prune(model, )


def validate():
    global error_history

    batch_time = AverageMeter()
    data_time = AverageMeter()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()

    for i, (input, target) in enumerate(valloader):

        # measure data loading time
        data_time.update(time.time() - end)

        input, target = input.to(device), target.to(device)

        # compute output
        output = model(input)

        loss = criterion(output, target)

        # measure accuracy and record loss
        err1, err5 = get_error(output.detach(), target, topk=(1, 5))

        losses.update(loss.item(), input.size(0))
        top1.update(err1.item(), input.size(0))
        top5.update(err5.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Error@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                i, len(valloader), batch_time=batch_time, loss=losses,
                top1=top1, top5=top5))

    print(' * Error@1 {top1.avg:.3f} Error@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))



    # Record Top 1 for CIFAR
    error_history.append(top1.avg)


if __name__ == '__main__':

    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD([v for v in model.parameters() if v.requires_grad],
                                lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)

    for epoch in range(start_epoch, args.no_epochs):

        print('Epoch %d:' % epoch)
        print('Learning rate is %s' % [v['lr'] for v in optimizer.param_groups][0])

        # finetune for one epoch
        finetune()
        # # evaluate on validation set
        if epoch != 0 and ((epoch % args.val_every == 0) or (epoch + 1 == args.no_epochs)):  # Save at last epoch!
            validate()

            # Error history is recorded in validate(). Record params here
            no_params = pruner.get_cost(model) + model.fixed_params
            param_history.append(no_params)

        # Save before pruning
        if epoch != 0 and ((epoch % args.save_every == 0) or (epoch + 1 == args.no_epochs)):  #
            filename = 'checkpoints/%s_%d_prunes.t7' % (args.save_file, epoch)
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'error_history': error_history,
                'param_history': param_history,
                'prune_history': pruner.prune_history,
            }, filename=filename)

        ## Prune
        prune()



================================================
FILE: train.py
================================================
"""This script just trains models from scratch, to later be pruned"""

import argparse
import json
import os
import time
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.model_zoo as model_zoo

from models import *

from funcs import *


parser = argparse.ArgumentParser(description='Pruning')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers')
parser.add_argument('--GPU', default='0', type=str, help='GPU to use')
parser.add_argument('--save_file', default='saveto', type=str, help='save file for checkpoints')
parser.add_argument('--base_file', default='bbb', type=str, help='base file for checkpoints')
parser.add_argument('--print_freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)')
parser.add_argument('--data_loc', default='/disk/scratch/datasets/cifar')

# Learning specific arguments
parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('-lr', '--learning_rate', default=.1, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('-epochs', '--no_epochs', default=200, type=int, metavar='epochs', help='no. epochs')
parser.add_argument('--epoch_step', default='[60,120,160]', type=str, help='json list with epochs to drop lr on')
parser.add_argument('--lr_decay_ratio', default=0.2, type=float, help='learning rate decay factor')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--weight_decay', '--wd', default=0.0005, type=float, metavar='W', help='weight decay')
parser.add_argument('--eval', '-e', action='store_true', help='resume from checkpoint')
parser.add_argument('--mask', '-m', type=int, help='mask mode', default=0)
parser.add_argument('--deploy', '-de', action='store_true', help='prune and deploy model')
parser.add_argument('--params_left', '-pl', default=0, type=int, help='prune til...')
parser.add_argument('--net', choices=['res', 'dense'], default='res')

# Net specific
parser.add_argument('--depth', '-d', default=40, type=int, metavar='D', help='depth of wideresnet/densenet')
parser.add_argument('--width', '-w', default=2.0, type=float, metavar='W', help='width of wideresnet')
parser.add_argument('--growth', default=12, type=int, help='growth rate of densenet')
parser.add_argument('--transition_rate', default=0.5, type=float, help='transition rate of densenet')


# Uniform bottlenecks
parser.add_argument('--bottle', action='store_true', help='Linearly scale bottlenecks')
parser.add_argument('--bottle_mult', default=0.5, type=float, help='bottleneck multiplier')


if not os.path.exists('checkpoints/'):
    os.makedirs('checkpoints/')

args = parser.parse_args()
print(args)
os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if args.net == 'res':
    if not args.bottle:
        model = WideResNet(args.depth, args.width, mask=args.mask)
    else:
        model = WideResNetBottle(args.depth, args.width, bottle_mult=args.bottle_mult)
elif args.net == 'dense':
    if not args.bottle:
        model = DenseNet(args.growth, args.depth, args.transition_rate, 10, True, mask=args.mask)
    else:
        model = DenseNet(args.growth, args.depth, args.transition_rate, 10, True, width=args.bottle_mult)

else:
    raise ValueError('pick a valid net')

pruner = Pruner()

if args.deploy:
    # Feed example to activate masks
    model(torch.rand(1, 3, 32, 32))
    SD = torch.load('checkpoints/%s.t7' % args.base_file)

    if not args.eval:

        pruner = Pruner()
        pruner._get_masks(model)

        for ii in SD['prune_history']:
            pruner.fixed_prune(model, ii)

    else:
        model.load_state_dict(SD['state_dict'])

pruner.compress(model)

get_inf_params(model)
time.sleep(1)
model.to(device)

normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                 std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

print('==> Preparing data..')
num_classes = 10

transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
                                      (4, 4, 4, 4), mode='reflect').squeeze()),
    transforms.ToPILImage(),
    transforms.RandomCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

transform_val = transforms.Compose([
    transforms.ToTensor(),
    normalize,

])

trainset = torchvision.datasets.CIFAR10(root=args.data_loc,
                                        train=True, download=True, transform=transform_train)
valset = torchvision.datasets.CIFAR10(root=args.data_loc,
                                      train=False, download=True, transform=transform_val)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True,
                                          num_workers=args.workers,
                                          pin_memory=False)
valloader = torch.utils.data.DataLoader(valset, batch_size=50, shuffle=False,
                                        num_workers=args.workers,
                                        pin_memory=False)

error_history = []
epoch_step = json.loads(args.epoch_step)


def train():
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()

    for i, (input, target) in enumerate(trainloader):

        # measure data loading time
        data_time.update(time.time() - end)

        input, target = input.to(device), target.to(device)

        # compute output
        output = model(input)

        loss = criterion(output, target)

        # measure accuracy and record loss
        err1, err5 = get_error(output.detach(), target, topk=(1, 5))

        losses.update(loss.item(), input.size(0))
        top1.update(err1.item(), input.size(0))
        top5.update(err5.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Error@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                epoch, i, len(trainloader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5))




def validate():
    global error_history

    batch_time = AverageMeter()
    data_time = AverageMeter()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()

    for i, (input, target) in enumerate(valloader):

        # measure data loading time
        data_time.update(time.time() - end)

        input, target = input.to(device), target.to(device)

        # compute output
        output = model(input)

        loss = criterion(output, target)

        # measure accuracy and record loss
        err1, err5 = get_error(output.detach(), target, topk=(1, 5))

        losses.update(loss.item(), input.size(0))
        top1.update(err1.item(), input.size(0))
        top5.update(err5.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Error@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                i, len(valloader), batch_time=batch_time, loss=losses,
                top1=top1, top5=top5))

    print(' * Error@1 {top1.avg:.3f} Error@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))


    # Record Top 1 for CIFAR
    error_history.append(top1.avg)


if __name__ == '__main__':

    filename = 'checkpoints/%s.t7' % args.save_file
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD([v for v in model.parameters() if v.requires_grad],
                                lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=epoch_step, gamma=args.lr_decay_ratio)

    if not args.eval:

        for epoch in range(args.no_epochs):

            print('Epoch %d:' % epoch)
            print('Learning rate is %s' % [v['lr'] for v in optimizer.param_groups][0])
            # train for one epoch
            train()
            scheduler.step()
            # # evaluate on validation set
            validate()

            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'error_history': error_history,
            }, filename=filename)

    else:
        if not args.deploy:
            model.load_state_dict(torch.load(filename)['state_dict'])
        epoch = 0
        validate()
Download .txt
gitextract_3q8rylo5/

├── LICENSE
├── README.md
├── funcs.py
├── models/
│   ├── __init__.py
│   ├── densenet.py
│   └── wideresnet.py
├── prune.py
└── train.py
Download .txt
SYMBOL INDEX (102 symbols across 5 files)

FILE: funcs.py
  class Pruner (line 17) | class Pruner:
    method __init__ (line 18) | def __init__(self, module_name='MaskBlock'):
    method fisher_prune (line 24) | def fisher_prune(self, model, prune_every):
    method fixed_prune (line 33) | def fixed_prune(self, model, ID):
    method random_prune (line 37) | def random_prune(self, model):
    method l1_prune (line 52) | def l1_prune(self, model, prune_every):
    method prune (line 79) | def prune(self, model, feat_index):
    method compress (line 102) | def compress(self, model):
    method _get_fisher (line 107) | def _get_fisher(self, model):
    method _get_masks (line 124) | def _get_masks(self, model):
    method _update_cost (line 133) | def _update_cost(self, model):
    method get_cost (line 138) | def get_cost(self, model):
    method concat (line 147) | def concat(input):
  function find (line 151) | def find(input):
  function concat (line 156) | def concat(input):
  function save_checkpoint (line 160) | def save_checkpoint(state, filename='checkpoint.pth.tar'):
  function get_error (line 164) | def get_error(output, target, topk=(1,)):
  class AverageMeter (line 180) | class AverageMeter(object):
    method __init__ (line 183) | def __init__(self):
    method reset (line 186) | def reset(self):
    method update (line 192) | def update(self, val, n=1):
  function get_inf_params (line 199) | def get_inf_params(net, verbose=True, sd=False):
  function get_num_gen (line 228) | def get_num_gen(gen):
  function is_pruned (line 232) | def is_pruned(layer):
  function is_leaf (line 240) | def is_leaf(model):
  function get_layer_info (line 244) | def get_layer_info(layer):
  function get_layer_param (line 250) | def get_layer_param(model):
  function measure_layer (line 255) | def measure_layer(layer, x):
  function measure_model (line 325) | def measure_model(model, H, W):

FILE: models/densenet.py
  class Identity (line 19) | class Identity(nn.Module):
    method __init__ (line 20) | def __init__(self):
    method forward (line 23) | def forward(self, x):
  class Zero (line 27) | class Zero(nn.Module):
    method __init__ (line 28) | def __init__(self):
    method forward (line 31) | def forward(self, x):
  class ZeroMake (line 35) | class ZeroMake(nn.Module):
    method __init__ (line 36) | def __init__(self, channels, spatial):
    method forward (line 41) | def forward(self, x):
  class MaskBlock (line 46) | class MaskBlock(nn.Module):
    method __init__ (line 47) | def __init__(self, nChannels, growthRate):
    method forward (line 72) | def forward(self, x):
    method _create_mask (line 86) | def _create_mask(self, x, out):
    method _fisher (line 93) | def _fisher(self, _, __, grad_output):
    method reset_fisher (line 101) | def reset_fisher(self):
    method update (line 104) | def update(self, previous_mask):
    method cost (line 108) | def cost(self):
    method compress_weights (line 123) | def compress_weights(self):
  class Bottleneck (line 155) | class Bottleneck(nn.Module):
    method __init__ (line 156) | def __init__(self, nChannels, growthRate, width=1):
    method forward (line 166) | def forward(self, x):
  class SingleLayer (line 173) | class SingleLayer(nn.Module):
    method __init__ (line 174) | def __init__(self, nChannels, growthRate):
    method forward (line 180) | def forward(self, x):
  class Transition (line 186) | class Transition(nn.Module):
    method __init__ (line 187) | def __init__(self, nChannels, nOutChannels):
    method forward (line 193) | def forward(self, x):
  class DenseNet (line 199) | class DenseNet(nn.Module):
    method __init__ (line 200) | def __init__(self, growthRate, depth, reduction, nClasses, bottleneck,...
    method _make_dense (line 245) | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck,...
    method forward (line 257) | def forward(self, x):

FILE: models/wideresnet.py
  class Identity (line 7) | class Identity(nn.Module):
    method __init__ (line 8) | def __init__(self):
    method forward (line 11) | def forward(self, x):
  class Zero (line 15) | class Zero(nn.Module):
    method __init__ (line 16) | def __init__(self):
    method forward (line 19) | def forward(self, x):
  class ZeroMake (line 23) | class ZeroMake(nn.Module):
    method __init__ (line 24) | def __init__(self, channels, spatial):
    method forward (line 29) | def forward(self, x):
  class BasicBlock (line 34) | class BasicBlock(nn.Module):
    method __init__ (line 35) | def __init__(self, in_channels, out_channels, stride, dropRate=0.0):
    method forward (line 50) | def forward(self, x):
  class BottleBlock (line 63) | class BottleBlock(nn.Module):
    method __init__ (line 64) | def __init__(self, in_channels, out_channels, mid_channels, stride, dr...
    method forward (line 79) | def forward(self, x):
  class MaskBlock (line 92) | class MaskBlock(nn.Module):
    method __init__ (line 95) | def __init__(self, in_channels, out_channels, stride=1, dropRate=0.0):
    method forward (line 126) | def forward(self, x):
    method _create_mask (line 153) | def _create_mask(self, x, out):
    method _fisher (line 159) | def _fisher(self, notused1, notused2, grad_output):
    method reset_fisher (line 167) | def reset_fisher(self):
    method cost (line 170) | def cost(self):
    method compress_weights (line 192) | def compress_weights(self):
  class NetworkBlock (line 226) | class NetworkBlock(nn.Module):
    method __init__ (line 227) | def __init__(self, nb_layers, in_channels, out_channels, block, stride...
    method _make_layer (line 231) | def _make_layer(self, block, in_channels, out_channels, nb_layers, str...
    method forward (line 237) | def forward(self, x):
  class NetworkBlockBottle (line 241) | class NetworkBlockBottle(nn.Module):
    method __init__ (line 242) | def __init__(self, nb_layers, in_channels, out_channels, mid_channels,...
    method _make_layer (line 246) | def _make_layer(self, block, in_channels, out_channels, mid_channels, ...
    method forward (line 254) | def forward(self, x):
  class WideResNet (line 258) | class WideResNet(nn.Module):
    method __init__ (line 259) | def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0, ...
    method forward (line 302) | def forward(self, x):
  class WideResNetBottle (line 313) | class WideResNetBottle(nn.Module):
    method __init__ (line 314) | def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0, ...
    method forward (line 352) | def forward(self, x):

FILE: prune.py
  function finetune (line 114) | def finetune():
  function prune (line 175) | def prune():
  function validate (line 189) | def validate():

FILE: train.py
  function train (line 134) | def train():
  function validate (line 187) | def validate():
Condensed preview — 8 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (60K chars).
[
  {
    "path": "LICENSE",
    "chars": 1067,
    "preview": "MIT License\n\nCopyright (c) 2019 BayesWatch\n\nPermission is hereby granted, free of charge, to any person obtaining a copy"
  },
  {
    "path": "README.md",
    "chars": 4058,
    "preview": "# A Closer Look at Structured Pruning for Neural Network Compression\n\nCode used to reproduce experiments in https://arxi"
  },
  {
    "path": "funcs.py",
    "chars": 10715,
    "preview": "import random\nimport numpy as np\nimport torchvision.transforms as transforms\nimport torchvision\nimport time\nfrom functoo"
  },
  {
    "path": "models/__init__.py",
    "chars": 51,
    "preview": "from .wideresnet import *\nfrom .densenet import *\n\n"
  },
  {
    "path": "models/densenet.py",
    "chars": 9235,
    "preview": "import torch\n\nimport torch.nn as nn\nimport torch.optim as optim\n\nimport torch.nn.functional as F\nfrom torch.autograd imp"
  },
  {
    "path": "models/wideresnet.py",
    "chars": 13490,
    "preview": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Identity(nn.Module):\n    def __in"
  },
  {
    "path": "prune.py",
    "chars": 9903,
    "preview": "\"\"\"Pruning script\"\"\"\n\nimport argparse\nimport os\n\nimport torch.utils.model_zoo as model_zoo\n\nfrom funcs import *\nfrom mod"
  },
  {
    "path": "train.py",
    "chars": 9533,
    "preview": "\"\"\"This script just trains models from scratch, to later be pruned\"\"\"\n\nimport argparse\nimport json\nimport os\nimport time"
  }
]

About this extraction

This page contains the full source code of the BayesWatch/pytorch-prunes GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 8 files (56.7 KB), approximately 14.6k tokens, and a symbol index with 102 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!