Full Code of Tushar-N/blockdrop for AI

master ec52b36d38dc cached
10 files
46.6 KB
12.0k tokens
70 symbols
1 requests
Download .txt
Repository: Tushar-N/blockdrop
Branch: master
Commit: ec52b36d38dc
Files: 10
Total size: 46.6 KB

Directory structure:
gitextract_y_xp5prk/

├── .gitignore
├── Readme.md
├── cl_training.py
├── finetune.py
├── models/
│   ├── __init__.py
│   ├── base.py
│   └── resnet.py
├── requirements.txt
├── test.py
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*.*~
*.pyc
*.pkl
*.h5
*.t7
*.t7*
*.log
*.png

data/
cv/



================================================
FILE: Readme.md
================================================
# BlockDrop: Dynamic Inference Paths in Residual Networks
![BlockDrop Model](https://user-images.githubusercontent.com/4995097/35775877-3cc64f86-0957-11e8-85c4-9bd16cda22a0.png)

This code implements a policy network that learns to dynamically choose which blocks of a ResNet to execute during inference so as to best reduce total computation without degrading prediction accuracy. Built upon a ResNet-101 model, our method achieves a speedup of 20% on average, going as high as 36% for some images, while maintaining the same 76.4% top-1 accuracy on ImageNet.

This is the code accompanying the work:  
Zuxuan Wu*, Tushar Nagarajan*, Abhishek Kumar, Steven Rennie, Larry S. Davis, Kristen Grauman, and Rogerio Feris. BlockDrop: Dynamic Inference Paths in Residual Networks [[arxiv]](https://arxiv.org/pdf/1711.08393.pdf)  
(* authors contributed equally)

## Prerequisites
The code is written and tested using Python (2.7) and PyTorch (v0.3.0).

**Packages**: Install using `pip install -r requirements.txt`

**Pretrained models**: Our models require standard pretrained ResNets on CIFAR and ImageNet as starting points. These can be trained using [this](https://github.com/felixgwu/img_classification_pk_pytorch) repository, or can be obtained directly from us

```bash
wget -O blockdrop-checkpoints.tar.gz https://utexas.box.com/shared/static/ok98i51v14c0q9lvs1z5g71m6b3zm8sj.gz
tar -zxvf blockdrop-checkpoints.tar.gz
```
The downloaded checkpoints will be unpacked to `./cv/` for further use. The folder also contains various checkpoints from each stage of training.

**Datasets**: PyTorch's *torchvision* package automatically downloads CIFAR10 and CIFAR100 during training. ImageNet must be downloaded and organized following [these steps](https://github.com/soumith/imagenet-multiGPU.torch#data-processing).

## Training a model
Training occurs in two steps (1) Curriculum Learning and (2) Joint Finetuning.  
Models operating on ResNets of different depths can be trained on different datasets using the same script. Examples of how to train these models are given below. Checkpoints and tensorboard log files will be saved to folder specified in `--cv_dir`

#### Curriculum Learning
The policy network can be trained using a CL schedule as follows.

```bash
# Train a model on CIFAR 10 built upon a ResNet-110
python cl_training.py --model R110_C10 --cv_dir cv/R110_C10_cl/ --lr 1e-3 --batch_size 2048 --max_epochs 5000

# Train a model on ImageNet built upon a ResNet-101
python cl_training.py --model R101_ImgNet --cv_dir cv/R101_ImgNet_cl/ --lr 1e-3 --batch_size 2048 --max_epochs 45 --data_dir data/imagenet/
```

Model checkpoints after the curriculum learning step can be found in the downloaded folder. For example: `./cv/cl_learning/R110_C10/ckpt_E_5300_A_0.754_R_2.22E-01_S_20.10_#_7787.t7`

#### Joint Finetuning
Checkpoints trained during the curriculum learning phase can be used to further jointly finetune the base ResNet to achieve the results reported in the paper. Different values for the penalty parameter control the trade-off between accuracy and speed.

```bash
# Finetune a ResNet-110 on CIFAR 10 using the checkpoint from cl_training
python finetune.py --model R110_C10 --lr 1e-4 --penalty -10 --pretrained cv/cl_training/R110_C10/ckpt_E_5300_A_0.754_R_2.22E-01_S_20.10_#_7787.t7 --batch_size 256  --max_epochs 2000 --cv_dir cv/R110_C10_ft_-10/

# Finetune a ResNet-101 on ImageNet using the checkpoint from cl_training
python finetune.py --model R101_ImgNet --lr 1e-4  --penalty -5 --pretrained cv/cl_training/R101_ImgNet/ckpt_E_4_A_0.746_R_-3.70E-01_S_29.79_#_484.t7 --data_dir data/imagenet/ --batch_size 320 --max_epochs 10 --cv_dir cv/R101_ImgNet_ft_-5/
```

Model checkpoints after the joint finetuning step can be found in the downloaded folder. For example: `./cv/finetuned/R101_ImgNet_gamma_5/ckpt_E_10_A_0.764_R_-8.46E-01_S_24.77_#_10.t7`

## Testing and Profiling
Once jointly finetuned, models can be profiled for accuracy and FLOPs counts.
```bash
python test.py --model R110_C10 --load cv/finetuned/R110_C10_gamma_10/ckpt_E_2000_A_0.936_R_1.95E-01_S_16.93_#_469.t7
```
The model should produce an accuracy of 93.6% and use 1.81E+08 FLOPs on average. The output should look like this:
```
    Accuracy: 0.936
    Block Usage: 16.933 ± 3.717
    FLOPs/img: 1.81E+08 ± 3.43E+07
    Unique Policies: 469
```

The ImageNet model can be evaluated in a similar manner, and will generate a corresponding output.
```
python test.py --model R101_ImgNet --load cv/finetuned/R101_ImgNet_gamma_5/ckpt_E_10_A_0.764_R_-8.46E-01_S_24.77_#_10.t7
```
```
    Accuracy: 0.764
    Block Usage: 24.770 ± 0.980
    FLOPs/img: 1.25E+10 ± 4.28E+08
    Unique Policies: 10
```


## Visualization
Learned policies over ResNet blocks show that there is a clear separation between easy/hard images in terms of the number of blocks they require. In addition, unique policies over the blocks admit distinct image styles.

![Policy visualization](https://user-images.githubusercontent.com/4995097/35775878-3e5ee4e8-0957-11e8-832d-b9dc2ea8fecc.png)

For more qualitative results, see Sec. 4.3 and Figures 4. and 5. in the paper.



## Cite

If you find this repository useful in your own research, please consider citing:
```
@inproceedings{blockdrop,
  title={BlockDrop: Dynamic Inference Paths in Residual Networks},
  author={Wu, Zuxuan and Nagarajan, Tushar and Kumar, Abhishek and Rennie, Steven and Davis, Larry S and Grauman, Kristen and Feris, Rogerio},
  booktitle={CVPR},
  year={2018}
}
```


================================================
FILE: cl_training.py
================================================
import os
from tensorboard_logger import configure, log_value
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.utils.data as torchdata
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import tqdm
import utils
import torch.optim as optim
from torch.distributions import Bernoulli

import torch.backends.cudnn as cudnn
cudnn.benchmark = True


import argparse
parser = argparse.ArgumentParser(description='BlockDrop Training')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument('--beta', type=float, default=1e-1, help='entropy multiplier')
parser.add_argument('--wd', type=float, default=0.0, help='weight decay')
parser.add_argument('--model', default='R110_C10', help='R<depth>_<dataset> see utils.py for a list of configurations')
parser.add_argument('--data_dir', default='data/', help='data directory')
parser.add_argument('--load', default=None, help='checkpoint to load agent from')
parser.add_argument('--cv_dir', default='cv/tmp/', help='checkpoint directory (models and logs are saved here)')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--epoch_step', type=int, default=10000, help='epochs after which lr is decayed')
parser.add_argument('--max_epochs', type=int, default=10000, help='total epochs to run')
parser.add_argument('--lr_decay_ratio', type=float, default=0.1, help='lr *= lr_decay_ratio after epoch_steps')
parser.add_argument('--parallel', action ='store_true', default=False, help='use multiple GPUs for training')
parser.add_argument('--cl_step', type=int, default=1, help='steps for curriculum training')
# parser.add_argument('--joint', action ='store_true', default=True, help='train both the policy network and the resnet')
parser.add_argument('--penalty', type=float, default=-1, help='gamma: reward for incorrect predictions')
parser.add_argument('--alpha', type=float, default=0.8, help='probability bounding factor')
args = parser.parse_args()

if not os.path.exists(args.cv_dir):
    os.system('mkdir ' + args.cv_dir)
utils.save_args(__file__, args)

def get_reward(preds, targets, policy):

    block_use = policy.sum(1).float()/policy.size(1)
    sparse_reward = 1.0-block_use**2

    _, pred_idx = preds.max(1)
    match = (pred_idx==targets).data

    reward = sparse_reward
    reward[1-match] = args.penalty
    reward = reward.unsqueeze(1)

    return reward, match.float()


def train(epoch):

    agent.train()

    matches, rewards, policies = [], [], []
    for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(trainloader), total=len(trainloader)):

        inputs, targets = Variable(inputs), Variable(targets).cuda(async=True)
        if not args.parallel:
            inputs = inputs.cuda()

        probs, value = agent(inputs)

        #---------------------------------------------------------------------#

        policy_map = probs.data.clone()
        policy_map[policy_map<0.5] = 0.0
        policy_map[policy_map>=0.5] = 1.0
        policy_map = Variable(policy_map)

        probs = probs*args.alpha + (1-probs)*(1-args.alpha)
        distr = Bernoulli(probs)
        policy = distr.sample()

        if args.cl_step < num_blocks:
            policy[:, :-args.cl_step] = 1
            policy_map[:, :-args.cl_step] = 1

            policy_mask = Variable(torch.ones(inputs.size(0), policy.size(1))).cuda()
            policy_mask[:, :-args.cl_step] = 0
        else:
            policy_mask = None

        v_inputs = Variable(inputs.data, volatile=True)
        preds_map = rnet.forward(v_inputs, policy_map)
        preds_sample = rnet.forward(v_inputs, policy)

        reward_map, _ = get_reward(preds_map, targets, policy_map.data)
        reward_sample, match = get_reward(preds_sample, targets, policy.data)

        advantage = reward_sample - reward_map

        loss = -distr.log_prob(policy)
        loss = loss * Variable(advantage).expand_as(policy)

        if policy_mask is not None:
            loss = policy_mask * loss # mask for curriculum learning

        loss = loss.sum()

        probs = probs.clamp(1e-15, 1-1e-15)
        entropy_loss = -probs*torch.log(probs)
        entropy_loss = args.beta*entropy_loss.sum()

        loss = (loss - entropy_loss)/inputs.size(0)

        #---------------------------------------------------------------------#

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        matches.append(match.cpu())
        rewards.append(reward_sample.cpu())
        policies.append(policy.data.cpu())

    accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches)

    log_str = 'E: %d | A: %.3f | R: %.2E | S: %.3f | V: %.3f | #: %d'%(epoch, accuracy, reward, sparsity, variance, len(policy_set))
    print log_str

    log_value('train_accuracy', accuracy, epoch)
    log_value('train_reward', reward, epoch)
    log_value('train_sparsity', sparsity, epoch)
    log_value('train_variance', variance, epoch)
    log_value('train_unique_policies', len(policy_set), epoch)


def test(epoch):

    agent.eval()

    matches, rewards, policies = [], [], []
    for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(testloader), total=len(testloader)):

        inputs, targets = Variable(inputs, volatile=True), Variable(targets).cuda(async=True)
        if not args.parallel:
            inputs = inputs.cuda()

        probs, _ = agent(inputs)

        policy = probs.data.clone()
        policy[policy<0.5] = 0.0
        policy[policy>=0.5] = 1.0
        policy = Variable(policy)

        if args.cl_step < num_blocks:
            policy[:, :-args.cl_step] = 1

        preds = rnet.forward(inputs, policy)
        reward, match = get_reward(preds, targets, policy.data)

        matches.append(match)
        rewards.append(reward)
        policies.append(policy.data)

    accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches)

    log_str = 'TS - A: %.3f | R: %.2E | S: %.3f | V: %.3f | #: %d'%(accuracy, reward, sparsity, variance, len(policy_set))
    print log_str

    log_value('test_accuracy', accuracy, epoch)
    log_value('test_reward', reward, epoch)
    log_value('test_sparsity', sparsity, epoch)
    log_value('test_variance', variance, epoch)
    log_value('test_unique_policies', len(policy_set), epoch)

    # save the model
    agent_state_dict = agent.module.state_dict() if args.parallel else agent.state_dict()

    state = {
      'agent': agent_state_dict,
      'epoch': epoch,
      'reward': reward,
      'acc': accuracy
    }
    torch.save(state, args.cv_dir+'/ckpt_E_%d_A_%.3f_R_%.2E_S_%.2f_#_%d.t7'%(epoch, accuracy, reward, sparsity, len(policy_set)))


