Full Code of alvinwan/shiftresnet-cifar for AI

master 289a1a162526 cached
13 files
49.5 KB
13.4k tokens
37 symbols
1 requests
Download .txt
Repository: alvinwan/shiftresnet-cifar
Branch: master
Commit: 289a1a162526
Files: 13
Total size: 49.5 KB

Directory structure:
gitextract_279b2ese/

├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── count.py
├── eval.py
├── main.py
├── models/
│   ├── __init__.py
│   ├── depthwiseresnet.py
│   ├── resnet.py
│   └── shiftresnet.py
├── requirements.txt
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.idea
test.py
data
checkpoint

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/


================================================
FILE: .gitmodules
================================================
[submodule "models/shiftnet_cuda_v2"]
	path = models/shiftnet_cuda_v2
	url = git@github.com:peterhj/shiftnet_cuda_v2.git
	branch = master


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# ShiftResNet

Train ResNet with shift operations on CIFAR10, CIFAR100 using PyTorch. This uses the [original resnet CIFAR10 codebase](https://github.com/kuangliu/pytorch-cifar.git) written by Kuang Liu. In this codebase, we replace 3x3 convolutional layers with a conv-shift-conv--a 1x1 convolutional layer, a set of shift operations, and a second 1x1 convolutional layer. The repository includes the following:

- training utility to reproduce results
- efficient implementation of the shift layer from [Peter Jin](https://people.eecs.berkeley.edu/~phj/)
- ResNet and ShiftResNet derivatives on CIFAR10/CIFAR100
- count utility for parameters and FLOPs
- evaluation script for offline evaluation
- links to 60+ pretrained models: [#12](https://github.com/alvinwan/shiftresnet-cifar/issues/12) for CIFAR-10 and CIFAR-100

Unless otherwise specified, the code was written by and experiments were run by [Alvin Wan](http://alvinwan.com) with help from [Bichen Wu](https://github.com/BichenWuUCB).

## [_Shift:_ A Zero FLOP, Zero Parameter Alternative to Spatial Convolutions](https://arxiv.org/pdf/1711.08141.pdf)
By Bichen Wu, Alvin Wan, Xiangyu Yue, Peter Jin, Sicheng Zhao, Noah Golmant, Amir Gholaminejad, Joseph Gonzalez, Kurt Keutzer

Tradeoffs and further analysis can be found in the paper. If you find this work useful for your research, please consider citing:

    @inproceedings{shift,
        Author = {Bichen Wu and Alvin Wan and Xiangyu Yue and Peter Jin and Sicheng Zhao and Noah Golmant and Amir Gholaminejad and Joseph Gonzalez and Kurt Keutzer},
        Title = {Shift: A Zero FLOP, Zero Parameter Alternative to Spatial Convolutions},
        Journal = {arXiv:1711.08141},
        Year = {2017}
    }
    

## Getting Started

1. If you have not already, setup a virtual environment with Python2.7, and activate it.

```
virtualenv shift --python=python2.7
source shift/bin/activate
```

Your prompt should now be prefaced with `(shift)`, as in

```
(shift) [user@server:~]$ 
```

2. Install `pytorch` and `torchvision`. Access [pytorch.org](http://pytorch.org), scroll down to the "Getting Started" section, and select the appropriate OS, package manager, Python, and CUDA build. For example, selecting Linux, pip, Python2.7, and CUDA 8 gives the following, as of the time of this writing

```
pip install pytorch torchvision # upgrade to latest PyTorch 0.4.1 official stable version
```

3. Clone the repository

```
git clone --recursive git@github.com:alvinwan/shiftresnet-cifar.git
```

4. `cd` into the cuda layer repository.
```
cd shiftresnet-cifar/models/shiftnet_cuda_v2
```

5. Follow the [ShiftNet Cuda layer instructions](https://github.com/peterhj/shiftnet_cuda_v2), steps 5 and 6:

```
pip install -r requirements.txt
make
```

6. In dir `shiftresnet-cifar/models/shiftnet_cuda_v2`, create an additional `__init__.py` so that Python2 can use `shiftnet_cuda_v2` as a module.

```
touch __init__.py
```

7. Then, `cd` back into the root of this repository. Create the `checkpoint` directory and download a checkpoint.

```
cd ../..
mkdir checkpoint
```

In this example below, we download the original `ResNet20`, 3x smaller `ShiftResNet20-3`, and 3x smaller `ResNet20`. Download [all CIFAR-100 models](https://github.com/alvinwan/shiftresnet-cifar/issues/12). Save these in a `checkpoint` directory, so that your file structure resembles the following:

```
shiftresnet-cifar/
   |
   |-- eval.py
   |-- checkpoint/
       |-- resnet20_cifar100.t7
       |-- ...
```

8. Run the following. This will get you started, downloading the dataset locally to `./data` accordingly. We begin by just evaluating the original ResNet model on CIFAR100.

```
python eval.py --model=checkpoint/resnet20_cifar100.t7 --dataset=cifar100
```

This default ResNet model should give 66.25%. By default, the script loads and trains on CIFAR10. Use the `--dataset` flag, as above, for CIFAR100.

### ShiftNet Expansion

To control the expansion hyperparameter for ShiftNet, identify a ShiftNet architecture and apply expansion. For example, the following uses ResNet20 with Shift modules of expansion `3c`. We should start by counting parameters and FLOPS (for CIFAR10/CIFAR100):

```
python count.py --arch=shiftresnet20 --expansion=3
```

This should output the following parameter and FLOP count:

```
Parameters: (new) 95642 (original) 272474 (reduction) 2.85
FLOPs: (new) 16581248 (original) 40960640 (reduction) 2.47
```

We can then evaluate the associated ShiftResNet, which we downloaded in the first part of this README. Note the arguments to `main.py` and `count.py` are very similar.

```
python eval.py --model=checkpoint/shiftresnet20_3.0_cifar100.t7 --dataset=cifar100
```

The ShiftResNet model above yields 70.77% on CIFAR-100.

### ResNet Reduction

To reduce ResNet by some factor, in terms of its parameters, specify a reduction either block-wise or net-wise. The former reduces the internal channel representation for each BasicBlock. The latter reduces the input and output channels for all convolution layers by half. First, we can check the reduction in parameter count for the entire network. For example, we specify a block-wise reduction of 3x below:

```
python count.py --arch=resnet20 --reduction=2.8 --reduction-mode=block
```

This should output the following parameter and FLOP count:

```
==> resnet20 with reduction 2.80
Parameters: (new) 96206 (original) 272474 (reduction) 2.83
FLOPs: (new) 14197376 (original) 40960640 (reduction) 2.89
```

We again evaluate the associated neural network, which we downloaded in the first part of this README.

```
python eval.py --model=checkpoint/resnet20_2.8_block_cifar100.t7 --dataset=cifar100
```

This reduced ResNet gives 68.30% accuracy on CIFAR-100, 2.47% less than ShiftResNet despite having several hundred more parameters.

## Experiments

Below, we run experiments on the following:

1. Varying expansion used for all conv-shift-conv layers in the neural network. Here, we replace 3x3 filters.
2. Varying number of output channels for a 3x3 convolution filter, matching the reduction in parameters that shift provides. This is `--reduction-mode=block`, which is *not* the default reduction mode.

`a` is the number of filters in the first set of 1x1 convolutional filters. `c` is the number of channels in our input.

### CIFAR-100 Accuracy

Accuracies below are all Top 1. All CIFAR-100 pretrained models can be found [here](https://github.com/alvinwan/shiftresnet-cifar/issues/12) (It's worth noticing that this pre-trained model is encoded in the python2 way which may cause problems when the model is loaded in a python3 program.). Below, we compare reductions in parameters for the entire net (`--reduction_mode=net`) and block-wise (`--reduction_mode=block`)

| Model | `e` | SRN Acc* | RN Conv Acc | RN Depth Acc | Params | Reduction (conv) | `r`** | `r`*** |
|-------|-----|----------|-------------|--------------|--------|------------------|-------|--------|
| ResNet20  | 1c | 55.05% | 50.23% | **61.32%** | 0.03 | 7.8 (7.2) | 1.12 | 0.38 |
| ResNet20  | 3c | **65.83%** | 60.72% | 64.51% | 0.10 | 2.9 (2.8) | 0.38 | 0.13 | 
| ResNet20  | 6c | **69.73%** | 65.59% | 65.38% | 0.19 | 1.5 | 0.19 | 0.065 |
| ResNet20  | 9c | **70.77%** | 68.30% | 65.59% | 0.28 | .98 | 0.125 | 0.04 |
| ResNet20  | -- | -- | 66.25% | -- | 0.27 | 1.0 | -- | -- |
| ResNet56  | 1c | 63.20% | 58.70% | **65.30%** | 0.10 | 8.4 (7.6) | 1.12 | 0.38 |
| ResNet56  | 3c | **69.77%** | 66.89% | 66.49% | 0.29 | 2.9 | 0.37 | 0.128 |
| ResNet56  | 6c | **72.33%** | 70.49% | 67.46% | 0.58 | 1.5 | 0.19 | 0.065 |
| ResNet56  | 9c | **73.43%** | 71.57% | 67.75% | 0.87 | 0.98 | 0.124 | 0.04 |
| ResNet56  | -- | -- | 69.27% | -- | 0.86 | 1.0 | -- | -- |
| ResNet110 | 1c | **68.01%** | 65.79% | 65.80% | 0.20 | 8.5 (7.8) | 1.1 | 0.37 |
| ResNet110 | 3c | **72.10%** | 70.22% | 67.22% | 0.59 | 2.9 | 0.37 | 0.125 |
| ResNet110 | 6c | **73.17%** | 72.21% | 68.11% | 1.18 | 1.5 | 0.19 | 0.065 |
| ResNet110 | 9c | **73.71%** | 72.67% | 68.39% | 1.76 | 0.98 | 0.123 | 0.04 |
| ResNet110 | -- | -- | 72.11% | -- | 1.73 | 1.0 | -- | -- |

`*` `SRN` ShiftResNet and `RN` ResNet accuracy using convolutional layers (by reducing the number of channels in the intermediate representation of each ResNet block) and using depth-wise convolutional layers (again reducing number of channels in intermediate representation)

`**` This parameter `r` is used for the `--reduction` flag when replicating results for depth-wise convolutional blocks AND for mobilenet blocks.

`***` This parameter `r` is used for the `--reduction` flag with shuffle blocks.

### CIFAR-10 Accuracy

All CIFAR-10 pretrained models can be found on [here](https://github.com/alvinwan/shiftresnet-cifar/issues/12) (Same as above, the encoding is in python2 way which is different from python3's encoding).

| Model | `e` | ShiftResNet Acc | ResNet Acc | Params* | Reduction** |
|-------|-----|-----|-----------|---------|-------------|
| ResNet20 | c | 85.78% | 84.77% | 0.03 | 7.8 (7.2) |
| ResNet20 | 3c | 89.56% | 88.81% | 0.10 | 2.9 (2.8) |
| ResNet20 | 6c | 91.07% | 91.30% | 0.19 | 1.5  |
| ResNet20 | 9c | 91.79 | 91.96% | 0.28 | .98 |
| ResNet20 | original | - | 91.35% | 0.27 | 1.0 |
| ResNet56 | c | 89.69% | 88.32% | 0.10 | 8.4 (7.6) |
| ResNet56 | 3c | 92.48% | 91.20% | 0.29 | 2.9 |
| ResNet56 | 6c | 93.49% | 93.01% | 0.58 | 1.5 |
| ResNet56 | 9c | 93.17% | 93.74% | 0.87 | 0.98 |
| ResNet56 | original | - | 92.01% | 0.86 | 1.0 |
| ResNet110 | c | 90.67% | 89.79% | 0.20 | 8.5 (7.8) |
| ResNet110 | 3c | 92.42% | 93.18% | 0.59 | 2.9 |
| ResNet110 | 6c | 93.03% | 93.40% | 1.18 | 1.5 |
| ResNet110 | 9c | 93.36% | 94.09% | 1.76 | 0.98 (0.95) |
| ResNet110 | original | - | 92.46% | 1.73 | 1.0 |

`*` parameters are in the millions

`**` The number in parantheses is the reduction in parameters we used for ResNet, if we could not obtain the exact reduction in parameters used for shift.

`***` If using `--reduction_mode=block`, pass the `reduction` to `main.py` for the `--reduction` flag, to reproduce the provided accuracies. This represents the amount to reduce each resnet block's number of "internal convolutional channels" by. In constrast, the column to the left of it is the total neural network's reduction in parameters.


================================================
FILE: count.py
================================================
from models import ResNet20
from models import ShiftResNet20
from models import ResNet56
from models import ShiftResNet56
from models import ResNet110
from models import ShiftResNet110
import torch
from torch.autograd import Variable
import numpy as np
import argparse

all_models = {
    'resnet20': ResNet20,
    'shiftresnet20': ShiftResNet20,
    'resnet56': ResNet56,
    'shiftresnet56': ShiftResNet56,
    'resnet110': ResNet110,
    'shiftresnet110': ShiftResNet110,
}

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--arch', choices=all_models.keys(),
                    help='Architecture to count parameters for', default='shiftresnet110')
parser.add_argument('--expansion', type=int, default=1, help='expansion for shift layers')
parser.add_argument('--reduction', type=float, default=1, help='reduction for resnet')
parser.add_argument('--reduction-mode', choices=('block', 'net', 'depthwise', 'shuffle', 'mobile'), help='"block" reduces inner representation for BasicBlock, "net" reduces for all layers', default='net')
args = parser.parse_args()

def count_params(net):
     return sum([np.prod(param.size()) for name, param in net.named_parameters()])

def count_flops(net):
     """Approximately count number of FLOPs"""
     dummy = Variable(torch.randn(1, 3, 32, 32)).cuda()  # size is specific to cifar10, cifar100!
     net.cuda().forward(dummy)
     return net.flops()

original = all_models[args.arch.replace('shift', '')]()
original_count = count_params(original)
original_flops = count_flops(original)

cls = all_models[args.arch]

assert 'shift' not in args.arch or args.reduction == 1, \
    'Only default resnet supports reductions'
if args.reduction != 1:
    print('==> %s with reduction %.2f' % (args.arch, args.reduction))
    net = cls(reduction=args.reduction, reduction_mode=args.reduction_mode)
else:
    net = cls() if 'shift' not in args.arch else cls(expansion=args.expansion)
new_count = count_params(net)
new_flops = count_flops(net)

print('Parameters: (new) %d (original) %d (reduction) %.2f' % (
      new_count, original_count, float(original_count) / new_count))
print('FLOPs: (new) %d (original) %d (reduction) %.2f' % (
      new_flops, original_flops, float(original_flops) / new_flops))


================================================
FILE: eval.py
================================================
'''Test CIFAR10 with PyTorch.'''
from __future__ import print_function

import glob

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models import *
from utils import progress_bar
from torch.autograd import Variable


parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--model', action='append', help='Specify model to test')
parser.add_argument('--suppress-errors', action='store_true')
parser.add_argument('--dataset', choices=('cifar10', 'cifar100'), help='Dataset to train and validate on.', default='cifar10')
args = parser.parse_args()

use_cuda = torch.cuda.is_available()
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

if args.dataset == 'cifar10':
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
elif args.dataset == 'cifar100':
    testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

testloader = torch.utils.data.DataLoader(testset, batch_size=512, shuffle=False, num_workers=4)

criterion = nn.CrossEntropyLoss()

# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'

def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(testloader):
        with torch.no_grad():
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = Variable(inputs), Variable(targets)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item() * targets.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()

        progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (test_loss/total, 100.*correct/total, correct, total))
    return ' '

for pattern in args.model:
  for model in sorted(glob.iglob(pattern), reverse=True):
    print('Reading from model', model)
    checkpoint = torch.load(model)
    net = checkpoint['net']
    best_acc = checkpoint.get('acc', 0)
    start_epoch = checkpoint.get('epoch', 0)
    if use_cuda:
        net.cuda()
        net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
        cudnn.benchmark = True
    if args.suppress_errors:
      try:
        print(test(0))
      except AssertionError as e:
        print('The model may be malformed.')
        print(e)
    else:
      print(test(0))


================================================
FILE: main.py
================================================
'''Train CIFAR10 with PyTorch.'''

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models import ResNet20
from models import ResNet56
from models import ResNet110
from models import ShiftResNet20
from models import ShiftResNet56
from models import ShiftResNet110
from models import DepthwiseResNet20
from models import DepthwiseResNet56
from models import DepthwiseResNet110
from utils import progress_bar
from torch.autograd import Variable


all_models = {
    'resnet20': ResNet20,
    'shiftresnet20': ShiftResNet20,
    'depthwiseresnet20': DepthwiseResNet20,
    'resnet56': ResNet56,
    'shiftresnet56': ShiftResNet56,
    'depthwiseresnet56': DepthwiseResNet56,
    'resnet110': ResNet110,
    'shiftresnet110': ShiftResNet110,
    'depthwiseresnet110': DepthwiseResNet110
}

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
parser.add_argument('--batch_size', '-b', default=128, type=int, help='batch size')
parser.add_argument('--arch', '-a', choices=all_models.keys(), default='shiftresnet110', help='neural network architecture')
parser.add_argument('--expansion', '-e', help='Expansion for shift resnet.', default=1, type=float)
parser.add_argument('--reduction', help='Amount to reduce raw resnet model by', default=1.0, type=float)
parser.add_argument('--reduction-mode', choices=('block', 'net', 'depthwise'), help='"block" reduces inner representation for BasicBlock, "net" reduces for all layers', default='net')
parser.add_argument('--dataset', choices=('cifar10', 'cifar100', 'imagenet'), help='Dataset to train and validate on.', default='cifar10')
parser.add_argument('--datadir', help='Folder containing data', default='./data/')
args = parser.parse_args()

use_cuda = torch.cuda.is_available()
best_acc = 0.0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

if args.dataset == 'cifar10':
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    num_classes=10
elif args.dataset == 'cifar100':
    trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    num_classes = 100
elif args.dataset == 'imagenet':
    raise NotImplementedError()
    transform_train = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
                             std = [ 0.229, 0.224, 0.225 ]),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
                             std = [ 0.229, 0.224, 0.225 ]),
    ])

    traindir = os.path.join(args.datadir, 'train')
    valdir = os.path.join(args.datadir, 'val')
    trainset = torchvision.datasets.ImageFolder(traindir, transform_train)
    testset = torchvision.datasets.ImageFolder(valdir, transform_test)
    num_classes = 1000

trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=2)


if 'shift' in args.arch:
    suffix = '_%s' % args.expansion
elif args.reduction != 1:
    suffix = '_%s_%s' % (args.reduction, args.reduction_mode)
else:
    suffix = ''

if args.dataset == 'cifar100':
    suffix += '_cifar100'

if args.dataset == 'imagenet':
    suffix += '_imagenet'

path = './checkpoint/%s%s.t7' % (args.arch, suffix)
print('Using path: %s' % path)

# Model
if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint.. %s' % path)
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(path)
    net = checkpoint['net']
    best_acc = float(checkpoint['acc'])
    start_epoch = checkpoint['epoch']
else:
    print('==> Building model..')
    cls = all_models[args.arch]
    assert 'shift' not in args.arch or args.reduction == 1, \
        'Only default resnet and depthwise resnet support reductions'
    if args.reduction != 1:
        print('==> %s with reduction %.2f' % (args.arch, args.reduction))
        net = cls(reduction=args.reduction, reduction_mode=args.reduction_mode, num_classes=num_classes)
    else:
        net = cls(args.expansion, num_classes=num_classes) if 'shift' in args.arch else cls(num_classes=num_classes)

if use_cuda:
    net.cuda()
    net = torch.nn.DataParallel(
        net, device_ids=range(torch.cuda.device_count()))
    cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()

def adjust_learning_rate(epoch, lr):
    if epoch <= 81:  # 32k iterations
      return lr
    elif epoch <= 122:  # 48k iterations
      return lr/10
    else:
      return lr/100

# Training
def train(epoch):
    lr = adjust_learning_rate(epoch, args.lr)
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * targets.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/total, 100.*float(correct)/float(total), correct, total))

