Repository: Christian-lyc/NAM Branch: main Commit: 6e8d3e321c60 Files: 6 Total size: 28.9 KB Directory structure: gitextract_tsj638te/ ├── MODELS/ │ ├── attention.py │ ├── bam.py │ ├── cbam.py │ └── model_resnet.py ├── README.md └── train_cifar100.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: MODELS/attention.py ================================================ import torch.nn as nn import torch from torch.nn import functional as F class Channel_Att(nn.Module): def __init__(self, channels, t=16): super(Channel_Att, self).__init__() self.channels = channels self.bn2 = nn.BatchNorm2d(self.channels, affine=True) def forward(self, x): residual = x x = self.bn2(x) weight_bn = self.bn2.weight.data.abs() / torch.sum(self.bn2.weight.data.abs()) x = x.permute(0, 2, 3, 1).contiguous() x = torch.mul(weight_bn, x) x = x.permute(0, 3, 1, 2).contiguous() x = torch.sigmoid(x) * residual # return x class Att(nn.Module): def __init__(self, channels,shape, out_channels=None, no_spatial=True): super(Att, self).__init__() self.Channel_Att = Channel_Att(channels) def forward(self, x): x_out1=self.Channel_Att(x) return x_out1 ================================================ FILE: MODELS/bam.py ================================================ import torch import math import torch.nn as nn import torch.nn.functional as F class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) class ChannelGate(nn.Module): def __init__(self, gate_channel, reduction_ratio=16, num_layers=1): super(ChannelGate, self).__init__() #self.gate_activation = gate_activation self.gate_c = nn.Sequential() self.gate_c.add_module( 'flatten', Flatten() ) gate_channels = [gate_channel] gate_channels += [gate_channel // reduction_ratio] * num_layers gate_channels += [gate_channel] for i in range( len(gate_channels) - 2 ): self.gate_c.add_module( 'gate_c_fc_%d'%i, nn.Linear(gate_channels[i], gate_channels[i+1]) ) self.gate_c.add_module( 'gate_c_bn_%d'%(i+1), nn.BatchNorm1d(gate_channels[i+1]) ) self.gate_c.add_module( 'gate_c_relu_%d'%(i+1), nn.ReLU() ) self.gate_c.add_module( 'gate_c_fc_final', nn.Linear(gate_channels[-2], gate_channels[-1]) ) def forward(self, in_tensor): avg_pool = F.avg_pool2d( in_tensor, in_tensor.size(2), stride=in_tensor.size(2) ) return self.gate_c( avg_pool ).unsqueeze(2).unsqueeze(3).expand_as(in_tensor) class SpatialGate(nn.Module): def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num=2, dilation_val=4): super(SpatialGate, self).__init__() self.gate_s = nn.Sequential() self.gate_s.add_module( 'gate_s_conv_reduce0', nn.Conv2d(gate_channel, gate_channel//reduction_ratio, kernel_size=1)) self.gate_s.add_module( 'gate_s_bn_reduce0', nn.BatchNorm2d(gate_channel//reduction_ratio) ) self.gate_s.add_module( 'gate_s_relu_reduce0',nn.ReLU() ) for i in range( dilation_conv_num ): self.gate_s.add_module( 'gate_s_conv_di_%d'%i, nn.Conv2d(gate_channel//reduction_ratio, gate_channel//reduction_ratio, kernel_size=3, \ padding=dilation_val, dilation=dilation_val) ) self.gate_s.add_module( 'gate_s_bn_di_%d'%i, nn.BatchNorm2d(gate_channel//reduction_ratio) ) self.gate_s.add_module( 'gate_s_relu_di_%d'%i, nn.ReLU() ) self.gate_s.add_module( 'gate_s_conv_final', nn.Conv2d(gate_channel//reduction_ratio, 1, kernel_size=1) ) def forward(self, in_tensor): return self.gate_s( in_tensor ).expand_as(in_tensor) class BAM(nn.Module): def __init__(self, gate_channel): super(BAM, self).__init__() self.channel_att = ChannelGate(gate_channel) self.spatial_att = SpatialGate(gate_channel) def forward(self,in_tensor): att = 1 + F.sigmoid( self.channel_att(in_tensor) * self.spatial_att(in_tensor) ) return att * in_tensor ================================================ FILE: MODELS/cbam.py ================================================ import torch import math import torch.nn as nn import torch.nn.functional as F class BasicConv(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): super(BasicConv, self).__init__() self.out_channels = out_planes self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None self.relu = nn.ReLU() if relu else None def forward(self, x): x = self.conv(x) if self.bn is not None: x = self.bn(x) if self.relu is not None: x = self.relu(x) return x class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) class ChannelGate(nn.Module): def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): super(ChannelGate, self).__init__() self.gate_channels = gate_channels self.mlp = nn.Sequential( Flatten(), nn.Linear(gate_channels, gate_channels // reduction_ratio), nn.ReLU(), nn.Linear(gate_channels // reduction_ratio, gate_channels) ) self.pool_types = pool_types def forward(self, x): channel_att_sum = None for pool_type in self.pool_types: if pool_type=='avg': avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( avg_pool ) elif pool_type=='max': max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( max_pool ) elif pool_type=='lp': lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) channel_att_raw = self.mlp( lp_pool ) elif pool_type=='lse': # LSE pool only lse_pool = logsumexp_2d(x) channel_att_raw = self.mlp( lse_pool ) if channel_att_sum is None: channel_att_sum = channel_att_raw else: channel_att_sum = channel_att_sum + channel_att_raw scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) return x * scale def logsumexp_2d(tensor): tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() return outputs class ChannelPool(nn.Module): def forward(self, x): return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) class SpatialGate(nn.Module): def __init__(self): super(SpatialGate, self).__init__() kernel_size = 7 self.compress = ChannelPool() self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) def forward(self, x): x_compress = self.compress(x) x_out = self.spatial(x_compress) scale = torch.sigmoid(x_out) # broadcasting return x * scale class CBAM(nn.Module): def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): super(CBAM, self).__init__() self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) self.no_spatial=no_spatial if not no_spatial: self.SpatialGate = SpatialGate() def forward(self, x): x_out = self.ChannelGate(x) if not self.no_spatial: x_out = self.SpatialGate(x_out) return x_out ================================================ FILE: MODELS/model_resnet.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import math from torch.nn import init from .cbam import * from .bam import * from .attention import * def conv3x3(in_planes, out_planes, stride=1): "3x3 convolution with padding" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, shape,stride=1, downsample=None, use_cbam=False, use_nam=False,no_spatial=True): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride self.no_spatial = no_spatial if use_cbam: self.cbam = CBAM(planes, 16) else: self.cbam = None if use_nam: self.nam = Att(planes,no_spatial=self.no_spatial,shape=shape) else: self.nam = None def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) if not self.cbam is None: out = self.cbam(out) if not self.nam is None: out = self.nam(out) out += residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes,shape, stride=1, downsample=None, use_cbam=False, use_nam=False, no_spatial=False): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride self.no_spatial = no_spatial if use_cbam: self.cbam = CBAM(planes * 4, 16) else: self.cbam = None if use_nam: self.nam = Att(planes * 4, no_spatial=self.no_spatial,shape=shape) else: self.nam = None def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) if not self.cbam is None: out = self.cbam(out) if not self.nam is None: out = self.nam(out) out += residual out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, block, layers, network_type, num_classes, att_type=None): self.inplanes = 64 super(ResNet, self).__init__() self.network_type = network_type # different model config between ImageNet and CIFAR if network_type == "ImageNet": self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.avgpool = nn.AvgPool2d(7) shape=56 else: self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) shape=32 self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) if att_type == 'BAM': self.bam1 = BAM(64*block.expansion) self.bam2 = BAM(128*block.expansion) self.bam3 = BAM(256*block.expansion) else: self.bam1, self.bam2, self.bam3 = None, None, None self.layer1 = self._make_layer(block, 64, shape,layers[0], att_type=att_type, no_spatial=False) self.layer2 = self._make_layer(block, 128,shape//2, layers[1], stride=2, att_type=att_type, no_spatial=False) self.layer3 = self._make_layer(block, 256, shape//4,layers[2], stride=2, att_type=att_type, no_spatial=False) self.layer4 = self._make_layer(block, 512, shape//8, layers[3], stride=2, att_type=att_type, no_spatial=False) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_uniform_(m.weight.data) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.BatchNorm1d): m.weight.data.fill_(1) m.bias.data.zero_() ''' init.kaiming_normal_(self.fc.weight) for key in self.state_dict(): if key.split('.')[-1] == "weight": if "conv" in key: init.kaiming_normal_(self.state_dict()[key], mode='fan_out') if "bn" in key: if "SpatialGate" in key: self.state_dict()[key][...] = 0 else: self.state_dict()[key][...] = 1 elif key.split(".")[-1] == 'bias': self.state_dict()[key][...] = 0 ''' def _make_layer(self, block, planes, shape, blocks, stride=1, att_type=None, no_spatial=True): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append( block(self.inplanes, planes, shape,stride, downsample, use_cbam=att_type == 'CBAM', use_nam=att_type == 'NAM', no_spatial=no_spatial)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, shape,use_cbam=att_type == 'CBAM', use_nam=att_type == 'NAM', no_spatial=no_spatial)) return nn.Sequential(*layers) def forward(self, x,label=None): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) if self.network_type == "ImageNet": x = self.maxpool(x) x = self.layer1(x) if not self.bam1 is None: x = self.bam1(x) x = self.layer2(x) if not self.bam2 is None: x = self.bam2(x) x = self.layer3(x) if not self.bam3 is None: x = self.bam3(x) x = self.layer4(x) if self.network_type == "ImageNet": x = self.avgpool(x) else: x = F.avg_pool2d(x, 4) x = x.view(x.size(0), -1) x = self.fc(x) return x def ResidualNet(network_type, depth, num_classes, att_type): assert network_type in ["ImageNet", "CIFAR10", "CIFAR100"], "network type should be ImageNet or CIFAR10 / CIFAR100" assert depth in [18, 34, 50, 101], 'network depth should be 18, 34, 50 or 101' if depth == 18: model = ResNet(BasicBlock, [2, 2, 2, 2], network_type, num_classes, att_type) elif depth == 34: model = ResNet(BasicBlock, [3, 4, 6, 3], network_type, num_classes, att_type) elif depth == 50: model = ResNet(Bottleneck, [3, 4, 6, 3], network_type, num_classes, att_type) elif depth == 101: model = ResNet(Bottleneck, [3, 4, 23, 3], network_type, num_classes, att_type) return model ================================================ FILE: README.md ================================================ # NAM ================================================ FILE: train_cifar100.py ================================================ import argparse import os import shutil import time import random import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim import torch.utils import torch.utils.data import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models from MODELS.model_resnet import * from PIL import ImageFile from thop import profile ImageFile.LOAD_TRUNCATED_IMAGES = True model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser.add_argument('data', metavar='DIR', help='path to dataset') parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet',help='model architecture: ' +' | '.join(model_names) + ' (default: resnet18)') parser.add_argument('--depth', default=50, type=int, metavar='D', help='model depth') parser.add_argument('--ngpu', default=4, type=int, metavar='G', help='number of gpus to use') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('-b', '--batch-size', default=128, type=int, metavar='N', help='mini-batch size (default: 256)') parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') parser.add_argument('--print-freq', '-p', default=100, type=int,metavar='N', help='print frequency (default: 10)') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument("--seed", type=int, default=1234, metavar='BS', help='input batch size for training (default: 64)') parser.add_argument("--prefix", type=str, required=True, metavar='PFX', help='prefix for logging & checkpoint saving') parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluation only') parser.add_argument('--att-type', type=str, choices=['BAM', 'CBAM','NAM'], default=None) parser.add_argument('--milestones',type=list,default=[60, 120, 160],help='optimizer milestones') parser.add_argument('--set', type=str, default='cifar100', help='location of the data corpus') parser.add_argument('--gamma',type=float,default=0.2,help='gamma')## parser.add_argument('--sparsity-regularization', '-sr', dest='sr', action='store_true',help='train with channel sparsity regularization') parser.add_argument('--s', type=float, default=0.0001,help='scale sparse rate (default: 0.0001)') best_prec1 = 0 if not os.path.exists('./checkpoints'): os.mkdir('./checkpoints') def updateBN(model): Op = model._modules.items() for m in Op: if m[0]=='layer1': for m1 in m[1]: m1.nam.Channel_Att.bn2.weight.grad.data.add_(args.s * torch.sign(m1.nam.Channel_Att.bn2.weight.data)) if m[0]=='layer2': for m2 in m[1]: m2.nam.Channel_Att.bn2.weight.grad.data.add_(args.s * torch.sign(m2.nam.Channel_Att.bn2.weight.data)) if m[0]=='layer3': for m3 in m[1]: m3.nam.Channel_Att.bn2.weight.grad.data.add_(args.s * torch.sign(m3.nam.Channel_Att.bn2.weight.data)) if m[0]=='layer4': for m4 in m[1]: m4.nam.Channel_Att.bn2.weight.grad.data.add_(args.s * torch.sign(m4.nam.Channel_Att.bn2.weight.data)) def main(): global args, best_prec1 global viz, train_lot, test_lot args = parser.parse_args() print ("args", args) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) random.seed(args.seed) # create model if args.arch == "resnet": model = ResidualNet( 'CIFAR100', args.depth, 100, args.att_type ) inputs = torch.randn(1, 3, 32, 32) total_ops, total_params = profile(model, (inputs,), verbose=False) print(" %.2f | %.2f" % ( total_params / (1000 ** 2), total_ops / (1000 ** 3))) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) #model = torch.nn.DataParallel(model).cuda() model = model.cuda() #print ("model") #print (model) train_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=args.milestones, gamma=args.gamma)## # get the number of model parameters print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) if 'optimizer' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True if args.set=='cifar100': train_transform, valid_transform = data_transforms_cifar10(args) train_data = datasets.CIFAR100(root=args.data, train=True, download=True, transform=train_transform) valid_data = datasets.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform) else: train_transform, valid_transform = data_transforms_cifar10(args) train_data = datasets.CIFAR10(root=args.data, train=True, download=True, transform=train_transform) valid_data = datasets.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform) train_loader = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) val_loader = torch.utils.data.DataLoader( valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers) for epoch in range(args.start_epoch, args.epochs): #adjust_learning_rate(optimizer, epoch) # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, model, criterion, epoch) train_scheduler.step() # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer' : optimizer.state_dict(), }, is_best, args.prefix) def train(train_loader, model, criterion, optimizer, epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to train mode model.train() end = time.time() for i, (input, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) target = target.cuda() input_var = torch.autograd.Variable(input).cuda() target_var = torch.autograd.Variable(target) # compute output output = model(input_var) loss = criterion(output, target_var) # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() if args.sr and args.att_type=='NAM': updateBN(model) optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) def validate(val_loader, model, criterion, epoch): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() end = time.time() for i, (input, target) in enumerate(val_loader): target = target.cuda() with torch.no_grad(): input_var = torch.autograd.Variable(input).cuda() target_var = torch.autograd.Variable(target) # compute output #output = model(input_var) #loss = criterion(output, target_var) output = model(input_var)#,target_var loss = criterion(output, target_var) # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( i, len(val_loader), batch_time=batch_time, loss=losses, top1=top1, top5=top5)) print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' .format(top1=top1, top5=top5)) return top1.avg def save_checkpoint(state, is_best, prefix): filename='./checkpoints/%s_checkpoint.pth.tar'%prefix torch.save(state, filename) if is_best: shutil.copyfile(filename, './checkpoints/%s_model_best.pth.tar'%prefix) class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def adjust_learning_rate(optimizer, epoch): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" lr = args.lr * (0.1 ** (epoch // 30)) for param_group in optimizer.param_groups: param_group['lr'] = lr def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res def data_transforms_cifar10(args): CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD), ]) #if args.cutout: # train_transform.transforms.append(Cutout(args.cutout_length)) valid_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD), ]) return train_transform, valid_transform def data_transforms_cifar100(args): #CIFAR_MEAN = [0.5071, 0.4865, 0.4409] #CIFAR_STD = [0.2673, 0.2564, 0.2762] CIFAR_MEAN = [125.3/ 255.0, 123.0/ 255.0, 113.9/ 255.0] CIFAR_STD = [63.0/ 255.0, 62.1/ 255.0, 66.7/ 255.0] train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD), ]) #if args.cutout: # train_transform.transforms.append(Cutout(args.cutout_length)) valid_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD), ]) return train_transform, valid_transform if __name__ == '__main__': main()