#--------------------------------------------------------------------------------------------------------#
trainset, testset = utils.get_dataset(args.model, args.data_dir)
trainloader = torchdata.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4)
testloader = torchdata.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=4)
rnet, agent = utils.get_model(args.model)
num_blocks = sum(rnet.layer_config)

start_epoch = 0
if args.load is not None:
    checkpoint = torch.load(args.load)
    agent.load_state_dict(checkpoint['agent'])
    start_epoch = checkpoint['epoch'] + 1
    print 'loaded agent from', args.load

if args.parallel:
    agent = nn.DataParallel(agent)
    rnet = nn.DataParallel(rnet)

rnet.eval().cuda()
agent.cuda()

optimizer = optim.Adam(agent.parameters(), lr=args.lr, weight_decay=args.wd)

configure(args.cv_dir+'/log', flush_secs=5)
lr_scheduler = utils.LrScheduler(optimizer, args.lr, args.lr_decay_ratio, args.epoch_step)
for epoch in range(start_epoch, start_epoch+args.max_epochs+1):
    lr_scheduler.adjust_learning_rate(epoch)

    if args.cl_step < num_blocks:
        args.cl_step = 1 + 1 * (epoch // 1)
    else:
        args.cl_step = num_blocks

    print 'training the last %d blocks ...' % args.cl_step
    train(epoch)

    if epoch % 10 == 0:
        test(epoch)


================================================
FILE: finetune.py
================================================
import os
from tensorboard_logger import configure, log_value
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.utils.data as torchdata
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import tqdm
import utils
import torch.optim as optim
from torch.distributions import Bernoulli

import torch.backends.cudnn as cudnn
cudnn.benchmark = True


import argparse
parser = argparse.ArgumentParser(description='BlockDrop Training')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--wd', type=float, default=0.0, help='weight decay')
parser.add_argument('--model', default='R110_C10', help='R<depth>_<dataset> see utils.py for a list of configurations')
parser.add_argument('--data_dir', default='data/', help='data directory')
parser.add_argument('--load', default=None, help='checkpoint to load rnet+agent from')
parser.add_argument('--pretrained', default=None, help='pretrained policy model checkpoint (from curriculum training)')
parser.add_argument('--cv_dir', default='cv/tmp/', help='checkpoint directory (models and logs are saved here)')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--epoch_step', type=int, default=1600, help='epochs after which lr is decayed')
parser.add_argument('--max_epochs', type=int, default=2000, help='total epochs to run')
parser.add_argument('--lr_decay_ratio', type=float, default=0.1, help='lr *= lr_decay_ratio after epoch_steps')
parser.add_argument('--parallel', action ='store_true', default=False, help='use multiple GPUs for training')
# parser.add_argument('--joint', action ='store_true', default=True, help='train both the policy network and the resnet')
parser.add_argument('--penalty', type=float, default=-5, help='gamma: reward for incorrect predictions')
parser.add_argument('--alpha', type=float, default=0.8, help='probability bounding factor')
args = parser.parse_args()

if not os.path.exists(args.cv_dir):
    os.system('mkdir ' + args.cv_dir)
utils.save_args(__file__, args)

def get_reward(preds, targets, policy):

    block_use = policy.sum(1).float()/policy.size(1)
    sparse_reward = 1.0-block_use**2

    _, pred_idx = preds.max(1)
    match = (pred_idx==targets).data

    reward = sparse_reward
    reward[1-match] = args.penalty
    reward = reward.unsqueeze(1)

    return reward, match.float()


def train(epoch):

    agent.train()
    rnet.train()

    matches, rewards, policies = [], [], []
    for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(trainloader), total=len(trainloader)):

        inputs, targets = Variable(inputs), Variable(targets).cuda(async=True)
        if not args.parallel:
            inputs = inputs.cuda()

        probs, value = agent(inputs)

        #---------------------------------------------------------------------#

        policy_map = probs.data.clone()
        policy_map[policy_map<0.5] = 0.0
        policy_map[policy_map>=0.5] = 1.0
        policy_map = Variable(policy_map)

        probs = probs*args.alpha + (1-probs)*(1-args.alpha)
        distr = Bernoulli(probs)
        policy = distr.sample()

        v_inputs = Variable(inputs.data, volatile=True)
        preds_map = rnet.forward(v_inputs, policy_map)
        preds_sample = rnet.forward(inputs, policy)

        reward_map, _ = get_reward(preds_map, targets, policy_map.data)
        reward_sample, match = get_reward(preds_sample, targets, policy.data)

        advantage = reward_sample - reward_map
        # advantage = advantage.expand_as(policy)
        loss = -distr.log_prob(policy).sum(1, keepdim=True) * Variable(advantage)
        loss = loss.sum()

        #---------------------------------------------------------------------#
        loss += F.cross_entropy(preds_sample, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        matches.append(match.cpu())
        rewards.append(reward_sample.cpu())
        policies.append(policy.data.cpu())

    accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches)

    log_str = 'E: %d | A: %.3f | R: %.2E | S: %.3f | V: %.3f | #: %d'%(epoch, accuracy, reward, sparsity, variance, len(policy_set))
    print log_str

    log_value('train_accuracy', accuracy, epoch)
    log_value('train_reward', reward, epoch)
    log_value('train_sparsity', sparsity, epoch)
    log_value('train_variance', variance, epoch)
    log_value('train_unique_policies', len(policy_set), epoch)