def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(testloader):
        with torch.no_grad():
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = Variable(inputs), Variable(targets)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item() * targets.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()

        progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (test_loss/total, 100.*float(correct)/float(total), correct, total))

    # Save checkpoint.
    acc = 100.*float(correct)/float(total)
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.module if use_cuda else net,
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, path)
        print('* Saved checkpoint to %s' % path)
        best_acc = acc


for epoch in range(start_epoch, 164):
    train(epoch)
    test(epoch)


================================================
FILE: models/__init__.py
================================================
from .resnet import *
from .shiftresnet import *
from .depthwiseresnet import *


================================================
FILE: models/depthwiseresnet.py
================================================
"""PyTorch implementation of DepthwiseResNet

ShiftResNet modifications written by Bichen Wu and Alvin Wan.

Reference:
[1] Bichen Wu, Alvin Wan, Xiangyu Yue, Peter Jin, Sicheng Zhao, Noah Golmant,
    Amir Gholaminejad, Joseph Gonzalez, Kurt Keutzer
    Shift: A Zero FLOP, Zero Parameter Alternative to Spatial Convolutions.
    arXiv:1711.08141
"""

import torch.nn as nn
import torch.nn.functional as F

from .resnet import ResNet


class DepthWiseWithSkipBlock(nn.Module):

    def __init__(self, in_planes, out_planes, stride=1, reduction=1):
        super(DepthWiseWithSkipBlock, self).__init__()
        self.expansion = 1 / float(reduction)
        self.in_planes = in_planes
        self.mid_planes = mid_planes = int(self.expansion * out_planes)
        self.out_planes = out_planes

        self.conv1 = nn.Conv2d(
            in_planes, mid_planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_planes)
        self.depth = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, padding=1,
                               stride=1, bias=False, groups=mid_planes)
        self.bn2 = nn.BatchNorm2d(mid_planes)
        self.conv3 = nn.Conv2d(
            mid_planes, out_planes, kernel_size=1, bias=False, stride=stride)
        self.bn3 = nn.BatchNorm2d(out_planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != out_planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                          bias=False),
                nn.BatchNorm2d(out_planes)
            )

    def flops(self):
        if not hasattr(self, 'int_nchw'):
            raise UserWarning('Must run forward at least once')
        (_, _, int_h, int_w), (
        _, _, out_h, out_w) = self.int_nchw, self.out_nchw
        flops = int_h * int_w * self.mid_planes * self.in_planes + out_h * out_w * self.mid_planes * self.out_planes
        flops += out_h * out_w * self.mid_planes * 9  # depth-wise convolution
        if len(self.shortcut) > 0:
            flops += self.in_planes * self.out_planes * out_h * out_w
        return flops

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        self.int_nchw = out.size()
        out = self.bn2(self.depth(out))
        out = self.bn3(self.conv3(out))
        self.out_nchw = out.size()
        out += self.shortcut(x)
        out = F.relu(out)
        return out


