[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2019 BayesWatch\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# A Closer Look at Structured Pruning for Neural Network Compression\n\nCode used to reproduce experiments in https://arxiv.org/abs/1810.04622.\n\nTo prune, we fill our networks with custom `MaskBlocks`, which are manipulated using `Pruner` in funcs.py. There will certainly be a better way to do this, but we leave this as an exercise to someone who can code much better than we can.\n## Setup\nThis is best done in a clean conda environment:\n\n```\nconda create -n prunes python=3.6\nconda activate prunes\nconda install pytorch torchvision -c pytorch\n```\n\n## Repository layout\n-`train.py`: contains all of the code for training large models from scratch and for training pruned models from scratch  \n-`prune.py`: contains the code for pruning trained models   \n-`funcs.py`: contains useful pruning functions and any functions we used commonly   \n\n## CIFAR Experiments\nFirst, you will need some initial models. \n\nTo train a WRN-40-2:\n```\npython train.py --net='res' --depth=40 --width=2.0 --data_loc=<path-to-data> --save_file='res'\n```\n\nThe default arguments of train.py are suitable for training WRNs. The following trains a DenseNet-BC-100 (k=12) with its default hyperparameters:\n\n```\npython train.py --net='dense' --depth=100 --data_loc=<path-to-data> --save_file='dense' --no_epochs 300 -b 64 --epoch_step '[150,225]' --weight_decay 0.0001 --lr_decay_ratio 0.1\n```\n\nThese will automatically save checkpoints to the `checkpoints` folder.\n \n \n \n### Pruning \nOnce training is finished, we can prune our networks using prune.py (defaults are set to WRN pruning, so extra arguments are needed for DenseNets)  \n```\npython prune.py --net='res'   --data_loc=<path-to-data> --base_model='res' --save_file='res_fisher'\npython prune.py --net='res'   --data_loc=<path-to-data> --l1_prune=True --base_model='res' --save_file='res_l1'\n\npython prune.py --net='dense' --depth 100 --data_loc=<path-to-data> --base_model='dense' --save_file='dense_fisher' --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64 --no_epochs 2600\npython prune.py --net='dense' --depth 100 --data_loc=<path-to-data> --l1_prune=True --base_model='dense' --save_file='dense_l1'  --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64  --no_epochs 2600\n\n```\nNote that the default is to perform Fisher pruning, so you don't need to pass a flag to use it.  \nOnce finished, we can train the pruned models from scratch, e.g.:  \n```\npython train.py --data_loc=<path-to-data> --net='res' --base_file='res_fisher_<N>_prunes' --deploy --mask=1 --save_file='res_fisher_<N>_prunes_scratch'\n```\n\nEach model can then be evaluated using:\n```\npython train.py --deploy --eval --data_loc=<path-to-data> --net='res' --mask=1 --base_file='res_fisher_<N>_prunes'\n```\n\n\n### Training Reduced models\n\nThis can be done by varying the input arguments to train.py. To reduce depth or width of a WRN, change the corresponding option:\n```\npython train.py --net='res' --depth=<REDUCED DEPTH> --width=<REDUCE WIDTH> --data_loc=<path-to-data> --save_file='res_reduced'\n```\n\nTo add bottlenecks, use the following:\n\n```\npython train.py --net='res' --depth=40 --width=2.0 --data_loc=<path-to-data> --save_file='res_bottle' --bottle --bottle_mult <Z>\n```\n\nWith DenseNets you can modify the `depth` or `growth`, or use `--bottle --bottle_mult <Z>` as above.\n\n\n### Acknowledgements\n\n[Jack Turner][jack] wrote the L1 stuff, and some other stuff for that matter.\n\nCode has been liberally borrowed from many a repo, including, but not limited to:\n\n```\nhttps://github.com/xternalz/WideResNet-pytorch\nhttps://github.com/bamos/densenet.pytorch\nhttps://github.com/kuangliu/pytorch-cifar\nhttps://github.com/ShichenLiu/CondenseNet\n```\n### Citing this work\n\nIf you would like to cite this work, please use the following bibtex entry:\n\n```\n@article{crowley2018pruning,\n  title={A Closer Look at Structured Pruning for Neural Network Compression},\n  author={Crowley, Elliot J and Turner, Jack and Storkey, Amos and O'Boyle, Michael},\n  journal={arXiv preprint arXiv:1810.04622},\n  year={2018},\n  }\n```\n[jack]: https://github.com/jack-willturner\n"
  },
  {
    "path": "funcs.py",
    "content": "import random\nimport numpy as np\nimport torchvision.transforms as transforms\nimport torchvision\nimport time\nfrom functools import reduce\nfrom models import *\nimport random\nimport time\nimport operator\nimport torchvision\nimport torchvision.transforms as transforms\n\nfrom models import *\n\n\nclass Pruner:\n    def __init__(self, module_name='MaskBlock'):\n        # First get vector of masks\n        self.module_name = module_name\n        self.masks = []\n        self.prune_history = []\n\n    def fisher_prune(self, model, prune_every):\n\n        self._get_fisher(model)\n        tot_loss = self.fisher.div(prune_every) + 1e6 * (1 - self.masks)  # dummy value for off masks\n        print(len(tot_loss))\n        min, argmin = torch.min(tot_loss, 0)\n        self.prune(model, argmin.item())\n        self.prune_history.append(argmin.item())\n\n    def fixed_prune(self, model, ID):\n        self.prune(model, ID)\n        self.prune_history.append(ID)\n\n    def random_prune(self, model):\n\n        self._get_fisher(model)\n        # Do this to update costs.\n        masks = []\n        for m in model.modules():\n            if m._get_name() == self.module_name:\n                masks.append(m.mask.detach())\n\n        masks = self.concat(masks)\n        masks_on = [i for i, v in enumerate(masks) if v == 1]\n        random_pick = random.choice(masks_on)\n        self.prune(model, random_pick)\n        self.prune_history.append(random_pick)\n\n    def l1_prune(self, model, prune_every):\n        masks = []\n        l1_norms = []\n\n        for m in model.modules():\n            if m._get_name() == 'MaskBlock':\n                l1_norm = torch.sum(m.conv1.weight, (1, 2, 3)).detach().cpu().numpy()\n                masks.append(m.mask.detach())\n                l1_norms.append(l1_norm)\n\n        masks = self.concat(masks)\n        self.masks = masks\n        l1_norms = np.concatenate(l1_norms)\n\n        l1_norms_on = []\n        for m, l in zip(masks, l1_norms):\n            if m == 1:\n                l1_norms_on.append(l)\n            else:\n                l1_norms_on.append(9999.)  # dummy value\n\n        smallest_norm = min(l1_norms_on)\n        pick = np.where(l1_norms == smallest_norm)[0][0]\n\n        self.prune(model, pick)\n        self.prune_history.append(pick)\n\n    def prune(self, model, feat_index):\n        print('Pruned %d out of %d channels so far' % (len(self.prune_history), len(self.masks)))\n        if len(self.prune_history) > len(self.masks):\n            raise Exception('Time to stop')\n        \"\"\"feat_index refers to the index of a feature map. This function modifies the mask to turn it off.\"\"\"\n        safe = 0\n        running_index = 0\n        for m in model.modules():\n            if m._get_name() == self.module_name:\n                mask_indices = range(running_index, running_index + len(m.mask))\n\n                if feat_index in mask_indices:\n                    print('Pruning channel %d' % feat_index)\n                    local_index = mask_indices.index(feat_index)\n                    m.mask[local_index] = 0\n                    safe = 1\n                    break\n                else:\n                    running_index += len(m.mask)\n                    # print(running_index)\n        if not safe:\n            raise Exception('The provided index doesn''t correspond to any feature maps. This is bad.')\n\n    def compress(self, model):\n        for m in model.modules():\n            if m._get_name() == 'MaskBlock':\n                m.compress_weights()\n\n    def _get_fisher(self, model):\n        masks = []\n        fisher = []\n\n        self._update_cost(model)\n\n        for m in model.modules():\n            if m._get_name() == self.module_name:\n                masks.append(m.mask.detach())\n                fisher.append(m.running_fisher.detach())\n\n                # Now clear the fisher cache\n                m.reset_fisher()\n\n        self.masks = self.concat(masks)\n        self.fisher = self.concat(fisher)\n\n    def _get_masks(self, model):\n        masks = []\n\n        for m in model.modules():\n            if m._get_name() == self.module_name:\n                masks.append(m.mask.detach())\n\n        self.masks = self.concat(masks)\n\n    def _update_cost(self, model):\n        for m in model.modules():\n            if m._get_name() == self.module_name:\n                m.cost()\n\n    def get_cost(self, model):\n        params = 0\n        for m in model.modules():\n            if m._get_name() == self.module_name:\n                m.cost()\n                params += m.params\n        return params\n\n    @staticmethod\n    def concat(input):\n        return torch.cat([item for item in input])\n\n\ndef find(input):\n    # Find as in MATLAB to find indices in a binary vector\n    return [i for i, j in enumerate(input) if j]\n\n\ndef concat(input):\n    return torch.cat([item for item in input])\n\n\ndef save_checkpoint(state, filename='checkpoint.pth.tar'):\n    torch.save(state, filename)\n\n\ndef get_error(output, target, topk=(1,)):\n    \"\"\"Computes the error@k for the specified values of k\"\"\"\n    maxk = max(topk)\n    batch_size = target.size(0)\n\n    _, pred = output.topk(maxk, 1, True, True)\n    pred = pred.t()\n    correct = pred.eq(target.view(1, -1).expand_as(pred))\n\n    res = []\n    for k in topk:\n        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)\n        res.append(100.0 - correct_k.mul_(100.0 / batch_size))\n    return res\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef get_inf_params(net, verbose=True, sd=False):\n    if sd:\n        params = net\n    else:\n        params = net.state_dict()\n    tot = 0\n    conv_tot = 0\n    for p in params:\n        no = params[p].view(-1).__len__()\n\n        if ('num_batches_tracked' not in p) and ('running' not in p) and ('mask' not in p):\n            tot += no\n\n            if verbose:\n                print('%s has %d params' % (p, no))\n        if 'conv' in p:\n            conv_tot += no\n\n    if verbose:\n        print('Net has %d conv params' % conv_tot)\n        print('Net has %d params in total' % tot)\n\n    return tot\n\n\ncount_ops = 0\ncount_params = 0\n\n\ndef get_num_gen(gen):\n    return sum(1 for x in gen)\n\n\ndef is_pruned(layer):\n    try:\n        layer.mask\n        return True\n    except AttributeError:\n        return False\n\n\ndef is_leaf(model):\n    return get_num_gen(model.children()) == 0\n\n\ndef get_layer_info(layer):\n    layer_str = str(layer)\n    type_name = layer_str[:layer_str.find('(')].strip()\n    return type_name\n\n\ndef get_layer_param(model):\n    return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()])\n\n\n### The input batch size should be 1 to call this function\ndef measure_layer(layer, x):\n    global count_ops, count_params\n    delta_ops = 0\n    delta_params = 0\n    multi_add = 1\n    type_name = get_layer_info(layer)\n\n    ### ops_conv\n    if type_name in ['Conv2d']:\n        out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) /\n                    layer.stride[0] + 1)\n        out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) /\n                    layer.stride[1] + 1)\n        delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \\\n                    layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add\n        delta_params = get_layer_param(layer)\n\n    ### ops_learned_conv\n    elif type_name in ['LearnedGroupConv']:\n        measure_layer(layer.relu, x)\n        measure_layer(layer.norm, x)\n        conv = layer.conv\n        out_h = int((x.size()[2] + 2 * conv.padding[0] - conv.kernel_size[0]) /\n                    conv.stride[0] + 1)\n        out_w = int((x.size()[3] + 2 * conv.padding[1] - conv.kernel_size[1]) /\n                    conv.stride[1] + 1)\n        delta_ops = conv.in_channels * conv.out_channels * conv.kernel_size[0] * \\\n                    conv.kernel_size[1] * out_h * out_w / layer.condense_factor * multi_add\n        delta_params = get_layer_param(conv) / layer.condense_factor\n\n    ### ops_nonlinearity\n    elif type_name in ['ReLU']:\n        delta_ops = x.numel()\n        delta_params = get_layer_param(layer)\n\n    ### ops_pooling\n    elif type_name in ['AvgPool2d']:\n        in_w = x.size()[2]\n        kernel_ops = layer.kernel_size * layer.kernel_size\n        out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)\n        out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)\n        delta_ops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops\n        print(delta_ops)\n        delta_params = get_layer_param(layer)\n\n    elif type_name in ['AdaptiveAvgPool2d']:\n        delta_ops = x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3]\n        delta_params = get_layer_param(layer)\n\n    ### ops_linear\n    elif type_name in ['Linear']:\n        weight_ops = layer.weight.numel() * multi_add\n        bias_ops = layer.bias.numel()\n        delta_ops = x.size()[0] * (weight_ops + bias_ops)\n        delta_params = get_layer_param(layer)\n\n    ### ops_nothing\n    elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout']:\n        delta_params = get_layer_param(layer)\n\n    ### unknown layer type\n    else:\n        None\n        # raise TypeError('unknown layer type: %s' % type_name)\n\n    count_ops += delta_ops\n    count_params += delta_params\n    return\n\n\ndef measure_model(model, H, W):\n    global count_ops, count_params\n    count_ops = 0\n    count_params = 0\n    data = Variable(torch.zeros(1, 3, H, W))\n\n    def should_measure(x):\n        return is_leaf(x) or is_pruned(x)\n\n    def modify_forward(model):\n        for child in model.children():\n            if should_measure(child):\n                def new_forward(m):\n                    def lambda_forward(x):\n                        measure_layer(m, x)\n                        return m.old_forward(x)\n\n                    return lambda_forward\n\n                child.old_forward = child.forward\n                child.forward = new_forward(child)\n            else:\n                modify_forward(child)\n\n    def restore_forward(model):\n        for child in model.children():\n            # leaf node\n            if is_leaf(child) and hasattr(child, 'old_forward'):\n                child.forward = child.old_forward\n                child.old_forward = None\n            else:\n                restore_forward(child)\n\n    modify_forward(model)\n    model.forward(data)\n    restore_forward(model)\n\n    return count_ops, count_params\n"
  },
  {
    "path": "models/__init__.py",
    "content": "from .wideresnet import *\nfrom .densenet import *\n\n"
  },
  {
    "path": "models/densenet.py",
    "content": "import torch\n\nimport torch.nn as nn\nimport torch.optim as optim\n\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\n\nimport torchvision.datasets as dset\nimport torchvision.transforms as transforms\nfrom torch.utils.data import DataLoader\n\nimport torchvision.models as models\n\nimport sys\nimport math\n\n\nclass Identity(nn.Module):\n    def __init__(self):\n        super(Identity, self).__init__()\n\n    def forward(self, x):\n        return x\n\n\nclass Zero(nn.Module):\n    def __init__(self):\n        super(Zero, self).__init__()\n\n    def forward(self, x):\n        return x * 0\n\n\nclass ZeroMake(nn.Module):\n    def __init__(self, channels, spatial):\n        super(ZeroMake, self).__init__()\n        self.spatial = spatial\n        self.channels = channels\n\n    def forward(self, x):\n        return torch.zeros([x.size()[0], self.channels, x.size()[2] // self.spatial, x.size()[3] // self.spatial],\n                           dtype=x.dtype, layout=x.layout, device=x.device)\n\n\nclass MaskBlock(nn.Module):\n    def __init__(self, nChannels, growthRate):\n        super(MaskBlock, self).__init__()\n        interChannels = 4 * growthRate\n        self.bn1 = nn.BatchNorm2d(nChannels)\n        self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1,\n                               bias=False)\n        self.bn2 = nn.BatchNorm2d(interChannels)\n        self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3,\n                               padding=1, bias=False)\n\n        self.activation = Identity()\n        self.activation.register_backward_hook(self._fisher)\n        self.register_buffer('mask', None)\n\n        self.input_shape = None\n        self.output_shape = None\n        self.flops = None\n        self.params = None\n        self.in_channels = nChannels\n        self.out_channels = growthRate\n        self.stride = 1\n\n        # Fisher method is called on backward passes\n        self.running_fisher = 0\n\n    def forward(self, x):\n        out = self.conv1(F.relu(self.bn1(x)))\n        out = F.relu(self.bn2(out))\n        if self.mask is not None:\n            out = out * self.mask[None, :, None, None]\n        else:\n            self._create_mask(x, out)\n        out = self.activation(out)\n        self.act = out\n\n        out = self.conv2(out)\n        out = torch.cat([x, out], 1)\n        return out\n\n    def _create_mask(self, x, out):\n        \"\"\"This takes an activation to generate the exact mask required. It also records input and output shapes\n        for posterity.\"\"\"\n        self.mask = x.new_ones(out.shape[1])\n        self.input_shape = x.size()\n        self.output_shape = out.size()\n\n    def _fisher(self, _, __, grad_output):\n        act = self.act.detach()\n        grad = grad_output[0].detach()\n\n        g_nk = (act * grad).sum(-1).sum(-1)\n        del_k = g_nk.pow(2).mean(0).mul(0.5)\n        self.running_fisher += del_k\n\n    def reset_fisher(self):\n        self.running_fisher = 0 * self.running_fisher\n\n    def update(self, previous_mask):\n        # This is only required for non-modular nets.\n        return None\n\n    def cost(self):\n\n        in_channels = self.in_channels\n        out_channels = self.out_channels\n        middle_channels = int(self.mask.sum().item())\n\n        conv1_size = self.conv1.weight.size()\n        conv2_size = self.conv2.weight.size()\n\n        self.params = in_channels * middle_channels * conv1_size[2] * conv1_size[3] + middle_channels * out_channels * \\\n                      conv2_size[2] * conv2_size[3]\n\n        self.params += 2 * in_channels + 2 * middle_channels\n\n\n    def compress_weights(self):\n        middle_dim = int(self.mask.sum().item())\n\n        if middle_dim is not 0:\n            conv1 = nn.Conv2d(self.in_channels, middle_dim, kernel_size=3, stride=1, bias=False)\n            conv1.weight = nn.Parameter(self.conv1.weight[self.mask == 1, :, :, :])\n\n            # Batch norm 2 changes\n            bn2 = nn.BatchNorm2d(middle_dim)\n            bn2.weight = nn.Parameter(self.bn2.weight[self.mask == 1])\n            bn2.bias = nn.Parameter(self.bn2.bias[self.mask == 1])\n            bn2.running_mean = self.bn2.running_mean[self.mask == 1]\n            bn2.running_var = self.bn2.running_var[self.mask == 1]\n\n            conv2 = nn.Conv2d(middle_dim, self.out_channels, kernel_size=3, stride=1, padding=1, bias=False)\n            conv2.weight = nn.Parameter(self.conv2.weight[:, self.mask == 1, :, :])\n\n        if middle_dim is 0:\n            conv1 = Zero()\n            bn2 = Zero()\n            conv2 = ZeroMake(channels=self.out_channels, spatial=self.stride)\n\n        self.conv1 = conv1\n        self.conv2 = conv2\n        self.bn2 = bn2\n\n        if middle_dim is not 0:\n            self.mask = torch.ones(middle_dim)\n        else:\n            self.mask = torch.ones(1)\n\n\nclass Bottleneck(nn.Module):\n    def __init__(self, nChannels, growthRate, width=1):\n        super(Bottleneck, self).__init__()\n        interChannels = int(4 * growthRate * width)\n        self.bn1 = nn.BatchNorm2d(nChannels)\n        self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1,\n                               bias=False)\n        self.bn2 = nn.BatchNorm2d(interChannels)\n        self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3,\n                               padding=1, bias=False)\n\n    def forward(self, x):\n        out = self.conv1(F.relu(self.bn1(x)))\n        out = self.conv2(F.relu(self.bn2(out)))\n        out = torch.cat((x, out), 1)\n        return out\n\n\nclass SingleLayer(nn.Module):\n    def __init__(self, nChannels, growthRate):\n        super(SingleLayer, self).__init__()\n        self.bn1 = nn.BatchNorm2d(nChannels)\n        self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3,\n                               padding=1, bias=False)\n\n    def forward(self, x):\n        out = self.conv1(F.relu(self.bn1(x)))\n        out = torch.cat((x, out), 1)\n        return out\n\n\nclass Transition(nn.Module):\n    def __init__(self, nChannels, nOutChannels):\n        super(Transition, self).__init__()\n        self.bn1 = nn.BatchNorm2d(nChannels)\n        self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1,\n                               bias=False)\n\n    def forward(self, x):\n        out = self.conv1(F.relu(self.bn1(x)))\n        out = F.avg_pool2d(out, 2)\n        return out\n\n\nclass DenseNet(nn.Module):\n    def __init__(self, growthRate, depth, reduction, nClasses, bottleneck, mask=False, width=1.):\n        super(DenseNet, self).__init__()\n\n        nDenseBlocks = (depth - 4) // 3\n        if bottleneck:\n            nDenseBlocks //= 2\n\n        nChannels = 2 * growthRate\n        self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1,\n                               bias=False)\n\n        self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, mask, width)\n        nChannels += nDenseBlocks * growthRate\n        nOutChannels = int(math.floor(nChannels * reduction))\n        self.trans1 = Transition(nChannels, nOutChannels)\n\n        nChannels = nOutChannels\n        self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, mask, width)\n        nChannels += nDenseBlocks * growthRate\n        nOutChannels = int(math.floor(nChannels * reduction))\n        self.trans2 = Transition(nChannels, nOutChannels)\n\n        nChannels = nOutChannels\n        self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, mask, width)\n        nChannels += nDenseBlocks * growthRate\n\n        self.bn1 = nn.BatchNorm2d(nChannels)\n        self.fc = nn.Linear(nChannels, nClasses)\n\n        # Count params that don't exist in blocks (conv1, bn1, fc, trans1, trans2, trans3)\n        self.fixed_params = len(self.conv1.weight.view(-1)) + len(self.bn1.weight) + len(self.bn1.bias) + \\\n                            len(self.fc.weight.view(-1)) + len(self.fc.bias)\n        self.fixed_params += len(self.trans1.conv1.weight.view(-1)) + 2 * len(self.trans1.bn1.weight)\n        self.fixed_params += len(self.trans2.conv1.weight.view(-1)) + 2 * len(self.trans2.bn1.weight)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2. / n))\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n            elif isinstance(m, nn.Linear):\n                m.bias.data.zero_()\n\n    def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck, mask=False, width=1):\n        layers = []\n        for i in range(int(nDenseBlocks)):\n            if bottleneck and mask:\n                layers.append(MaskBlock(nChannels, growthRate))\n            elif bottleneck:\n                layers.append(Bottleneck(nChannels, growthRate, width))\n            else:\n                layers.append(SingleLayer(nChannels, growthRate))\n            nChannels += growthRate\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.trans1(self.dense1(out))\n        out = self.trans2(self.dense2(out))\n        out = self.dense3(out)\n        out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8))\n        out = self.fc(out)\n        return out\n"
  },
  {
    "path": "models/wideresnet.py",
    "content": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Identity(nn.Module):\n    def __init__(self):\n        super(Identity, self).__init__()\n\n    def forward(self, x):\n        return x\n\n\nclass Zero(nn.Module):\n    def __init__(self):\n        super(Zero, self).__init__()\n\n    def forward(self, x):\n        return x * 0\n\n\nclass ZeroMake(nn.Module):\n    def __init__(self, channels, spatial):\n        super(ZeroMake, self).__init__()\n        self.spatial = spatial\n        self.channels = channels\n\n    def forward(self, x):\n        return torch.zeros([x.size()[0], self.channels, x.size()[2] // self.spatial, x.size()[3] // self.spatial],\n                           dtype=x.dtype, layout=x.layout, device=x.device)\n\n\nclass BasicBlock(nn.Module):\n    def __init__(self, in_channels, out_channels, stride, dropRate=0.0):\n        super(BasicBlock, self).__init__()\n        self.bn1 = nn.BatchNorm2d(in_channels)\n        self.relu1 = nn.ReLU(inplace=True)\n        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(out_channels)\n        self.relu2 = nn.ReLU(inplace=True)\n        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1,\n                               padding=1, bias=False)\n        self.droprate = dropRate\n        self.equalInOut = (in_channels == out_channels)\n        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,\n                                                                padding=0, bias=False) or None\n\n    def forward(self, x):\n        if not self.equalInOut:\n            x = self.relu1(self.bn1(x))\n        else:\n            out = self.relu1(self.bn1(x))\n        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))\n        if self.droprate > 0:\n            out = F.dropout(out, p=self.droprate, training=self.training)\n        out = self.conv2(out)\n\n        return torch.add(x if self.equalInOut else self.convShortcut(x), out)\n\n\nclass BottleBlock(nn.Module):\n    def __init__(self, in_channels, out_channels, mid_channels, stride, dropRate=0.0):\n        super(BottleBlock, self).__init__()\n        self.bn1 = nn.BatchNorm2d(in_channels)\n        self.relu1 = nn.ReLU(inplace=True)\n        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(mid_channels)\n        self.relu2 = nn.ReLU(inplace=True)\n        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1,\n                               padding=1, bias=False)\n        self.droprate = dropRate\n        self.equalInOut = (in_channels == out_channels)\n        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,\n                                                                padding=0, bias=False) or None\n\n    def forward(self, x):\n        if not self.equalInOut:\n            x = self.relu1(self.bn1(x))\n        else:\n            out = self.relu1(self.bn1(x))\n        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))\n        if self.droprate > 0:\n            out = F.dropout(out, p=self.droprate, training=self.training)\n        out = self.conv2(out)\n\n        return torch.add(x if self.equalInOut else self.convShortcut(x), out)\n\n\nclass MaskBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, in_channels, out_channels, stride=1, dropRate=0.0):\n\n        super(MaskBlock, self).__init__()\n        self.bn1 = nn.BatchNorm2d(in_channels)\n        self.relu1 = nn.ReLU(inplace=True)\n        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(out_channels)\n        self.relu2 = nn.ReLU(inplace=True)\n        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)\n\n        self.droprate = dropRate\n        self.equalInOut = (in_channels == out_channels)\n        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,\n                                                                padding=0, bias=False) or None\n\n        self.activation = Identity()\n        self.activation.register_backward_hook(self._fisher)\n        self.register_buffer('mask', None)\n\n        self.input_shape = None\n        self.output_shape = None\n        self.flops = None\n        self.params = None\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.stride = stride\n        self.got_shapes = False\n\n        # Fisher method is called on backward passes\n        self.running_fisher = 0\n\n    def forward(self, x):\n\n        if not self.equalInOut:\n            x = self.relu1(self.bn1(x))\n        else:\n            out = self.relu1(self.bn1(x))\n\n        out = self.conv1(out if self.equalInOut else x)\n\n        out = self.relu2(self.bn2(out))\n\n        if self.mask is not None:\n            out = out * self.mask[None, :, None, None]\n\n        else:\n            self._create_mask(x, out)\n\n        out = self.activation(out)\n        self.act = out\n\n        if self.droprate > 0:\n            out = F.dropout(out, p=self.droprate, training=self.training)\n\n        out = self.conv2(out)\n\n        return torch.add(x if self.equalInOut else self.convShortcut(x), out)\n\n    def _create_mask(self, x, out):\n\n        self.mask = x.new_ones(out.shape[1])\n        self.input_shape = x.size()\n        self.output_shape = out.size()\n\n    def _fisher(self, notused1, notused2, grad_output):\n        act = self.act.detach()\n        grad = grad_output[0].detach()\n\n        g_nk = (act * grad).sum(-1).sum(-1)\n        del_k = g_nk.pow(2).mean(0).mul(0.5)\n        self.running_fisher += del_k\n\n    def reset_fisher(self):\n        self.running_fisher = 0 * self.running_fisher\n\n    def cost(self):\n\n        in_channels = self.in_channels\n        out_channels = self.out_channels\n        middle_channels = int(self.mask.sum().item())\n\n        conv1_size = self.conv1.weight.size()\n        conv2_size = self.conv2.weight.size()\n\n        # convs\n        self.params = in_channels * middle_channels * conv1_size[2] * conv1_size[3] + middle_channels * out_channels * \\\n                      conv2_size[2] * conv2_size[3]\n\n        # batchnorms, assuming running stats are absorbed\n        self.params += 2 * in_channels + 2 * middle_channels\n\n        # skip\n        if not self.equalInOut:\n            self.params += in_channels * out_channels\n        else:\n            self.params += 0\n\n    def compress_weights(self):\n\n        middle_dim = int(self.mask.sum().item())\n        print(middle_dim)\n\n        if middle_dim is not 0:\n            conv1 = nn.Conv2d(self.in_channels, middle_dim, kernel_size=3, stride=self.stride, padding=1, bias=False)\n            conv1.weight = nn.Parameter(self.conv1.weight[self.mask == 1, :, :, :])\n\n            # Batch norm 2 changes\n            bn2 = nn.BatchNorm2d(middle_dim)\n            bn2.weight = nn.Parameter(self.bn2.weight[self.mask == 1])\n            bn2.bias = nn.Parameter(self.bn2.bias[self.mask == 1])\n            bn2.running_mean = self.bn2.running_mean[self.mask == 1]\n            bn2.running_var = self.bn2.running_var[self.mask == 1]\n\n            conv2 = nn.Conv2d(middle_dim, self.out_channels, kernel_size=3, stride=1, padding=1, bias=False)\n            conv2.weight = nn.Parameter(self.conv2.weight[:, self.mask == 1, :, :])\n\n        if middle_dim is 0:\n            conv1 = Zero()\n            bn2 = Zero()\n            conv2 = ZeroMake(channels=self.out_channels, spatial=self.stride)\n\n        self.conv1 = conv1\n        self.conv2 = conv2\n        self.bn2 = bn2\n\n        if middle_dim is not 0:\n            self.mask = torch.ones(middle_dim)\n        else:\n            self.mask = torch.ones(1)\n\n\nclass NetworkBlock(nn.Module):\n    def __init__(self, nb_layers, in_channels, out_channels, block, stride, dropRate=0.0):\n        super(NetworkBlock, self).__init__()\n        self.layer = self._make_layer(block, in_channels, out_channels, nb_layers, stride, dropRate)\n\n    def _make_layer(self, block, in_channels, out_channels, nb_layers, stride, dropRate):\n        layers = []\n        for i in range(int(nb_layers)):\n            layers.append(block(i == 0 and in_channels or out_channels, out_channels, i == 0 and stride or 1, dropRate))\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        return self.layer(x)\n\n\nclass NetworkBlockBottle(nn.Module):\n    def __init__(self, nb_layers, in_channels, out_channels, mid_channels, block, stride, dropRate=0.0):\n        super(NetworkBlockBottle, self).__init__()\n        self.layer = self._make_layer(block, in_channels, out_channels, mid_channels, nb_layers, stride, dropRate)\n\n    def _make_layer(self, block, in_channels, out_channels, mid_channels, nb_layers, stride, dropRate):\n        layers = []\n        for i in range(int(nb_layers)):\n            layers.append(\n                block(i == 0 and in_channels or out_channels, out_channels, mid_channels, i == 0 and stride or 1,\n                      dropRate))\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        return self.layer(x)\n\n\nclass WideResNet(nn.Module):\n    def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0, mask=False):\n        super(WideResNet, self).__init__()\n\n        nChannels = [16, int(16 * widen_factor), int(32 * widen_factor), int(64 * widen_factor)]\n\n        assert ((depth - 4) % 6 == 0)\n        n = (depth - 4) / 6\n\n        if mask == 1:\n            block = MaskBlock\n        else:\n            block = BasicBlock\n\n        # 1st conv before any network block\n        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,\n                               padding=1, bias=False)\n        # 1st block\n        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)\n        # 2nd block\n        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)\n        # 3rd block\n        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)\n        # global average pooling and classifier\n        self.bn1 = nn.BatchNorm2d(nChannels[3])\n        self.relu = nn.ReLU(inplace=True)\n        self.fc = nn.Linear(nChannels[3], num_classes)\n        self.nChannels = nChannels[3]\n\n        # Count params that don't exist in blocks (conv1, bn1, fc)\n        self.fixed_params = len(self.conv1.weight.view(-1)) + len(self.bn1.weight) + len(self.bn1.bias) + \\\n                            len(self.fc.weight.view(-1)) + len(self.fc.bias)\n\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2. / n))\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n            elif isinstance(m, nn.Linear):\n                m.bias.data.zero_()\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.block1(out)\n        out = self.block2(out)\n        out = self.block3(out)\n        out = self.relu(self.bn1(out))\n        out = F.avg_pool2d(out, 8)\n        out = out.view(-1, self.nChannels)\n        return self.fc(out)\n\n\nclass WideResNetBottle(nn.Module):\n    def __init__(self, depth, widen_factor, num_classes=10, dropRate=0.0, bottle_mult=0.5):\n        super(WideResNetBottle, self).__init__()\n\n        nChannels = [16, int(16 * widen_factor), int(32 * widen_factor), int(64 * widen_factor)]\n\n        assert ((depth - 4) % 6 == 0)\n        n = (depth - 4) / 6\n\n        block = BottleBlock\n\n        # 1st conv before any network block\n        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,\n                               padding=1, bias=False)\n        # 1st block\n        self.block1 = NetworkBlockBottle(n, nChannels[0], nChannels[1], int(nChannels[1] * bottle_mult), block, 1,\n                                         dropRate)\n        # 2nd block\n        self.block2 = NetworkBlockBottle(n, nChannels[1], nChannels[2], int(nChannels[2] * bottle_mult), block, 2,\n                                         dropRate)\n        # 3rd block\n        self.block3 = NetworkBlockBottle(n, nChannels[2], nChannels[3], int(nChannels[3] * bottle_mult), block, 2,\n                                         dropRate)\n        # global average pooling and classifier\n        self.bn1 = nn.BatchNorm2d(nChannels[3])\n        self.relu = nn.ReLU(inplace=True)\n        self.fc = nn.Linear(nChannels[3], num_classes)\n        self.nChannels = nChannels[3]\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2. / n))\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n            elif isinstance(m, nn.Linear):\n                m.bias.data.zero_()\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.block1(out)\n        out = self.block2(out)\n        out = self.block3(out)\n        out = self.relu(self.bn1(out))\n        out = F.avg_pool2d(out, 8)\n        out = out.view(-1, self.nChannels)\n        return self.fc(out)\n"
  },
  {
    "path": "prune.py",
    "content": "\"\"\"Pruning script\"\"\"\n\nimport argparse\nimport os\n\nimport torch.utils.model_zoo as model_zoo\n\nfrom funcs import *\nfrom models import *\n\n\nparser = argparse.ArgumentParser(description='Pruning')\nparser.add_argument('-j', '--workers', default=0, type=int, metavar='N', help='number of data loading workers')\nparser.add_argument('--GPU', default='0', type=str, help='GPU to use')\nparser.add_argument('--save_file', default='wrn16_2_p', type=str, help='save file for checkpoints')\nparser.add_argument('--print_freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)')\nparser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')\nparser.add_argument('--resume_ckpt', default='checkpoint', type=str,\n                    help='save file for resumed checkpoint')\nparser.add_argument('--data_loc', default='/disk/scratch/datasets/cifar', type=str, help='where is the dataset')\n\n# Learning specific arguments\nparser.add_argument('--optimizer', choices=['sgd', 'adam'], default='sgd', type=str, help='optimizer')\nparser.add_argument('-b', '--batch_size', default=128, type=int, metavar='N', help='mini-batch size (default: 256)')\nparser.add_argument('-lr', '--learning_rate', default=8e-4, type=float, metavar='LR', help='initial learning rate')\nparser.add_argument('-epochs', '--no_epochs', default=1300, type=int, metavar='epochs', help='no. epochs')\nparser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')\nparser.add_argument('--weight_decay', '--wd', default=0.0005, type=float, metavar='W', help='weight decay')\nparser.add_argument('--prune_every', default=100, type=int, help='prune every X steps')\nparser.add_argument('--save_every', default=100, type=int, help='save model every X EPOCHS')\nparser.add_argument('--random', default=False, type=bool, help='Prune at random')\nparser.add_argument('--base_model', default='base_model', type=str, help='basemodel')\nparser.add_argument('--val_every', default=1, type=int, help='val model every X EPOCHS')\nparser.add_argument('--mask', default=1, type=int, help='Mask type')\nparser.add_argument('--l1_prune', default=False, type=bool, help='Prune via l1 norm')\nparser.add_argument('--net', default='dense', type=str, help='dense, res')\nparser.add_argument('--width', default=2.0, type=float, metavar='D')\nparser.add_argument('--depth', default=40, type=int, metavar='W')\nparser.add_argument('--growth', default=12, type=int, help='growth rate of densenet')\nparser.add_argument('--transition_rate', default=0.5, type=float, help='transition rate of densenet')\n\nargs = parser.parse_args()\nprint(args)\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = args.GPU\n\ndevice = torch.device(\"cuda:%s\" % '0' if torch.cuda.is_available() else \"cpu\")\n\n\nif args.net == 'res':\n    model = WideResNet(args.depth, args.width, mask=args.mask)\nelif args.net =='dense':\n    model = DenseNet(args.growth, args.depth, args.transition_rate, 10, True, mask=args.mask)\n\nmodel.load_state_dict(torch.load('checkpoints/%s.t7' % args.base_model, map_location='cpu')['state_dict'], strict=True)\n\nif args.resume:\n    state = torch.load('checkpoints/%s.t7' % args.resume_ckpt, map_location='cpu')\n\n    model = resume_from(state, model_type=args.net)\n    error_history = state['error_history']\n    prune_history = state['prune_history']\n    flop_history = state['flop_history']\n    param_history = state['param_history']\n    start_epoch = state['epoch']\n\nelse:\n\n    error_history = []\n    prune_history = []\n    param_history = []\n    start_epoch = 0\n\nmodel.to(device)\n\nnormMean = [0.49139968, 0.48215827, 0.44653124]\nnormStd = [0.24703233, 0.24348505, 0.26158768]\nnormTransform = transforms.Normalize(normMean, normStd)\n\nprint('==> Preparing data..')\nnum_classes = 10\n\ntransform_train = transforms.Compose([\n    transforms.RandomCrop(32, padding=4),\n    transforms.RandomHorizontalFlip(),\n    transforms.ToTensor(),\n    normTransform\n])\n\ntransform_val = transforms.Compose([\n    transforms.ToTensor(),\n    normTransform\n\n])\n\ntrainset = torchvision.datasets.CIFAR10(root=args.data_loc,\n                                        train=True, download=True, transform=transform_train)\nvalset = torchvision.datasets.CIFAR10(root=args.data_loc,\n                                      train=False, download=True, transform=transform_val)\n\ntrainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True,\n                                          num_workers=args.workers,\n                                          pin_memory=False)\nvalloader = torch.utils.data.DataLoader(valset, batch_size=50, shuffle=False,\n                                        num_workers=args.workers,\n                                        pin_memory=False)\n\nprune_count = 0\npruner = Pruner()\npruner.prune_history = prune_history\n\nNO_STEPS = args.prune_every\n\n\ndef finetune():\n    batch_time = AverageMeter()\n    data_time = AverageMeter()\n    losses = AverageMeter()\n    top1 = AverageMeter()\n    top5 = AverageMeter()\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n\n    dataiter = iter(trainloader)\n\n    for i in range(0, NO_STEPS):\n\n        try:\n            input, target = dataiter.next()\n        except StopIteration:\n            dataiter = iter(trainloader)\n            input, target = dataiter.next()\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        input, target = input.to(device), target.to(device)\n\n        # compute output\n        output = model(input)\n\n        loss = criterion(output, target)\n\n        # measure accuracy and record loss\n        err1, err5 = get_error(output.detach(), target, topk=(1, 5))\n\n        losses.update(loss.item(), input.size(0))\n        top1.update(err1.item(), input.size(0))\n        top5.update(err5.item(), input.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            print('Prunepoch: [{0}][{1}/{2}]\\t'\n                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\\t'\n                  'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n                  'Error@1 {top1.val:.3f} ({top1.avg:.3f})\\t'\n                  'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format(\n                epoch, i, NO_STEPS, batch_time=batch_time,\n                data_time=data_time, loss=losses, top1=top1, top5=top5))\n\n\n\n\ndef prune():\n    print('Pruning')\n    if args.random is False:\n        if args.l1_prune is False:\n            print('fisher pruning')\n            pruner.fisher_prune(model, prune_every=args.prune_every)\n        else:\n            print('l1 pruning')\n            pruner.l1_prune(model, prune_every=args.prune_every)\n    else:\n        print('random pruning')\n        pruner.random_prune(model, )\n\n\ndef validate():\n    global error_history\n\n    batch_time = AverageMeter()\n    data_time = AverageMeter()\n\n    losses = AverageMeter()\n    top1 = AverageMeter()\n    top5 = AverageMeter()\n\n    # switch to evaluate mode\n    model.eval()\n\n    end = time.time()\n\n    for i, (input, target) in enumerate(valloader):\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        input, target = input.to(device), target.to(device)\n\n        # compute output\n        output = model(input)\n\n        loss = criterion(output, target)\n\n        # measure accuracy and record loss\n        err1, err5 = get_error(output.detach(), target, topk=(1, 5))\n\n        losses.update(loss.item(), input.size(0))\n        top1.update(err1.item(), input.size(0))\n        top5.update(err5.item(), input.size(0))\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            print('Test: [{0}/{1}]\\t'\n                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                  'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n                  'Error@1 {top1.val:.3f} ({top1.avg:.3f})\\t'\n                  'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format(\n                i, len(valloader), batch_time=batch_time, loss=losses,\n                top1=top1, top5=top5))\n\n    print(' * Error@1 {top1.avg:.3f} Error@5 {top5.avg:.3f}'\n          .format(top1=top1, top5=top5))\n\n\n\n    # Record Top 1 for CIFAR\n    error_history.append(top1.avg)\n\n\nif __name__ == '__main__':\n\n    criterion = nn.CrossEntropyLoss()\n\n    optimizer = torch.optim.SGD([v for v in model.parameters() if v.requires_grad],\n                                lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)\n\n    for epoch in range(start_epoch, args.no_epochs):\n\n        print('Epoch %d:' % epoch)\n        print('Learning rate is %s' % [v['lr'] for v in optimizer.param_groups][0])\n\n        # finetune for one epoch\n        finetune()\n        # # evaluate on validation set\n        if epoch != 0 and ((epoch % args.val_every == 0) or (epoch + 1 == args.no_epochs)):  # Save at last epoch!\n            validate()\n\n            # Error history is recorded in validate(). Record params here\n            no_params = pruner.get_cost(model) + model.fixed_params\n            param_history.append(no_params)\n\n        # Save before pruning\n        if epoch != 0 and ((epoch % args.save_every == 0) or (epoch + 1 == args.no_epochs)):  #\n            filename = 'checkpoints/%s_%d_prunes.t7' % (args.save_file, epoch)\n            save_checkpoint({\n                'epoch': epoch + 1,\n                'state_dict': model.state_dict(),\n                'error_history': error_history,\n                'param_history': param_history,\n                'prune_history': pruner.prune_history,\n            }, filename=filename)\n\n        ## Prune\n        prune()\n\n"
  },
  {
    "path": "train.py",
    "content": "\"\"\"This script just trains models from scratch, to later be pruned\"\"\"\n\nimport argparse\nimport json\nimport os\nimport time\nimport torch.optim.lr_scheduler as lr_scheduler\nimport torch.utils.model_zoo as model_zoo\n\nfrom models import *\n\nfrom funcs import *\n\n\nparser = argparse.ArgumentParser(description='Pruning')\nparser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers')\nparser.add_argument('--GPU', default='0', type=str, help='GPU to use')\nparser.add_argument('--save_file', default='saveto', type=str, help='save file for checkpoints')\nparser.add_argument('--base_file', default='bbb', type=str, help='base file for checkpoints')\nparser.add_argument('--print_freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)')\nparser.add_argument('--data_loc', default='/disk/scratch/datasets/cifar')\n\n# Learning specific arguments\nparser.add_argument('-b', '--batch_size', default=128, type=int, metavar='N', help='mini-batch size (default: 256)')\nparser.add_argument('-lr', '--learning_rate', default=.1, type=float, metavar='LR', help='initial learning rate')\nparser.add_argument('-epochs', '--no_epochs', default=200, type=int, metavar='epochs', help='no. epochs')\nparser.add_argument('--epoch_step', default='[60,120,160]', type=str, help='json list with epochs to drop lr on')\nparser.add_argument('--lr_decay_ratio', default=0.2, type=float, help='learning rate decay factor')\nparser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')\nparser.add_argument('--weight_decay', '--wd', default=0.0005, type=float, metavar='W', help='weight decay')\nparser.add_argument('--eval', '-e', action='store_true', help='resume from checkpoint')\nparser.add_argument('--mask', '-m', type=int, help='mask mode', default=0)\nparser.add_argument('--deploy', '-de', action='store_true', help='prune and deploy model')\nparser.add_argument('--params_left', '-pl', default=0, type=int, help='prune til...')\nparser.add_argument('--net', choices=['res', 'dense'], default='res')\n\n# Net specific\nparser.add_argument('--depth', '-d', default=40, type=int, metavar='D', help='depth of wideresnet/densenet')\nparser.add_argument('--width', '-w', default=2.0, type=float, metavar='W', help='width of wideresnet')\nparser.add_argument('--growth', default=12, type=int, help='growth rate of densenet')\nparser.add_argument('--transition_rate', default=0.5, type=float, help='transition rate of densenet')\n\n\n# Uniform bottlenecks\nparser.add_argument('--bottle', action='store_true', help='Linearly scale bottlenecks')\nparser.add_argument('--bottle_mult', default=0.5, type=float, help='bottleneck multiplier')\n\n\nif not os.path.exists('checkpoints/'):\n    os.makedirs('checkpoints/')\n\nargs = parser.parse_args()\nprint(args)\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = args.GPU\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\nif args.net == 'res':\n    if not args.bottle:\n        model = WideResNet(args.depth, args.width, mask=args.mask)\n    else:\n        model = WideResNetBottle(args.depth, args.width, bottle_mult=args.bottle_mult)\nelif args.net == 'dense':\n    if not args.bottle:\n        model = DenseNet(args.growth, args.depth, args.transition_rate, 10, True, mask=args.mask)\n    else:\n        model = DenseNet(args.growth, args.depth, args.transition_rate, 10, True, width=args.bottle_mult)\n\nelse:\n    raise ValueError('pick a valid net')\n\npruner = Pruner()\n\nif args.deploy:\n    # Feed example to activate masks\n    model(torch.rand(1, 3, 32, 32))\n    SD = torch.load('checkpoints/%s.t7' % args.base_file)\n\n    if not args.eval:\n\n        pruner = Pruner()\n        pruner._get_masks(model)\n\n        for ii in SD['prune_history']:\n            pruner.fixed_prune(model, ii)\n\n    else:\n        model.load_state_dict(SD['state_dict'])\n\npruner.compress(model)\n\nget_inf_params(model)\ntime.sleep(1)\nmodel.to(device)\n\nnormalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],\n                                 std=[x / 255.0 for x in [63.0, 62.1, 66.7]])\n\nprint('==> Preparing data..')\nnum_classes = 10\n\ntransform_train = transforms.Compose([\n    transforms.ToTensor(),\n    transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),\n                                      (4, 4, 4, 4), mode='reflect').squeeze()),\n    transforms.ToPILImage(),\n    transforms.RandomCrop(32),\n    transforms.RandomHorizontalFlip(),\n    transforms.ToTensor(),\n    normalize,\n])\n\ntransform_val = transforms.Compose([\n    transforms.ToTensor(),\n    normalize,\n\n])\n\ntrainset = torchvision.datasets.CIFAR10(root=args.data_loc,\n                                        train=True, download=True, transform=transform_train)\nvalset = torchvision.datasets.CIFAR10(root=args.data_loc,\n                                      train=False, download=True, transform=transform_val)\n\ntrainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True,\n                                          num_workers=args.workers,\n                                          pin_memory=False)\nvalloader = torch.utils.data.DataLoader(valset, batch_size=50, shuffle=False,\n                                        num_workers=args.workers,\n                                        pin_memory=False)\n\nerror_history = []\nepoch_step = json.loads(args.epoch_step)\n\n\ndef train():\n    batch_time = AverageMeter()\n    data_time = AverageMeter()\n    losses = AverageMeter()\n    top1 = AverageMeter()\n    top5 = AverageMeter()\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n\n    for i, (input, target) in enumerate(trainloader):\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        input, target = input.to(device), target.to(device)\n\n        # compute output\n        output = model(input)\n\n        loss = criterion(output, target)\n\n        # measure accuracy and record loss\n        err1, err5 = get_error(output.detach(), target, topk=(1, 5))\n\n        losses.update(loss.item(), input.size(0))\n        top1.update(err1.item(), input.size(0))\n        top5.update(err5.item(), input.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            print('Epoch: [{0}][{1}/{2}]\\t'\n                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\\t'\n                  'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n                  'Error@1 {top1.val:.3f} ({top1.avg:.3f})\\t'\n                  'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format(\n                epoch, i, len(trainloader), batch_time=batch_time,\n                data_time=data_time, loss=losses, top1=top1, top5=top5))\n\n\n\n\ndef validate():\n    global error_history\n\n    batch_time = AverageMeter()\n    data_time = AverageMeter()\n\n    losses = AverageMeter()\n    top1 = AverageMeter()\n    top5 = AverageMeter()\n\n    # switch to evaluate mode\n    model.eval()\n\n    end = time.time()\n\n    for i, (input, target) in enumerate(valloader):\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        input, target = input.to(device), target.to(device)\n\n        # compute output\n        output = model(input)\n\n        loss = criterion(output, target)\n\n        # measure accuracy and record loss\n        err1, err5 = get_error(output.detach(), target, topk=(1, 5))\n\n        losses.update(loss.item(), input.size(0))\n        top1.update(err1.item(), input.size(0))\n        top5.update(err5.item(), input.size(0))\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            print('Test: [{0}/{1}]\\t'\n                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                  'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n                  'Error@1 {top1.val:.3f} ({top1.avg:.3f})\\t'\n                  'Error@5 {top5.val:.3f} ({top5.avg:.3f})'.format(\n                i, len(valloader), batch_time=batch_time, loss=losses,\n                top1=top1, top5=top5))\n\n    print(' * Error@1 {top1.avg:.3f} Error@5 {top5.avg:.3f}'\n          .format(top1=top1, top5=top5))\n\n\n    # Record Top 1 for CIFAR\n    error_history.append(top1.avg)\n\n\nif __name__ == '__main__':\n\n    filename = 'checkpoints/%s.t7' % args.save_file\n    criterion = nn.CrossEntropyLoss()\n    optimizer = torch.optim.SGD([v for v in model.parameters() if v.requires_grad],\n                                lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)\n    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=epoch_step, gamma=args.lr_decay_ratio)\n\n    if not args.eval:\n\n        for epoch in range(args.no_epochs):\n\n            print('Epoch %d:' % epoch)\n            print('Learning rate is %s' % [v['lr'] for v in optimizer.param_groups][0])\n            # train for one epoch\n            train()\n            scheduler.step()\n            # # evaluate on validation set\n            validate()\n\n            save_checkpoint({\n                'epoch': epoch + 1,\n                'state_dict': model.state_dict(),\n                'error_history': error_history,\n            }, filename=filename)\n\n    else:\n        if not args.deploy:\n            model.load_state_dict(torch.load(filename)['state_dict'])\n        epoch = 0\n        validate()\n"
  }
]