def test(epoch):

    agent.eval()
    rnet.eval()

    matches, rewards, policies = [], [], []
    for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(testloader), total=len(testloader)):

        inputs, targets = Variable(inputs, volatile=True), Variable(targets).cuda(async=True)
        if not args.parallel:
            inputs = inputs.cuda()

        probs, _ = agent(inputs)

        policy = probs.data.clone()
        policy[policy<0.5] = 0.0
        policy[policy>=0.5] = 1.0
        policy = Variable(policy)

        preds = rnet.forward(inputs, policy)
        reward, match = get_reward(preds, targets, policy.data)

        matches.append(match)
        rewards.append(reward)
        policies.append(policy.data)

    accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches)

    log_str = 'TS - A: %.3f | R: %.2E | S: %.3f | V: %.3f | #: %d'%(accuracy, reward, sparsity, variance, len(policy_set))
    print log_str

    log_value('test_accuracy', accuracy, epoch)
    log_value('test_reward', reward, epoch)
    log_value('test_sparsity', sparsity, epoch)
    log_value('test_variance', variance, epoch)
    log_value('test_unique_policies', len(policy_set), epoch)

    # save the model
    agent_state_dict = agent.module.state_dict() if args.parallel else agent.state_dict()
    rnet_state_dict = rnet.module.state_dict() if args.parallel else rnet.state_dict()

    state = {
      'agent': agent_state_dict,
      'resnet': rnet_state_dict,
      'epoch': epoch,
      'reward': reward,
      'acc': accuracy
    }
    torch.save(state, args.cv_dir+'/ckpt_E_%d_A_%.3f_R_%.2E_S_%.2f_#_%d.t7'%(epoch, accuracy, reward, sparsity, len(policy_set)))


