Repository: tahmid0007/VisualTransformers
Branch: main
Commit: d9fd5834e9a0
Files: 2
Total size: 12.9 KB
Directory structure:
gitextract_ps0_r7th/
├── README.md
└── ResViT.py
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
# (Pytorch) Visual Transformers: Token-based Image Representation and Processing for Computer Vision:
A Pytorch Implementation of the following paper "Visual Transformers: Token-based Image Representation and Processing for Computer Vision"
**Visual Transformers**
Find the original paper [here](https://arxiv.org/abs/2006.03677).
- This Pytorch Implementation is based on [This repo](https://github.com/tahmid0007/VisionTransformer). The default dataset used here is CIFAR10 which can be easily changed to ImageNet or anything else.
- You might need to install einops.
================================================
FILE: ResViT.py
================================================
# -*- coding: utf-8 -*-
"""
Created on Fri Oct 16 11:37:52 2020
@author: mthossain
"""
import PIL
import time
import torch
import torchvision
import torch.nn.functional as F
from einops import rearrange
from torch import nn
import torch.nn.init as init
def _weights_init(m):
classname = m.__class__.__name__
#print(classname)
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight)
class LambdaLayer(nn.Module):
def __init__(self, lambd):
super(LambdaLayer, self).__init__()
self.lambd = lambd
def forward(self, x):
return self.lambd(x)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, option='A'):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
if option == 'A':
"""
For CIFAR10 ResNet paper uses option A.
"""
self.shortcut = LambdaLayer(lambda x:
F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
elif option == 'B':
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
#print(out.size())
return out
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class LayerNormalize(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class MLP_Block(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.1):
super().__init__()
self.nn1 = nn.Linear(dim, hidden_dim)
torch.nn.init.xavier_uniform_(self.nn1.weight)
torch.nn.init.normal_(self.nn1.bias, std = 1e-6)
self.af1 = nn.GELU()
self.do1 = nn.Dropout(dropout)
self.nn2 = nn.Linear(hidden_dim, dim)
torch.nn.init.xavier_uniform_(self.nn2.weight)
torch.nn.init.normal_(self.nn2.bias, std = 1e-6)
self.do2 = nn.Dropout(dropout)
def forward(self, x):
x = self.nn1(x)
x = self.af1(x)
x = self.do1(x)
x = self.nn2(x)
x = self.do2(x)
return x
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dropout = 0.1):
super().__init__()
self.heads = heads
self.scale = dim ** -0.5 # 1/sqrt(dim)
self.to_qkv = nn.Linear(dim, dim * 3, bias = True) # Wq,Wk,Wv for each vector, thats why *3
torch.nn.init.xavier_uniform_(self.to_qkv.weight)
torch.nn.init.zeros_(self.to_qkv.bias)
self.nn1 = nn.Linear(dim, dim)
torch.nn.init.xavier_uniform_(self.nn1.weight)
torch.nn.init.zeros_(self.nn1.bias)
self.do1 = nn.Dropout(dropout)
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x) #gets q = Q = Wq matmul x1, k = Wk mm x2, v = Wv mm x3
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h) # split into multi head attentions
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, float('-inf'))
del mask
attn = dots.softmax(dim=-1) #follow the softmax,q,d,v equation in the paper
out = torch.einsum('bhij,bhjd->bhid', attn, v) #product of v times whatever inside softmax
out = rearrange(out, 'b h n d -> b n (h d)') #concat heads into one matrix, ready for next encoder block
out = self.nn1(out)
out = self.do1(out)
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(LayerNormalize(dim, Attention(dim, heads = heads, dropout = dropout))),
Residual(LayerNormalize(dim, MLP_Block(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None):
for attention, mlp in self.layers:
x = attention(x, mask = mask) # go to attention
x = mlp(x) #go to MLP_Block
return x
class ViTResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10, dim = 128, num_tokens = 8, mlp_dim = 256, heads = 8, depth = 6, emb_dropout = 0.1, dropout= 0.1):
super(ViTResNet, self).__init__()
self.in_planes = 16
self.L = num_tokens
self.cT = dim
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) #8x8 feature maps (64 in total)
self.apply(_weights_init)
# Tokenization
self.token_wA = nn.Parameter(torch.empty(BATCH_SIZE_TRAIN,self.L, 64),requires_grad = True) #Tokenization parameters
torch.nn.init.xavier_uniform_(self.token_wA)
self.token_wV = nn.Parameter(torch.empty(BATCH_SIZE_TRAIN,64,self.cT),requires_grad = True) #Tokenization parameters
torch.nn.init.xavier_uniform_(self.token_wV)
self.pos_embedding = nn.Parameter(torch.empty(1, (num_tokens + 1), dim))
torch.nn.init.normal_(self.pos_embedding, std = .02) # initialized based on the paper
#self.patch_conv= nn.Conv2d(64,dim, self.patch_size, stride = self.patch_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim)) #initialized based on the paper
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.to_cls_token = nn.Identity()
self.nn1 = nn.Linear(dim, num_classes) # if finetuning, just use a linear layer without further hidden layers (paper)
torch.nn.init.xavier_uniform_(self.nn1.weight)
torch.nn.init.normal_(self.nn1.bias, std = 1e-6)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, img, mask = None):
x = F.relu(self.bn1(self.conv1(img)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = rearrange(x, 'b c h w -> b (h w) c') # 64 vectors each with 64 points. These are the sequences or word vecotrs like in NLP
#Tokenization
wa = rearrange(self.token_wA, 'b h w -> b w h') #Transpose
A= torch.einsum('bij,bjk->bik', x, wa)
A = rearrange(A, 'b h w -> b w h') #Transpose
A = A.softmax(dim=-1)
VV= torch.einsum('bij,bjk->bik', x, self.token_wV)
T = torch.einsum('bij,bjk->bik', A, VV)
#print(T.size())
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
x = torch.cat((cls_tokens, T), dim=1)
x += self.pos_embedding
x = self.dropout(x)
x = self.transformer(x, mask) #main game
x = self.to_cls_token(x[:, 0])
x = self.nn1(x)
return x
BATCH_SIZE_TRAIN = 100
BATCH_SIZE_TEST = 100
DL_PATH = "C:\Pytorch\Spyder\CIFAR10_data" # Use your own path
# CIFAR10: 60000 32x32 color images in 10 classes, with 6000 images per class
transform = torchvision.transforms.Compose(
[torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.RandomRotation(10, resample=PIL.Image.BILINEAR),
torchvision.transforms.RandomAffine(8, translate=(.15,.15)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
train_dataset = torchvision.datasets.CIFAR10(DL_PATH, train=True,
download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(DL_PATH, train=False,
download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN,
shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE_TEST,
shuffle=False)
def train(model, optimizer, data_loader, loss_history):
total_samples = len(data_loader.dataset)
model.train()
for i, (data, target) in enumerate(data_loader):
optimizer.zero_grad()
output = F.log_softmax(model(data), dim=1)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if i % 100 == 0:
print('[' + '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(total_samples) +
' (' + '{:3.0f}'.format(100 * i / len(data_loader)) + '%)] Loss: ' +
'{:6.4f}'.format(loss.item()))
loss_history.append(loss.item())
def evaluate(model, data_loader, loss_history):
model.eval()
total_samples = len(data_loader.dataset)
correct_samples = 0
total_loss = 0
with torch.no_grad():
for data, target in data_loader:
output = F.log_softmax(model(data), dim=1)
loss = F.nll_loss(output, target, reduction='sum')
_, pred = torch.max(output, dim=1)
total_loss += loss.item()
correct_samples += pred.eq(target).sum()
avg_loss = total_loss / total_samples
loss_history.append(avg_loss)
print('\nAverage test loss: ' + '{:.4f}'.format(avg_loss) +
' Accuracy:' + '{:5}'.format(correct_samples) + '/' +
'{:5}'.format(total_samples) + ' (' +
'{:4.2f}'.format(100.0 * correct_samples / total_samples) + '%)\n')
N_EPOCHS = 150
model = ViTResNet(BasicBlock, [3, 3, 3])
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
#optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate,momentum=.9,weight_decay=1e-4)
#lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[35,48],gamma = 0.1)
train_loss_history, test_loss_history = [], []
for epoch in range(1, N_EPOCHS + 1):
print('Epoch:', epoch)
start_time = time.time()
train(model, optimizer, train_loader, train_loss_history)
print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')
evaluate(model, test_loader, test_loss_history)
print('Execution time')
PATH = ".\ViTRes.pt" # Use your own path
torch.save(model.state_dict(), PATH)
# =============================================================================
# model = ViT()
# model.load_state_dict(torch.load(PATH))
# model.eval()
# =============================================================================