def DepthwiseResNet20(reduction=1, num_classes=10):
    block = lambda in_planes, planes, stride: \
        DepthWiseWithSkipBlock(in_planes, planes, stride, reduction=reduction)
    return ResNet(block, [3, 3, 3], num_classes=num_classes)


def DepthwiseResNet56(reduction=1, num_classes=10):
    block = lambda in_planes, planes, stride: \
        DepthWiseWithSkipBlock(in_planes, planes, stride, reduction=reduction)
    return ResNet(block, [9, 9, 9], num_classes=num_classes)


def DepthwiseResNet110(reduction=1, num_classes=10):
    block = lambda in_planes, planes, stride: \
        DepthWiseWithSkipBlock(in_planes, planes, stride, reduction=reduction)
    return ResNet(block, [18, 18, 18], num_classes=num_classes)


================================================
FILE: models/resnet.py
================================================
"""PyTorch implementation of ResNet

ResNet modifications written by Bichen Wu and Alvin Wan, based
off of ResNet implementation by Kuang Liu.

Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
"""
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):

    def __init__(self, in_planes, planes, stride=1, reduction=1):
        super(BasicBlock, self).__init__()
        self.expansion = 1 / float(reduction)
        self.in_planes = in_planes
        self.mid_planes = mid_planes = int(self.expansion * planes)
        self.out_planes = planes

        self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_planes)
        self.conv2 = nn.Conv2d(mid_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def flops(self):
        if not hasattr(self, 'int_nchw'):
            raise UserWarning('Must run forward at least once')
        (_, _, int_h, int_w), (_, _, out_h, out_w) = self.int_nchw, self.out_nchw
        flops = int_h*int_w*9*self.mid_planes*self.in_planes + out_h*out_w*9*self.mid_planes*self.out_planes
        if len(self.shortcut) > 0:
            flops += self.in_planes*self.out_planes*out_h*out_w
        return flops

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        self.int_nchw = out.size()
        out = self.bn2(self.conv2(out))
        self.out_nchw = out.size()
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, reduction=1, num_classes=10):
        super(ResNet, self).__init__()
        self.reduction = float(reduction) ** 0.5
        self.num_classes = num_classes
        self.in_planes = int(16 / self.reduction)

        self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_planes)
        self.layer1 = self._make_layer(block, self.in_planes, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, int(32 / self.reduction), num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, int(64 / self.reduction), num_blocks[2], stride=2)
        self.linear = nn.Linear(int(64 / self.reduction), num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        planes = int(planes)
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def flops(self):
        if not hasattr(self, 'int_nchw'):
            raise UserWarning('Must run forward at least once')
        (_, _, int_h, int_w), (out_h, out_w) = self.int_nchw, self.out_hw
        flops = 0
        for mod in (self.layer1, self.layer2, self.layer3):
            for layer in mod:
                flops += layer.flops()
        return int_h*int_w*9*self.in_planes*3 + out_w*self.num_classes + flops

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        self.int_nchw = out.size()
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        self.out_hw = out.size()
        out = self.linear(out)
        return out


def ResNetWrapper(num_blocks, reduction=1, reduction_mode='net', num_classes=10):
    if reduction_mode == 'block':
        block = lambda in_planes, planes, stride: \
            BasicBlock(in_planes, planes, stride, reduction=reduction)
        return ResNet(block, num_blocks, num_classes=num_classes)
    return ResNet(BasicBlock, num_blocks, num_classes=num_classes, reduction=reduction)


def ResNet20(reduction=1, reduction_mode='net', num_classes=10):
    return ResNetWrapper([3, 3, 3], reduction, reduction_mode, num_classes)


def ResNet56(reduction=1, reduction_mode='net', num_classes=10):
    return ResNetWrapper([9, 9, 9], reduction, reduction_mode, num_classes)


def ResNet110(reduction=1, reduction_mode='net', num_classes=10):
    return ResNetWrapper([18, 18, 18], reduction, reduction_mode, num_classes)


================================================
FILE: models/shiftresnet.py
================================================
"""PyTorch implementation of ShiftResNet

ShiftResNet modifications written by Bichen Wu and Alvin Wan. Efficient CUDA
implementation of shift written by Peter Jin.

Reference:
[1] Bichen Wu, Alvin Wan, Xiangyu Yue, Peter Jin, Sicheng Zhao, Noah Golmant,
    Amir Gholaminejad, Joseph Gonzalez, Kurt Keutzer
    Shift: A Zero FLOP, Zero Parameter Alternative to Spatial Convolutions.
    arXiv:1711.08141
"""

import torch.nn as nn
import torch.nn.functional as F

from .resnet import ResNet
from models.shiftnet_cuda_v2.nn import GenericShift_cuda


class ShiftConv(nn.Module):

    def __init__(self, in_planes, out_planes, stride=1, expansion=1):
        super(ShiftConv, self).__init__()
        self.expansion = expansion
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.mid_planes = mid_planes = int(out_planes * self.expansion)

        self.conv1 = nn.Conv2d(
            in_planes, mid_planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_planes)

        self.shift2 = GenericShift_cuda(kernel_size=3, dilate_factor=1)
        self.conv2 = nn.Conv2d(
            mid_planes, out_planes, kernel_size=1, bias=False, stride=stride)
        self.bn2 = nn.BatchNorm2d(out_planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != out_planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                  in_planes, out_planes, kernel_size=1, stride=stride,
                  bias=False),
                nn.BatchNorm2d(out_planes)
            )

    def flops(self):
        if not hasattr(self, 'int_nchw'):
            raise UserWarning('Must run forward at least once')
        (_, _, int_h, int_w), (_, _, out_h, out_w) = self.int_nchw, self.out_nchw
        flops = int_h * int_w * self.in_planes * self.mid_planes + \
                out_h * out_w * self.mid_planes * self.out_planes
        if len(self.shortcut) > 0:
            flops += self.in_planes * self.out_planes * out_h * out_w
        return flops

    def forward(self, x):
        shortcut = self.shortcut(x)
        x = F.relu(self.bn1(self.conv1(x)))
        self.int_nchw = x.size()
        x = F.relu(self.bn2(self.conv2(self.shift2(x))))
        self.out_nchw = x.size()
        x += shortcut
        return x