#--------------------------------------------------------------------------------------------------------#
trainset, testset = utils.get_dataset(args.model, args.data_dir)
trainloader = torchdata.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4)
testloader = torchdata.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=4)
rnet, agent = utils.get_model(args.model)

if args.pretrained is not None:
    checkpoint = torch.load(args.pretrained)
    key = 'net' if 'net' in checkpoint else 'agent'
    agent.load_state_dict(checkpoint[key])
    print 'loaded pretrained model from', args.pretrained

start_epoch = 0
if args.load is not None:
    checkpoint = torch.load(args.load)
    rnet.load_state_dict(checkpoint['resnet'])
    agent.load_state_dict(checkpoint['agent'])
    start_epoch = checkpoint['epoch'] + 1
    print 'loaded agent from', args.load


if args.parallel:
    agent = nn.DataParallel(agent)
    rnet = nn.DataParallel(rnet)

rnet.cuda()
agent.cuda()

optimizer = optim.Adam(list(agent.parameters())+list(rnet.parameters()), lr=args.lr, weight_decay=args.wd)

configure(args.cv_dir+'/log', flush_secs=5)
lr_scheduler = utils.LrScheduler(optimizer, args.lr, args.lr_decay_ratio, args.epoch_step)
for epoch in range(start_epoch, start_epoch+args.max_epochs+1):
    lr_scheduler.adjust_learning_rate(epoch)

    train(epoch)
    if epoch%10==0:
        test(epoch)


================================================
FILE: models/__init__.py
================================================


================================================
FILE: models/base.py
================================================
import torch.nn as nn
import math
import torch
import torchvision.models as torchmodels
import re
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import torch.nn.utils as torchutils
from torch.nn import init, Parameter


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
    def forward(self, x):
        return x

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    def forward(self, x):
        return x.view(x.size(0), -1)

def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

    def forward(self, x):
       
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        return out

class DownsampleB(nn.Module):

    def __init__(self, nIn, nOut, stride):
        super(DownsampleB, self).__init__()
        self.avg = nn.AvgPool2d(stride)
        self.expand_ratio = nOut // nIn

    def forward(self, x):
        x = self.avg(x)
        return torch.cat([x] + [x.mul(0)] * (self.expand_ratio - 1), 1)





================================================
FILE: models/resnet.py
================================================
import torch.nn as nn
import math
import torch
import torchvision.models as torchmodels
import re
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import random
import torch.nn.init as torchinit
import math
from torch.nn import init, Parameter
import copy

from models import base
import utils

#--------------------------------------------------------------------------------------------------#
class FlatResNet(nn.Module):

    def seed(self, x):
        # x = self.relu(self.bn1(self.conv1(x))) -- CIFAR
        # x = self.maxpool(self.relu(self.bn1(self.conv1(x)))) -- ImageNet
        raise NotImplementedError

    # run a variable policy batch through the resnet implemented as a full mask over the residual
    # fast to train, non-indicative of time saving (use forward_single instead)
    def forward(self, x, policy):

        x = self.seed(x)

        t = 0
        for segment, num_blocks in enumerate(self.layer_config):
            for b in range(num_blocks):
                action = policy[:,t].contiguous()
                residual = self.ds[segment](x) if b==0 else x

                # early termination if all actions in the batch are zero
                if action.data.sum() == 0:
                    x = residual
                    t += 1
                    continue

                action_mask = action.float().view(-1,1,1,1)
                fx = F.relu(residual + self.blocks[segment][b](x))
                x = fx*action_mask + residual*(1-action_mask)
                t += 1

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    # run a single, fixed policy for all items in the batch
    # policy is a (15,) vector. Use with batch_size=1 for profiling
    def forward_single(self, x, policy):
        x = self.seed(x)

        t = 0
        for segment, num_blocks in enumerate(self.layer_config):
           for b in range(num_blocks):
                residual = self.ds[segment](x) if b==0 else x
                if policy[t]==1:
                    x = residual + self.blocks[segment][b](x)
                    x = F.relu(x)
                else:
                    x = residual
                t += 1

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


    def forward_full(self, x):
        x = self.seed(x)

        for segment, num_blocks in enumerate(self.layer_config):
            for b in range(num_blocks):
                residual = self.ds[segment](x) if b==0 else x
                x = F.relu(residual + self.blocks[segment][b](x))

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x



