Repository: Christian-lyc/NAM
Branch: main
Commit: 6e8d3e321c60
Files: 6
Total size: 28.9 KB
Directory structure:
gitextract_tsj638te/
├── MODELS/
│ ├── attention.py
│ ├── bam.py
│ ├── cbam.py
│ └── model_resnet.py
├── README.md
└── train_cifar100.py
================================================
FILE CONTENTS
================================================
================================================
FILE: MODELS/attention.py
================================================
import torch.nn as nn
import torch
from torch.nn import functional as F
class Channel_Att(nn.Module):
def __init__(self, channels, t=16):
super(Channel_Att, self).__init__()
self.channels = channels
self.bn2 = nn.BatchNorm2d(self.channels, affine=True)
def forward(self, x):
residual = x
x = self.bn2(x)
weight_bn = self.bn2.weight.data.abs() / torch.sum(self.bn2.weight.data.abs())
x = x.permute(0, 2, 3, 1).contiguous()
x = torch.mul(weight_bn, x)
x = x.permute(0, 3, 1, 2).contiguous()
x = torch.sigmoid(x) * residual #
return x
class Att(nn.Module):
def __init__(self, channels,shape, out_channels=None, no_spatial=True):
super(Att, self).__init__()
self.Channel_Att = Channel_Att(channels)
def forward(self, x):
x_out1=self.Channel_Att(x)
return x_out1
================================================
FILE: MODELS/bam.py
================================================
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(self, gate_channel, reduction_ratio=16, num_layers=1):
super(ChannelGate, self).__init__()
#self.gate_activation = gate_activation
self.gate_c = nn.Sequential()
self.gate_c.add_module( 'flatten', Flatten() )
gate_channels = [gate_channel]
gate_channels += [gate_channel // reduction_ratio] * num_layers
gate_channels += [gate_channel]
for i in range( len(gate_channels) - 2 ):
self.gate_c.add_module( 'gate_c_fc_%d'%i, nn.Linear(gate_channels[i], gate_channels[i+1]) )
self.gate_c.add_module( 'gate_c_bn_%d'%(i+1), nn.BatchNorm1d(gate_channels[i+1]) )
self.gate_c.add_module( 'gate_c_relu_%d'%(i+1), nn.ReLU() )
self.gate_c.add_module( 'gate_c_fc_final', nn.Linear(gate_channels[-2], gate_channels[-1]) )
def forward(self, in_tensor):
avg_pool = F.avg_pool2d( in_tensor, in_tensor.size(2), stride=in_tensor.size(2) )
return self.gate_c( avg_pool ).unsqueeze(2).unsqueeze(3).expand_as(in_tensor)
class SpatialGate(nn.Module):
def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num=2, dilation_val=4):
super(SpatialGate, self).__init__()
self.gate_s = nn.Sequential()
self.gate_s.add_module( 'gate_s_conv_reduce0', nn.Conv2d(gate_channel, gate_channel//reduction_ratio, kernel_size=1))
self.gate_s.add_module( 'gate_s_bn_reduce0', nn.BatchNorm2d(gate_channel//reduction_ratio) )
self.gate_s.add_module( 'gate_s_relu_reduce0',nn.ReLU() )
for i in range( dilation_conv_num ):
self.gate_s.add_module( 'gate_s_conv_di_%d'%i, nn.Conv2d(gate_channel//reduction_ratio, gate_channel//reduction_ratio, kernel_size=3, \
padding=dilation_val, dilation=dilation_val) )
self.gate_s.add_module( 'gate_s_bn_di_%d'%i, nn.BatchNorm2d(gate_channel//reduction_ratio) )
self.gate_s.add_module( 'gate_s_relu_di_%d'%i, nn.ReLU() )
self.gate_s.add_module( 'gate_s_conv_final', nn.Conv2d(gate_channel//reduction_ratio, 1, kernel_size=1) )
def forward(self, in_tensor):
return self.gate_s( in_tensor ).expand_as(in_tensor)
class BAM(nn.Module):
def __init__(self, gate_channel):
super(BAM, self).__init__()
self.channel_att = ChannelGate(gate_channel)
self.spatial_att = SpatialGate(gate_channel)
def forward(self,in_tensor):
att = 1 + F.sigmoid( self.channel_att(in_tensor) * self.spatial_att(in_tensor) )
return att * in_tensor
================================================
FILE: MODELS/cbam.py
================================================
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels // reduction_ratio, gate_channels)
)
self.pool_types = pool_types
def forward(self, x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type=='avg':
avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( avg_pool )
elif pool_type=='max':
max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( max_pool )
elif pool_type=='lp':
lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( lp_pool )
elif pool_type=='lse':
# LSE pool only
lse_pool = logsumexp_2d(x)
channel_att_raw = self.mlp( lse_pool )
if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
return x * scale
def logsumexp_2d(tensor):
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
return outputs
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = torch.sigmoid(x_out) # broadcasting
return x * scale
class CBAM(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
super(CBAM, self).__init__()
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
self.no_spatial=no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_out = self.ChannelGate(x)
if not self.no_spatial:
x_out = self.SpatialGate(x_out)
return x_out
================================================
FILE: MODELS/model_resnet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn import init
from .cbam import *
from .bam import *
from .attention import *
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, shape,stride=1, downsample=None, use_cbam=False, use_nam=False,no_spatial=True):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
self.no_spatial = no_spatial
if use_cbam:
self.cbam = CBAM(planes, 16)
else:
self.cbam = None
if use_nam:
self.nam = Att(planes,no_spatial=self.no_spatial,shape=shape)
else:
self.nam = None
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
if not self.cbam is None:
out = self.cbam(out)
if not self.nam is None:
out = self.nam(out)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes,shape, stride=1, downsample=None, use_cbam=False, use_nam=False, no_spatial=False):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.no_spatial = no_spatial
if use_cbam:
self.cbam = CBAM(planes * 4, 16)
else:
self.cbam = None
if use_nam:
self.nam = Att(planes * 4, no_spatial=self.no_spatial,shape=shape)
else:
self.nam = None
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
if not self.cbam is None:
out = self.cbam(out)
if not self.nam is None:
out = self.nam(out)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, network_type, num_classes, att_type=None):
self.inplanes = 64
super(ResNet, self).__init__()
self.network_type = network_type
# different model config between ImageNet and CIFAR
if network_type == "ImageNet":
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.avgpool = nn.AvgPool2d(7)
shape=56
else:
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
shape=32
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
if att_type == 'BAM':
self.bam1 = BAM(64*block.expansion)
self.bam2 = BAM(128*block.expansion)
self.bam3 = BAM(256*block.expansion)
else:
self.bam1, self.bam2, self.bam3 = None, None, None
self.layer1 = self._make_layer(block, 64, shape,layers[0], att_type=att_type, no_spatial=False)
self.layer2 = self._make_layer(block, 128,shape//2, layers[1], stride=2, att_type=att_type, no_spatial=False)
self.layer3 = self._make_layer(block, 256, shape//4,layers[2], stride=2, att_type=att_type, no_spatial=False)
self.layer4 = self._make_layer(block, 512, shape//8, layers[3], stride=2, att_type=att_type, no_spatial=False)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight.data)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm1d):
m.weight.data.fill_(1)
m.bias.data.zero_()
'''
init.kaiming_normal_(self.fc.weight)
for key in self.state_dict():
if key.split('.')[-1] == "weight":
if "conv" in key:
init.kaiming_normal_(self.state_dict()[key], mode='fan_out')
if "bn" in key:
if "SpatialGate" in key:
self.state_dict()[key][...] = 0
else:
self.state_dict()[key][...] = 1
elif key.split(".")[-1] == 'bias':
self.state_dict()[key][...] = 0
'''
def _make_layer(self, block, planes, shape, blocks, stride=1, att_type=None, no_spatial=True):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(
block(self.inplanes, planes, shape,stride, downsample, use_cbam=att_type == 'CBAM', use_nam=att_type == 'NAM',
no_spatial=no_spatial))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, shape,use_cbam=att_type == 'CBAM', use_nam=att_type == 'NAM',
no_spatial=no_spatial))
return nn.Sequential(*layers)
def forward(self, x,label=None):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
if self.network_type == "ImageNet":
x = self.maxpool(x)
x = self.layer1(x)
if not self.bam1 is None:
x = self.bam1(x)
x = self.layer2(x)
if not self.bam2 is None:
x = self.bam2(x)
x = self.layer3(x)
if not self.bam3 is None:
x = self.bam3(x)
x = self.layer4(x)
if self.network_type == "ImageNet":
x = self.avgpool(x)
else:
x = F.avg_pool2d(x, 4)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def ResidualNet(network_type, depth, num_classes, att_type):
assert network_type in ["ImageNet", "CIFAR10", "CIFAR100"], "network type should be ImageNet or CIFAR10 / CIFAR100"
assert depth in [18, 34, 50, 101], 'network depth should be 18, 34, 50 or 101'
if depth == 18:
model = ResNet(BasicBlock, [2, 2, 2, 2], network_type, num_classes, att_type)
elif depth == 34:
model = ResNet(BasicBlock, [3, 4, 6, 3], network_type, num_classes, att_type)
elif depth == 50:
model = ResNet(Bottleneck, [3, 4, 6, 3], network_type, num_classes, att_type)
elif depth == 101:
model = ResNet(Bottleneck, [3, 4, 23, 3], network_type, num_classes, att_type)
return model
================================================
FILE: README.md
================================================
# NAM
================================================
FILE: train_cifar100.py
================================================
import argparse
import os
import shutil
import time
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from MODELS.model_resnet import *
from PIL import ImageFile
from thop import profile
ImageFile.LOAD_TRUNCATED_IMAGES = True
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR', help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet',help='model architecture: ' +' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('--depth', default=50, type=int, metavar='D', help='model depth')
parser.add_argument('--ngpu', default=4, type=int, metavar='G', help='number of gpus to use')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=128, type=int, metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=100, type=int,metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument("--seed", type=int, default=1234, metavar='BS', help='input batch size for training (default: 64)')
parser.add_argument("--prefix", type=str, required=True, metavar='PFX', help='prefix for logging & checkpoint saving')
parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluation only')
parser.add_argument('--att-type', type=str, choices=['BAM', 'CBAM','NAM'], default=None)
parser.add_argument('--milestones',type=list,default=[60, 120, 160],help='optimizer milestones')
parser.add_argument('--set', type=str, default='cifar100', help='location of the data corpus')
parser.add_argument('--gamma',type=float,default=0.2,help='gamma')##
parser.add_argument('--sparsity-regularization', '-sr', dest='sr', action='store_true',help='train with channel sparsity regularization')
parser.add_argument('--s', type=float, default=0.0001,help='scale sparse rate (default: 0.0001)')
best_prec1 = 0
if not os.path.exists('./checkpoints'):
os.mkdir('./checkpoints')
def updateBN(model):
Op = model._modules.items()
for m in Op:
if m[0]=='layer1':
for m1 in m[1]:
m1.nam.Channel_Att.bn2.weight.grad.data.add_(args.s * torch.sign(m1.nam.Channel_Att.bn2.weight.data))
if m[0]=='layer2':
for m2 in m[1]:
m2.nam.Channel_Att.bn2.weight.grad.data.add_(args.s * torch.sign(m2.nam.Channel_Att.bn2.weight.data))
if m[0]=='layer3':
for m3 in m[1]:
m3.nam.Channel_Att.bn2.weight.grad.data.add_(args.s * torch.sign(m3.nam.Channel_Att.bn2.weight.data))
if m[0]=='layer4':
for m4 in m[1]:
m4.nam.Channel_Att.bn2.weight.grad.data.add_(args.s * torch.sign(m4.nam.Channel_Att.bn2.weight.data))
def main():
global args, best_prec1
global viz, train_lot, test_lot
args = parser.parse_args()
print ("args", args)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
random.seed(args.seed)
# create model
if args.arch == "resnet":
model = ResidualNet( 'CIFAR100', args.depth, 100, args.att_type )
inputs = torch.randn(1, 3, 32, 32)
total_ops, total_params = profile(model, (inputs,), verbose=False)
print(" %.2f | %.2f" % ( total_params / (1000 ** 2), total_ops / (1000 ** 3)))
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
#model = torch.nn.DataParallel(model).cuda()
model = model.cuda()
#print ("model")
#print (model)
train_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=args.milestones, gamma=args.gamma)##
# get the number of model parameters
print('Number of model parameters: {}'.format(
sum([p.data.nelement() for p in model.parameters()])))
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
if 'optimizer' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
cudnn.benchmark = True
if args.set=='cifar100':
train_transform, valid_transform = data_transforms_cifar10(args)
train_data = datasets.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
valid_data = datasets.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
else:
train_transform, valid_transform = data_transforms_cifar10(args)
train_data = datasets.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
valid_data = datasets.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers)
val_loader = torch.utils.data.DataLoader(
valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers)
for epoch in range(args.start_epoch, args.epochs):
#adjust_learning_rate(optimizer, epoch)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch)
# evaluate on validation set
prec1 = validate(val_loader, model, criterion, epoch)
train_scheduler.step()
# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
}, is_best, args.prefix)
def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to train mode
model.train()
end = time.time()
for i, (input, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
target = target.cuda()
input_var = torch.autograd.Variable(input).cuda()
target_var = torch.autograd.Variable(target)
# compute output
output = model(input_var)
loss = criterion(output, target_var)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))
top5.update(prec5.item(), input.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
if args.sr and args.att_type=='NAM':
updateBN(model)
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
def validate(val_loader, model, criterion, epoch):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
for i, (input, target) in enumerate(val_loader):
target = target.cuda()
with torch.no_grad():
input_var = torch.autograd.Variable(input).cuda()
target_var = torch.autograd.Variable(target)
# compute output
#output = model(input_var)
#loss = criterion(output, target_var)
output = model(input_var)#,target_var
loss = criterion(output, target_var)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))
top5.update(prec5.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg
def save_checkpoint(state, is_best, prefix):
filename='./checkpoints/%s_checkpoint.pth.tar'%prefix
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, './checkpoints/%s_model_best.pth.tar'%prefix)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = args.lr * (0.1 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def data_transforms_cifar10(args):
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
#if args.cutout:
# train_transform.transforms.append(Cutout(args.cutout_length))
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
return train_transform, valid_transform
def data_transforms_cifar100(args):
#CIFAR_MEAN = [0.5071, 0.4865, 0.4409]
#CIFAR_STD = [0.2673, 0.2564, 0.2762]
CIFAR_MEAN = [125.3/ 255.0, 123.0/ 255.0, 113.9/ 255.0]
CIFAR_STD = [63.0/ 255.0, 62.1/ 255.0, 66.7/ 255.0]
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
#if args.cutout:
# train_transform.transforms.append(Cutout(args.cutout_length))
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
return train_transform, valid_transform
if __name__ == '__main__':
main()
gitextract_tsj638te/ ├── MODELS/ │ ├── attention.py │ ├── bam.py │ ├── cbam.py │ └── model_resnet.py ├── README.md └── train_cifar100.py
SYMBOL INDEX (59 symbols across 5 files)
FILE: MODELS/attention.py
class Channel_Att (line 6) | class Channel_Att(nn.Module):
method __init__ (line 7) | def __init__(self, channels, t=16):
method forward (line 14) | def forward(self, x):
class Att (line 28) | class Att(nn.Module):
method __init__ (line 29) | def __init__(self, channels,shape, out_channels=None, no_spatial=True):
method forward (line 33) | def forward(self, x):
FILE: MODELS/bam.py
class Flatten (line 6) | class Flatten(nn.Module):
method forward (line 7) | def forward(self, x):
class ChannelGate (line 9) | class ChannelGate(nn.Module):
method __init__ (line 10) | def __init__(self, gate_channel, reduction_ratio=16, num_layers=1):
method forward (line 23) | def forward(self, in_tensor):
class SpatialGate (line 27) | class SpatialGate(nn.Module):
method __init__ (line 28) | def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num...
method forward (line 40) | def forward(self, in_tensor):
class BAM (line 42) | class BAM(nn.Module):
method __init__ (line 43) | def __init__(self, gate_channel):
method forward (line 47) | def forward(self,in_tensor):
FILE: MODELS/cbam.py
class BasicConv (line 6) | class BasicConv(nn.Module):
method __init__ (line 7) | def __init__(self, in_planes, out_planes, kernel_size, stride=1, paddi...
method forward (line 14) | def forward(self, x):
class Flatten (line 22) | class Flatten(nn.Module):
method forward (line 23) | def forward(self, x):
class ChannelGate (line 26) | class ChannelGate(nn.Module):
method __init__ (line 27) | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg...
method forward (line 37) | def forward(self, x):
function logsumexp_2d (line 62) | def logsumexp_2d(tensor):
class ChannelPool (line 68) | class ChannelPool(nn.Module):
method forward (line 69) | def forward(self, x):
class SpatialGate (line 72) | class SpatialGate(nn.Module):
method __init__ (line 73) | def __init__(self):
method forward (line 78) | def forward(self, x):
class CBAM (line 84) | class CBAM(nn.Module):
method __init__ (line 85) | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg...
method forward (line 91) | def forward(self, x):
FILE: MODELS/model_resnet.py
function conv3x3 (line 11) | def conv3x3(in_planes, out_planes, stride=1):
class BasicBlock (line 17) | class BasicBlock(nn.Module):
method __init__ (line 20) | def __init__(self, inplanes, planes, shape,stride=1, downsample=None, ...
method forward (line 41) | def forward(self, x):
class Bottleneck (line 66) | class Bottleneck(nn.Module):
method __init__ (line 69) | def __init__(self, inplanes, planes,shape, stride=1, downsample=None, ...
method forward (line 94) | def forward(self, x):
class ResNet (line 128) | class ResNet(nn.Module):
method __init__ (line 129) | def __init__(self, block, layers, network_type, num_classes, att_type=...
method _make_layer (line 185) | def _make_layer(self, block, planes, shape, blocks, stride=1, att_type...
method forward (line 204) | def forward(self, x,label=None):
function ResidualNet (line 234) | def ResidualNet(network_type, depth, num_classes, att_type):
FILE: train_cifar100.py
function updateBN (line 55) | def updateBN(model):
function main (line 71) | def main():
function train (line 163) | def train(train_loader, model, criterion, optimizer, epoch):
function validate (line 212) | def validate(val_loader, model, criterion, epoch):
function save_checkpoint (line 260) | def save_checkpoint(state, is_best, prefix):
class AverageMeter (line 267) | class AverageMeter(object):
method __init__ (line 269) | def __init__(self):
method reset (line 272) | def reset(self):
method update (line 278) | def update(self, val, n=1):
function adjust_learning_rate (line 285) | def adjust_learning_rate(optimizer, epoch):
function accuracy (line 292) | def accuracy(output, target, topk=(1,)):
function data_transforms_cifar10 (line 307) | def data_transforms_cifar10(args):
function data_transforms_cifar100 (line 326) | def data_transforms_cifar100(args):
Condensed preview — 6 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (31K chars).
[
{
"path": "MODELS/attention.py",
"chars": 933,
"preview": "import torch.nn as nn\nimport torch\nfrom torch.nn import functional as F\n\n\nclass Channel_Att(nn.Module):\n def __init__"
},
{
"path": "MODELS/bam.py",
"chars": 2726,
"preview": "import torch\nimport math\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Flatten(nn.Module):\n def forwar"
},
{
"path": "MODELS/cbam.py",
"chars": 3868,
"preview": "import torch\nimport math\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass BasicConv(nn.Module):\n def __in"
},
{
"path": "MODELS/model_resnet.py",
"chars": 8153,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nfrom torch.nn import init\nfrom .cbam impo"
},
{
"path": "README.md",
"chars": 5,
"preview": "# NAM"
},
{
"path": "train_cifar100.py",
"chars": 13878,
"preview": "import argparse\nimport os\nimport shutil\nimport time\nimport random\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.pa"
}
]
About this extraction
This page contains the full source code of the Christian-lyc/NAM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 6 files (28.9 KB), approximately 7.8k tokens, and a symbol index with 59 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.