[
  {
    "path": "MODELS/attention.py",
    "content": "import torch.nn as nn\nimport torch\nfrom torch.nn import functional as F\n\n\nclass Channel_Att(nn.Module):\n    def __init__(self, channels, t=16):\n        super(Channel_Att, self).__init__()\n        self.channels = channels\n      \n        self.bn2 = nn.BatchNorm2d(self.channels, affine=True)\n\n\n    def forward(self, x):\n        residual = x\n\n        x = self.bn2(x)\n        weight_bn = self.bn2.weight.data.abs() / torch.sum(self.bn2.weight.data.abs())\n        x = x.permute(0, 2, 3, 1).contiguous()\n        x = torch.mul(weight_bn, x)\n        x = x.permute(0, 3, 1, 2).contiguous()\n        \n        x = torch.sigmoid(x) * residual #\n        \n        return x\n\n\nclass Att(nn.Module):\n    def __init__(self, channels,shape, out_channels=None, no_spatial=True):\n        super(Att, self).__init__()\n        self.Channel_Att = Channel_Att(channels)\n  \n    def forward(self, x):\n        x_out1=self.Channel_Att(x)\n \n        return x_out1  \n"
  },
  {
    "path": "MODELS/bam.py",
    "content": "import torch\nimport math\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Flatten(nn.Module):\n    def forward(self, x):\n        return x.view(x.size(0), -1)\nclass ChannelGate(nn.Module):\n    def __init__(self, gate_channel, reduction_ratio=16, num_layers=1):\n        super(ChannelGate, self).__init__()\n        #self.gate_activation = gate_activation\n        self.gate_c = nn.Sequential()\n        self.gate_c.add_module( 'flatten', Flatten() )\n        gate_channels = [gate_channel]\n        gate_channels += [gate_channel // reduction_ratio] * num_layers\n        gate_channels += [gate_channel]\n        for i in range( len(gate_channels) - 2 ):\n            self.gate_c.add_module( 'gate_c_fc_%d'%i, nn.Linear(gate_channels[i], gate_channels[i+1]) )\n            self.gate_c.add_module( 'gate_c_bn_%d'%(i+1), nn.BatchNorm1d(gate_channels[i+1]) )\n            self.gate_c.add_module( 'gate_c_relu_%d'%(i+1), nn.ReLU() )\n        self.gate_c.add_module( 'gate_c_fc_final', nn.Linear(gate_channels[-2], gate_channels[-1]) )\n    def forward(self, in_tensor):\n        avg_pool = F.avg_pool2d( in_tensor, in_tensor.size(2), stride=in_tensor.size(2) )\n        return self.gate_c( avg_pool ).unsqueeze(2).unsqueeze(3).expand_as(in_tensor)\n\nclass SpatialGate(nn.Module):\n    def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num=2, dilation_val=4):\n        super(SpatialGate, self).__init__()\n        self.gate_s = nn.Sequential()\n        self.gate_s.add_module( 'gate_s_conv_reduce0', nn.Conv2d(gate_channel, gate_channel//reduction_ratio, kernel_size=1))\n        self.gate_s.add_module( 'gate_s_bn_reduce0',\tnn.BatchNorm2d(gate_channel//reduction_ratio) )\n        self.gate_s.add_module( 'gate_s_relu_reduce0',nn.ReLU() )\n        for i in range( dilation_conv_num ):\n            self.gate_s.add_module( 'gate_s_conv_di_%d'%i, nn.Conv2d(gate_channel//reduction_ratio, gate_channel//reduction_ratio, kernel_size=3, \\\n\t\t\t\t\t\tpadding=dilation_val, dilation=dilation_val) )\n            self.gate_s.add_module( 'gate_s_bn_di_%d'%i, nn.BatchNorm2d(gate_channel//reduction_ratio) )\n            self.gate_s.add_module( 'gate_s_relu_di_%d'%i, nn.ReLU() )\n        self.gate_s.add_module( 'gate_s_conv_final', nn.Conv2d(gate_channel//reduction_ratio, 1, kernel_size=1) )\n    def forward(self, in_tensor):\n        return self.gate_s( in_tensor ).expand_as(in_tensor)\nclass BAM(nn.Module):\n    def __init__(self, gate_channel):\n        super(BAM, self).__init__()\n        self.channel_att = ChannelGate(gate_channel)\n        self.spatial_att = SpatialGate(gate_channel)\n    def forward(self,in_tensor):\n        att = 1 + F.sigmoid( self.channel_att(in_tensor) * self.spatial_att(in_tensor) )\n        return att * in_tensor\n"
  },
  {
    "path": "MODELS/cbam.py",
    "content": "import torch\nimport math\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass BasicConv(nn.Module):\n    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):\n        super(BasicConv, self).__init__()\n        self.out_channels = out_planes\n        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)\n        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None\n        self.relu = nn.ReLU() if relu else None\n\n    def forward(self, x):\n        x = self.conv(x)\n        if self.bn is not None:\n            x = self.bn(x)\n        if self.relu is not None:\n            x = self.relu(x)\n        return x\n\nclass Flatten(nn.Module):\n    def forward(self, x):\n        return x.view(x.size(0), -1)\n\nclass ChannelGate(nn.Module):\n    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):\n        super(ChannelGate, self).__init__()\n        self.gate_channels = gate_channels\n        self.mlp = nn.Sequential(\n            Flatten(),\n            nn.Linear(gate_channels, gate_channels // reduction_ratio),\n            nn.ReLU(),\n            nn.Linear(gate_channels // reduction_ratio, gate_channels)\n            )\n        self.pool_types = pool_types\n    def forward(self, x):\n        channel_att_sum = None\n        for pool_type in self.pool_types:\n            if pool_type=='avg':\n                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n                channel_att_raw = self.mlp( avg_pool )\n            elif pool_type=='max':\n                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n                channel_att_raw = self.mlp( max_pool )\n            elif pool_type=='lp':\n                lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n                channel_att_raw = self.mlp( lp_pool )\n            elif pool_type=='lse':\n                # LSE pool only\n                lse_pool = logsumexp_2d(x)\n                channel_att_raw = self.mlp( lse_pool )\n\n            if channel_att_sum is None:\n                channel_att_sum = channel_att_raw\n            else:\n                channel_att_sum = channel_att_sum + channel_att_raw\n\n        scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)\n        return x * scale\n\ndef logsumexp_2d(tensor):\n    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)\n    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)\n    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()\n    return outputs\n\nclass ChannelPool(nn.Module):\n    def forward(self, x):\n        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )\n\nclass SpatialGate(nn.Module):\n    def __init__(self):\n        super(SpatialGate, self).__init__()\n        kernel_size = 7\n        self.compress = ChannelPool()\n        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)\n    def forward(self, x):\n        x_compress = self.compress(x)\n        x_out = self.spatial(x_compress)\n        scale = torch.sigmoid(x_out) # broadcasting\n        return x * scale\n\nclass CBAM(nn.Module):\n    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):\n        super(CBAM, self).__init__()\n        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)\n        self.no_spatial=no_spatial\n        if not no_spatial:\n            self.SpatialGate = SpatialGate()\n    def forward(self, x):\n        x_out = self.ChannelGate(x)\n        if not self.no_spatial:\n            x_out = self.SpatialGate(x_out)\n        return x_out\n"
  },
  {
    "path": "MODELS/model_resnet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nfrom torch.nn import init\nfrom .cbam import *\nfrom .bam import *\nfrom .attention import *\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"3x3 convolution with padding\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=1, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, shape,stride=1, downsample=None, use_cbam=False, use_nam=False,no_spatial=True):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.downsample = downsample\n        self.stride = stride\n        self.no_spatial = no_spatial\n\n        if use_cbam:\n            self.cbam = CBAM(planes, 16)\n        else:\n            self.cbam = None\n\n        if use_nam:\n            self.nam = Att(planes,no_spatial=self.no_spatial,shape=shape)\n        else:\n            self.nam = None\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        if not self.cbam is None:\n            out = self.cbam(out)\n\n        if not self.nam is None:\n            out = self.nam(out)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes,shape, stride=1, downsample=None, use_cbam=False, use_nam=False, no_spatial=False):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * 4)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n        self.no_spatial = no_spatial\n\n        if use_cbam:\n            self.cbam = CBAM(planes * 4, 16)\n        else:\n            self.cbam = None\n        \n        if use_nam:\n            self.nam = Att(planes * 4, no_spatial=self.no_spatial,shape=shape)\n  \n        else:\n            self.nam = None\n        \n    def forward(self, x):\n        \n        \n        residual = x\n        \n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n        \n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        if not self.cbam is None:\n            out = self.cbam(out)\n\n        if not self.nam is None:\n            out = self.nam(out)\n\n        out += residual\n\n\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(nn.Module):\n    def __init__(self, block, layers, network_type, num_classes, att_type=None):\n        self.inplanes = 64\n        super(ResNet, self).__init__()\n        self.network_type = network_type\n        # different model config between ImageNet and CIFAR\n\n        if network_type == \"ImageNet\":\n            self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)\n            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n            self.avgpool = nn.AvgPool2d(7)\n            shape=56\n        else:\n            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n            shape=32\n\n        self.bn1 = nn.BatchNorm2d(64)\n        self.relu = nn.ReLU(inplace=True)\n\n        if att_type == 'BAM':\n            self.bam1 = BAM(64*block.expansion)\n            self.bam2 = BAM(128*block.expansion)\n            self.bam3 = BAM(256*block.expansion)\n        else:\n            self.bam1, self.bam2, self.bam3 = None, None, None\n\n        self.layer1 = self._make_layer(block, 64, shape,layers[0], att_type=att_type, no_spatial=False)  \n        self.layer2 = self._make_layer(block, 128,shape//2, layers[1], stride=2, att_type=att_type, no_spatial=False)\n        self.layer3 = self._make_layer(block, 256, shape//4,layers[2], stride=2, att_type=att_type, no_spatial=False)\n        self.layer4 = self._make_layer(block, 512, shape//8, layers[3], stride=2, att_type=att_type, no_spatial=False)  \n\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n        \n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_uniform_(m.weight.data)\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n            elif isinstance(m, nn.BatchNorm1d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n        '''\n        init.kaiming_normal_(self.fc.weight)\n        for key in self.state_dict():\n            if key.split('.')[-1] == \"weight\":\n                if \"conv\" in key:\n                    init.kaiming_normal_(self.state_dict()[key], mode='fan_out')\n                if \"bn\" in key:\n                    if \"SpatialGate\" in key:\n                        self.state_dict()[key][...] = 0\n                    else:\n                        self.state_dict()[key][...] = 1\n            elif key.split(\".\")[-1] == 'bias':\n                self.state_dict()[key][...] = 0\n        '''\n\n    def _make_layer(self, block, planes, shape, blocks, stride=1, att_type=None, no_spatial=True):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n        layers = []\n        layers.append(\n            block(self.inplanes, planes, shape,stride, downsample, use_cbam=att_type == 'CBAM', use_nam=att_type == 'NAM',\n                  no_spatial=no_spatial))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes, shape,use_cbam=att_type == 'CBAM', use_nam=att_type == 'NAM',\n                                no_spatial=no_spatial))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x,label=None):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        if self.network_type == \"ImageNet\":\n            x = self.maxpool(x)\n\n        x = self.layer1(x)\n        if not self.bam1 is None:\n            x = self.bam1(x)\n\n        x = self.layer2(x)\n        if not self.bam2 is None:\n            x = self.bam2(x)\n\n        x = self.layer3(x)\n        if not self.bam3 is None:\n            x = self.bam3(x)\n\n        x = self.layer4(x)\n\n        if self.network_type == \"ImageNet\":\n            x = self.avgpool(x)\n        else:\n            x = F.avg_pool2d(x, 4)\n        x = x.view(x.size(0), -1)\n        x = self.fc(x)\n        return x\n\n\ndef ResidualNet(network_type, depth, num_classes, att_type):\n    assert network_type in [\"ImageNet\", \"CIFAR10\", \"CIFAR100\"], \"network type should be ImageNet or CIFAR10 / CIFAR100\"\n    assert depth in [18, 34, 50, 101], 'network depth should be 18, 34, 50 or 101'\n\n    if depth == 18:\n        model = ResNet(BasicBlock, [2, 2, 2, 2], network_type, num_classes, att_type)\n\n    elif depth == 34:\n        model = ResNet(BasicBlock, [3, 4, 6, 3], network_type, num_classes, att_type)\n\n    elif depth == 50:\n        model = ResNet(Bottleneck, [3, 4, 6, 3], network_type, num_classes, att_type)\n\n    elif depth == 101:\n        model = ResNet(Bottleneck, [3, 4, 23, 3], network_type, num_classes, att_type)\n\n    return model\n"
  },
  {
    "path": "README.md",
    "content": "# NAM"
  },
  {
    "path": "train_cifar100.py",
    "content": "import argparse\nimport os\nimport shutil\nimport time\nimport random\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.backends.cudnn as cudnn\nimport torch.optim\nimport torch.utils\nimport torch.utils.data\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\nimport torchvision.models as models\nfrom MODELS.model_resnet import *\nfrom PIL import ImageFile\nfrom thop import profile\n\nImageFile.LOAD_TRUNCATED_IMAGES = True\nmodel_names = sorted(name for name in models.__dict__\n    if name.islower() and not name.startswith(\"__\")\n    and callable(models.__dict__[name]))\n\nparser = argparse.ArgumentParser(description='PyTorch ImageNet Training')\nparser.add_argument('data', metavar='DIR', help='path to dataset')\nparser.add_argument('--arch', '-a', metavar='ARCH', default='resnet',help='model architecture: ' +' | '.join(model_names) +\n                        ' (default: resnet18)')\nparser.add_argument('--depth', default=50, type=int, metavar='D', help='model depth')\nparser.add_argument('--ngpu', default=4, type=int, metavar='G', help='number of gpus to use')\nparser.add_argument('-j', '--workers', default=4, type=int, metavar='N',help='number of data loading workers (default: 4)')\nparser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run')\nparser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')\nparser.add_argument('-b', '--batch-size', default=128, type=int, metavar='N', help='mini-batch size (default: 256)')\nparser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate')\nparser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')\nparser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, metavar='W', help='weight decay (default: 1e-4)')\nparser.add_argument('--print-freq', '-p', default=100, type=int,metavar='N', help='print frequency (default: 10)')\nparser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')\nparser.add_argument(\"--seed\", type=int, default=1234, metavar='BS', help='input batch size for training (default: 64)')\nparser.add_argument(\"--prefix\", type=str, required=True, metavar='PFX', help='prefix for logging & checkpoint saving')\nparser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluation only')\nparser.add_argument('--att-type', type=str, choices=['BAM', 'CBAM','NAM'], default=None)\nparser.add_argument('--milestones',type=list,default=[60, 120, 160],help='optimizer milestones')\nparser.add_argument('--set', type=str, default='cifar100', help='location of the data corpus')\nparser.add_argument('--gamma',type=float,default=0.2,help='gamma')##\nparser.add_argument('--sparsity-regularization', '-sr', dest='sr', action='store_true',help='train with channel sparsity regularization')\nparser.add_argument('--s', type=float, default=0.0001,help='scale sparse rate (default: 0.0001)')\nbest_prec1 = 0\n\nif not os.path.exists('./checkpoints'):\n    os.mkdir('./checkpoints')\n    \ndef updateBN(model):\n    Op = model._modules.items()\n    for m in Op:\n        if m[0]=='layer1':\n            for m1 in m[1]:\n                m1.nam.Channel_Att.bn2.weight.grad.data.add_(args.s * torch.sign(m1.nam.Channel_Att.bn2.weight.data))\n        if m[0]=='layer2':\n            for m2 in m[1]:\n                m2.nam.Channel_Att.bn2.weight.grad.data.add_(args.s * torch.sign(m2.nam.Channel_Att.bn2.weight.data))\n        if m[0]=='layer3':\n            for m3 in m[1]:\n                m3.nam.Channel_Att.bn2.weight.grad.data.add_(args.s * torch.sign(m3.nam.Channel_Att.bn2.weight.data))\n        if m[0]=='layer4':\n            for m4 in m[1]:\n                m4.nam.Channel_Att.bn2.weight.grad.data.add_(args.s * torch.sign(m4.nam.Channel_Att.bn2.weight.data))\n\ndef main():\n    global args, best_prec1\n    global viz, train_lot, test_lot\n    args = parser.parse_args()\n    print (\"args\", args)\n\n    torch.manual_seed(args.seed)\n    torch.cuda.manual_seed_all(args.seed)\n    random.seed(args.seed)\n\n    # create model\n    if args.arch == \"resnet\":\n        model = ResidualNet( 'CIFAR100', args.depth, 100, args.att_type )\n    \n    inputs = torch.randn(1, 3, 32, 32)\n    total_ops, total_params = profile(model, (inputs,), verbose=False)\n    print(\" %.2f | %.2f\" % ( total_params / (1000 ** 2), total_ops / (1000 ** 3)))\n\n    # define loss function (criterion) and optimizer\n    criterion = nn.CrossEntropyLoss().cuda()\n\n    optimizer = torch.optim.SGD(model.parameters(), args.lr,\n                            momentum=args.momentum,\n                            weight_decay=args.weight_decay)\n    model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))\n    #model = torch.nn.DataParallel(model).cuda()\n    model = model.cuda()\n    #print (\"model\")\n    #print (model)\n    \n    train_scheduler = torch.optim.lr_scheduler.MultiStepLR(\n        optimizer, milestones=args.milestones, gamma=args.gamma)##\n\n    # get the number of model parameters\n    print('Number of model parameters: {}'.format(\n        sum([p.data.nelement() for p in model.parameters()])))\n\n    # optionally resume from a checkpoint\n    if args.resume:\n        if os.path.isfile(args.resume):\n            print(\"=> loading checkpoint '{}'\".format(args.resume))\n            checkpoint = torch.load(args.resume)\n            args.start_epoch = checkpoint['epoch']\n            best_prec1 = checkpoint['best_prec1']\n            model.load_state_dict(checkpoint['state_dict'])\n            if 'optimizer' in checkpoint:\n                optimizer.load_state_dict(checkpoint['optimizer'])\n            print(\"=> loaded checkpoint '{}' (epoch {})\"\n                  .format(args.resume, checkpoint['epoch']))\n        else:\n            print(\"=> no checkpoint found at '{}'\".format(args.resume))\n\n\n    cudnn.benchmark = True\n    \n    if args.set=='cifar100':\n        train_transform, valid_transform = data_transforms_cifar10(args)\n        train_data = datasets.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)\n        valid_data = datasets.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)\n    else:\n        train_transform, valid_transform = data_transforms_cifar10(args)\n        train_data = datasets.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)\n        valid_data = datasets.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)\n        \n    train_loader = torch.utils.data.DataLoader(\n      train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers)\n\n    val_loader = torch.utils.data.DataLoader(\n      valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers)\n    \n    for epoch in range(args.start_epoch, args.epochs):\n        #adjust_learning_rate(optimizer, epoch)\n        \n        # train for one epoch\n        train(train_loader, model, criterion, optimizer, epoch)\n        \n        # evaluate on validation set\n        prec1 = validate(val_loader, model, criterion, epoch)\n        \n        train_scheduler.step()\n        \n        # remember best prec@1 and save checkpoint\n        is_best = prec1 > best_prec1\n        best_prec1 = max(prec1, best_prec1)\n        save_checkpoint({\n            'epoch': epoch + 1,\n            'arch': args.arch,\n            'state_dict': model.state_dict(),\n            'best_prec1': best_prec1,\n            'optimizer' : optimizer.state_dict(),\n        }, is_best, args.prefix)\n\ndef train(train_loader, model, criterion, optimizer, epoch):\n    batch_time = AverageMeter()\n    data_time = AverageMeter()\n    losses = AverageMeter()\n    top1 = AverageMeter()\n    top5 = AverageMeter()\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i, (input, target) in enumerate(train_loader):\n        # measure data loading time\n        data_time.update(time.time() - end)\n        target = target.cuda()\n        input_var = torch.autograd.Variable(input).cuda()\n        target_var = torch.autograd.Variable(target)\n        \n        # compute output\n        output = model(input_var)\n        loss = criterion(output, target_var)\n\n        # measure accuracy and record loss\n        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))\n        losses.update(loss.item(), input.size(0))\n        top1.update(prec1.item(), input.size(0))\n        top5.update(prec5.item(), input.size(0))\n        \n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        if args.sr and args.att_type=='NAM':\n            updateBN(model)\n        optimizer.step()\n        \n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n        \n        if i % args.print_freq == 0:\n            print('Epoch: [{0}][{1}/{2}]\\t'\n                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\\t'\n                  'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\\t'\n                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(\n                   epoch, i, len(train_loader), batch_time=batch_time,\n                   data_time=data_time, loss=losses, top1=top1, top5=top5))\n\ndef validate(val_loader, model, criterion, epoch):\n    batch_time = AverageMeter()\n    losses = AverageMeter()\n    top1 = AverageMeter()\n    top5 = AverageMeter()\n\n    # switch to evaluate mode\n    model.eval()\n\n    end = time.time()\n    for i, (input, target) in enumerate(val_loader):\n        target = target.cuda()\n        with torch.no_grad():\n            input_var = torch.autograd.Variable(input).cuda()\n            target_var = torch.autograd.Variable(target)\n        \n        # compute output\n            #output = model(input_var)\n            #loss = criterion(output, target_var)\n            \n            output = model(input_var)#,target_var\n            loss = criterion(output, target_var)\n        \n        # measure accuracy and record loss\n        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))\n        losses.update(loss.item(), input.size(0))\n        top1.update(prec1.item(), input.size(0))\n        top5.update(prec5.item(), input.size(0))\n        \n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n        \n        if i % args.print_freq == 0:\n            print('Test: [{0}/{1}]\\t'\n                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                  'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\\t'\n                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(\n                   i, len(val_loader), batch_time=batch_time, loss=losses,\n                   top1=top1, top5=top5))\n    \n    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'\n            .format(top1=top1, top5=top5))\n\n    return top1.avg\n\n\ndef save_checkpoint(state, is_best, prefix):\n    filename='./checkpoints/%s_checkpoint.pth.tar'%prefix\n    torch.save(state, filename)\n    if is_best:\n        shutil.copyfile(filename, './checkpoints/%s_model_best.pth.tar'%prefix)\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef adjust_learning_rate(optimizer, epoch):\n    \"\"\"Sets the learning rate to the initial LR decayed by 10 every 30 epochs\"\"\"\n    lr = args.lr * (0.1 ** (epoch // 30))\n    for param_group in optimizer.param_groups:\n        param_group['lr'] = lr\n\n\ndef accuracy(output, target, topk=(1,)):\n    \"\"\"Computes the precision@k for the specified values of k\"\"\"\n    maxk = max(topk)\n    batch_size = target.size(0)\n\n    _, pred = output.topk(maxk, 1, True, True)\n    pred = pred.t()\n    correct = pred.eq(target.view(1, -1).expand_as(pred))\n\n    res = []\n    for k in topk:\n        correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)\n        res.append(correct_k.mul_(100.0 / batch_size))\n    return res\n\ndef data_transforms_cifar10(args):\n    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]\n    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]\n\n    train_transform = transforms.Compose([\n        transforms.RandomCrop(32, padding=4),\n        transforms.RandomHorizontalFlip(),\n        transforms.ToTensor(),\n        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),\n      ])\n    #if args.cutout:\n    #    train_transform.transforms.append(Cutout(args.cutout_length))\n\n    valid_transform = transforms.Compose([\n        transforms.ToTensor(),\n        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),\n        ])\n    return train_transform, valid_transform\n\ndef data_transforms_cifar100(args):\n    #CIFAR_MEAN = [0.5071, 0.4865, 0.4409]\n    #CIFAR_STD = [0.2673, 0.2564, 0.2762]\n    CIFAR_MEAN = [125.3/ 255.0, 123.0/ 255.0, 113.9/ 255.0] \n    CIFAR_STD = [63.0/ 255.0, 62.1/ 255.0, 66.7/ 255.0] \n\n    train_transform = transforms.Compose([\n        transforms.RandomCrop(32, padding=4),\n        transforms.RandomHorizontalFlip(),\n        transforms.ToTensor(),\n        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),\n      ])\n    #if args.cutout:\n    #    train_transform.transforms.append(Cutout(args.cutout_length))\n\n    valid_transform = transforms.Compose([\n        transforms.ToTensor(),\n        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),\n        ])\n    return train_transform, valid_transform\n\n\nif __name__ == '__main__':\n    main()\n"
  }
]