# Smaller Flattened Resnet, tailored for CIFAR
class FlatResNet32(FlatResNet):

    def __init__(self, block, layers, num_classes=10):
        super(FlatResNet32, self).__init__()

        self.inplanes = 16
        self.conv1 = base.conv3x3(3, 16)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(8)

        strides = [1, 2, 2]
        filt_sizes = [16, 32, 64]
        self.blocks, self.ds = [], []
        for idx, (filt_size, num_blocks, stride) in enumerate(zip(filt_sizes, layers, strides)):
            blocks, ds = self._make_layer(block, filt_size, num_blocks, stride=stride)
            self.blocks.append(nn.ModuleList(blocks))
            self.ds.append(ds)

        self.blocks = nn.ModuleList(self.blocks)
        self.ds = nn.ModuleList(self.ds)
        self.fc = nn.Linear(64 * block.expansion, num_classes)
        self.fc_dim = 64 * block.expansion

        self.layer_config = layers

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def seed(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        return x

    def _make_layer(self, block, planes, blocks, stride=1):

        downsample = nn.Sequential()
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = base.DownsampleB(self.inplanes, planes * block.expansion, stride)

        layers = [block(self.inplanes, planes, stride)]
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, 1))

        return layers, downsample


# Regular Flattened Resnet, tailored for Imagenet etc.
class FlatResNet224(FlatResNet):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(FlatResNet224, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        strides = [1, 2, 2, 2]
        filt_sizes = [64, 128, 256, 512]
        self.blocks, self.ds = [], []
        for idx, (filt_size, num_blocks, stride) in enumerate(zip(filt_sizes, layers, strides)):
            blocks, ds = self._make_layer(block, filt_size, num_blocks, stride=stride)
            self.blocks.append(nn.ModuleList(blocks))
            self.ds.append(ds)

        self.blocks = nn.ModuleList(self.blocks)
        self.ds = nn.ModuleList(self.ds)

        self.avgpool = nn.AvgPool2d(7)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        self.layer_config = layers

    def seed(self, x):
        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
        return x

    def _make_layer(self, block, planes, blocks, stride=1):

        downsample = nn.Sequential()
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = [block(self.inplanes, planes, stride)]
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return layers, downsample


#---------------------------------------------------------------------------------------------------------#

# Class to generate resnetNB or any other config (default is 3B)
# removed the fc layer so it serves as a feature extractor
class Policy32(nn.Module):

    def __init__(self, layer_config=[1,1,1], num_blocks=15):
        super(Policy32, self).__init__()
        self.features = FlatResNet32(base.BasicBlock, layer_config, num_classes=10)
        self.feat_dim = self.features.fc.weight.data.shape[1]
        self.features.fc = nn.Sequential()

        self.logit = nn.Linear(self.feat_dim, num_blocks)
        self.vnet = nn.Linear(self.feat_dim, 1)

    def load_state_dict(self, state_dict):
        # support legacy models
        state_dict = {k:v for k,v in state_dict.items() if not k.startswith('features.fc')}
        return super(Policy32, self).load_state_dict(state_dict)


    def forward(self, x):
        x = self.features.forward_full(x)
        value = self.vnet(x)
        probs = F.sigmoid(self.logit(x))
        return probs, value


class Policy224(nn.Module):

    def __init__(self, layer_config=[1,1,1,1], num_blocks=16):
        super(Policy224, self).__init__()
        self.features = FlatResNet224(base.BasicBlock, layer_config, num_classes=1000)

        resnet18 = torchmodels.resnet18(pretrained=True)
        utils.load_weights_to_flatresnet(resnet18, self.features)

        self.features.avgpool = nn.AvgPool2d(4)
        self.feat_dim = self.features.fc.weight.data.shape[1]
        self.features.fc = nn.Sequential()


        self.logit = nn.Linear(self.feat_dim, num_blocks)
        self.vnet = nn.Linear(self.feat_dim, 1)


    def load_state_dict(self, state_dict):
        # support legacy models
        state_dict = {k:v for k,v in state_dict.items() if not k.startswith('features.fc')}
        return super(Policy224, self).load_state_dict(state_dict)

    def forward(self, x):
        x = F.avg_pool2d(x, 2)
        x = self.features.forward_full(x)
        value = self.vnet(x)
        probs = F.sigmoid(self.logit(x))
        return probs, value

#--------------------------------------------------------------------------------------------------------#

class StepResnet32(FlatResNet32):

    def __init__(self, block, layers, num_classes, joint=False):
        super(StepResnet, self).__init__(block, layers, num_classes)
        self.eval() # default to eval mode

        self.joint = joint

        self.state_ptr = {}
        t = 0
        for segment, num_blocks in enumerate(self.layer_config):
            for b in range(num_blocks):
                self.state_ptr[t] = (segment, b)
                t += 1

    def seed(self, x):
        self.state = self.relu(self.bn1(self.conv1(x)))
        self.t = 0

        if self.joint:
            return self.state
        return Variable(self.state.data)

    def step(self, action):
        segment, b = self.state_ptr[self.t]
        residual = self.ds[segment](self.state) if b==0 else self.state
        action_mask = action.float().view(-1,1,1,1)

        fx = F.relu(residual + self.blocks[segment][b](self.state))
        self.state = fx*action_mask + residual*(1-action_mask)
        self.t += 1

        if self.joint:
            return self.state
        return Variable(self.state.data)


    def step_single(self, action):
        segment, b = self.state_ptr[self.t]
        residual = self.ds[segment](self.state) if b==0 else self.state

        if action.data[0,0]==1:
            self.state = F.relu(residual + self.blocks[segment][b](self.state))
        else:
            self.state = residual

        self.t += 1

        if self.joint:
            return self.state
        return Variable(self.state.data)

    def predict(self):
        x = self.avgpool(self.state)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


class StepPolicy32(nn.Module):

    def __init__(self, layer_config):
        super(StepPolicy, self).__init__()
        in_dim = [16] + [16]*layer_config[0] + [32]*layer_config[1] + [64]*(layer_config[2]-1)
        self.pnet = nn.ModuleList([nn.Linear(dim, 2) for dim in in_dim])
        self.vnet = nn.ModuleList([nn.Linear(dim, 1) for dim in in_dim])

    def forward(self, state):
        x, t = state
        x = F.avg_pool2d(x, x.size(2)).view(x.size(0), -1) # pool + flatten --> (B, 16/32/64)
        logit = F.softmax(self.pnet[t](x))
        value = self.vnet[t](x)
        return logit, value


================================================
FILE: requirements.txt
================================================
tqdm
numpy
tensorboard_logger
argparse


================================================
FILE: test.py
================================================
import torch
from torch.autograd import Variable
import torch.utils.data as torchdata
import torch.nn as nn
import numpy as np
import tqdm
import utils

import torch.backends.cudnn as cudnn
cudnn.benchmark = True

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='R110_C10')
parser.add_argument('--data_dir', default='data/')
parser.add_argument('--load', default=None)
args = parser.parse_args()

