Repository: Tushar-N/blockdrop Branch: master Commit: ec52b36d38dc Files: 10 Total size: 46.6 KB Directory structure: gitextract_y_xp5prk/ ├── .gitignore ├── Readme.md ├── cl_training.py ├── finetune.py ├── models/ │ ├── __init__.py │ ├── base.py │ └── resnet.py ├── requirements.txt ├── test.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.*~ *.pyc *.pkl *.h5 *.t7 *.t7* *.log *.png data/ cv/ ================================================ FILE: Readme.md ================================================ # BlockDrop: Dynamic Inference Paths in Residual Networks ![BlockDrop Model](https://user-images.githubusercontent.com/4995097/35775877-3cc64f86-0957-11e8-85c4-9bd16cda22a0.png) This code implements a policy network that learns to dynamically choose which blocks of a ResNet to execute during inference so as to best reduce total computation without degrading prediction accuracy. Built upon a ResNet-101 model, our method achieves a speedup of 20% on average, going as high as 36% for some images, while maintaining the same 76.4% top-1 accuracy on ImageNet. This is the code accompanying the work: Zuxuan Wu*, Tushar Nagarajan*, Abhishek Kumar, Steven Rennie, Larry S. Davis, Kristen Grauman, and Rogerio Feris. BlockDrop: Dynamic Inference Paths in Residual Networks [[arxiv]](https://arxiv.org/pdf/1711.08393.pdf) (* authors contributed equally) ## Prerequisites The code is written and tested using Python (2.7) and PyTorch (v0.3.0). **Packages**: Install using `pip install -r requirements.txt` **Pretrained models**: Our models require standard pretrained ResNets on CIFAR and ImageNet as starting points. These can be trained using [this](https://github.com/felixgwu/img_classification_pk_pytorch) repository, or can be obtained directly from us ```bash wget -O blockdrop-checkpoints.tar.gz https://utexas.box.com/shared/static/ok98i51v14c0q9lvs1z5g71m6b3zm8sj.gz tar -zxvf blockdrop-checkpoints.tar.gz ``` The downloaded checkpoints will be unpacked to `./cv/` for further use. The folder also contains various checkpoints from each stage of training. **Datasets**: PyTorch's *torchvision* package automatically downloads CIFAR10 and CIFAR100 during training. ImageNet must be downloaded and organized following [these steps](https://github.com/soumith/imagenet-multiGPU.torch#data-processing). ## Training a model Training occurs in two steps (1) Curriculum Learning and (2) Joint Finetuning. Models operating on ResNets of different depths can be trained on different datasets using the same script. Examples of how to train these models are given below. Checkpoints and tensorboard log files will be saved to folder specified in `--cv_dir` #### Curriculum Learning The policy network can be trained using a CL schedule as follows. ```bash # Train a model on CIFAR 10 built upon a ResNet-110 python cl_training.py --model R110_C10 --cv_dir cv/R110_C10_cl/ --lr 1e-3 --batch_size 2048 --max_epochs 5000 # Train a model on ImageNet built upon a ResNet-101 python cl_training.py --model R101_ImgNet --cv_dir cv/R101_ImgNet_cl/ --lr 1e-3 --batch_size 2048 --max_epochs 45 --data_dir data/imagenet/ ``` Model checkpoints after the curriculum learning step can be found in the downloaded folder. For example: `./cv/cl_learning/R110_C10/ckpt_E_5300_A_0.754_R_2.22E-01_S_20.10_#_7787.t7` #### Joint Finetuning Checkpoints trained during the curriculum learning phase can be used to further jointly finetune the base ResNet to achieve the results reported in the paper. Different values for the penalty parameter control the trade-off between accuracy and speed. ```bash # Finetune a ResNet-110 on CIFAR 10 using the checkpoint from cl_training python finetune.py --model R110_C10 --lr 1e-4 --penalty -10 --pretrained cv/cl_training/R110_C10/ckpt_E_5300_A_0.754_R_2.22E-01_S_20.10_#_7787.t7 --batch_size 256 --max_epochs 2000 --cv_dir cv/R110_C10_ft_-10/ # Finetune a ResNet-101 on ImageNet using the checkpoint from cl_training python finetune.py --model R101_ImgNet --lr 1e-4 --penalty -5 --pretrained cv/cl_training/R101_ImgNet/ckpt_E_4_A_0.746_R_-3.70E-01_S_29.79_#_484.t7 --data_dir data/imagenet/ --batch_size 320 --max_epochs 10 --cv_dir cv/R101_ImgNet_ft_-5/ ``` Model checkpoints after the joint finetuning step can be found in the downloaded folder. For example: `./cv/finetuned/R101_ImgNet_gamma_5/ckpt_E_10_A_0.764_R_-8.46E-01_S_24.77_#_10.t7` ## Testing and Profiling Once jointly finetuned, models can be profiled for accuracy and FLOPs counts. ```bash python test.py --model R110_C10 --load cv/finetuned/R110_C10_gamma_10/ckpt_E_2000_A_0.936_R_1.95E-01_S_16.93_#_469.t7 ``` The model should produce an accuracy of 93.6% and use 1.81E+08 FLOPs on average. The output should look like this: ``` Accuracy: 0.936 Block Usage: 16.933 ± 3.717 FLOPs/img: 1.81E+08 ± 3.43E+07 Unique Policies: 469 ``` The ImageNet model can be evaluated in a similar manner, and will generate a corresponding output. ``` python test.py --model R101_ImgNet --load cv/finetuned/R101_ImgNet_gamma_5/ckpt_E_10_A_0.764_R_-8.46E-01_S_24.77_#_10.t7 ``` ``` Accuracy: 0.764 Block Usage: 24.770 ± 0.980 FLOPs/img: 1.25E+10 ± 4.28E+08 Unique Policies: 10 ``` ## Visualization Learned policies over ResNet blocks show that there is a clear separation between easy/hard images in terms of the number of blocks they require. In addition, unique policies over the blocks admit distinct image styles. ![Policy visualization](https://user-images.githubusercontent.com/4995097/35775878-3e5ee4e8-0957-11e8-832d-b9dc2ea8fecc.png) For more qualitative results, see Sec. 4.3 and Figures 4. and 5. in the paper. ## Cite If you find this repository useful in your own research, please consider citing: ``` @inproceedings{blockdrop, title={BlockDrop: Dynamic Inference Paths in Residual Networks}, author={Wu, Zuxuan and Nagarajan, Tushar and Kumar, Abhishek and Rennie, Steven and Davis, Larry S and Grauman, Kristen and Feris, Rogerio}, booktitle={CVPR}, year={2018} } ``` ================================================ FILE: cl_training.py ================================================ import os from tensorboard_logger import configure, log_value import torch import torch.autograd as autograd from torch.autograd import Variable import torch.utils.data as torchdata import torch.nn as nn import torch.nn.functional as F import numpy as np import tqdm import utils import torch.optim as optim from torch.distributions import Bernoulli import torch.backends.cudnn as cudnn cudnn.benchmark = True import argparse parser = argparse.ArgumentParser(description='BlockDrop Training') parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') parser.add_argument('--beta', type=float, default=1e-1, help='entropy multiplier') parser.add_argument('--wd', type=float, default=0.0, help='weight decay') parser.add_argument('--model', default='R110_C10', help='R_ see utils.py for a list of configurations') parser.add_argument('--data_dir', default='data/', help='data directory') parser.add_argument('--load', default=None, help='checkpoint to load agent from') parser.add_argument('--cv_dir', default='cv/tmp/', help='checkpoint directory (models and logs are saved here)') parser.add_argument('--batch_size', type=int, default=256, help='batch size') parser.add_argument('--epoch_step', type=int, default=10000, help='epochs after which lr is decayed') parser.add_argument('--max_epochs', type=int, default=10000, help='total epochs to run') parser.add_argument('--lr_decay_ratio', type=float, default=0.1, help='lr *= lr_decay_ratio after epoch_steps') parser.add_argument('--parallel', action ='store_true', default=False, help='use multiple GPUs for training') parser.add_argument('--cl_step', type=int, default=1, help='steps for curriculum training') # parser.add_argument('--joint', action ='store_true', default=True, help='train both the policy network and the resnet') parser.add_argument('--penalty', type=float, default=-1, help='gamma: reward for incorrect predictions') parser.add_argument('--alpha', type=float, default=0.8, help='probability bounding factor') args = parser.parse_args() if not os.path.exists(args.cv_dir): os.system('mkdir ' + args.cv_dir) utils.save_args(__file__, args) def get_reward(preds, targets, policy): block_use = policy.sum(1).float()/policy.size(1) sparse_reward = 1.0-block_use**2 _, pred_idx = preds.max(1) match = (pred_idx==targets).data reward = sparse_reward reward[1-match] = args.penalty reward = reward.unsqueeze(1) return reward, match.float() def train(epoch): agent.train() matches, rewards, policies = [], [], [] for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(trainloader), total=len(trainloader)): inputs, targets = Variable(inputs), Variable(targets).cuda(async=True) if not args.parallel: inputs = inputs.cuda() probs, value = agent(inputs) #---------------------------------------------------------------------# policy_map = probs.data.clone() policy_map[policy_map<0.5] = 0.0 policy_map[policy_map>=0.5] = 1.0 policy_map = Variable(policy_map) probs = probs*args.alpha + (1-probs)*(1-args.alpha) distr = Bernoulli(probs) policy = distr.sample() if args.cl_step < num_blocks: policy[:, :-args.cl_step] = 1 policy_map[:, :-args.cl_step] = 1 policy_mask = Variable(torch.ones(inputs.size(0), policy.size(1))).cuda() policy_mask[:, :-args.cl_step] = 0 else: policy_mask = None v_inputs = Variable(inputs.data, volatile=True) preds_map = rnet.forward(v_inputs, policy_map) preds_sample = rnet.forward(v_inputs, policy) reward_map, _ = get_reward(preds_map, targets, policy_map.data) reward_sample, match = get_reward(preds_sample, targets, policy.data) advantage = reward_sample - reward_map loss = -distr.log_prob(policy) loss = loss * Variable(advantage).expand_as(policy) if policy_mask is not None: loss = policy_mask * loss # mask for curriculum learning loss = loss.sum() probs = probs.clamp(1e-15, 1-1e-15) entropy_loss = -probs*torch.log(probs) entropy_loss = args.beta*entropy_loss.sum() loss = (loss - entropy_loss)/inputs.size(0) #---------------------------------------------------------------------# optimizer.zero_grad() loss.backward() optimizer.step() matches.append(match.cpu()) rewards.append(reward_sample.cpu()) policies.append(policy.data.cpu()) accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches) log_str = 'E: %d | A: %.3f | R: %.2E | S: %.3f | V: %.3f | #: %d'%(epoch, accuracy, reward, sparsity, variance, len(policy_set)) print log_str log_value('train_accuracy', accuracy, epoch) log_value('train_reward', reward, epoch) log_value('train_sparsity', sparsity, epoch) log_value('train_variance', variance, epoch) log_value('train_unique_policies', len(policy_set), epoch) def test(epoch): agent.eval() matches, rewards, policies = [], [], [] for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(testloader), total=len(testloader)): inputs, targets = Variable(inputs, volatile=True), Variable(targets).cuda(async=True) if not args.parallel: inputs = inputs.cuda() probs, _ = agent(inputs) policy = probs.data.clone() policy[policy<0.5] = 0.0 policy[policy>=0.5] = 1.0 policy = Variable(policy) if args.cl_step < num_blocks: policy[:, :-args.cl_step] = 1 preds = rnet.forward(inputs, policy) reward, match = get_reward(preds, targets, policy.data) matches.append(match) rewards.append(reward) policies.append(policy.data) accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches) log_str = 'TS - A: %.3f | R: %.2E | S: %.3f | V: %.3f | #: %d'%(accuracy, reward, sparsity, variance, len(policy_set)) print log_str log_value('test_accuracy', accuracy, epoch) log_value('test_reward', reward, epoch) log_value('test_sparsity', sparsity, epoch) log_value('test_variance', variance, epoch) log_value('test_unique_policies', len(policy_set), epoch) # save the model agent_state_dict = agent.module.state_dict() if args.parallel else agent.state_dict() state = { 'agent': agent_state_dict, 'epoch': epoch, 'reward': reward, 'acc': accuracy } torch.save(state, args.cv_dir+'/ckpt_E_%d_A_%.3f_R_%.2E_S_%.2f_#_%d.t7'%(epoch, accuracy, reward, sparsity, len(policy_set))) #--------------------------------------------------------------------------------------------------------# trainset, testset = utils.get_dataset(args.model, args.data_dir) trainloader = torchdata.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4) testloader = torchdata.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=4) rnet, agent = utils.get_model(args.model) num_blocks = sum(rnet.layer_config) start_epoch = 0 if args.load is not None: checkpoint = torch.load(args.load) agent.load_state_dict(checkpoint['agent']) start_epoch = checkpoint['epoch'] + 1 print 'loaded agent from', args.load if args.parallel: agent = nn.DataParallel(agent) rnet = nn.DataParallel(rnet) rnet.eval().cuda() agent.cuda() optimizer = optim.Adam(agent.parameters(), lr=args.lr, weight_decay=args.wd) configure(args.cv_dir+'/log', flush_secs=5) lr_scheduler = utils.LrScheduler(optimizer, args.lr, args.lr_decay_ratio, args.epoch_step) for epoch in range(start_epoch, start_epoch+args.max_epochs+1): lr_scheduler.adjust_learning_rate(epoch) if args.cl_step < num_blocks: args.cl_step = 1 + 1 * (epoch // 1) else: args.cl_step = num_blocks print 'training the last %d blocks ...' % args.cl_step train(epoch) if epoch % 10 == 0: test(epoch) ================================================ FILE: finetune.py ================================================ import os from tensorboard_logger import configure, log_value import torch import torch.autograd as autograd from torch.autograd import Variable import torch.utils.data as torchdata import torch.nn as nn import torch.nn.functional as F import numpy as np import tqdm import utils import torch.optim as optim from torch.distributions import Bernoulli import torch.backends.cudnn as cudnn cudnn.benchmark = True import argparse parser = argparse.ArgumentParser(description='BlockDrop Training') parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') parser.add_argument('--wd', type=float, default=0.0, help='weight decay') parser.add_argument('--model', default='R110_C10', help='R_ see utils.py for a list of configurations') parser.add_argument('--data_dir', default='data/', help='data directory') parser.add_argument('--load', default=None, help='checkpoint to load rnet+agent from') parser.add_argument('--pretrained', default=None, help='pretrained policy model checkpoint (from curriculum training)') parser.add_argument('--cv_dir', default='cv/tmp/', help='checkpoint directory (models and logs are saved here)') parser.add_argument('--batch_size', type=int, default=256, help='batch size') parser.add_argument('--epoch_step', type=int, default=1600, help='epochs after which lr is decayed') parser.add_argument('--max_epochs', type=int, default=2000, help='total epochs to run') parser.add_argument('--lr_decay_ratio', type=float, default=0.1, help='lr *= lr_decay_ratio after epoch_steps') parser.add_argument('--parallel', action ='store_true', default=False, help='use multiple GPUs for training') # parser.add_argument('--joint', action ='store_true', default=True, help='train both the policy network and the resnet') parser.add_argument('--penalty', type=float, default=-5, help='gamma: reward for incorrect predictions') parser.add_argument('--alpha', type=float, default=0.8, help='probability bounding factor') args = parser.parse_args() if not os.path.exists(args.cv_dir): os.system('mkdir ' + args.cv_dir) utils.save_args(__file__, args) def get_reward(preds, targets, policy): block_use = policy.sum(1).float()/policy.size(1) sparse_reward = 1.0-block_use**2 _, pred_idx = preds.max(1) match = (pred_idx==targets).data reward = sparse_reward reward[1-match] = args.penalty reward = reward.unsqueeze(1) return reward, match.float() def train(epoch): agent.train() rnet.train() matches, rewards, policies = [], [], [] for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(trainloader), total=len(trainloader)): inputs, targets = Variable(inputs), Variable(targets).cuda(async=True) if not args.parallel: inputs = inputs.cuda() probs, value = agent(inputs) #---------------------------------------------------------------------# policy_map = probs.data.clone() policy_map[policy_map<0.5] = 0.0 policy_map[policy_map>=0.5] = 1.0 policy_map = Variable(policy_map) probs = probs*args.alpha + (1-probs)*(1-args.alpha) distr = Bernoulli(probs) policy = distr.sample() v_inputs = Variable(inputs.data, volatile=True) preds_map = rnet.forward(v_inputs, policy_map) preds_sample = rnet.forward(inputs, policy) reward_map, _ = get_reward(preds_map, targets, policy_map.data) reward_sample, match = get_reward(preds_sample, targets, policy.data) advantage = reward_sample - reward_map # advantage = advantage.expand_as(policy) loss = -distr.log_prob(policy).sum(1, keepdim=True) * Variable(advantage) loss = loss.sum() #---------------------------------------------------------------------# loss += F.cross_entropy(preds_sample, targets) optimizer.zero_grad() loss.backward() optimizer.step() matches.append(match.cpu()) rewards.append(reward_sample.cpu()) policies.append(policy.data.cpu()) accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches) log_str = 'E: %d | A: %.3f | R: %.2E | S: %.3f | V: %.3f | #: %d'%(epoch, accuracy, reward, sparsity, variance, len(policy_set)) print log_str log_value('train_accuracy', accuracy, epoch) log_value('train_reward', reward, epoch) log_value('train_sparsity', sparsity, epoch) log_value('train_variance', variance, epoch) log_value('train_unique_policies', len(policy_set), epoch) def test(epoch): agent.eval() rnet.eval() matches, rewards, policies = [], [], [] for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(testloader), total=len(testloader)): inputs, targets = Variable(inputs, volatile=True), Variable(targets).cuda(async=True) if not args.parallel: inputs = inputs.cuda() probs, _ = agent(inputs) policy = probs.data.clone() policy[policy<0.5] = 0.0 policy[policy>=0.5] = 1.0 policy = Variable(policy) preds = rnet.forward(inputs, policy) reward, match = get_reward(preds, targets, policy.data) matches.append(match) rewards.append(reward) policies.append(policy.data) accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches) log_str = 'TS - A: %.3f | R: %.2E | S: %.3f | V: %.3f | #: %d'%(accuracy, reward, sparsity, variance, len(policy_set)) print log_str log_value('test_accuracy', accuracy, epoch) log_value('test_reward', reward, epoch) log_value('test_sparsity', sparsity, epoch) log_value('test_variance', variance, epoch) log_value('test_unique_policies', len(policy_set), epoch) # save the model agent_state_dict = agent.module.state_dict() if args.parallel else agent.state_dict() rnet_state_dict = rnet.module.state_dict() if args.parallel else rnet.state_dict() state = { 'agent': agent_state_dict, 'resnet': rnet_state_dict, 'epoch': epoch, 'reward': reward, 'acc': accuracy } torch.save(state, args.cv_dir+'/ckpt_E_%d_A_%.3f_R_%.2E_S_%.2f_#_%d.t7'%(epoch, accuracy, reward, sparsity, len(policy_set))) #--------------------------------------------------------------------------------------------------------# trainset, testset = utils.get_dataset(args.model, args.data_dir) trainloader = torchdata.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4) testloader = torchdata.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=4) rnet, agent = utils.get_model(args.model) if args.pretrained is not None: checkpoint = torch.load(args.pretrained) key = 'net' if 'net' in checkpoint else 'agent' agent.load_state_dict(checkpoint[key]) print 'loaded pretrained model from', args.pretrained start_epoch = 0 if args.load is not None: checkpoint = torch.load(args.load) rnet.load_state_dict(checkpoint['resnet']) agent.load_state_dict(checkpoint['agent']) start_epoch = checkpoint['epoch'] + 1 print 'loaded agent from', args.load if args.parallel: agent = nn.DataParallel(agent) rnet = nn.DataParallel(rnet) rnet.cuda() agent.cuda() optimizer = optim.Adam(list(agent.parameters())+list(rnet.parameters()), lr=args.lr, weight_decay=args.wd) configure(args.cv_dir+'/log', flush_secs=5) lr_scheduler = utils.LrScheduler(optimizer, args.lr, args.lr_decay_ratio, args.epoch_step) for epoch in range(start_epoch, start_epoch+args.max_epochs+1): lr_scheduler.adjust_learning_rate(epoch) train(epoch) if epoch%10==0: test(epoch) ================================================ FILE: models/__init__.py ================================================ ================================================ FILE: models/base.py ================================================ import torch.nn as nn import math import torch import torchvision.models as torchmodels import re from torch.autograd import Variable import torch.optim as optim import torch.nn.functional as F import numpy as np import random import torch.nn.utils as torchutils from torch.nn import init, Parameter class Identity(nn.Module): def __init__(self): super(Identity, self).__init__() def forward(self, x): return x class Flatten(nn.Module): def __init__(self): super(Flatten, self).__init__() def forward(self, x): return x.view(x.size(0), -1) def conv3x3(in_planes, out_planes, stride=1): "3x3 convolution with padding" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) def forward(self, x): out = self.conv1(x) out = self.bn1(out) out = F.relu(out) out = self.conv2(out) out = self.bn2(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) return out class DownsampleB(nn.Module): def __init__(self, nIn, nOut, stride): super(DownsampleB, self).__init__() self.avg = nn.AvgPool2d(stride) self.expand_ratio = nOut // nIn def forward(self, x): x = self.avg(x) return torch.cat([x] + [x.mul(0)] * (self.expand_ratio - 1), 1) ================================================ FILE: models/resnet.py ================================================ import torch.nn as nn import math import torch import torchvision.models as torchmodels import re from torch.autograd import Variable import torch.nn.functional as F import numpy as np import random import torch.nn.init as torchinit import math from torch.nn import init, Parameter import copy from models import base import utils #--------------------------------------------------------------------------------------------------# class FlatResNet(nn.Module): def seed(self, x): # x = self.relu(self.bn1(self.conv1(x))) -- CIFAR # x = self.maxpool(self.relu(self.bn1(self.conv1(x)))) -- ImageNet raise NotImplementedError # run a variable policy batch through the resnet implemented as a full mask over the residual # fast to train, non-indicative of time saving (use forward_single instead) def forward(self, x, policy): x = self.seed(x) t = 0 for segment, num_blocks in enumerate(self.layer_config): for b in range(num_blocks): action = policy[:,t].contiguous() residual = self.ds[segment](x) if b==0 else x # early termination if all actions in the batch are zero if action.data.sum() == 0: x = residual t += 1 continue action_mask = action.float().view(-1,1,1,1) fx = F.relu(residual + self.blocks[segment][b](x)) x = fx*action_mask + residual*(1-action_mask) t += 1 x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x # run a single, fixed policy for all items in the batch # policy is a (15,) vector. Use with batch_size=1 for profiling def forward_single(self, x, policy): x = self.seed(x) t = 0 for segment, num_blocks in enumerate(self.layer_config): for b in range(num_blocks): residual = self.ds[segment](x) if b==0 else x if policy[t]==1: x = residual + self.blocks[segment][b](x) x = F.relu(x) else: x = residual t += 1 x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x def forward_full(self, x): x = self.seed(x) for segment, num_blocks in enumerate(self.layer_config): for b in range(num_blocks): residual = self.ds[segment](x) if b==0 else x x = F.relu(residual + self.blocks[segment][b](x)) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x # Smaller Flattened Resnet, tailored for CIFAR class FlatResNet32(FlatResNet): def __init__(self, block, layers, num_classes=10): super(FlatResNet32, self).__init__() self.inplanes = 16 self.conv1 = base.conv3x3(3, 16) self.bn1 = nn.BatchNorm2d(16) self.relu = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(8) strides = [1, 2, 2] filt_sizes = [16, 32, 64] self.blocks, self.ds = [], [] for idx, (filt_size, num_blocks, stride) in enumerate(zip(filt_sizes, layers, strides)): blocks, ds = self._make_layer(block, filt_size, num_blocks, stride=stride) self.blocks.append(nn.ModuleList(blocks)) self.ds.append(ds) self.blocks = nn.ModuleList(self.blocks) self.ds = nn.ModuleList(self.ds) self.fc = nn.Linear(64 * block.expansion, num_classes) self.fc_dim = 64 * block.expansion self.layer_config = layers for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def seed(self, x): x = self.relu(self.bn1(self.conv1(x))) return x def _make_layer(self, block, planes, blocks, stride=1): downsample = nn.Sequential() if stride != 1 or self.inplanes != planes * block.expansion: downsample = base.DownsampleB(self.inplanes, planes * block.expansion, stride) layers = [block(self.inplanes, planes, stride)] self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, 1)) return layers, downsample # Regular Flattened Resnet, tailored for Imagenet etc. class FlatResNet224(FlatResNet): def __init__(self, block, layers, num_classes=1000): self.inplanes = 64 super(FlatResNet224, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) strides = [1, 2, 2, 2] filt_sizes = [64, 128, 256, 512] self.blocks, self.ds = [], [] for idx, (filt_size, num_blocks, stride) in enumerate(zip(filt_sizes, layers, strides)): blocks, ds = self._make_layer(block, filt_size, num_blocks, stride=stride) self.blocks.append(nn.ModuleList(blocks)) self.ds.append(ds) self.blocks = nn.ModuleList(self.blocks) self.ds = nn.ModuleList(self.ds) self.avgpool = nn.AvgPool2d(7) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() self.layer_config = layers def seed(self, x): x = self.maxpool(self.relu(self.bn1(self.conv1(x)))) return x def _make_layer(self, block, planes, blocks, stride=1): downsample = nn.Sequential() if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [block(self.inplanes, planes, stride)] self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes)) return layers, downsample #---------------------------------------------------------------------------------------------------------# # Class to generate resnetNB or any other config (default is 3B) # removed the fc layer so it serves as a feature extractor class Policy32(nn.Module): def __init__(self, layer_config=[1,1,1], num_blocks=15): super(Policy32, self).__init__() self.features = FlatResNet32(base.BasicBlock, layer_config, num_classes=10) self.feat_dim = self.features.fc.weight.data.shape[1] self.features.fc = nn.Sequential() self.logit = nn.Linear(self.feat_dim, num_blocks) self.vnet = nn.Linear(self.feat_dim, 1) def load_state_dict(self, state_dict): # support legacy models state_dict = {k:v for k,v in state_dict.items() if not k.startswith('features.fc')} return super(Policy32, self).load_state_dict(state_dict) def forward(self, x): x = self.features.forward_full(x) value = self.vnet(x) probs = F.sigmoid(self.logit(x)) return probs, value class Policy224(nn.Module): def __init__(self, layer_config=[1,1,1,1], num_blocks=16): super(Policy224, self).__init__() self.features = FlatResNet224(base.BasicBlock, layer_config, num_classes=1000) resnet18 = torchmodels.resnet18(pretrained=True) utils.load_weights_to_flatresnet(resnet18, self.features) self.features.avgpool = nn.AvgPool2d(4) self.feat_dim = self.features.fc.weight.data.shape[1] self.features.fc = nn.Sequential() self.logit = nn.Linear(self.feat_dim, num_blocks) self.vnet = nn.Linear(self.feat_dim, 1) def load_state_dict(self, state_dict): # support legacy models state_dict = {k:v for k,v in state_dict.items() if not k.startswith('features.fc')} return super(Policy224, self).load_state_dict(state_dict) def forward(self, x): x = F.avg_pool2d(x, 2) x = self.features.forward_full(x) value = self.vnet(x) probs = F.sigmoid(self.logit(x)) return probs, value #--------------------------------------------------------------------------------------------------------# class StepResnet32(FlatResNet32): def __init__(self, block, layers, num_classes, joint=False): super(StepResnet, self).__init__(block, layers, num_classes) self.eval() # default to eval mode self.joint = joint self.state_ptr = {} t = 0 for segment, num_blocks in enumerate(self.layer_config): for b in range(num_blocks): self.state_ptr[t] = (segment, b) t += 1 def seed(self, x): self.state = self.relu(self.bn1(self.conv1(x))) self.t = 0 if self.joint: return self.state return Variable(self.state.data) def step(self, action): segment, b = self.state_ptr[self.t] residual = self.ds[segment](self.state) if b==0 else self.state action_mask = action.float().view(-1,1,1,1) fx = F.relu(residual + self.blocks[segment][b](self.state)) self.state = fx*action_mask + residual*(1-action_mask) self.t += 1 if self.joint: return self.state return Variable(self.state.data) def step_single(self, action): segment, b = self.state_ptr[self.t] residual = self.ds[segment](self.state) if b==0 else self.state if action.data[0,0]==1: self.state = F.relu(residual + self.blocks[segment][b](self.state)) else: self.state = residual self.t += 1 if self.joint: return self.state return Variable(self.state.data) def predict(self): x = self.avgpool(self.state) x = x.view(x.size(0), -1) x = self.fc(x) return x class StepPolicy32(nn.Module): def __init__(self, layer_config): super(StepPolicy, self).__init__() in_dim = [16] + [16]*layer_config[0] + [32]*layer_config[1] + [64]*(layer_config[2]-1) self.pnet = nn.ModuleList([nn.Linear(dim, 2) for dim in in_dim]) self.vnet = nn.ModuleList([nn.Linear(dim, 1) for dim in in_dim]) def forward(self, state): x, t = state x = F.avg_pool2d(x, x.size(2)).view(x.size(0), -1) # pool + flatten --> (B, 16/32/64) logit = F.softmax(self.pnet[t](x)) value = self.vnet[t](x) return logit, value ================================================ FILE: requirements.txt ================================================ tqdm numpy tensorboard_logger argparse ================================================ FILE: test.py ================================================ import torch from torch.autograd import Variable import torch.utils.data as torchdata import torch.nn as nn import numpy as np import tqdm import utils import torch.backends.cudnn as cudnn cudnn.benchmark = True import argparse parser = argparse.ArgumentParser() parser.add_argument('--model', default='R110_C10') parser.add_argument('--data_dir', default='data/') parser.add_argument('--load', default=None) args = parser.parse_args() #---------------------------------------------------------------------------------------------# class FConv2d(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(FConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.num_ops = 0 def forward(self, x): output = super(FConv2d, self).forward(x) output_area = output.size(-1)*output.size(-2) filter_area = np.prod(self.kernel_size) self.num_ops += 2*self.in_channels*self.out_channels*filter_area*output_area return output class FLinear(nn.Linear): def __init__(self, in_features, out_features, bias=True): super(FLinear, self).__init__(in_features, out_features, bias) self.num_ops = 0 def forward(self, x): output = super(FLinear, self).forward(x) self.num_ops += 2*self.in_features*self.out_features return output def count_flops(model, reset=True): op_count = 0 for m in model.modules(): if hasattr(m, 'num_ops'): op_count += m.num_ops if reset: # count and reset to 0 m.num_ops = 0 return op_count # replace all nn.Conv and nn.Linear layers with layers that count flops nn.Conv2d = FConv2d nn.Linear = FLinear #--------------------------------------------------------------------------------------------# def test(): total_ops = [] matches, policies = [], [] for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(testloader), total=len(testloader)): inputs, targets = Variable(inputs, volatile=True).cuda(), Variable(targets).cuda() probs, _ = agent(inputs) policy = probs.clone() policy[policy<0.5] = 0.0 policy[policy>=0.5] = 1.0 preds = rnet.forward_single(inputs, policy.data.squeeze(0)) _ , pred_idx = preds.max(1) match = (pred_idx==targets).data.float() matches.append(match) policies.append(policy.data) ops = count_flops(agent) + count_flops(rnet) total_ops.append(ops) accuracy, _, sparsity, variance, policy_set = utils.performance_stats(policies, matches, matches) ops_mean, ops_std = np.mean(total_ops), np.std(total_ops) log_str = u''' Accuracy: %.3f Block Usage: %.3f \u00B1 %.3f FLOPs/img: %.2E \u00B1 %.2E Unique Policies: %d '''%(accuracy, sparsity, variance, ops_mean, ops_std, len(policy_set)) print log_str #--------------------------------------------------------------------------------------------------------# trainset, testset = utils.get_dataset(args.model, args.data_dir) testloader = torchdata.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4) rnet, agent = utils.get_model(args.model) # if no model is loaded, use all blocks agent.logit.weight.data.fill_(0) agent.logit.bias.data.fill_(10) print "loading checkpoints" if args.load is not None: utils.load_checkpoint(rnet, agent, args.load) rnet.eval().cuda() agent.eval().cuda() test() ================================================ FILE: utils.py ================================================ import os import re import torch import torchvision.transforms as transforms import torchvision.datasets as torchdata import numpy as np # Save the training script and all the arguments to a file so that you # don't feel like an idiot later when you can't replicate results import shutil def save_args(__file__, args): shutil.copy(os.path.basename(__file__), args.cv_dir) with open(args.cv_dir+'/args.txt','w') as f: f.write(str(args)) def performance_stats(policies, rewards, matches): policies = torch.cat(policies, 0) rewards = torch.cat(rewards, 0) accuracy = torch.cat(matches, 0).mean() reward = rewards.mean() sparsity = policies.sum(1).mean() variance = policies.sum(1).std() policy_set = [p.cpu().numpy().astype(np.int).astype(np.str) for p in policies] policy_set = set([''.join(p) for p in policy_set]) return accuracy, reward, sparsity, variance, policy_set class LrScheduler: def __init__(self, optimizer, base_lr, lr_decay_ratio, epoch_step): self.base_lr = base_lr self.lr_decay_ratio = lr_decay_ratio self.epoch_step = epoch_step self.optimizer = optimizer def adjust_learning_rate(self, epoch): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" lr = self.base_lr * (self.lr_decay_ratio ** (epoch // self.epoch_step)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr if epoch%self.epoch_step==0: print '# setting learning_rate to %.2E'%lr # load model weights trained using scripts from https://github.com/felixgwu/img_classification_pk_pytorch OR # from torchvision models into our flattened resnets def load_weights_to_flatresnet(source_model, target_model): # compatibility for nn.Modules + checkpoints if hasattr(source_model, 'state_dict'): source_model = {'state_dict': source_model.state_dict()} source_state = source_model['state_dict'] target_state = target_model.state_dict() # remove the module. prefix if it exists (thanks nn.DataParallel) if source_state.keys()[0].startswith('module.'): source_state = {k[7:]:v for k,v in source_state.items()} common = set(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var','fc.weight', 'fc.bias']) for key in source_state.keys(): if key in common: target_state[key] = source_state[key] continue if 'downsample' in key: layer, num, item = re.match('layer(\d+).*\.(\d+)\.(.*)', key).groups() translated = 'ds.%s.%s.%s'%(int(layer)-1, num, item) else: layer, item = re.match('layer(\d+)\.(.*)', key).groups() translated = 'blocks.%s.%s'%(int(layer)-1, item) if translated in target_state.keys(): target_state[translated] = source_state[key] else: print translated, 'block missing' target_model.load_state_dict(target_state) return target_model def load_checkpoint(rnet, agent, load): if load=='nil': return None checkpoint = torch.load(load) if 'resnet' in checkpoint: rnet.load_state_dict(checkpoint['resnet']) print 'loaded resnet from', os.path.basename(load) if 'agent' in checkpoint: agent.load_state_dict(checkpoint['agent']) print 'loaded agent from', os.path.basename(load) # backward compatibility (some old checkpoints) if 'net' in checkpoint: checkpoint['net'] = {k:v for k,v in checkpoint['net'].items() if 'features.fc' not in k} agent.load_state_dict(checkpoint['net']) print 'loaded agent from', os.path.basename(load) def get_transforms(rnet, dset): # Only the R32 pretrained model subtracts the mean, sorry :( if dset=='C10' and rnet=='R32': mean = [x/255.0 for x in [125.3, 123.0, 113.9]] std = [x/255.0 for x in [63.0, 62.1, 66.7]] transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std) ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std) ]) elif dset=='C100' or dset=='C10' and rnet!='R32': transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) transform_test = transforms.Compose([ transforms.ToTensor(), ]) elif dset=='ImgNet': mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] transform_train = transforms.Compose([ transforms.Scale(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std) ]) transform_test = transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean, std) ]) return transform_train, transform_test # Pick from the datasets available and the hundreds of models we have lying around depending on the requirements. def get_dataset(model, root='data/'): rnet, dset = model.split('_') transform_train, transform_test = get_transforms(rnet, dset) if dset=='C10': trainset = torchdata.CIFAR10(root=root, train=True, download=True, transform=transform_train) testset = torchdata.CIFAR10(root=root, train=False, download=True, transform=transform_test) elif dset=='C100': trainset = torchdata.CIFAR100(root=root, train=True, download=True, transform=transform_train) testset = torchdata.CIFAR100(root=root, train=False, download=True, transform=transform_test) elif dset=='ImgNet': trainset = torchdata.ImageFolder(root+'/train/', transform_train) testset = torchdata.ImageFolder(root+'/val/', transform_test) return trainset, testset # Make a new if statement for every new model variety you want to index def get_model(model): from models import resnet, base if model=='R32_C10': rnet_checkpoint = 'cv/pretrained/R32_C10/pk_E_164_A_0.923.t7' layer_config = [5, 5, 5] rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=10) agent = resnet.Policy32([1,1,1], num_blocks=15) elif model=='R110_C10': rnet_checkpoint = 'cv/pretrained/R110_C10/pk_E_130_A_0.932.t7' layer_config = [18, 18, 18] rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=10) agent = resnet.Policy32([1,1,1], num_blocks=54) elif model=='R32_C100': rnet_checkpoint = 'cv/pretrained/R32_C100/pk_E_164_A_0.693.t7' layer_config = [5, 5, 5] rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=100) agent = resnet.Policy32([1,1,1], num_blocks=15) elif model=='R110_C100': rnet_checkpoint = 'cv/pretrained/R110_C100/pk_E_160_A_0.723.t7' layer_config = [18, 18, 18] rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=100) agent = resnet.Policy32([1,1,1], num_blocks=54) elif model=='R101_ImgNet': rnet_checkpoint = 'cv/pretrained/R101_ImgNet/ImageNet_R101_224_76.464' layer_config = [3,4,23,3] rnet = resnet.FlatResNet224(base.Bottleneck, layer_config, num_classes=1000) agent = resnet.Policy224([1,1,1,1], num_blocks=33) # load pretrained weights into flat ResNet rnet_checkpoint = torch.load(rnet_checkpoint) load_weights_to_flatresnet(rnet_checkpoint, rnet) return rnet, agent