def ShiftResNet20(expansion=1, num_classes=10):
    block = lambda in_planes, out_planes, stride: \
        ShiftConv(in_planes, out_planes, stride, expansion=expansion)
    return ResNet(block, [3, 3, 3], num_classes=num_classes)


def ShiftResNet56(expansion=1, num_classes=10):
    block = lambda in_planes, out_planes, stride: \
        ShiftConv(in_planes, out_planes, stride, expansion=expansion)
    return ResNet(block, [9, 9, 9], num_classes=num_classes)


def ShiftResNet110(expansion=1, num_classes=10):
    block = lambda in_planes, out_planes, stride: \
        ShiftConv(in_planes, out_planes, stride, expansion=expansion)
    return ResNet(block, [18, 18, 18], num_classes=num_classes)


================================================
FILE: requirements.txt
================================================
cffi==1.11.2
numpy==1.13.3


================================================
FILE: utils.py
================================================
'''Some helper functions for PyTorch, including:
    - get_mean_and_std: calculate the mean and std value of dataset.
    - msr_init: net parameter initialization.
    - progress_bar: progress bar mimic xlua.progress.
'''
import os
import sys
import time
import math

import torch.nn as nn
import torch.nn.init as init


def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:,i,:,:].mean()
            std[i] += inputs[:,i,:,:].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std