#---------------------------------------------------------------------------------------------#
class FConv2d(nn.Conv2d):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(FConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                 padding, dilation, groups, bias)
        self.num_ops = 0

    def forward(self, x):
        output = super(FConv2d, self).forward(x)
        output_area = output.size(-1)*output.size(-2)
        filter_area = np.prod(self.kernel_size)
        self.num_ops += 2*self.in_channels*self.out_channels*filter_area*output_area
        return output

class FLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(FLinear, self).__init__(in_features, out_features, bias)
        self.num_ops = 0

    def forward(self, x):
        output = super(FLinear, self).forward(x)
        self.num_ops += 2*self.in_features*self.out_features
        return output

def count_flops(model, reset=True):
    op_count = 0
    for m in model.modules():
        if hasattr(m, 'num_ops'):
            op_count += m.num_ops
            if reset: # count and reset to 0
                m.num_ops = 0

    return op_count

# replace all nn.Conv and nn.Linear layers with layers that count flops
nn.Conv2d = FConv2d
nn.Linear = FLinear

#--------------------------------------------------------------------------------------------#

def test():

    total_ops = []
    matches, policies = [], []
    for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(testloader), total=len(testloader)):

        inputs, targets = Variable(inputs, volatile=True).cuda(), Variable(targets).cuda()
        probs, _ = agent(inputs)

        policy = probs.clone()
        policy[policy<0.5] = 0.0
        policy[policy>=0.5] = 1.0

        preds = rnet.forward_single(inputs, policy.data.squeeze(0))
        _ , pred_idx = preds.max(1)
        match = (pred_idx==targets).data.float()

        matches.append(match)
        policies.append(policy.data)

        ops = count_flops(agent) + count_flops(rnet)
        total_ops.append(ops)

    accuracy, _, sparsity, variance, policy_set = utils.performance_stats(policies, matches, matches)
    ops_mean, ops_std = np.mean(total_ops), np.std(total_ops)

    log_str = u'''
    Accuracy: %.3f
    Block Usage: %.3f \u00B1 %.3f
    FLOPs/img: %.2E \u00B1 %.2E
    Unique Policies: %d
    '''%(accuracy, sparsity, variance, ops_mean, ops_std, len(policy_set))

    print log_str

#--------------------------------------------------------------------------------------------------------#
trainset, testset = utils.get_dataset(args.model, args.data_dir)
testloader = torchdata.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
rnet, agent = utils.get_model(args.model)

# if no model is loaded, use all blocks
agent.logit.weight.data.fill_(0)
agent.logit.bias.data.fill_(10)

print "loading checkpoints"

if args.load is not None:
    utils.load_checkpoint(rnet, agent, args.load)

rnet.eval().cuda()
agent.eval().cuda()

test()


================================================
FILE: utils.py
================================================
import os
import re
import torch
import torchvision.transforms as transforms
import torchvision.datasets as torchdata
import numpy as np

# Save the training script and all the arguments to a file so that you
# don't feel like an idiot later when you can't replicate results
import shutil
def save_args(__file__, args):
    shutil.copy(os.path.basename(__file__), args.cv_dir)
    with open(args.cv_dir+'/args.txt','w') as f:
        f.write(str(args))

def performance_stats(policies, rewards, matches):

    policies = torch.cat(policies, 0)
    rewards = torch.cat(rewards, 0)
    accuracy = torch.cat(matches, 0).mean()

    reward = rewards.mean()
    sparsity = policies.sum(1).mean()
    variance = policies.sum(1).std()

    policy_set = [p.cpu().numpy().astype(np.int).astype(np.str) for p in policies]
    policy_set = set([''.join(p) for p in policy_set])

    return accuracy, reward, sparsity, variance, policy_set

