[
  {
    "path": ".gitignore",
    "content": "*.*~\n*.pyc\n*.pkl\n*.h5\n*.t7\n*.t7*\n*.log\n*.png\n\ndata/\ncv/\n\n"
  },
  {
    "path": "Readme.md",
    "content": "# BlockDrop: Dynamic Inference Paths in Residual Networks\n![BlockDrop Model](https://user-images.githubusercontent.com/4995097/35775877-3cc64f86-0957-11e8-85c4-9bd16cda22a0.png)\n\nThis code implements a policy network that learns to dynamically choose which blocks of a ResNet to execute during inference so as to best reduce total computation without degrading prediction accuracy. Built upon a ResNet-101 model, our method achieves a speedup of 20% on average, going as high as 36% for some images, while maintaining the same 76.4% top-1 accuracy on ImageNet.\n\nThis is the code accompanying the work:  \nZuxuan Wu*, Tushar Nagarajan*, Abhishek Kumar, Steven Rennie, Larry S. Davis, Kristen Grauman, and Rogerio Feris. BlockDrop: Dynamic Inference Paths in Residual Networks [[arxiv]](https://arxiv.org/pdf/1711.08393.pdf)  \n(* authors contributed equally)\n\n## Prerequisites\nThe code is written and tested using Python (2.7) and PyTorch (v0.3.0).\n\n**Packages**: Install using `pip install -r requirements.txt`\n\n**Pretrained models**: Our models require standard pretrained ResNets on CIFAR and ImageNet as starting points. These can be trained using [this](https://github.com/felixgwu/img_classification_pk_pytorch) repository, or can be obtained directly from us\n\n```bash\nwget -O blockdrop-checkpoints.tar.gz https://utexas.box.com/shared/static/ok98i51v14c0q9lvs1z5g71m6b3zm8sj.gz\ntar -zxvf blockdrop-checkpoints.tar.gz\n```\nThe downloaded checkpoints will be unpacked to `./cv/` for further use. The folder also contains various checkpoints from each stage of training.\n\n**Datasets**: PyTorch's *torchvision* package automatically downloads CIFAR10 and CIFAR100 during training. ImageNet must be downloaded and organized following [these steps](https://github.com/soumith/imagenet-multiGPU.torch#data-processing).\n\n## Training a model\nTraining occurs in two steps (1) Curriculum Learning and (2) Joint Finetuning.  \nModels operating on ResNets of different depths can be trained on different datasets using the same script. Examples of how to train these models are given below. Checkpoints and tensorboard log files will be saved to folder specified in `--cv_dir`\n\n#### Curriculum Learning\nThe policy network can be trained using a CL schedule as follows.\n\n```bash\n# Train a model on CIFAR 10 built upon a ResNet-110\npython cl_training.py --model R110_C10 --cv_dir cv/R110_C10_cl/ --lr 1e-3 --batch_size 2048 --max_epochs 5000\n\n# Train a model on ImageNet built upon a ResNet-101\npython cl_training.py --model R101_ImgNet --cv_dir cv/R101_ImgNet_cl/ --lr 1e-3 --batch_size 2048 --max_epochs 45 --data_dir data/imagenet/\n```\n\nModel checkpoints after the curriculum learning step can be found in the downloaded folder. For example: `./cv/cl_learning/R110_C10/ckpt_E_5300_A_0.754_R_2.22E-01_S_20.10_#_7787.t7`\n\n#### Joint Finetuning\nCheckpoints trained during the curriculum learning phase can be used to further jointly finetune the base ResNet to achieve the results reported in the paper. Different values for the penalty parameter control the trade-off between accuracy and speed.\n\n```bash\n# Finetune a ResNet-110 on CIFAR 10 using the checkpoint from cl_training\npython finetune.py --model R110_C10 --lr 1e-4 --penalty -10 --pretrained cv/cl_training/R110_C10/ckpt_E_5300_A_0.754_R_2.22E-01_S_20.10_#_7787.t7 --batch_size 256  --max_epochs 2000 --cv_dir cv/R110_C10_ft_-10/\n\n# Finetune a ResNet-101 on ImageNet using the checkpoint from cl_training\npython finetune.py --model R101_ImgNet --lr 1e-4  --penalty -5 --pretrained cv/cl_training/R101_ImgNet/ckpt_E_4_A_0.746_R_-3.70E-01_S_29.79_#_484.t7 --data_dir data/imagenet/ --batch_size 320 --max_epochs 10 --cv_dir cv/R101_ImgNet_ft_-5/\n```\n\nModel checkpoints after the joint finetuning step can be found in the downloaded folder. For example: `./cv/finetuned/R101_ImgNet_gamma_5/ckpt_E_10_A_0.764_R_-8.46E-01_S_24.77_#_10.t7`\n\n## Testing and Profiling\nOnce jointly finetuned, models can be profiled for accuracy and FLOPs counts.\n```bash\npython test.py --model R110_C10 --load cv/finetuned/R110_C10_gamma_10/ckpt_E_2000_A_0.936_R_1.95E-01_S_16.93_#_469.t7\n```\nThe model should produce an accuracy of 93.6% and use 1.81E+08 FLOPs on average. The output should look like this:\n```\n    Accuracy: 0.936\n    Block Usage: 16.933 ± 3.717\n    FLOPs/img: 1.81E+08 ± 3.43E+07\n    Unique Policies: 469\n```\n\nThe ImageNet model can be evaluated in a similar manner, and will generate a corresponding output.\n```\npython test.py --model R101_ImgNet --load cv/finetuned/R101_ImgNet_gamma_5/ckpt_E_10_A_0.764_R_-8.46E-01_S_24.77_#_10.t7\n```\n```\n    Accuracy: 0.764\n    Block Usage: 24.770 ± 0.980\n    FLOPs/img: 1.25E+10 ± 4.28E+08\n    Unique Policies: 10\n```\n\n\n## Visualization\nLearned policies over ResNet blocks show that there is a clear separation between easy/hard images in terms of the number of blocks they require. In addition, unique policies over the blocks admit distinct image styles.\n\n![Policy visualization](https://user-images.githubusercontent.com/4995097/35775878-3e5ee4e8-0957-11e8-832d-b9dc2ea8fecc.png)\n\nFor more qualitative results, see Sec. 4.3 and Figures 4. and 5. in the paper.\n\n\n\n## Cite\n\nIf you find this repository useful in your own research, please consider citing:\n```\n@inproceedings{blockdrop,\n  title={BlockDrop: Dynamic Inference Paths in Residual Networks},\n  author={Wu, Zuxuan and Nagarajan, Tushar and Kumar, Abhishek and Rennie, Steven and Davis, Larry S and Grauman, Kristen and Feris, Rogerio},\n  booktitle={CVPR},\n  year={2018}\n}\n```\n"
  },
  {
    "path": "cl_training.py",
    "content": "import os\r\nfrom tensorboard_logger import configure, log_value\r\nimport torch\r\nimport torch.autograd as autograd\r\nfrom torch.autograd import Variable\r\nimport torch.utils.data as torchdata\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport numpy as np\r\nimport tqdm\r\nimport utils\r\nimport torch.optim as optim\r\nfrom torch.distributions import Bernoulli\r\n\r\nimport torch.backends.cudnn as cudnn\r\ncudnn.benchmark = True\r\n\r\n\r\nimport argparse\r\nparser = argparse.ArgumentParser(description='BlockDrop Training')\r\nparser.add_argument('--lr', type=float, default=1e-3, help='learning rate')\r\nparser.add_argument('--beta', type=float, default=1e-1, help='entropy multiplier')\r\nparser.add_argument('--wd', type=float, default=0.0, help='weight decay')\r\nparser.add_argument('--model', default='R110_C10', help='R<depth>_<dataset> see utils.py for a list of configurations')\r\nparser.add_argument('--data_dir', default='data/', help='data directory')\r\nparser.add_argument('--load', default=None, help='checkpoint to load agent from')\r\nparser.add_argument('--cv_dir', default='cv/tmp/', help='checkpoint directory (models and logs are saved here)')\r\nparser.add_argument('--batch_size', type=int, default=256, help='batch size')\r\nparser.add_argument('--epoch_step', type=int, default=10000, help='epochs after which lr is decayed')\r\nparser.add_argument('--max_epochs', type=int, default=10000, help='total epochs to run')\r\nparser.add_argument('--lr_decay_ratio', type=float, default=0.1, help='lr *= lr_decay_ratio after epoch_steps')\r\nparser.add_argument('--parallel', action ='store_true', default=False, help='use multiple GPUs for training')\r\nparser.add_argument('--cl_step', type=int, default=1, help='steps for curriculum training')\r\n# parser.add_argument('--joint', action ='store_true', default=True, help='train both the policy network and the resnet')\r\nparser.add_argument('--penalty', type=float, default=-1, help='gamma: reward for incorrect predictions')\r\nparser.add_argument('--alpha', type=float, default=0.8, help='probability bounding factor')\r\nargs = parser.parse_args()\r\n\r\nif not os.path.exists(args.cv_dir):\r\n    os.system('mkdir ' + args.cv_dir)\r\nutils.save_args(__file__, args)\r\n\r\ndef get_reward(preds, targets, policy):\r\n\r\n    block_use = policy.sum(1).float()/policy.size(1)\r\n    sparse_reward = 1.0-block_use**2\r\n\r\n    _, pred_idx = preds.max(1)\r\n    match = (pred_idx==targets).data\r\n\r\n    reward = sparse_reward\r\n    reward[1-match] = args.penalty\r\n    reward = reward.unsqueeze(1)\r\n\r\n    return reward, match.float()\r\n\r\n\r\ndef train(epoch):\r\n\r\n    agent.train()\r\n\r\n    matches, rewards, policies = [], [], []\r\n    for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(trainloader), total=len(trainloader)):\r\n\r\n        inputs, targets = Variable(inputs), Variable(targets).cuda(async=True)\r\n        if not args.parallel:\r\n            inputs = inputs.cuda()\r\n\r\n        probs, value = agent(inputs)\r\n\r\n        #---------------------------------------------------------------------#\r\n\r\n        policy_map = probs.data.clone()\r\n        policy_map[policy_map<0.5] = 0.0\r\n        policy_map[policy_map>=0.5] = 1.0\r\n        policy_map = Variable(policy_map)\r\n\r\n        probs = probs*args.alpha + (1-probs)*(1-args.alpha)\r\n        distr = Bernoulli(probs)\r\n        policy = distr.sample()\r\n\r\n        if args.cl_step < num_blocks:\r\n            policy[:, :-args.cl_step] = 1\r\n            policy_map[:, :-args.cl_step] = 1\r\n\r\n            policy_mask = Variable(torch.ones(inputs.size(0), policy.size(1))).cuda()\r\n            policy_mask[:, :-args.cl_step] = 0\r\n        else:\r\n            policy_mask = None\r\n\r\n        v_inputs = Variable(inputs.data, volatile=True)\r\n        preds_map = rnet.forward(v_inputs, policy_map)\r\n        preds_sample = rnet.forward(v_inputs, policy)\r\n\r\n        reward_map, _ = get_reward(preds_map, targets, policy_map.data)\r\n        reward_sample, match = get_reward(preds_sample, targets, policy.data)\r\n\r\n        advantage = reward_sample - reward_map\r\n\r\n        loss = -distr.log_prob(policy)\r\n        loss = loss * Variable(advantage).expand_as(policy)\r\n\r\n        if policy_mask is not None:\r\n            loss = policy_mask * loss # mask for curriculum learning\r\n\r\n        loss = loss.sum()\r\n\r\n        probs = probs.clamp(1e-15, 1-1e-15)\r\n        entropy_loss = -probs*torch.log(probs)\r\n        entropy_loss = args.beta*entropy_loss.sum()\r\n\r\n        loss = (loss - entropy_loss)/inputs.size(0)\r\n\r\n        #---------------------------------------------------------------------#\r\n\r\n        optimizer.zero_grad()\r\n        loss.backward()\r\n        optimizer.step()\r\n\r\n        matches.append(match.cpu())\r\n        rewards.append(reward_sample.cpu())\r\n        policies.append(policy.data.cpu())\r\n\r\n    accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches)\r\n\r\n    log_str = 'E: %d | A: %.3f | R: %.2E | S: %.3f | V: %.3f | #: %d'%(epoch, accuracy, reward, sparsity, variance, len(policy_set))\r\n    print log_str\r\n\r\n    log_value('train_accuracy', accuracy, epoch)\r\n    log_value('train_reward', reward, epoch)\r\n    log_value('train_sparsity', sparsity, epoch)\r\n    log_value('train_variance', variance, epoch)\r\n    log_value('train_unique_policies', len(policy_set), epoch)\r\n\r\n\r\ndef test(epoch):\r\n\r\n    agent.eval()\r\n\r\n    matches, rewards, policies = [], [], []\r\n    for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(testloader), total=len(testloader)):\r\n\r\n        inputs, targets = Variable(inputs, volatile=True), Variable(targets).cuda(async=True)\r\n        if not args.parallel:\r\n            inputs = inputs.cuda()\r\n\r\n        probs, _ = agent(inputs)\r\n\r\n        policy = probs.data.clone()\r\n        policy[policy<0.5] = 0.0\r\n        policy[policy>=0.5] = 1.0\r\n        policy = Variable(policy)\r\n\r\n        if args.cl_step < num_blocks:\r\n            policy[:, :-args.cl_step] = 1\r\n\r\n        preds = rnet.forward(inputs, policy)\r\n        reward, match = get_reward(preds, targets, policy.data)\r\n\r\n        matches.append(match)\r\n        rewards.append(reward)\r\n        policies.append(policy.data)\r\n\r\n    accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches)\r\n\r\n    log_str = 'TS - A: %.3f | R: %.2E | S: %.3f | V: %.3f | #: %d'%(accuracy, reward, sparsity, variance, len(policy_set))\r\n    print log_str\r\n\r\n    log_value('test_accuracy', accuracy, epoch)\r\n    log_value('test_reward', reward, epoch)\r\n    log_value('test_sparsity', sparsity, epoch)\r\n    log_value('test_variance', variance, epoch)\r\n    log_value('test_unique_policies', len(policy_set), epoch)\r\n\r\n    # save the model\r\n    agent_state_dict = agent.module.state_dict() if args.parallel else agent.state_dict()\r\n\r\n    state = {\r\n      'agent': agent_state_dict,\r\n      'epoch': epoch,\r\n      'reward': reward,\r\n      'acc': accuracy\r\n    }\r\n    torch.save(state, args.cv_dir+'/ckpt_E_%d_A_%.3f_R_%.2E_S_%.2f_#_%d.t7'%(epoch, accuracy, reward, sparsity, len(policy_set)))\r\n\r\n\r\n#--------------------------------------------------------------------------------------------------------#\r\ntrainset, testset = utils.get_dataset(args.model, args.data_dir)\r\ntrainloader = torchdata.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4)\r\ntestloader = torchdata.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=4)\r\nrnet, agent = utils.get_model(args.model)\r\nnum_blocks = sum(rnet.layer_config)\r\n\r\nstart_epoch = 0\r\nif args.load is not None:\r\n    checkpoint = torch.load(args.load)\r\n    agent.load_state_dict(checkpoint['agent'])\r\n    start_epoch = checkpoint['epoch'] + 1\r\n    print 'loaded agent from', args.load\r\n\r\nif args.parallel:\r\n    agent = nn.DataParallel(agent)\r\n    rnet = nn.DataParallel(rnet)\r\n\r\nrnet.eval().cuda()\r\nagent.cuda()\r\n\r\noptimizer = optim.Adam(agent.parameters(), lr=args.lr, weight_decay=args.wd)\r\n\r\nconfigure(args.cv_dir+'/log', flush_secs=5)\r\nlr_scheduler = utils.LrScheduler(optimizer, args.lr, args.lr_decay_ratio, args.epoch_step)\r\nfor epoch in range(start_epoch, start_epoch+args.max_epochs+1):\r\n    lr_scheduler.adjust_learning_rate(epoch)\r\n\r\n    if args.cl_step < num_blocks:\r\n        args.cl_step = 1 + 1 * (epoch // 1)\r\n    else:\r\n        args.cl_step = num_blocks\r\n\r\n    print 'training the last %d blocks ...' % args.cl_step\r\n    train(epoch)\r\n\r\n    if epoch % 10 == 0:\r\n        test(epoch)\r\n"
  },
  {
    "path": "finetune.py",
    "content": "import os\r\nfrom tensorboard_logger import configure, log_value\r\nimport torch\r\nimport torch.autograd as autograd\r\nfrom torch.autograd import Variable\r\nimport torch.utils.data as torchdata\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport numpy as np\r\nimport tqdm\r\nimport utils\r\nimport torch.optim as optim\r\nfrom torch.distributions import Bernoulli\r\n\r\nimport torch.backends.cudnn as cudnn\r\ncudnn.benchmark = True\r\n\r\n\r\nimport argparse\r\nparser = argparse.ArgumentParser(description='BlockDrop Training')\r\nparser.add_argument('--lr', type=float, default=1e-4, help='learning rate')\r\nparser.add_argument('--wd', type=float, default=0.0, help='weight decay')\r\nparser.add_argument('--model', default='R110_C10', help='R<depth>_<dataset> see utils.py for a list of configurations')\r\nparser.add_argument('--data_dir', default='data/', help='data directory')\r\nparser.add_argument('--load', default=None, help='checkpoint to load rnet+agent from')\r\nparser.add_argument('--pretrained', default=None, help='pretrained policy model checkpoint (from curriculum training)')\r\nparser.add_argument('--cv_dir', default='cv/tmp/', help='checkpoint directory (models and logs are saved here)')\r\nparser.add_argument('--batch_size', type=int, default=256, help='batch size')\r\nparser.add_argument('--epoch_step', type=int, default=1600, help='epochs after which lr is decayed')\r\nparser.add_argument('--max_epochs', type=int, default=2000, help='total epochs to run')\r\nparser.add_argument('--lr_decay_ratio', type=float, default=0.1, help='lr *= lr_decay_ratio after epoch_steps')\r\nparser.add_argument('--parallel', action ='store_true', default=False, help='use multiple GPUs for training')\r\n# parser.add_argument('--joint', action ='store_true', default=True, help='train both the policy network and the resnet')\r\nparser.add_argument('--penalty', type=float, default=-5, help='gamma: reward for incorrect predictions')\r\nparser.add_argument('--alpha', type=float, default=0.8, help='probability bounding factor')\r\nargs = parser.parse_args()\r\n\r\nif not os.path.exists(args.cv_dir):\r\n    os.system('mkdir ' + args.cv_dir)\r\nutils.save_args(__file__, args)\r\n\r\ndef get_reward(preds, targets, policy):\r\n\r\n    block_use = policy.sum(1).float()/policy.size(1)\r\n    sparse_reward = 1.0-block_use**2\r\n\r\n    _, pred_idx = preds.max(1)\r\n    match = (pred_idx==targets).data\r\n\r\n    reward = sparse_reward\r\n    reward[1-match] = args.penalty\r\n    reward = reward.unsqueeze(1)\r\n\r\n    return reward, match.float()\r\n\r\n\r\ndef train(epoch):\r\n\r\n    agent.train()\r\n    rnet.train()\r\n\r\n    matches, rewards, policies = [], [], []\r\n    for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(trainloader), total=len(trainloader)):\r\n\r\n        inputs, targets = Variable(inputs), Variable(targets).cuda(async=True)\r\n        if not args.parallel:\r\n            inputs = inputs.cuda()\r\n\r\n        probs, value = agent(inputs)\r\n\r\n        #---------------------------------------------------------------------#\r\n\r\n        policy_map = probs.data.clone()\r\n        policy_map[policy_map<0.5] = 0.0\r\n        policy_map[policy_map>=0.5] = 1.0\r\n        policy_map = Variable(policy_map)\r\n\r\n        probs = probs*args.alpha + (1-probs)*(1-args.alpha)\r\n        distr = Bernoulli(probs)\r\n        policy = distr.sample()\r\n\r\n        v_inputs = Variable(inputs.data, volatile=True)\r\n        preds_map = rnet.forward(v_inputs, policy_map)\r\n        preds_sample = rnet.forward(inputs, policy)\r\n\r\n        reward_map, _ = get_reward(preds_map, targets, policy_map.data)\r\n        reward_sample, match = get_reward(preds_sample, targets, policy.data)\r\n\r\n        advantage = reward_sample - reward_map\r\n        # advantage = advantage.expand_as(policy)\r\n        loss = -distr.log_prob(policy).sum(1, keepdim=True) * Variable(advantage)\r\n        loss = loss.sum()\r\n\r\n        #---------------------------------------------------------------------#\r\n        loss += F.cross_entropy(preds_sample, targets)\r\n\r\n        optimizer.zero_grad()\r\n        loss.backward()\r\n        optimizer.step()\r\n\r\n        matches.append(match.cpu())\r\n        rewards.append(reward_sample.cpu())\r\n        policies.append(policy.data.cpu())\r\n\r\n    accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches)\r\n\r\n    log_str = 'E: %d | A: %.3f | R: %.2E | S: %.3f | V: %.3f | #: %d'%(epoch, accuracy, reward, sparsity, variance, len(policy_set))\r\n    print log_str\r\n\r\n    log_value('train_accuracy', accuracy, epoch)\r\n    log_value('train_reward', reward, epoch)\r\n    log_value('train_sparsity', sparsity, epoch)\r\n    log_value('train_variance', variance, epoch)\r\n    log_value('train_unique_policies', len(policy_set), epoch)\r\n\r\n\r\ndef test(epoch):\r\n\r\n    agent.eval()\r\n    rnet.eval()\r\n\r\n    matches, rewards, policies = [], [], []\r\n    for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(testloader), total=len(testloader)):\r\n\r\n        inputs, targets = Variable(inputs, volatile=True), Variable(targets).cuda(async=True)\r\n        if not args.parallel:\r\n            inputs = inputs.cuda()\r\n\r\n        probs, _ = agent(inputs)\r\n\r\n        policy = probs.data.clone()\r\n        policy[policy<0.5] = 0.0\r\n        policy[policy>=0.5] = 1.0\r\n        policy = Variable(policy)\r\n\r\n        preds = rnet.forward(inputs, policy)\r\n        reward, match = get_reward(preds, targets, policy.data)\r\n\r\n        matches.append(match)\r\n        rewards.append(reward)\r\n        policies.append(policy.data)\r\n\r\n    accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(policies, rewards, matches)\r\n\r\n    log_str = 'TS - A: %.3f | R: %.2E | S: %.3f | V: %.3f | #: %d'%(accuracy, reward, sparsity, variance, len(policy_set))\r\n    print log_str\r\n\r\n    log_value('test_accuracy', accuracy, epoch)\r\n    log_value('test_reward', reward, epoch)\r\n    log_value('test_sparsity', sparsity, epoch)\r\n    log_value('test_variance', variance, epoch)\r\n    log_value('test_unique_policies', len(policy_set), epoch)\r\n\r\n    # save the model\r\n    agent_state_dict = agent.module.state_dict() if args.parallel else agent.state_dict()\r\n    rnet_state_dict = rnet.module.state_dict() if args.parallel else rnet.state_dict()\r\n\r\n    state = {\r\n      'agent': agent_state_dict,\r\n      'resnet': rnet_state_dict,\r\n      'epoch': epoch,\r\n      'reward': reward,\r\n      'acc': accuracy\r\n    }\r\n    torch.save(state, args.cv_dir+'/ckpt_E_%d_A_%.3f_R_%.2E_S_%.2f_#_%d.t7'%(epoch, accuracy, reward, sparsity, len(policy_set)))\r\n\r\n\r\n#--------------------------------------------------------------------------------------------------------#\r\ntrainset, testset = utils.get_dataset(args.model, args.data_dir)\r\ntrainloader = torchdata.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4)\r\ntestloader = torchdata.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=4)\r\nrnet, agent = utils.get_model(args.model)\r\n\r\nif args.pretrained is not None:\r\n    checkpoint = torch.load(args.pretrained)\r\n    key = 'net' if 'net' in checkpoint else 'agent'\r\n    agent.load_state_dict(checkpoint[key])\r\n    print 'loaded pretrained model from', args.pretrained\r\n\r\nstart_epoch = 0\r\nif args.load is not None:\r\n    checkpoint = torch.load(args.load)\r\n    rnet.load_state_dict(checkpoint['resnet'])\r\n    agent.load_state_dict(checkpoint['agent'])\r\n    start_epoch = checkpoint['epoch'] + 1\r\n    print 'loaded agent from', args.load\r\n\r\n\r\nif args.parallel:\r\n    agent = nn.DataParallel(agent)\r\n    rnet = nn.DataParallel(rnet)\r\n\r\nrnet.cuda()\r\nagent.cuda()\r\n\r\noptimizer = optim.Adam(list(agent.parameters())+list(rnet.parameters()), lr=args.lr, weight_decay=args.wd)\r\n\r\nconfigure(args.cv_dir+'/log', flush_secs=5)\r\nlr_scheduler = utils.LrScheduler(optimizer, args.lr, args.lr_decay_ratio, args.epoch_step)\r\nfor epoch in range(start_epoch, start_epoch+args.max_epochs+1):\r\n    lr_scheduler.adjust_learning_rate(epoch)\r\n\r\n    train(epoch)\r\n    if epoch%10==0:\r\n        test(epoch)\r\n"
  },
  {
    "path": "models/__init__.py",
    "content": ""
  },
  {
    "path": "models/base.py",
    "content": "import torch.nn as nn\r\nimport math\r\nimport torch\r\nimport torchvision.models as torchmodels\r\nimport re\r\nfrom torch.autograd import Variable\r\nimport torch.optim as optim\r\nimport torch.nn.functional as F\r\nimport numpy as np\r\nimport random\r\nimport torch.nn.utils as torchutils\r\nfrom torch.nn import init, Parameter\r\n\r\n\r\nclass Identity(nn.Module):\r\n    def __init__(self):\r\n        super(Identity, self).__init__()\r\n    def forward(self, x):\r\n        return x\r\n\r\nclass Flatten(nn.Module):\r\n    def __init__(self):\r\n        super(Flatten, self).__init__()\r\n    def forward(self, x):\r\n        return x.view(x.size(0), -1)\r\n\r\ndef conv3x3(in_planes, out_planes, stride=1):\r\n    \"3x3 convolution with padding\"\r\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\r\n\r\nclass BasicBlock(nn.Module):\r\n    expansion = 1\r\n\r\n    def __init__(self, inplanes, planes, stride=1):\r\n        super(BasicBlock, self).__init__()\r\n        self.conv1 = conv3x3(inplanes, planes, stride)\r\n        self.bn1 = nn.BatchNorm2d(planes)\r\n        self.conv2 = conv3x3(planes, planes)\r\n        self.bn2 = nn.BatchNorm2d(planes)\r\n\r\n    def forward(self, x):\r\n       \r\n        out = self.conv1(x)\r\n        out = self.bn1(out)\r\n        out = F.relu(out)\r\n\r\n        out = self.conv2(out)\r\n        out = self.bn2(out)\r\n\r\n        return out\r\n\r\nclass Bottleneck(nn.Module):\r\n    expansion = 4\r\n\r\n    def __init__(self, inplanes, planes, stride=1):\r\n        super(Bottleneck, self).__init__()\r\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\r\n        self.bn1 = nn.BatchNorm2d(planes)\r\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\r\n        self.bn2 = nn.BatchNorm2d(planes)\r\n        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\r\n        self.bn3 = nn.BatchNorm2d(planes * 4)\r\n        self.relu = nn.ReLU(inplace=True)\r\n\r\n    def forward(self, x):\r\n        residual = x\r\n\r\n        out = self.conv1(x)\r\n        out = self.bn1(out)\r\n        out = self.relu(out)\r\n\r\n        out = self.conv2(out)\r\n        out = self.bn2(out)\r\n        out = self.relu(out)\r\n\r\n        out = self.conv3(out)\r\n        out = self.bn3(out)\r\n\r\n        return out\r\n\r\nclass DownsampleB(nn.Module):\r\n\r\n    def __init__(self, nIn, nOut, stride):\r\n        super(DownsampleB, self).__init__()\r\n        self.avg = nn.AvgPool2d(stride)\r\n        self.expand_ratio = nOut // nIn\r\n\r\n    def forward(self, x):\r\n        x = self.avg(x)\r\n        return torch.cat([x] + [x.mul(0)] * (self.expand_ratio - 1), 1)\r\n\r\n\r\n\r\n"
  },
  {
    "path": "models/resnet.py",
    "content": "import torch.nn as nn\r\nimport math\r\nimport torch\r\nimport torchvision.models as torchmodels\r\nimport re\r\nfrom torch.autograd import Variable\r\nimport torch.nn.functional as F\r\nimport numpy as np\r\nimport random\r\nimport torch.nn.init as torchinit\r\nimport math\r\nfrom torch.nn import init, Parameter\r\nimport copy\r\n\r\nfrom models import base\r\nimport utils\r\n\r\n#--------------------------------------------------------------------------------------------------#\r\nclass FlatResNet(nn.Module):\r\n\r\n    def seed(self, x):\r\n        # x = self.relu(self.bn1(self.conv1(x))) -- CIFAR\r\n        # x = self.maxpool(self.relu(self.bn1(self.conv1(x)))) -- ImageNet\r\n        raise NotImplementedError\r\n\r\n    # run a variable policy batch through the resnet implemented as a full mask over the residual\r\n    # fast to train, non-indicative of time saving (use forward_single instead)\r\n    def forward(self, x, policy):\r\n\r\n        x = self.seed(x)\r\n\r\n        t = 0\r\n        for segment, num_blocks in enumerate(self.layer_config):\r\n            for b in range(num_blocks):\r\n                action = policy[:,t].contiguous()\r\n                residual = self.ds[segment](x) if b==0 else x\r\n\r\n                # early termination if all actions in the batch are zero\r\n                if action.data.sum() == 0:\r\n                    x = residual\r\n                    t += 1\r\n                    continue\r\n\r\n                action_mask = action.float().view(-1,1,1,1)\r\n                fx = F.relu(residual + self.blocks[segment][b](x))\r\n                x = fx*action_mask + residual*(1-action_mask)\r\n                t += 1\r\n\r\n        x = self.avgpool(x)\r\n        x = x.view(x.size(0), -1)\r\n        x = self.fc(x)\r\n        return x\r\n\r\n    # run a single, fixed policy for all items in the batch\r\n    # policy is a (15,) vector. Use with batch_size=1 for profiling\r\n    def forward_single(self, x, policy):\r\n        x = self.seed(x)\r\n\r\n        t = 0\r\n        for segment, num_blocks in enumerate(self.layer_config):\r\n           for b in range(num_blocks):\r\n                residual = self.ds[segment](x) if b==0 else x\r\n                if policy[t]==1:\r\n                    x = residual + self.blocks[segment][b](x)\r\n                    x = F.relu(x)\r\n                else:\r\n                    x = residual\r\n                t += 1\r\n\r\n        x = self.avgpool(x)\r\n        x = x.view(x.size(0), -1)\r\n        x = self.fc(x)\r\n        return x\r\n\r\n\r\n    def forward_full(self, x):\r\n        x = self.seed(x)\r\n\r\n        for segment, num_blocks in enumerate(self.layer_config):\r\n            for b in range(num_blocks):\r\n                residual = self.ds[segment](x) if b==0 else x\r\n                x = F.relu(residual + self.blocks[segment][b](x))\r\n\r\n        x = self.avgpool(x)\r\n        x = x.view(x.size(0), -1)\r\n        x = self.fc(x)\r\n        return x\r\n\r\n\r\n\r\n# Smaller Flattened Resnet, tailored for CIFAR\r\nclass FlatResNet32(FlatResNet):\r\n\r\n    def __init__(self, block, layers, num_classes=10):\r\n        super(FlatResNet32, self).__init__()\r\n\r\n        self.inplanes = 16\r\n        self.conv1 = base.conv3x3(3, 16)\r\n        self.bn1 = nn.BatchNorm2d(16)\r\n        self.relu = nn.ReLU(inplace=True)\r\n        self.avgpool = nn.AvgPool2d(8)\r\n\r\n        strides = [1, 2, 2]\r\n        filt_sizes = [16, 32, 64]\r\n        self.blocks, self.ds = [], []\r\n        for idx, (filt_size, num_blocks, stride) in enumerate(zip(filt_sizes, layers, strides)):\r\n            blocks, ds = self._make_layer(block, filt_size, num_blocks, stride=stride)\r\n            self.blocks.append(nn.ModuleList(blocks))\r\n            self.ds.append(ds)\r\n\r\n        self.blocks = nn.ModuleList(self.blocks)\r\n        self.ds = nn.ModuleList(self.ds)\r\n        self.fc = nn.Linear(64 * block.expansion, num_classes)\r\n        self.fc_dim = 64 * block.expansion\r\n\r\n        self.layer_config = layers\r\n\r\n        for m in self.modules():\r\n            if isinstance(m, nn.Conv2d):\r\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\r\n                m.weight.data.normal_(0, math.sqrt(2. / n))\r\n            elif isinstance(m, nn.BatchNorm2d):\r\n                m.weight.data.fill_(1)\r\n                m.bias.data.zero_()\r\n\r\n    def seed(self, x):\r\n        x = self.relu(self.bn1(self.conv1(x)))\r\n        return x\r\n\r\n    def _make_layer(self, block, planes, blocks, stride=1):\r\n\r\n        downsample = nn.Sequential()\r\n        if stride != 1 or self.inplanes != planes * block.expansion:\r\n            downsample = base.DownsampleB(self.inplanes, planes * block.expansion, stride)\r\n\r\n        layers = [block(self.inplanes, planes, stride)]\r\n        self.inplanes = planes * block.expansion\r\n        for i in range(1, blocks):\r\n            layers.append(block(self.inplanes, planes, 1))\r\n\r\n        return layers, downsample\r\n\r\n\r\n# Regular Flattened Resnet, tailored for Imagenet etc.\r\nclass FlatResNet224(FlatResNet):\r\n\r\n    def __init__(self, block, layers, num_classes=1000):\r\n        self.inplanes = 64\r\n        super(FlatResNet224, self).__init__()\r\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)\r\n        self.bn1 = nn.BatchNorm2d(64)\r\n        self.relu = nn.ReLU(inplace=True)\r\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\r\n\r\n        strides = [1, 2, 2, 2]\r\n        filt_sizes = [64, 128, 256, 512]\r\n        self.blocks, self.ds = [], []\r\n        for idx, (filt_size, num_blocks, stride) in enumerate(zip(filt_sizes, layers, strides)):\r\n            blocks, ds = self._make_layer(block, filt_size, num_blocks, stride=stride)\r\n            self.blocks.append(nn.ModuleList(blocks))\r\n            self.ds.append(ds)\r\n\r\n        self.blocks = nn.ModuleList(self.blocks)\r\n        self.ds = nn.ModuleList(self.ds)\r\n\r\n        self.avgpool = nn.AvgPool2d(7)\r\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\r\n\r\n        for m in self.modules():\r\n            if isinstance(m, nn.Conv2d):\r\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\r\n                m.weight.data.normal_(0, math.sqrt(2. / n))\r\n            elif isinstance(m, nn.BatchNorm2d):\r\n                m.weight.data.fill_(1)\r\n                m.bias.data.zero_()\r\n\r\n        self.layer_config = layers\r\n\r\n    def seed(self, x):\r\n        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))\r\n        return x\r\n\r\n    def _make_layer(self, block, planes, blocks, stride=1):\r\n\r\n        downsample = nn.Sequential()\r\n        if stride != 1 or self.inplanes != planes * block.expansion:\r\n            downsample = nn.Sequential(\r\n                nn.Conv2d(self.inplanes, planes * block.expansion,\r\n                          kernel_size=1, stride=stride, bias=False),\r\n                nn.BatchNorm2d(planes * block.expansion),\r\n            )\r\n\r\n        layers = [block(self.inplanes, planes, stride)]\r\n        self.inplanes = planes * block.expansion\r\n        for i in range(1, blocks):\r\n            layers.append(block(self.inplanes, planes))\r\n\r\n        return layers, downsample\r\n\r\n\r\n#---------------------------------------------------------------------------------------------------------#\r\n\r\n# Class to generate resnetNB or any other config (default is 3B)\r\n# removed the fc layer so it serves as a feature extractor\r\nclass Policy32(nn.Module):\r\n\r\n    def __init__(self, layer_config=[1,1,1], num_blocks=15):\r\n        super(Policy32, self).__init__()\r\n        self.features = FlatResNet32(base.BasicBlock, layer_config, num_classes=10)\r\n        self.feat_dim = self.features.fc.weight.data.shape[1]\r\n        self.features.fc = nn.Sequential()\r\n\r\n        self.logit = nn.Linear(self.feat_dim, num_blocks)\r\n        self.vnet = nn.Linear(self.feat_dim, 1)\r\n\r\n    def load_state_dict(self, state_dict):\r\n        # support legacy models\r\n        state_dict = {k:v for k,v in state_dict.items() if not k.startswith('features.fc')}\r\n        return super(Policy32, self).load_state_dict(state_dict)\r\n\r\n\r\n    def forward(self, x):\r\n        x = self.features.forward_full(x)\r\n        value = self.vnet(x)\r\n        probs = F.sigmoid(self.logit(x))\r\n        return probs, value\r\n\r\n\r\nclass Policy224(nn.Module):\r\n\r\n    def __init__(self, layer_config=[1,1,1,1], num_blocks=16):\r\n        super(Policy224, self).__init__()\r\n        self.features = FlatResNet224(base.BasicBlock, layer_config, num_classes=1000)\r\n\r\n        resnet18 = torchmodels.resnet18(pretrained=True)\r\n        utils.load_weights_to_flatresnet(resnet18, self.features)\r\n\r\n        self.features.avgpool = nn.AvgPool2d(4)\r\n        self.feat_dim = self.features.fc.weight.data.shape[1]\r\n        self.features.fc = nn.Sequential()\r\n\r\n\r\n        self.logit = nn.Linear(self.feat_dim, num_blocks)\r\n        self.vnet = nn.Linear(self.feat_dim, 1)\r\n\r\n\r\n    def load_state_dict(self, state_dict):\r\n        # support legacy models\r\n        state_dict = {k:v for k,v in state_dict.items() if not k.startswith('features.fc')}\r\n        return super(Policy224, self).load_state_dict(state_dict)\r\n\r\n    def forward(self, x):\r\n        x = F.avg_pool2d(x, 2)\r\n        x = self.features.forward_full(x)\r\n        value = self.vnet(x)\r\n        probs = F.sigmoid(self.logit(x))\r\n        return probs, value\r\n\r\n#--------------------------------------------------------------------------------------------------------#\r\n\r\nclass StepResnet32(FlatResNet32):\r\n\r\n    def __init__(self, block, layers, num_classes, joint=False):\r\n        super(StepResnet, self).__init__(block, layers, num_classes)\r\n        self.eval() # default to eval mode\r\n\r\n        self.joint = joint\r\n\r\n        self.state_ptr = {}\r\n        t = 0\r\n        for segment, num_blocks in enumerate(self.layer_config):\r\n            for b in range(num_blocks):\r\n                self.state_ptr[t] = (segment, b)\r\n                t += 1\r\n\r\n    def seed(self, x):\r\n        self.state = self.relu(self.bn1(self.conv1(x)))\r\n        self.t = 0\r\n\r\n        if self.joint:\r\n            return self.state\r\n        return Variable(self.state.data)\r\n\r\n    def step(self, action):\r\n        segment, b = self.state_ptr[self.t]\r\n        residual = self.ds[segment](self.state) if b==0 else self.state\r\n        action_mask = action.float().view(-1,1,1,1)\r\n\r\n        fx = F.relu(residual + self.blocks[segment][b](self.state))\r\n        self.state = fx*action_mask + residual*(1-action_mask)\r\n        self.t += 1\r\n\r\n        if self.joint:\r\n            return self.state\r\n        return Variable(self.state.data)\r\n\r\n\r\n    def step_single(self, action):\r\n        segment, b = self.state_ptr[self.t]\r\n        residual = self.ds[segment](self.state) if b==0 else self.state\r\n\r\n        if action.data[0,0]==1:\r\n            self.state = F.relu(residual + self.blocks[segment][b](self.state))\r\n        else:\r\n            self.state = residual\r\n\r\n        self.t += 1\r\n\r\n        if self.joint:\r\n            return self.state\r\n        return Variable(self.state.data)\r\n\r\n    def predict(self):\r\n        x = self.avgpool(self.state)\r\n        x = x.view(x.size(0), -1)\r\n        x = self.fc(x)\r\n        return x\r\n\r\n\r\nclass StepPolicy32(nn.Module):\r\n\r\n    def __init__(self, layer_config):\r\n        super(StepPolicy, self).__init__()\r\n        in_dim = [16] + [16]*layer_config[0] + [32]*layer_config[1] + [64]*(layer_config[2]-1)\r\n        self.pnet = nn.ModuleList([nn.Linear(dim, 2) for dim in in_dim])\r\n        self.vnet = nn.ModuleList([nn.Linear(dim, 1) for dim in in_dim])\r\n\r\n    def forward(self, state):\r\n        x, t = state\r\n        x = F.avg_pool2d(x, x.size(2)).view(x.size(0), -1) # pool + flatten --> (B, 16/32/64)\r\n        logit = F.softmax(self.pnet[t](x))\r\n        value = self.vnet[t](x)\r\n        return logit, value\r\n"
  },
  {
    "path": "requirements.txt",
    "content": "tqdm\nnumpy\ntensorboard_logger\nargparse\n"
  },
  {
    "path": "test.py",
    "content": "import torch\r\nfrom torch.autograd import Variable\r\nimport torch.utils.data as torchdata\r\nimport torch.nn as nn\r\nimport numpy as np\r\nimport tqdm\r\nimport utils\r\n\r\nimport torch.backends.cudnn as cudnn\r\ncudnn.benchmark = True\r\n\r\nimport argparse\r\nparser = argparse.ArgumentParser()\r\nparser.add_argument('--model', default='R110_C10')\r\nparser.add_argument('--data_dir', default='data/')\r\nparser.add_argument('--load', default=None)\r\nargs = parser.parse_args()\r\n\r\n#---------------------------------------------------------------------------------------------#\r\nclass FConv2d(nn.Conv2d):\r\n\r\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\r\n                 padding=0, dilation=1, groups=1, bias=True):\r\n        super(FConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,\r\n                 padding, dilation, groups, bias)\r\n        self.num_ops = 0\r\n\r\n    def forward(self, x):\r\n        output = super(FConv2d, self).forward(x)\r\n        output_area = output.size(-1)*output.size(-2)\r\n        filter_area = np.prod(self.kernel_size)\r\n        self.num_ops += 2*self.in_channels*self.out_channels*filter_area*output_area\r\n        return output\r\n\r\nclass FLinear(nn.Linear):\r\n    def __init__(self, in_features, out_features, bias=True):\r\n        super(FLinear, self).__init__(in_features, out_features, bias)\r\n        self.num_ops = 0\r\n\r\n    def forward(self, x):\r\n        output = super(FLinear, self).forward(x)\r\n        self.num_ops += 2*self.in_features*self.out_features\r\n        return output\r\n\r\ndef count_flops(model, reset=True):\r\n    op_count = 0\r\n    for m in model.modules():\r\n        if hasattr(m, 'num_ops'):\r\n            op_count += m.num_ops\r\n            if reset: # count and reset to 0\r\n                m.num_ops = 0\r\n\r\n    return op_count\r\n\r\n# replace all nn.Conv and nn.Linear layers with layers that count flops\r\nnn.Conv2d = FConv2d\r\nnn.Linear = FLinear\r\n\r\n#--------------------------------------------------------------------------------------------#\r\n\r\ndef test():\r\n\r\n    total_ops = []\r\n    matches, policies = [], []\r\n    for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(testloader), total=len(testloader)):\r\n\r\n        inputs, targets = Variable(inputs, volatile=True).cuda(), Variable(targets).cuda()\r\n        probs, _ = agent(inputs)\r\n\r\n        policy = probs.clone()\r\n        policy[policy<0.5] = 0.0\r\n        policy[policy>=0.5] = 1.0\r\n\r\n        preds = rnet.forward_single(inputs, policy.data.squeeze(0))\r\n        _ , pred_idx = preds.max(1)\r\n        match = (pred_idx==targets).data.float()\r\n\r\n        matches.append(match)\r\n        policies.append(policy.data)\r\n\r\n        ops = count_flops(agent) + count_flops(rnet)\r\n        total_ops.append(ops)\r\n\r\n    accuracy, _, sparsity, variance, policy_set = utils.performance_stats(policies, matches, matches)\r\n    ops_mean, ops_std = np.mean(total_ops), np.std(total_ops)\r\n\r\n    log_str = u'''\r\n    Accuracy: %.3f\r\n    Block Usage: %.3f \\u00B1 %.3f\r\n    FLOPs/img: %.2E \\u00B1 %.2E\r\n    Unique Policies: %d\r\n    '''%(accuracy, sparsity, variance, ops_mean, ops_std, len(policy_set))\r\n\r\n    print log_str\r\n\r\n#--------------------------------------------------------------------------------------------------------#\r\ntrainset, testset = utils.get_dataset(args.model, args.data_dir)\r\ntestloader = torchdata.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)\r\nrnet, agent = utils.get_model(args.model)\r\n\r\n# if no model is loaded, use all blocks\r\nagent.logit.weight.data.fill_(0)\r\nagent.logit.bias.data.fill_(10)\r\n\r\nprint \"loading checkpoints\"\r\n\r\nif args.load is not None:\r\n    utils.load_checkpoint(rnet, agent, args.load)\r\n\r\nrnet.eval().cuda()\r\nagent.eval().cuda()\r\n\r\ntest()\r\n"
  },
  {
    "path": "utils.py",
    "content": "import os\nimport re\nimport torch\nimport torchvision.transforms as transforms\nimport torchvision.datasets as torchdata\nimport numpy as np\n\n# Save the training script and all the arguments to a file so that you\n# don't feel like an idiot later when you can't replicate results\nimport shutil\ndef save_args(__file__, args):\n    shutil.copy(os.path.basename(__file__), args.cv_dir)\n    with open(args.cv_dir+'/args.txt','w') as f:\n        f.write(str(args))\n\ndef performance_stats(policies, rewards, matches):\n\n    policies = torch.cat(policies, 0)\n    rewards = torch.cat(rewards, 0)\n    accuracy = torch.cat(matches, 0).mean()\n\n    reward = rewards.mean()\n    sparsity = policies.sum(1).mean()\n    variance = policies.sum(1).std()\n\n    policy_set = [p.cpu().numpy().astype(np.int).astype(np.str) for p in policies]\n    policy_set = set([''.join(p) for p in policy_set])\n\n    return accuracy, reward, sparsity, variance, policy_set\n\nclass LrScheduler:\n    def __init__(self, optimizer, base_lr, lr_decay_ratio, epoch_step):\n        self.base_lr = base_lr\n        self.lr_decay_ratio = lr_decay_ratio\n        self.epoch_step = epoch_step\n        self.optimizer = optimizer\n\n    def adjust_learning_rate(self, epoch):\n        \"\"\"Sets the learning rate to the initial LR decayed by 10 every 30 epochs\"\"\"\n        lr = self.base_lr * (self.lr_decay_ratio ** (epoch // self.epoch_step))\n        for param_group in self.optimizer.param_groups:\n            param_group['lr'] = lr\n            if epoch%self.epoch_step==0:\n                print '# setting learning_rate to %.2E'%lr\n\n\n# load model weights trained using scripts from https://github.com/felixgwu/img_classification_pk_pytorch OR\n# from torchvision models into our flattened resnets\ndef load_weights_to_flatresnet(source_model, target_model):\n\n    # compatibility for nn.Modules + checkpoints\n    if hasattr(source_model, 'state_dict'):\n        source_model = {'state_dict': source_model.state_dict()}\n    source_state = source_model['state_dict']\n    target_state = target_model.state_dict()\n\n    # remove the module. prefix if it exists (thanks nn.DataParallel)\n    if source_state.keys()[0].startswith('module.'):\n        source_state = {k[7:]:v for k,v in source_state.items()}\n\n\n    common = set(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var','fc.weight', 'fc.bias'])\n    for key in source_state.keys():\n\n        if key in common:\n            target_state[key] = source_state[key]\n            continue\n\n        if 'downsample' in key:\n            layer, num, item = re.match('layer(\\d+).*\\.(\\d+)\\.(.*)', key).groups()\n            translated = 'ds.%s.%s.%s'%(int(layer)-1, num, item)\n        else:\n            layer, item = re.match('layer(\\d+)\\.(.*)', key).groups()\n            translated = 'blocks.%s.%s'%(int(layer)-1, item)\n\n\n        if translated in target_state.keys():\n            target_state[translated] = source_state[key]\n        else:\n            print translated, 'block missing'\n\n    target_model.load_state_dict(target_state)\n    return target_model\n\ndef load_checkpoint(rnet, agent, load):\n    if load=='nil':\n        return None\n\n    checkpoint = torch.load(load)\n    if 'resnet' in checkpoint:\n        rnet.load_state_dict(checkpoint['resnet'])\n        print 'loaded resnet from', os.path.basename(load)\n    if 'agent' in checkpoint:\n        agent.load_state_dict(checkpoint['agent'])\n        print 'loaded agent from', os.path.basename(load)\n    # backward compatibility (some old checkpoints)\n    if 'net' in checkpoint:\n        checkpoint['net'] = {k:v for k,v in checkpoint['net'].items() if 'features.fc' not in k}\n        agent.load_state_dict(checkpoint['net'])\n        print 'loaded agent from', os.path.basename(load)\n\n\ndef get_transforms(rnet, dset):\n\n    # Only the R32 pretrained model subtracts the mean, sorry :(\n    if dset=='C10' and rnet=='R32':\n        mean = [x/255.0 for x in [125.3, 123.0, 113.9]]\n        std = [x/255.0 for x in [63.0, 62.1, 66.7]]\n        transform_train = transforms.Compose([\n            transforms.RandomCrop(32, padding=4),\n            transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n            transforms.Normalize(mean, std)\n            ])\n\n        transform_test = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Normalize(mean, std)\n            ])\n\n    elif dset=='C100' or dset=='C10' and rnet!='R32':\n        transform_train = transforms.Compose([\n            transforms.RandomCrop(32, padding=4),\n            transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n            ])\n\n        transform_test = transforms.Compose([\n            transforms.ToTensor(),\n            ])\n\n    elif dset=='ImgNet':\n        mean = [0.485, 0.456, 0.406]\n        std = [0.229, 0.224, 0.225]\n        transform_train = transforms.Compose([\n            transforms.Scale(256),\n            transforms.RandomCrop(224),\n            transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n            transforms.Normalize(mean, std)\n            ])\n\n        transform_test = transforms.Compose([\n            transforms.Scale(256),\n            transforms.CenterCrop(224),\n            transforms.ToTensor(),\n            transforms.Normalize(mean, std)\n            ])\n\n\n    return transform_train, transform_test\n\n# Pick from the datasets available and the hundreds of models we have lying around depending on the requirements.\ndef get_dataset(model, root='data/'):\n\n    rnet, dset = model.split('_')\n    transform_train, transform_test = get_transforms(rnet, dset)\n\n    if dset=='C10':\n        trainset = torchdata.CIFAR10(root=root, train=True, download=True, transform=transform_train)\n        testset = torchdata.CIFAR10(root=root, train=False, download=True, transform=transform_test)\n    elif dset=='C100':\n        trainset = torchdata.CIFAR100(root=root, train=True, download=True, transform=transform_train)\n        testset = torchdata.CIFAR100(root=root, train=False, download=True, transform=transform_test)\n    elif dset=='ImgNet':\n        trainset = torchdata.ImageFolder(root+'/train/', transform_train)\n        testset = torchdata.ImageFolder(root+'/val/', transform_test)\n\n    return trainset, testset\n\n# Make a new if statement for every new model variety you want to index\ndef get_model(model):\n\n    from models import resnet, base\n\n    if model=='R32_C10':\n        rnet_checkpoint = 'cv/pretrained/R32_C10/pk_E_164_A_0.923.t7'\n        layer_config = [5, 5, 5]\n        rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=10)\n        agent = resnet.Policy32([1,1,1], num_blocks=15)\n\n    elif model=='R110_C10':\n        rnet_checkpoint = 'cv/pretrained/R110_C10/pk_E_130_A_0.932.t7'\n        layer_config = [18, 18, 18]\n        rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=10)\n        agent = resnet.Policy32([1,1,1], num_blocks=54)\n\n    elif model=='R32_C100':\n        rnet_checkpoint = 'cv/pretrained/R32_C100/pk_E_164_A_0.693.t7'\n        layer_config = [5, 5, 5]\n        rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=100)\n        agent = resnet.Policy32([1,1,1], num_blocks=15)\n\n    elif model=='R110_C100':\n        rnet_checkpoint = 'cv/pretrained/R110_C100/pk_E_160_A_0.723.t7'\n        layer_config = [18, 18, 18]\n        rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=100)\n        agent = resnet.Policy32([1,1,1], num_blocks=54)\n\n    elif model=='R101_ImgNet':\n        rnet_checkpoint = 'cv/pretrained/R101_ImgNet/ImageNet_R101_224_76.464'\n        layer_config = [3,4,23,3]\n        rnet = resnet.FlatResNet224(base.Bottleneck, layer_config, num_classes=1000)\n        agent = resnet.Policy224([1,1,1,1], num_blocks=33)\n\n    # load pretrained weights into flat ResNet\n    rnet_checkpoint = torch.load(rnet_checkpoint)\n    load_weights_to_flatresnet(rnet_checkpoint, rnet)\n\n    return rnet, agent\n"
  }
]