def init_params(net):
    '''Init layer parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal(m.weight, mode='fan_out')
            if m.bias:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal(m.weight, std=1e-3)
            if m.bias:
                init.constant(m.bias, 0)


_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()

def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f
Download .txt
gitextract_279b2ese/

├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── count.py
├── eval.py
├── main.py
├── models/
│   ├── __init__.py
│   ├── depthwiseresnet.py
│   ├── resnet.py
│   └── shiftresnet.py
├── requirements.txt
└── utils.py
Download .txt
SYMBOL INDEX (37 symbols across 7 files)

FILE: count.py
  function count_params (line 29) | def count_params(net):
  function count_flops (line 32) | def count_flops(net):

FILE: eval.py
  function test (line 51) | def test(epoch):

FILE: main.py
  function adjust_learning_rate (line 148) | def adjust_learning_rate(epoch, lr):
  function train (line 157) | def train(epoch):
  function test (line 183) | def test(epoch):

FILE: models/depthwiseresnet.py
  class DepthWiseWithSkipBlock (line 18) | class DepthWiseWithSkipBlock(nn.Module):
    method __init__ (line 20) | def __init__(self, in_planes, out_planes, stride=1, reduction=1):
    method flops (line 45) | def flops(self):
    method forward (line 56) | def forward(self, x):
  function DepthwiseResNet20 (line 67) | def DepthwiseResNet20(reduction=1, num_classes=10):
  function DepthwiseResNet56 (line 73) | def DepthwiseResNet56(reduction=1, num_classes=10):
  function DepthwiseResNet110 (line 79) | def DepthwiseResNet110(reduction=1, num_classes=10):

FILE: models/resnet.py
  class BasicBlock (line 14) | class BasicBlock(nn.Module):
    method __init__ (line 16) | def __init__(self, in_planes, planes, stride=1, reduction=1):
    method flops (line 35) | def flops(self):
    method forward (line 44) | def forward(self, x):
  class ResNet (line 54) | class ResNet(nn.Module):
    method __init__ (line 55) | def __init__(self, block, num_blocks, reduction=1, num_classes=10):
    method _make_layer (line 68) | def _make_layer(self, block, planes, num_blocks, stride):
    method flops (line 77) | def flops(self):
    method forward (line 87) | def forward(self, x):
  function ResNetWrapper (line 100) | def ResNetWrapper(num_blocks, reduction=1, reduction_mode='net', num_cla...
  function ResNet20 (line 108) | def ResNet20(reduction=1, reduction_mode='net', num_classes=10):
  function ResNet56 (line 112) | def ResNet56(reduction=1, reduction_mode='net', num_classes=10):
  function ResNet110 (line 116) | def ResNet110(reduction=1, reduction_mode='net', num_classes=10):

FILE: models/shiftresnet.py
  class ShiftConv (line 20) | class ShiftConv(nn.Module):
    method __init__ (line 22) | def __init__(self, in_planes, out_planes, stride=1, expansion=1):
    method flops (line 47) | def flops(self):
    method forward (line 57) | def forward(self, x):
  function ShiftResNet20 (line 67) | def ShiftResNet20(expansion=1, num_classes=10):
  function ShiftResNet56 (line 73) | def ShiftResNet56(expansion=1, num_classes=10):
  function ShiftResNet110 (line 79) | def ShiftResNet110(expansion=1, num_classes=10):

FILE: utils.py
  function get_mean_and_std (line 15) | def get_mean_and_std(dataset):
  function init_params (line 29) | def init_params(net):
  function progress_bar (line 51) | def progress_bar(current, total, msg=None):
  function format_time (line 94) | def format_time(seconds):
Condensed preview — 13 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (53K chars).
[
  {
    "path": ".gitignore",
    "chars": 1188,
    "preview": ".idea\ntest.py\ndata\ncheckpoint\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions"
  },
  {
    "path": ".gitmodules",
    "chars": 138,
    "preview": "[submodule \"models/shiftnet_cuda_v2\"]\n\tpath = models/shiftnet_cuda_v2\n\turl = git@github.com:peterhj/shiftnet_cuda_v2.git"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 10279,
    "preview": "# ShiftResNet\n\nTrain ResNet with shift operations on CIFAR10, CIFAR100 using PyTorch. This uses the [original resnet CIF"
  },
  {
    "path": "count.py",
    "chars": 2285,
    "preview": "from models import ResNet20\nfrom models import ShiftResNet20\nfrom models import ResNet56\nfrom models import ShiftResNet5"
  },
  {
    "path": "eval.py",
    "chars": 3030,
    "preview": "'''Test CIFAR10 with PyTorch.'''\nfrom __future__ import print_function\n\nimport glob\n\nimport torch\nimport torch.nn as nn\n"
  },
  {
    "path": "main.py",
    "chars": 8078,
    "preview": "'''Train CIFAR10 with PyTorch.'''\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functi"
  },
  {
    "path": "models/__init__.py",
    "chars": 80,
    "preview": "from .resnet import *\nfrom .shiftresnet import *\nfrom .depthwiseresnet import *\n"
  },
  {
    "path": "models/depthwiseresnet.py",
    "chars": 3171,
    "preview": "\"\"\"PyTorch implementation of DepthwiseResNet\n\nShiftResNet modifications written by Bichen Wu and Alvin Wan.\n\nReference:\n"
  },
  {
    "path": "models/resnet.py",
    "chars": 4645,
    "preview": "\"\"\"PyTorch implementation of ResNet\n\nResNet modifications written by Bichen Wu and Alvin Wan, based\noff of ResNet implem"
  },
  {
    "path": "models/shiftresnet.py",
    "chars": 3005,
    "preview": "\"\"\"PyTorch implementation of ShiftResNet\n\nShiftResNet modifications written by Bichen Wu and Alvin Wan. Efficient CUDA\ni"
  },
  {
    "path": "requirements.txt",
    "chars": 27,
    "preview": "cffi==1.11.2\nnumpy==1.13.3\n"
  },
  {
    "path": "utils.py",
    "chars": 3446,
    "preview": "'''Some helper functions for PyTorch, including:\n    - get_mean_and_std: calculate the mean and std value of dataset.\n  "
  }
]

About this extraction

This page contains the full source code of the alvinwan/shiftresnet-cifar GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 13 files (49.5 KB), approximately 13.4k tokens, and a symbol index with 37 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!