class LrScheduler:
    def __init__(self, optimizer, base_lr, lr_decay_ratio, epoch_step):
        self.base_lr = base_lr
        self.lr_decay_ratio = lr_decay_ratio
        self.epoch_step = epoch_step
        self.optimizer = optimizer

    def adjust_learning_rate(self, epoch):
        """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
        lr = self.base_lr * (self.lr_decay_ratio ** (epoch // self.epoch_step))
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
            if epoch%self.epoch_step==0:
                print '# setting learning_rate to %.2E'%lr


# load model weights trained using scripts from https://github.com/felixgwu/img_classification_pk_pytorch OR
# from torchvision models into our flattened resnets
def load_weights_to_flatresnet(source_model, target_model):

    # compatibility for nn.Modules + checkpoints
    if hasattr(source_model, 'state_dict'):
        source_model = {'state_dict': source_model.state_dict()}
    source_state = source_model['state_dict']
    target_state = target_model.state_dict()

    # remove the module. prefix if it exists (thanks nn.DataParallel)
    if source_state.keys()[0].startswith('module.'):
        source_state = {k[7:]:v for k,v in source_state.items()}


    common = set(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var','fc.weight', 'fc.bias'])
    for key in source_state.keys():

        if key in common:
            target_state[key] = source_state[key]
            continue

        if 'downsample' in key:
            layer, num, item = re.match('layer(\d+).*\.(\d+)\.(.*)', key).groups()
            translated = 'ds.%s.%s.%s'%(int(layer)-1, num, item)
        else:
            layer, item = re.match('layer(\d+)\.(.*)', key).groups()
            translated = 'blocks.%s.%s'%(int(layer)-1, item)


        if translated in target_state.keys():
            target_state[translated] = source_state[key]
        else:
            print translated, 'block missing'

    target_model.load_state_dict(target_state)
    return target_model

def load_checkpoint(rnet, agent, load):
    if load=='nil':
        return None

    checkpoint = torch.load(load)
    if 'resnet' in checkpoint:
        rnet.load_state_dict(checkpoint['resnet'])
        print 'loaded resnet from', os.path.basename(load)
    if 'agent' in checkpoint:
        agent.load_state_dict(checkpoint['agent'])
        print 'loaded agent from', os.path.basename(load)
    # backward compatibility (some old checkpoints)
    if 'net' in checkpoint:
        checkpoint['net'] = {k:v for k,v in checkpoint['net'].items() if 'features.fc' not in k}
        agent.load_state_dict(checkpoint['net'])
        print 'loaded agent from', os.path.basename(load)


def get_transforms(rnet, dset):

    # Only the R32 pretrained model subtracts the mean, sorry :(
    if dset=='C10' and rnet=='R32':
        mean = [x/255.0 for x in [125.3, 123.0, 113.9]]
        std = [x/255.0 for x in [63.0, 62.1, 66.7]]
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
            ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
            ])

    elif dset=='C100' or dset=='C10' and rnet!='R32':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            ])

    elif dset=='ImgNet':
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        transform_train = transforms.Compose([
            transforms.Scale(256),
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
            ])

        transform_test = transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
            ])


    return transform_train, transform_test

# Pick from the datasets available and the hundreds of models we have lying around depending on the requirements.
def get_dataset(model, root='data/'):

    rnet, dset = model.split('_')
    transform_train, transform_test = get_transforms(rnet, dset)

    if dset=='C10':
        trainset = torchdata.CIFAR10(root=root, train=True, download=True, transform=transform_train)
        testset = torchdata.CIFAR10(root=root, train=False, download=True, transform=transform_test)
    elif dset=='C100':
        trainset = torchdata.CIFAR100(root=root, train=True, download=True, transform=transform_train)
        testset = torchdata.CIFAR100(root=root, train=False, download=True, transform=transform_test)
    elif dset=='ImgNet':
        trainset = torchdata.ImageFolder(root+'/train/', transform_train)
        testset = torchdata.ImageFolder(root+'/val/', transform_test)

    return trainset, testset

# Make a new if statement for every new model variety you want to index
def get_model(model):

    from models import resnet, base

    if model=='R32_C10':
        rnet_checkpoint = 'cv/pretrained/R32_C10/pk_E_164_A_0.923.t7'
        layer_config = [5, 5, 5]
        rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=10)
        agent = resnet.Policy32([1,1,1], num_blocks=15)

    elif model=='R110_C10':
        rnet_checkpoint = 'cv/pretrained/R110_C10/pk_E_130_A_0.932.t7'
        layer_config = [18, 18, 18]
        rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=10)
        agent = resnet.Policy32([1,1,1], num_blocks=54)

    elif model=='R32_C100':
        rnet_checkpoint = 'cv/pretrained/R32_C100/pk_E_164_A_0.693.t7'
        layer_config = [5, 5, 5]
        rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=100)
        agent = resnet.Policy32([1,1,1], num_blocks=15)

    elif model=='R110_C100':
        rnet_checkpoint = 'cv/pretrained/R110_C100/pk_E_160_A_0.723.t7'
        layer_config = [18, 18, 18]
        rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=100)
        agent = resnet.Policy32([1,1,1], num_blocks=54)

    elif model=='R101_ImgNet':
        rnet_checkpoint = 'cv/pretrained/R101_ImgNet/ImageNet_R101_224_76.464'
        layer_config = [3,4,23,3]
        rnet = resnet.FlatResNet224(base.Bottleneck, layer_config, num_classes=1000)
        agent = resnet.Policy224([1,1,1,1], num_blocks=33)

    # load pretrained weights into flat ResNet
    rnet_checkpoint = torch.load(rnet_checkpoint)
    load_weights_to_flatresnet(rnet_checkpoint, rnet)

    return rnet, agent
Download .txt
gitextract_y_xp5prk/

├── .gitignore
├── Readme.md
├── cl_training.py
├── finetune.py
├── models/
│   ├── __init__.py
│   ├── base.py
│   └── resnet.py
├── requirements.txt
├── test.py
└── utils.py
Download .txt
SYMBOL INDEX (70 symbols across 6 files)

FILE: cl_training.py
  function get_reward (line 43) | def get_reward(preds, targets, policy):
  function train (line 58) | def train(epoch):
  function test (line 136) | def test(epoch):

FILE: finetune.py
  function get_reward (line 42) | def get_reward(preds, targets, policy):
  function train (line 57) | def train(epoch):
  function test (line 117) | def test(epoch):

FILE: models/base.py
  class Identity (line 15) | class Identity(nn.Module):
    method __init__ (line 16) | def __init__(self):
    method forward (line 18) | def forward(self, x):
  class Flatten (line 21) | class Flatten(nn.Module):
    method __init__ (line 22) | def __init__(self):
    method forward (line 24) | def forward(self, x):
  function conv3x3 (line 27) | def conv3x3(in_planes, out_planes, stride=1):
  class BasicBlock (line 31) | class BasicBlock(nn.Module):
    method __init__ (line 34) | def __init__(self, inplanes, planes, stride=1):
    method forward (line 41) | def forward(self, x):
  class Bottleneck (line 52) | class Bottleneck(nn.Module):
    method __init__ (line 55) | def __init__(self, inplanes, planes, stride=1):
    method forward (line 65) | def forward(self, x):
  class DownsampleB (line 81) | class DownsampleB(nn.Module):
    method __init__ (line 83) | def __init__(self, nIn, nOut, stride):
    method forward (line 88) | def forward(self, x):

FILE: models/resnet.py
  class FlatResNet (line 19) | class FlatResNet(nn.Module):
    method seed (line 21) | def seed(self, x):
    method forward (line 28) | def forward(self, x, policy):
    method forward_single (line 56) | def forward_single(self, x, policy):
    method forward_full (line 76) | def forward_full(self, x):
  class FlatResNet32 (line 92) | class FlatResNet32(FlatResNet):
    method __init__ (line 94) | def __init__(self, block, layers, num_classes=10):
    method seed (line 126) | def seed(self, x):
    method _make_layer (line 130) | def _make_layer(self, block, planes, blocks, stride=1):
  class FlatResNet224 (line 145) | class FlatResNet224(FlatResNet):
    method __init__ (line 147) | def __init__(self, block, layers, num_classes=1000):
    method seed (line 179) | def seed(self, x):
    method _make_layer (line 183) | def _make_layer(self, block, planes, blocks, stride=1):
  class Policy32 (line 205) | class Policy32(nn.Module):
    method __init__ (line 207) | def __init__(self, layer_config=[1,1,1], num_blocks=15):
    method load_state_dict (line 216) | def load_state_dict(self, state_dict):
    method forward (line 222) | def forward(self, x):
  class Policy224 (line 229) | class Policy224(nn.Module):
    method __init__ (line 231) | def __init__(self, layer_config=[1,1,1,1], num_blocks=16):
    method load_state_dict (line 247) | def load_state_dict(self, state_dict):
    method forward (line 252) | def forward(self, x):
  class StepResnet32 (line 261) | class StepResnet32(FlatResNet32):
    method __init__ (line 263) | def __init__(self, block, layers, num_classes, joint=False):
    method seed (line 276) | def seed(self, x):
    method step (line 284) | def step(self, action):
    method step_single (line 298) | def step_single(self, action):
    method predict (line 313) | def predict(self):
  class StepPolicy32 (line 320) | class StepPolicy32(nn.Module):
    method __init__ (line 322) | def __init__(self, layer_config):
    method forward (line 328) | def forward(self, state):

FILE: test.py
  class FConv2d (line 20) | class FConv2d(nn.Conv2d):
    method __init__ (line 22) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
    method forward (line 28) | def forward(self, x):
  class FLinear (line 35) | class FLinear(nn.Linear):
    method __init__ (line 36) | def __init__(self, in_features, out_features, bias=True):
    method forward (line 40) | def forward(self, x):
  function count_flops (line 45) | def count_flops(model, reset=True):
  function test (line 61) | def test():

FILE: utils.py
  function save_args (line 11) | def save_args(__file__, args):
  function performance_stats (line 16) | def performance_stats(policies, rewards, matches):
  class LrScheduler (line 31) | class LrScheduler:
    method __init__ (line 32) | def __init__(self, optimizer, base_lr, lr_decay_ratio, epoch_step):
    method adjust_learning_rate (line 38) | def adjust_learning_rate(self, epoch):
  function load_weights_to_flatresnet (line 49) | def load_weights_to_flatresnet(source_model, target_model):
  function load_checkpoint (line 85) | def load_checkpoint(rnet, agent, load):
  function get_transforms (line 103) | def get_transforms(rnet, dset):
  function get_dataset (line 154) | def get_dataset(model, root='data/'):
  function get_model (line 172) | def get_model(model):
Condensed preview — 10 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (50K chars).
[
  {
    "path": ".gitignore",
    "chars": 57,
    "preview": "*.*~\n*.pyc\n*.pkl\n*.h5\n*.t7\n*.t7*\n*.log\n*.png\n\ndata/\ncv/\n\n"
  },
  {
    "path": "Readme.md",
    "chars": 5517,
    "preview": "# BlockDrop: Dynamic Inference Paths in Residual Networks\n![BlockDrop Model](https://user-images.githubusercontent.com/4"
  },
  {
    "path": "cl_training.py",
    "chars": 8371,
    "preview": "import os\r\nfrom tensorboard_logger import configure, log_value\r\nimport torch\r\nimport torch.autograd as autograd\r\nfrom to"
  },
  {
    "path": "finetune.py",
    "chars": 7926,
    "preview": "import os\r\nfrom tensorboard_logger import configure, log_value\r\nimport torch\r\nimport torch.autograd as autograd\r\nfrom to"
  },
  {
    "path": "models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "models/base.py",
    "chars": 2583,
    "preview": "import torch.nn as nn\r\nimport math\r\nimport torch\r\nimport torchvision.models as torchmodels\r\nimport re\r\nfrom torch.autogr"
  },
  {
    "path": "models/resnet.py",
    "chars": 11580,
    "preview": "import torch.nn as nn\r\nimport math\r\nimport torch\r\nimport torchvision.models as torchmodels\r\nimport re\r\nfrom torch.autogr"
  },
  {
    "path": "requirements.txt",
    "chars": 39,
    "preview": "tqdm\nnumpy\ntensorboard_logger\nargparse\n"
  },
  {
    "path": "test.py",
    "chars": 3698,
    "preview": "import torch\r\nfrom torch.autograd import Variable\r\nimport torch.utils.data as torchdata\r\nimport torch.nn as nn\r\nimport n"
  },
  {
    "path": "utils.py",
    "chars": 7915,
    "preview": "import os\nimport re\nimport torch\nimport torchvision.transforms as transforms\nimport torchvision.datasets as torchdata\nim"
  }
]

About this extraction

This page contains the full source code of the Tushar-N/blockdrop GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 10 files (46.6 KB), approximately 12.0k tokens, and a symbol index with 70 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!