Repository: tahmid0007/VisualTransformers Branch: main Commit: d9fd5834e9a0 Files: 2 Total size: 12.9 KB Directory structure: gitextract_ps0_r7th/ ├── README.md └── ResViT.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # (Pytorch) Visual Transformers: Token-based Image Representation and Processing for Computer Vision: A Pytorch Implementation of the following paper "Visual Transformers: Token-based Image Representation and Processing for Computer Vision" **Visual Transformers** Find the original paper [here](https://arxiv.org/abs/2006.03677).

- This Pytorch Implementation is based on [This repo](https://github.com/tahmid0007/VisionTransformer). The default dataset used here is CIFAR10 which can be easily changed to ImageNet or anything else. - You might need to install einops. ================================================ FILE: ResViT.py ================================================ # -*- coding: utf-8 -*- """ Created on Fri Oct 16 11:37:52 2020 @author: mthossain """ import PIL import time import torch import torchvision import torch.nn.functional as F from einops import rearrange from torch import nn import torch.nn.init as init def _weights_init(m): classname = m.__class__.__name__ #print(classname) if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight) class LambdaLayer(nn.Module): def __init__(self, lambd): super(LambdaLayer, self).__init__() self.lambd = lambd def forward(self, x): return self.lambd(x) class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1, option='A'): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != planes: if option == 'A': """ For CIFAR10 ResNet paper uses option A. """ self.shortcut = LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) elif option == 'B': self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) #print(out.size()) return out class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, **kwargs): return self.fn(x, **kwargs) + x class LayerNormalize(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) class MLP_Block(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.1): super().__init__() self.nn1 = nn.Linear(dim, hidden_dim) torch.nn.init.xavier_uniform_(self.nn1.weight) torch.nn.init.normal_(self.nn1.bias, std = 1e-6) self.af1 = nn.GELU() self.do1 = nn.Dropout(dropout) self.nn2 = nn.Linear(hidden_dim, dim) torch.nn.init.xavier_uniform_(self.nn2.weight) torch.nn.init.normal_(self.nn2.bias, std = 1e-6) self.do2 = nn.Dropout(dropout) def forward(self, x): x = self.nn1(x) x = self.af1(x) x = self.do1(x) x = self.nn2(x) x = self.do2(x) return x class Attention(nn.Module): def __init__(self, dim, heads = 8, dropout = 0.1): super().__init__() self.heads = heads self.scale = dim ** -0.5 # 1/sqrt(dim) self.to_qkv = nn.Linear(dim, dim * 3, bias = True) # Wq,Wk,Wv for each vector, thats why *3 torch.nn.init.xavier_uniform_(self.to_qkv.weight) torch.nn.init.zeros_(self.to_qkv.bias) self.nn1 = nn.Linear(dim, dim) torch.nn.init.xavier_uniform_(self.nn1.weight) torch.nn.init.zeros_(self.nn1.bias) self.do1 = nn.Dropout(dropout) def forward(self, x, mask = None): b, n, _, h = *x.shape, self.heads qkv = self.to_qkv(x) #gets q = Q = Wq matmul x1, k = Wk mm x2, v = Wv mm x3 q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h) # split into multi head attentions dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale if mask is not None: mask = F.pad(mask.flatten(1), (1, 0), value = True) assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' mask = mask[:, None, :] * mask[:, :, None] dots.masked_fill_(~mask, float('-inf')) del mask attn = dots.softmax(dim=-1) #follow the softmax,q,d,v equation in the paper out = torch.einsum('bhij,bhjd->bhid', attn, v) #product of v times whatever inside softmax out = rearrange(out, 'b h n d -> b n (h d)') #concat heads into one matrix, ready for next encoder block out = self.nn1(out) out = self.do1(out) return out class Transformer(nn.Module): def __init__(self, dim, depth, heads, mlp_dim, dropout): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ Residual(LayerNormalize(dim, Attention(dim, heads = heads, dropout = dropout))), Residual(LayerNormalize(dim, MLP_Block(dim, mlp_dim, dropout = dropout))) ])) def forward(self, x, mask = None): for attention, mlp in self.layers: x = attention(x, mask = mask) # go to attention x = mlp(x) #go to MLP_Block return x class ViTResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10, dim = 128, num_tokens = 8, mlp_dim = 256, heads = 8, depth = 6, emb_dropout = 0.1, dropout= 0.1): super(ViTResNet, self).__init__() self.in_planes = 16 self.L = num_tokens self.cT = dim self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(16) self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) #8x8 feature maps (64 in total) self.apply(_weights_init) # Tokenization self.token_wA = nn.Parameter(torch.empty(BATCH_SIZE_TRAIN,self.L, 64),requires_grad = True) #Tokenization parameters torch.nn.init.xavier_uniform_(self.token_wA) self.token_wV = nn.Parameter(torch.empty(BATCH_SIZE_TRAIN,64,self.cT),requires_grad = True) #Tokenization parameters torch.nn.init.xavier_uniform_(self.token_wV) self.pos_embedding = nn.Parameter(torch.empty(1, (num_tokens + 1), dim)) torch.nn.init.normal_(self.pos_embedding, std = .02) # initialized based on the paper #self.patch_conv= nn.Conv2d(64,dim, self.patch_size, stride = self.patch_size) self.cls_token = nn.Parameter(torch.zeros(1, 1, dim)) #initialized based on the paper self.dropout = nn.Dropout(emb_dropout) self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout) self.to_cls_token = nn.Identity() self.nn1 = nn.Linear(dim, num_classes) # if finetuning, just use a linear layer without further hidden layers (paper) torch.nn.init.xavier_uniform_(self.nn1.weight) torch.nn.init.normal_(self.nn1.bias, std = 1e-6) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1]*(num_blocks-1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, img, mask = None): x = F.relu(self.bn1(self.conv1(img))) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = rearrange(x, 'b c h w -> b (h w) c') # 64 vectors each with 64 points. These are the sequences or word vecotrs like in NLP #Tokenization wa = rearrange(self.token_wA, 'b h w -> b w h') #Transpose A= torch.einsum('bij,bjk->bik', x, wa) A = rearrange(A, 'b h w -> b w h') #Transpose A = A.softmax(dim=-1) VV= torch.einsum('bij,bjk->bik', x, self.token_wV) T = torch.einsum('bij,bjk->bik', A, VV) #print(T.size()) cls_tokens = self.cls_token.expand(img.shape[0], -1, -1) x = torch.cat((cls_tokens, T), dim=1) x += self.pos_embedding x = self.dropout(x) x = self.transformer(x, mask) #main game x = self.to_cls_token(x[:, 0]) x = self.nn1(x) return x BATCH_SIZE_TRAIN = 100 BATCH_SIZE_TEST = 100 DL_PATH = "C:\Pytorch\Spyder\CIFAR10_data" # Use your own path # CIFAR10: 60000 32x32 color images in 10 classes, with 6000 images per class transform = torchvision.transforms.Compose( [torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.RandomRotation(10, resample=PIL.Image.BILINEAR), torchvision.transforms.RandomAffine(8, translate=(.15,.15)), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) train_dataset = torchvision.datasets.CIFAR10(DL_PATH, train=True, download=True, transform=transform) test_dataset = torchvision.datasets.CIFAR10(DL_PATH, train=False, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE_TEST, shuffle=False) def train(model, optimizer, data_loader, loss_history): total_samples = len(data_loader.dataset) model.train() for i, (data, target) in enumerate(data_loader): optimizer.zero_grad() output = F.log_softmax(model(data), dim=1) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if i % 100 == 0: print('[' + '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(total_samples) + ' (' + '{:3.0f}'.format(100 * i / len(data_loader)) + '%)] Loss: ' + '{:6.4f}'.format(loss.item())) loss_history.append(loss.item()) def evaluate(model, data_loader, loss_history): model.eval() total_samples = len(data_loader.dataset) correct_samples = 0 total_loss = 0 with torch.no_grad(): for data, target in data_loader: output = F.log_softmax(model(data), dim=1) loss = F.nll_loss(output, target, reduction='sum') _, pred = torch.max(output, dim=1) total_loss += loss.item() correct_samples += pred.eq(target).sum() avg_loss = total_loss / total_samples loss_history.append(avg_loss) print('\nAverage test loss: ' + '{:.4f}'.format(avg_loss) + ' Accuracy:' + '{:5}'.format(correct_samples) + '/' + '{:5}'.format(total_samples) + ' (' + '{:4.2f}'.format(100.0 * correct_samples / total_samples) + '%)\n') N_EPOCHS = 150 model = ViTResNet(BasicBlock, [3, 3, 3]) optimizer = torch.optim.Adam(model.parameters(), lr=0.003) #optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate,momentum=.9,weight_decay=1e-4) #lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[35,48],gamma = 0.1) train_loss_history, test_loss_history = [], [] for epoch in range(1, N_EPOCHS + 1): print('Epoch:', epoch) start_time = time.time() train(model, optimizer, train_loader, train_loss_history) print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds') evaluate(model, test_loader, test_loss_history) print('Execution time') PATH = ".\ViTRes.pt" # Use your own path torch.save(model.state_dict(), PATH) # ============================================================================= # model = ViT() # model.load_state_dict(torch.load(PATH)) # model.eval() # =============================================================================