Repository: Tushar-N/blockdrop
Branch: master
Commit: ec52b36d38dc
Files: 10
Total size: 46.6 KB
Directory structure:
gitextract_y_xp5prk/
├── .gitignore
├── Readme.md
├── cl_training.py
├── finetune.py
├── models/
│ ├── __init__.py
│ ├── base.py
│ └── resnet.py
├── requirements.txt
├── test.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.*~
*.pyc
*.pkl
*.h5
*.t7
*.t7*
*.log
*.png
data/
cv/
================================================
FILE: Readme.md
================================================
# BlockDrop: Dynamic Inference Paths in Residual Networks

This code implements a policy network that learns to dynamically choose which blocks of a ResNet to execute during inference so as to best reduce total computation without degrading prediction accuracy. Built upon a ResNet-101 model, our method achieves a speedup of 20% on average, going as high as 36% for some images, while maintaining the same 76.4% top-1 accuracy on ImageNet.
This is the code accompanying the work:
Zuxuan Wu*, Tushar Nagarajan*, Abhishek Kumar, Steven Rennie, Larry S. Davis, Kristen Grauman, and Rogerio Feris. BlockDrop: Dynamic Inference Paths in Residual Networks [[arxiv]](https://arxiv.org/pdf/1711.08393.pdf)
(* authors contributed equally)
## Prerequisites
The code is written and tested using Python (2.7) and PyTorch (v0.3.0).
**Packages**: Install using `pip install -r requirements.txt`
**Pretrained models**: Our models require standard pretrained ResNets on CIFAR and ImageNet as starting points. These can be trained using [this](https://github.com/felixgwu/img_classification_pk_pytorch) repository, or can be obtained directly from us
```bash
wget -O blockdrop-checkpoints.tar.gz https://utexas.box.com/shared/static/ok98i51v14c0q9lvs1z5g71m6b3zm8sj.gz
tar -zxvf blockdrop-checkpoints.tar.gz
```
The downloaded checkpoints will be unpacked to `./cv/` for further use. The folder also contains various checkpoints from each stage of training.
**Datasets**: PyTorch's *torchvision* package automatically downloads CIFAR10 and CIFAR100 during training. ImageNet must be downloaded and organized following [these steps](https://github.com/soumith/imagenet-multiGPU.torch#data-processing).
## Training a model
Training occurs in two steps (1) Curriculum Learning and (2) Joint Finetuning.
Models operating on ResNets of different depths can be trained on different datasets using the same script. Examples of how to train these models are given below. Checkpoints and tensorboard log files will be saved to folder specified in `--cv_dir`
#### Curriculum Learning
The policy network can be trained using a CL schedule as follows.
```bash
# Train a model on CIFAR 10 built upon a ResNet-110
python cl_training.py --model R110_C10 --cv_dir cv/R110_C10_cl/ --lr 1e-3 --batch_size 2048 --max_epochs 5000
# Train a model on ImageNet built upon a ResNet-101
python cl_training.py --model R101_ImgNet --cv_dir cv/R101_ImgNet_cl/ --lr 1e-3 --batch_size 2048 --max_epochs 45 --data_dir data/imagenet/
```
Model checkpoints after the curriculum learning step can be found in the downloaded folder. For example: `./cv/cl_learning/R110_C10/ckpt_E_5300_A_0.754_R_2.22E-01_S_20.10_#_7787.t7`
#### Joint Finetuning
Checkpoints trained during the curriculum learning phase can be used to further jointly finetune the base ResNet to achieve the results reported in the paper. Different values for the penalty parameter control the trade-off between accuracy and speed.
```bash
# Finetune a ResNet-110 on CIFAR 10 using the checkpoint from cl_training
python finetune.py --model R110_C10 --lr 1e-4 --penalty -10 --pretrained cv/cl_training/R110_C10/ckpt_E_5300_A_0.754_R_2.22E-01_S_20.10_#_7787.t7 --batch_size 256 --max_epochs 2000 --cv_dir cv/R110_C10_ft_-10/
# Finetune a ResNet-101 on ImageNet using the checkpoint from cl_training
python finetune.py --model R101_ImgNet --lr 1e-4 --penalty -5 --pretrained cv/cl_training/R101_ImgNet/ckpt_E_4_A_0.746_R_-3.70E-01_S_29.79_#_484.t7 --data_dir data/imagenet/ --batch_size 320 --max_epochs 10 --cv_dir cv/R101_ImgNet_ft_-5/
```
Model checkpoints after the joint finetuning step can be found in the downloaded folder. For example: `./cv/finetuned/R101_ImgNet_gamma_5/ckpt_E_10_A_0.764_R_-8.46E-01_S_24.77_#_10.t7`
## Testing and Profiling
Once jointly finetuned, models can be profiled for accuracy and FLOPs counts.
```bash
python test.py --model R110_C10 --load cv/finetuned/R110_C10_gamma_10/ckpt_E_2000_A_0.936_R_1.95E-01_S_16.93_#_469.t7
```
The model should produce an accuracy of 93.6% and use 1.81E+08 FLOPs on average. The output should look like this:
```
Accuracy: 0.936
Block Usage: 16.933 ± 3.717
FLOPs/img: 1.81E+08 ± 3.43E+07
Unique Policies: 469
```
The ImageNet model can be evaluated in a similar manner, and will generate a corresponding output.
```
python test.py --model R101_ImgNet --load cv/finetuned/R101_ImgNet_gamma_5/ckpt_E_10_A_0.764_R_-8.46E-01_S_24.77_#_10.t7
```
```
Accuracy: 0.764
Block Usage: 24.770 ± 0.980
FLOPs/img: 1.25E+10 ± 4.28E+08
Unique Policies: 10
```
## Visualization
Learned policies over ResNet blocks show that there is a clear separation between easy/hard images in terms of the number of blocks they require. In addition, unique policies over the blocks admit distinct image styles.

For more qualitative results, see Sec. 4.3 and Figures 4. and 5. in the paper.
## Cite
If you find this repository useful in your own research, please consider citing:
```
@inproceedings{blockdrop,
title={BlockDrop: Dynamic Inference Paths in Residual Networks},
author={Wu, Zuxuan and Nagarajan, Tushar and Kumar, Abhishek and Rennie, Steven and Davis, Larry S and Grauman, Kristen and Feris, Rogerio},
booktitle={CVPR},
year={2018}
}
```
================================================
FILE: cl_training.py
================================================
import os
from tensorboard_logger import configure, log_value
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.utils.data as torchdata
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import tqdm
import utils
import torch.optim as optim
from torch.distributions import Bernoulli
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
import argparse
parser = argparse.ArgumentParser(description='BlockDrop Training')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument('--beta', type=float, default=1e-1, help='entropy multiplier')
parser.add_argument('--wd', type=float, default=0.0, help='weight decay')
parser.add_argument('--model', default='R110_C10', help='R<depth>_<dataset> see utils.py for a list of configurations')
parser.add_argument('--data_dir', default='data/', help='data directory')
parser.add_argument('--load', default=None, help='checkpoint to load agent from')
parser.add_argument('--cv_dir', default='cv/tmp/', help='checkpoint directory (models and logs are saved here)')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--epoch_step', type=int, default=10000, help='epochs after which lr is decayed')
parser.add_argument('--max_epochs', type=int, default=10000, help='total epochs to run')
parser.add_argument('--lr_decay_ratio', type=float, default=0.1, help='lr *= lr_decay_ratio after epoch_steps')
parser.add_argument('--parallel', action ='store_true', default=False, help='use multiple GPUs for training')
parser.add_argument('--cl_step', type=int, default=1, help='steps for curriculum training')
# parser.add_argument('--joint', action ='store_true', default=True, help='train both the policy network and the resnet')
parser.add_argument('--penalty', type=float, default=-1, help='gamma: reward for incorrect predictions')
parser.add_argument('--alpha', type=float, default=0.8, help='probability bounding factor')
args = parser.parse_args()
if not os.path.exists(args.cv_dir):
os.system('mkdir ' + args.cv_dir)
utils.save_args(__file__, args)
def get_reward(preds, targets, policy):
block_use = policy.sum(1).float()/policy.size(1)
sparse_reward = 1.0-block_use**2
_, pred_idx = preds.max(1)
match = (pred_idx==targets).data
reward = sparse_reward
reward[1-match] = args.penalty
reward = reward.unsqueeze(1)
return reward, match.float()
def train(epoch):
agent.train()
matches, rewards, policies = [], [], []
for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(trainloader), total=len(trainloader)):
inputs, targets = Variable(inputs), Variable(targets).cuda(async=True)
if not args.parallel:
inputs = inputs.cuda()
probs, value = agent(inputs)
#---------------------------------------------------------------------#
policy_map = probs.data.clone()
policy_map[policy_map<0.5] = 0.0
policy_map[policy_map>=0.5] = 1.0
policy_map = Variable(policy_map)
probs = probs*args.alpha + (1-probs)*(1-args.alpha)
distr = Bernoulli(probs)
policy = distr.sample()
if args.cl_step < num_blocks:
policy[:, :-args.cl_step] = 1
policy_map[:, :-args.cl_step] = 1
policy_mask = Variable(torch.ones(inputs.size(0), policy.size(1))).cuda()
policy_mask[:, :-args.cl_step] = 0
else:
policy_mask = None
v_inputs = Variable(inputs.data, volatile=True)
preds_map = rnet.forward(v_inputs, policy_map)
preds_sample = rnet.forward(v_inputs, policy)
reward_map, _ = get_reward(preds_map, targets, policy_map.data)
reward_sample, match = get_reward(preds_sample, targets, policy.data)
advantage = reward_sample - reward_map
loss = -distr.log_prob(policy)
loss = loss * Variable(advantage).expand_as(policy)
if policy_mask is not None:
loss = policy_mask * loss # mask for curriculum learning
loss = loss.sum()
probs = probs.clamp(1e-15, 1-1e-15)
entropy_loss = -probs*torch.log(probs)
entropy_loss = args.beta*entropy_loss.sum()
loss = (loss - entropy_loss)/inputs.size(0)
#---------------------------------------------------------------------#
optimizer.zero_grad()
loss.backward()
optimizer.step()
matches.append(match.cpu())
rewards.append(reward_sample.cpu())
policies.append(policy.data.cpu())
accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches)
log_str = 'E: %d | A: %.3f | R: %.2E | S: %.3f | V: %.3f | #: %d'%(epoch, accuracy, reward, sparsity, variance, len(policy_set))
print log_str
log_value('train_accuracy', accuracy, epoch)
log_value('train_reward', reward, epoch)
log_value('train_sparsity', sparsity, epoch)
log_value('train_variance', variance, epoch)
log_value('train_unique_policies', len(policy_set), epoch)
def test(epoch):
agent.eval()
matches, rewards, policies = [], [], []
for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(testloader), total=len(testloader)):
inputs, targets = Variable(inputs, volatile=True), Variable(targets).cuda(async=True)
if not args.parallel:
inputs = inputs.cuda()
probs, _ = agent(inputs)
policy = probs.data.clone()
policy[policy<0.5] = 0.0
policy[policy>=0.5] = 1.0
policy = Variable(policy)
if args.cl_step < num_blocks:
policy[:, :-args.cl_step] = 1
preds = rnet.forward(inputs, policy)
reward, match = get_reward(preds, targets, policy.data)
matches.append(match)
rewards.append(reward)
policies.append(policy.data)
accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches)
log_str = 'TS - A: %.3f | R: %.2E | S: %.3f | V: %.3f | #: %d'%(accuracy, reward, sparsity, variance, len(policy_set))
print log_str
log_value('test_accuracy', accuracy, epoch)
log_value('test_reward', reward, epoch)
log_value('test_sparsity', sparsity, epoch)
log_value('test_variance', variance, epoch)
log_value('test_unique_policies', len(policy_set), epoch)
# save the model
agent_state_dict = agent.module.state_dict() if args.parallel else agent.state_dict()
state = {
'agent': agent_state_dict,
'epoch': epoch,
'reward': reward,
'acc': accuracy
}
torch.save(state, args.cv_dir+'/ckpt_E_%d_A_%.3f_R_%.2E_S_%.2f_#_%d.t7'%(epoch, accuracy, reward, sparsity, len(policy_set)))
#--------------------------------------------------------------------------------------------------------#
trainset, testset = utils.get_dataset(args.model, args.data_dir)
trainloader = torchdata.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4)
testloader = torchdata.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=4)
rnet, agent = utils.get_model(args.model)
num_blocks = sum(rnet.layer_config)
start_epoch = 0
if args.load is not None:
checkpoint = torch.load(args.load)
agent.load_state_dict(checkpoint['agent'])
start_epoch = checkpoint['epoch'] + 1
print 'loaded agent from', args.load
if args.parallel:
agent = nn.DataParallel(agent)
rnet = nn.DataParallel(rnet)
rnet.eval().cuda()
agent.cuda()
optimizer = optim.Adam(agent.parameters(), lr=args.lr, weight_decay=args.wd)
configure(args.cv_dir+'/log', flush_secs=5)
lr_scheduler = utils.LrScheduler(optimizer, args.lr, args.lr_decay_ratio, args.epoch_step)
for epoch in range(start_epoch, start_epoch+args.max_epochs+1):
lr_scheduler.adjust_learning_rate(epoch)
if args.cl_step < num_blocks:
args.cl_step = 1 + 1 * (epoch // 1)
else:
args.cl_step = num_blocks
print 'training the last %d blocks ...' % args.cl_step
train(epoch)
if epoch % 10 == 0:
test(epoch)
================================================
FILE: finetune.py
================================================
import os
from tensorboard_logger import configure, log_value
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.utils.data as torchdata
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import tqdm
import utils
import torch.optim as optim
from torch.distributions import Bernoulli
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
import argparse
parser = argparse.ArgumentParser(description='BlockDrop Training')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--wd', type=float, default=0.0, help='weight decay')
parser.add_argument('--model', default='R110_C10', help='R<depth>_<dataset> see utils.py for a list of configurations')
parser.add_argument('--data_dir', default='data/', help='data directory')
parser.add_argument('--load', default=None, help='checkpoint to load rnet+agent from')
parser.add_argument('--pretrained', default=None, help='pretrained policy model checkpoint (from curriculum training)')
parser.add_argument('--cv_dir', default='cv/tmp/', help='checkpoint directory (models and logs are saved here)')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--epoch_step', type=int, default=1600, help='epochs after which lr is decayed')
parser.add_argument('--max_epochs', type=int, default=2000, help='total epochs to run')
parser.add_argument('--lr_decay_ratio', type=float, default=0.1, help='lr *= lr_decay_ratio after epoch_steps')
parser.add_argument('--parallel', action ='store_true', default=False, help='use multiple GPUs for training')
# parser.add_argument('--joint', action ='store_true', default=True, help='train both the policy network and the resnet')
parser.add_argument('--penalty', type=float, default=-5, help='gamma: reward for incorrect predictions')
parser.add_argument('--alpha', type=float, default=0.8, help='probability bounding factor')
args = parser.parse_args()
if not os.path.exists(args.cv_dir):
os.system('mkdir ' + args.cv_dir)
utils.save_args(__file__, args)
def get_reward(preds, targets, policy):
block_use = policy.sum(1).float()/policy.size(1)
sparse_reward = 1.0-block_use**2
_, pred_idx = preds.max(1)
match = (pred_idx==targets).data
reward = sparse_reward
reward[1-match] = args.penalty
reward = reward.unsqueeze(1)
return reward, match.float()
def train(epoch):
agent.train()
rnet.train()
matches, rewards, policies = [], [], []
for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(trainloader), total=len(trainloader)):
inputs, targets = Variable(inputs), Variable(targets).cuda(async=True)
if not args.parallel:
inputs = inputs.cuda()
probs, value = agent(inputs)
#---------------------------------------------------------------------#
policy_map = probs.data.clone()
policy_map[policy_map<0.5] = 0.0
policy_map[policy_map>=0.5] = 1.0
policy_map = Variable(policy_map)
probs = probs*args.alpha + (1-probs)*(1-args.alpha)
distr = Bernoulli(probs)
policy = distr.sample()
v_inputs = Variable(inputs.data, volatile=True)
preds_map = rnet.forward(v_inputs, policy_map)
preds_sample = rnet.forward(inputs, policy)
reward_map, _ = get_reward(preds_map, targets, policy_map.data)
reward_sample, match = get_reward(preds_sample, targets, policy.data)
advantage = reward_sample - reward_map
# advantage = advantage.expand_as(policy)
loss = -distr.log_prob(policy).sum(1, keepdim=True) * Variable(advantage)
loss = loss.sum()
#---------------------------------------------------------------------#
loss += F.cross_entropy(preds_sample, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
matches.append(match.cpu())
rewards.append(reward_sample.cpu())
policies.append(policy.data.cpu())
accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches)
log_str = 'E: %d | A: %.3f | R: %.2E | S: %.3f | V: %.3f | #: %d'%(epoch, accuracy, reward, sparsity, variance, len(policy_set))
print log_str
log_value('train_accuracy', accuracy, epoch)
log_value('train_reward', reward, epoch)
log_value('train_sparsity', sparsity, epoch)
log_value('train_variance', variance, epoch)
log_value('train_unique_policies', len(policy_set), epoch)
def test(epoch):
agent.eval()
rnet.eval()
matches, rewards, policies = [], [], []
for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(testloader), total=len(testloader)):
inputs, targets = Variable(inputs, volatile=True), Variable(targets).cuda(async=True)
if not args.parallel:
inputs = inputs.cuda()
probs, _ = agent(inputs)
policy = probs.data.clone()
policy[policy<0.5] = 0.0
policy[policy>=0.5] = 1.0
policy = Variable(policy)
preds = rnet.forward(inputs, policy)
reward, match = get_reward(preds, targets, policy.data)
matches.append(match)
rewards.append(reward)
policies.append(policy.data)
accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches)
log_str = 'TS - A: %.3f | R: %.2E | S: %.3f | V: %.3f | #: %d'%(accuracy, reward, sparsity, variance, len(policy_set))
print log_str
log_value('test_accuracy', accuracy, epoch)
log_value('test_reward', reward, epoch)
log_value('test_sparsity', sparsity, epoch)
log_value('test_variance', variance, epoch)
log_value('test_unique_policies', len(policy_set), epoch)
# save the model
agent_state_dict = agent.module.state_dict() if args.parallel else agent.state_dict()
rnet_state_dict = rnet.module.state_dict() if args.parallel else rnet.state_dict()
state = {
'agent': agent_state_dict,
'resnet': rnet_state_dict,
'epoch': epoch,
'reward': reward,
'acc': accuracy
}
torch.save(state, args.cv_dir+'/ckpt_E_%d_A_%.3f_R_%.2E_S_%.2f_#_%d.t7'%(epoch, accuracy, reward, sparsity, len(policy_set)))
#--------------------------------------------------------------------------------------------------------#
trainset, testset = utils.get_dataset(args.model, args.data_dir)
trainloader = torchdata.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4)
testloader = torchdata.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=4)
rnet, agent = utils.get_model(args.model)
if args.pretrained is not None:
checkpoint = torch.load(args.pretrained)
key = 'net' if 'net' in checkpoint else 'agent'
agent.load_state_dict(checkpoint[key])
print 'loaded pretrained model from', args.pretrained
start_epoch = 0
if args.load is not None:
checkpoint = torch.load(args.load)
rnet.load_state_dict(checkpoint['resnet'])
agent.load_state_dict(checkpoint['agent'])
start_epoch = checkpoint['epoch'] + 1
print 'loaded agent from', args.load
if args.parallel:
agent = nn.DataParallel(agent)
rnet = nn.DataParallel(rnet)
rnet.cuda()
agent.cuda()
optimizer = optim.Adam(list(agent.parameters())+list(rnet.parameters()), lr=args.lr, weight_decay=args.wd)
configure(args.cv_dir+'/log', flush_secs=5)
lr_scheduler = utils.LrScheduler(optimizer, args.lr, args.lr_decay_ratio, args.epoch_step)
for epoch in range(start_epoch, start_epoch+args.max_epochs+1):
lr_scheduler.adjust_learning_rate(epoch)
train(epoch)
if epoch%10==0:
test(epoch)
================================================
FILE: models/__init__.py
================================================
================================================
FILE: models/base.py
================================================
import torch.nn as nn
import math
import torch
import torchvision.models as torchmodels
import re
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import torch.nn.utils as torchutils
from torch.nn import init, Parameter
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
return x.view(x.size(0), -1)
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
return out
class DownsampleB(nn.Module):
def __init__(self, nIn, nOut, stride):
super(DownsampleB, self).__init__()
self.avg = nn.AvgPool2d(stride)
self.expand_ratio = nOut // nIn
def forward(self, x):
x = self.avg(x)
return torch.cat([x] + [x.mul(0)] * (self.expand_ratio - 1), 1)
================================================
FILE: models/resnet.py
================================================
import torch.nn as nn
import math
import torch
import torchvision.models as torchmodels
import re
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import random
import torch.nn.init as torchinit
import math
from torch.nn import init, Parameter
import copy
from models import base
import utils
#--------------------------------------------------------------------------------------------------#
class FlatResNet(nn.Module):
def seed(self, x):
# x = self.relu(self.bn1(self.conv1(x))) -- CIFAR
# x = self.maxpool(self.relu(self.bn1(self.conv1(x)))) -- ImageNet
raise NotImplementedError
# run a variable policy batch through the resnet implemented as a full mask over the residual
# fast to train, non-indicative of time saving (use forward_single instead)
def forward(self, x, policy):
x = self.seed(x)
t = 0
for segment, num_blocks in enumerate(self.layer_config):
for b in range(num_blocks):
action = policy[:,t].contiguous()
residual = self.ds[segment](x) if b==0 else x
# early termination if all actions in the batch are zero
if action.data.sum() == 0:
x = residual
t += 1
continue
action_mask = action.float().view(-1,1,1,1)
fx = F.relu(residual + self.blocks[segment][b](x))
x = fx*action_mask + residual*(1-action_mask)
t += 1
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# run a single, fixed policy for all items in the batch
# policy is a (15,) vector. Use with batch_size=1 for profiling
def forward_single(self, x, policy):
x = self.seed(x)
t = 0
for segment, num_blocks in enumerate(self.layer_config):
for b in range(num_blocks):
residual = self.ds[segment](x) if b==0 else x
if policy[t]==1:
x = residual + self.blocks[segment][b](x)
x = F.relu(x)
else:
x = residual
t += 1
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def forward_full(self, x):
x = self.seed(x)
for segment, num_blocks in enumerate(self.layer_config):
for b in range(num_blocks):
residual = self.ds[segment](x) if b==0 else x
x = F.relu(residual + self.blocks[segment][b](x))
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# Smaller Flattened Resnet, tailored for CIFAR
class FlatResNet32(FlatResNet):
def __init__(self, block, layers, num_classes=10):
super(FlatResNet32, self).__init__()
self.inplanes = 16
self.conv1 = base.conv3x3(3, 16)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(8)
strides = [1, 2, 2]
filt_sizes = [16, 32, 64]
self.blocks, self.ds = [], []
for idx, (filt_size, num_blocks, stride) in enumerate(zip(filt_sizes, layers, strides)):
blocks, ds = self._make_layer(block, filt_size, num_blocks, stride=stride)
self.blocks.append(nn.ModuleList(blocks))
self.ds.append(ds)
self.blocks = nn.ModuleList(self.blocks)
self.ds = nn.ModuleList(self.ds)
self.fc = nn.Linear(64 * block.expansion, num_classes)
self.fc_dim = 64 * block.expansion
self.layer_config = layers
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def seed(self, x):
x = self.relu(self.bn1(self.conv1(x)))
return x
def _make_layer(self, block, planes, blocks, stride=1):
downsample = nn.Sequential()
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = base.DownsampleB(self.inplanes, planes * block.expansion, stride)
layers = [block(self.inplanes, planes, stride)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, 1))
return layers, downsample
# Regular Flattened Resnet, tailored for Imagenet etc.
class FlatResNet224(FlatResNet):
def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(FlatResNet224, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
strides = [1, 2, 2, 2]
filt_sizes = [64, 128, 256, 512]
self.blocks, self.ds = [], []
for idx, (filt_size, num_blocks, stride) in enumerate(zip(filt_sizes, layers, strides)):
blocks, ds = self._make_layer(block, filt_size, num_blocks, stride=stride)
self.blocks.append(nn.ModuleList(blocks))
self.ds.append(ds)
self.blocks = nn.ModuleList(self.blocks)
self.ds = nn.ModuleList(self.ds)
self.avgpool = nn.AvgPool2d(7)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
self.layer_config = layers
def seed(self, x):
x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
return x
def _make_layer(self, block, planes, blocks, stride=1):
downsample = nn.Sequential()
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = [block(self.inplanes, planes, stride)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return layers, downsample
#---------------------------------------------------------------------------------------------------------#
# Class to generate resnetNB or any other config (default is 3B)
# removed the fc layer so it serves as a feature extractor
class Policy32(nn.Module):
def __init__(self, layer_config=[1,1,1], num_blocks=15):
super(Policy32, self).__init__()
self.features = FlatResNet32(base.BasicBlock, layer_config, num_classes=10)
self.feat_dim = self.features.fc.weight.data.shape[1]
self.features.fc = nn.Sequential()
self.logit = nn.Linear(self.feat_dim, num_blocks)
self.vnet = nn.Linear(self.feat_dim, 1)
def load_state_dict(self, state_dict):
# support legacy models
state_dict = {k:v for k,v in state_dict.items() if not k.startswith('features.fc')}
return super(Policy32, self).load_state_dict(state_dict)
def forward(self, x):
x = self.features.forward_full(x)
value = self.vnet(x)
probs = F.sigmoid(self.logit(x))
return probs, value
class Policy224(nn.Module):
def __init__(self, layer_config=[1,1,1,1], num_blocks=16):
super(Policy224, self).__init__()
self.features = FlatResNet224(base.BasicBlock, layer_config, num_classes=1000)
resnet18 = torchmodels.resnet18(pretrained=True)
utils.load_weights_to_flatresnet(resnet18, self.features)
self.features.avgpool = nn.AvgPool2d(4)
self.feat_dim = self.features.fc.weight.data.shape[1]
self.features.fc = nn.Sequential()
self.logit = nn.Linear(self.feat_dim, num_blocks)
self.vnet = nn.Linear(self.feat_dim, 1)
def load_state_dict(self, state_dict):
# support legacy models
state_dict = {k:v for k,v in state_dict.items() if not k.startswith('features.fc')}
return super(Policy224, self).load_state_dict(state_dict)
def forward(self, x):
x = F.avg_pool2d(x, 2)
x = self.features.forward_full(x)
value = self.vnet(x)
probs = F.sigmoid(self.logit(x))
return probs, value
#--------------------------------------------------------------------------------------------------------#
class StepResnet32(FlatResNet32):
def __init__(self, block, layers, num_classes, joint=False):
super(StepResnet, self).__init__(block, layers, num_classes)
self.eval() # default to eval mode
self.joint = joint
self.state_ptr = {}
t = 0
for segment, num_blocks in enumerate(self.layer_config):
for b in range(num_blocks):
self.state_ptr[t] = (segment, b)
t += 1
def seed(self, x):
self.state = self.relu(self.bn1(self.conv1(x)))
self.t = 0
if self.joint:
return self.state
return Variable(self.state.data)
def step(self, action):
segment, b = self.state_ptr[self.t]
residual = self.ds[segment](self.state) if b==0 else self.state
action_mask = action.float().view(-1,1,1,1)
fx = F.relu(residual + self.blocks[segment][b](self.state))
self.state = fx*action_mask + residual*(1-action_mask)
self.t += 1
if self.joint:
return self.state
return Variable(self.state.data)
def step_single(self, action):
segment, b = self.state_ptr[self.t]
residual = self.ds[segment](self.state) if b==0 else self.state
if action.data[0,0]==1:
self.state = F.relu(residual + self.blocks[segment][b](self.state))
else:
self.state = residual
self.t += 1
if self.joint:
return self.state
return Variable(self.state.data)
def predict(self):
x = self.avgpool(self.state)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class StepPolicy32(nn.Module):
def __init__(self, layer_config):
super(StepPolicy, self).__init__()
in_dim = [16] + [16]*layer_config[0] + [32]*layer_config[1] + [64]*(layer_config[2]-1)
self.pnet = nn.ModuleList([nn.Linear(dim, 2) for dim in in_dim])
self.vnet = nn.ModuleList([nn.Linear(dim, 1) for dim in in_dim])
def forward(self, state):
x, t = state
x = F.avg_pool2d(x, x.size(2)).view(x.size(0), -1) # pool + flatten --> (B, 16/32/64)
logit = F.softmax(self.pnet[t](x))
value = self.vnet[t](x)
return logit, value
================================================
FILE: requirements.txt
================================================
tqdm
numpy
tensorboard_logger
argparse
================================================
FILE: test.py
================================================
import torch
from torch.autograd import Variable
import torch.utils.data as torchdata
import torch.nn as nn
import numpy as np
import tqdm
import utils
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='R110_C10')
parser.add_argument('--data_dir', default='data/')
parser.add_argument('--load', default=None)
args = parser.parse_args()
#---------------------------------------------------------------------------------------------#
class FConv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(FConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
self.num_ops = 0
def forward(self, x):
output = super(FConv2d, self).forward(x)
output_area = output.size(-1)*output.size(-2)
filter_area = np.prod(self.kernel_size)
self.num_ops += 2*self.in_channels*self.out_channels*filter_area*output_area
return output
class FLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super(FLinear, self).__init__(in_features, out_features, bias)
self.num_ops = 0
def forward(self, x):
output = super(FLinear, self).forward(x)
self.num_ops += 2*self.in_features*self.out_features
return output
def count_flops(model, reset=True):
op_count = 0
for m in model.modules():
if hasattr(m, 'num_ops'):
op_count += m.num_ops
if reset: # count and reset to 0
m.num_ops = 0
return op_count
# replace all nn.Conv and nn.Linear layers with layers that count flops
nn.Conv2d = FConv2d
nn.Linear = FLinear
#--------------------------------------------------------------------------------------------#
def test():
total_ops = []
matches, policies = [], []
for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(testloader), total=len(testloader)):
inputs, targets = Variable(inputs, volatile=True).cuda(), Variable(targets).cuda()
probs, _ = agent(inputs)
policy = probs.clone()
policy[policy<0.5] = 0.0
policy[policy>=0.5] = 1.0
preds = rnet.forward_single(inputs, policy.data.squeeze(0))
_ , pred_idx = preds.max(1)
match = (pred_idx==targets).data.float()
matches.append(match)
policies.append(policy.data)
ops = count_flops(agent) + count_flops(rnet)
total_ops.append(ops)
accuracy, _, sparsity, variance, policy_set = utils.performance_stats(policies, matches, matches)
ops_mean, ops_std = np.mean(total_ops), np.std(total_ops)
log_str = u'''
Accuracy: %.3f
Block Usage: %.3f \u00B1 %.3f
FLOPs/img: %.2E \u00B1 %.2E
Unique Policies: %d
'''%(accuracy, sparsity, variance, ops_mean, ops_std, len(policy_set))
print log_str
#--------------------------------------------------------------------------------------------------------#
trainset, testset = utils.get_dataset(args.model, args.data_dir)
testloader = torchdata.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
rnet, agent = utils.get_model(args.model)
# if no model is loaded, use all blocks
agent.logit.weight.data.fill_(0)
agent.logit.bias.data.fill_(10)
print "loading checkpoints"
if args.load is not None:
utils.load_checkpoint(rnet, agent, args.load)
rnet.eval().cuda()
agent.eval().cuda()
test()
================================================
FILE: utils.py
================================================
import os
import re
import torch
import torchvision.transforms as transforms
import torchvision.datasets as torchdata
import numpy as np
# Save the training script and all the arguments to a file so that you
# don't feel like an idiot later when you can't replicate results
import shutil
def save_args(__file__, args):
shutil.copy(os.path.basename(__file__), args.cv_dir)
with open(args.cv_dir+'/args.txt','w') as f:
f.write(str(args))
def performance_stats(policies, rewards, matches):
policies = torch.cat(policies, 0)
rewards = torch.cat(rewards, 0)
accuracy = torch.cat(matches, 0).mean()
reward = rewards.mean()
sparsity = policies.sum(1).mean()
variance = policies.sum(1).std()
policy_set = [p.cpu().numpy().astype(np.int).astype(np.str) for p in policies]
policy_set = set([''.join(p) for p in policy_set])
return accuracy, reward, sparsity, variance, policy_set
class LrScheduler:
def __init__(self, optimizer, base_lr, lr_decay_ratio, epoch_step):
self.base_lr = base_lr
self.lr_decay_ratio = lr_decay_ratio
self.epoch_step = epoch_step
self.optimizer = optimizer
def adjust_learning_rate(self, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = self.base_lr * (self.lr_decay_ratio ** (epoch // self.epoch_step))
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
if epoch%self.epoch_step==0:
print '# setting learning_rate to %.2E'%lr
# load model weights trained using scripts from https://github.com/felixgwu/img_classification_pk_pytorch OR
# from torchvision models into our flattened resnets
def load_weights_to_flatresnet(source_model, target_model):
# compatibility for nn.Modules + checkpoints
if hasattr(source_model, 'state_dict'):
source_model = {'state_dict': source_model.state_dict()}
source_state = source_model['state_dict']
target_state = target_model.state_dict()
# remove the module. prefix if it exists (thanks nn.DataParallel)
if source_state.keys()[0].startswith('module.'):
source_state = {k[7:]:v for k,v in source_state.items()}
common = set(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var','fc.weight', 'fc.bias'])
for key in source_state.keys():
if key in common:
target_state[key] = source_state[key]
continue
if 'downsample' in key:
layer, num, item = re.match('layer(\d+).*\.(\d+)\.(.*)', key).groups()
translated = 'ds.%s.%s.%s'%(int(layer)-1, num, item)
else:
layer, item = re.match('layer(\d+)\.(.*)', key).groups()
translated = 'blocks.%s.%s'%(int(layer)-1, item)
if translated in target_state.keys():
target_state[translated] = source_state[key]
else:
print translated, 'block missing'
target_model.load_state_dict(target_state)
return target_model
def load_checkpoint(rnet, agent, load):
if load=='nil':
return None
checkpoint = torch.load(load)
if 'resnet' in checkpoint:
rnet.load_state_dict(checkpoint['resnet'])
print 'loaded resnet from', os.path.basename(load)
if 'agent' in checkpoint:
agent.load_state_dict(checkpoint['agent'])
print 'loaded agent from', os.path.basename(load)
# backward compatibility (some old checkpoints)
if 'net' in checkpoint:
checkpoint['net'] = {k:v for k,v in checkpoint['net'].items() if 'features.fc' not in k}
agent.load_state_dict(checkpoint['net'])
print 'loaded agent from', os.path.basename(load)
def get_transforms(rnet, dset):
# Only the R32 pretrained model subtracts the mean, sorry :(
if dset=='C10' and rnet=='R32':
mean = [x/255.0 for x in [125.3, 123.0, 113.9]]
std = [x/255.0 for x in [63.0, 62.1, 66.7]]
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
elif dset=='C100' or dset=='C10' and rnet!='R32':
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
])
elif dset=='ImgNet':
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
transform_train = transforms.Compose([
transforms.Scale(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
transform_test = transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
return transform_train, transform_test
# Pick from the datasets available and the hundreds of models we have lying around depending on the requirements.
def get_dataset(model, root='data/'):
rnet, dset = model.split('_')
transform_train, transform_test = get_transforms(rnet, dset)
if dset=='C10':
trainset = torchdata.CIFAR10(root=root, train=True, download=True, transform=transform_train)
testset = torchdata.CIFAR10(root=root, train=False, download=True, transform=transform_test)
elif dset=='C100':
trainset = torchdata.CIFAR100(root=root, train=True, download=True, transform=transform_train)
testset = torchdata.CIFAR100(root=root, train=False, download=True, transform=transform_test)
elif dset=='ImgNet':
trainset = torchdata.ImageFolder(root+'/train/', transform_train)
testset = torchdata.ImageFolder(root+'/val/', transform_test)
return trainset, testset
# Make a new if statement for every new model variety you want to index
def get_model(model):
from models import resnet, base
if model=='R32_C10':
rnet_checkpoint = 'cv/pretrained/R32_C10/pk_E_164_A_0.923.t7'
layer_config = [5, 5, 5]
rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=10)
agent = resnet.Policy32([1,1,1], num_blocks=15)
elif model=='R110_C10':
rnet_checkpoint = 'cv/pretrained/R110_C10/pk_E_130_A_0.932.t7'
layer_config = [18, 18, 18]
rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=10)
agent = resnet.Policy32([1,1,1], num_blocks=54)
elif model=='R32_C100':
rnet_checkpoint = 'cv/pretrained/R32_C100/pk_E_164_A_0.693.t7'
layer_config = [5, 5, 5]
rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=100)
agent = resnet.Policy32([1,1,1], num_blocks=15)
elif model=='R110_C100':
rnet_checkpoint = 'cv/pretrained/R110_C100/pk_E_160_A_0.723.t7'
layer_config = [18, 18, 18]
rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=100)
agent = resnet.Policy32([1,1,1], num_blocks=54)
elif model=='R101_ImgNet':
rnet_checkpoint = 'cv/pretrained/R101_ImgNet/ImageNet_R101_224_76.464'
layer_config = [3,4,23,3]
rnet = resnet.FlatResNet224(base.Bottleneck, layer_config, num_classes=1000)
agent = resnet.Policy224([1,1,1,1], num_blocks=33)
# load pretrained weights into flat ResNet
rnet_checkpoint = torch.load(rnet_checkpoint)
load_weights_to_flatresnet(rnet_checkpoint, rnet)
return rnet, agent
gitextract_y_xp5prk/ ├── .gitignore ├── Readme.md ├── cl_training.py ├── finetune.py ├── models/ │ ├── __init__.py │ ├── base.py │ └── resnet.py ├── requirements.txt ├── test.py └── utils.py
SYMBOL INDEX (70 symbols across 6 files)
FILE: cl_training.py
function get_reward (line 43) | def get_reward(preds, targets, policy):
function train (line 58) | def train(epoch):
function test (line 136) | def test(epoch):
FILE: finetune.py
function get_reward (line 42) | def get_reward(preds, targets, policy):
function train (line 57) | def train(epoch):
function test (line 117) | def test(epoch):
FILE: models/base.py
class Identity (line 15) | class Identity(nn.Module):
method __init__ (line 16) | def __init__(self):
method forward (line 18) | def forward(self, x):
class Flatten (line 21) | class Flatten(nn.Module):
method __init__ (line 22) | def __init__(self):
method forward (line 24) | def forward(self, x):
function conv3x3 (line 27) | def conv3x3(in_planes, out_planes, stride=1):
class BasicBlock (line 31) | class BasicBlock(nn.Module):
method __init__ (line 34) | def __init__(self, inplanes, planes, stride=1):
method forward (line 41) | def forward(self, x):
class Bottleneck (line 52) | class Bottleneck(nn.Module):
method __init__ (line 55) | def __init__(self, inplanes, planes, stride=1):
method forward (line 65) | def forward(self, x):
class DownsampleB (line 81) | class DownsampleB(nn.Module):
method __init__ (line 83) | def __init__(self, nIn, nOut, stride):
method forward (line 88) | def forward(self, x):
FILE: models/resnet.py
class FlatResNet (line 19) | class FlatResNet(nn.Module):
method seed (line 21) | def seed(self, x):
method forward (line 28) | def forward(self, x, policy):
method forward_single (line 56) | def forward_single(self, x, policy):
method forward_full (line 76) | def forward_full(self, x):
class FlatResNet32 (line 92) | class FlatResNet32(FlatResNet):
method __init__ (line 94) | def __init__(self, block, layers, num_classes=10):
method seed (line 126) | def seed(self, x):
method _make_layer (line 130) | def _make_layer(self, block, planes, blocks, stride=1):
class FlatResNet224 (line 145) | class FlatResNet224(FlatResNet):
method __init__ (line 147) | def __init__(self, block, layers, num_classes=1000):
method seed (line 179) | def seed(self, x):
method _make_layer (line 183) | def _make_layer(self, block, planes, blocks, stride=1):
class Policy32 (line 205) | class Policy32(nn.Module):
method __init__ (line 207) | def __init__(self, layer_config=[1,1,1], num_blocks=15):
method load_state_dict (line 216) | def load_state_dict(self, state_dict):
method forward (line 222) | def forward(self, x):
class Policy224 (line 229) | class Policy224(nn.Module):
method __init__ (line 231) | def __init__(self, layer_config=[1,1,1,1], num_blocks=16):
method load_state_dict (line 247) | def load_state_dict(self, state_dict):
method forward (line 252) | def forward(self, x):
class StepResnet32 (line 261) | class StepResnet32(FlatResNet32):
method __init__ (line 263) | def __init__(self, block, layers, num_classes, joint=False):
method seed (line 276) | def seed(self, x):
method step (line 284) | def step(self, action):
method step_single (line 298) | def step_single(self, action):
method predict (line 313) | def predict(self):
class StepPolicy32 (line 320) | class StepPolicy32(nn.Module):
method __init__ (line 322) | def __init__(self, layer_config):
method forward (line 328) | def forward(self, state):
FILE: test.py
class FConv2d (line 20) | class FConv2d(nn.Conv2d):
method __init__ (line 22) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
method forward (line 28) | def forward(self, x):
class FLinear (line 35) | class FLinear(nn.Linear):
method __init__ (line 36) | def __init__(self, in_features, out_features, bias=True):
method forward (line 40) | def forward(self, x):
function count_flops (line 45) | def count_flops(model, reset=True):
function test (line 61) | def test():
FILE: utils.py
function save_args (line 11) | def save_args(__file__, args):
function performance_stats (line 16) | def performance_stats(policies, rewards, matches):
class LrScheduler (line 31) | class LrScheduler:
method __init__ (line 32) | def __init__(self, optimizer, base_lr, lr_decay_ratio, epoch_step):
method adjust_learning_rate (line 38) | def adjust_learning_rate(self, epoch):
function load_weights_to_flatresnet (line 49) | def load_weights_to_flatresnet(source_model, target_model):
function load_checkpoint (line 85) | def load_checkpoint(rnet, agent, load):
function get_transforms (line 103) | def get_transforms(rnet, dset):
function get_dataset (line 154) | def get_dataset(model, root='data/'):
function get_model (line 172) | def get_model(model):
Condensed preview — 10 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (50K chars).
[
{
"path": ".gitignore",
"chars": 57,
"preview": "*.*~\n*.pyc\n*.pkl\n*.h5\n*.t7\n*.t7*\n*.log\n*.png\n\ndata/\ncv/\n\n"
},
{
"path": "Readme.md",
"chars": 5517,
"preview": "# BlockDrop: Dynamic Inference Paths in Residual Networks\n. The extraction includes 10 files (46.6 KB), approximately 12.0k tokens, and a symbol index with 70 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.