[
  {
    "path": "README.md",
    "content": "# (Pytorch) Visual Transformers: Token-based Image Representation and Processing for Computer Vision:\nA Pytorch Implementation of the following paper \"Visual Transformers: Token-based Image Representation and Processing for Computer Vision\"\n\n**Visual Transformers**\nFind the original paper [here](https://arxiv.org/abs/2006.03677).\n<p align=\"center\">\n  <img src=\"./Overview.png\" width=\"600\" title=\"Vision transformer\">\n</p>\n\n- This Pytorch Implementation is based on [This repo](https://github.com/tahmid0007/VisionTransformer). The default dataset used here is CIFAR10 which can be easily changed to ImageNet or anything else.\n- You might need to install einops.\n"
  },
  {
    "path": "ResViT.py",
    "content": "# -*- coding: utf-8 -*-\r\n\"\"\"\r\nCreated on Fri Oct 16 11:37:52 2020\r\n\r\n@author: mthossain\r\n\"\"\"\r\nimport PIL\r\nimport time\r\nimport torch\r\nimport torchvision\r\nimport torch.nn.functional as F\r\nfrom einops import rearrange\r\nfrom torch import nn\r\nimport torch.nn.init as init\r\ndef _weights_init(m):\r\n    classname = m.__class__.__name__\r\n    #print(classname)\r\n    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):\r\n        init.kaiming_normal_(m.weight)\r\n\r\nclass LambdaLayer(nn.Module):\r\n    def __init__(self, lambd):\r\n        super(LambdaLayer, self).__init__()\r\n        self.lambd = lambd\r\n\r\n    def forward(self, x):\r\n        return self.lambd(x)\r\n\r\n\r\nclass BasicBlock(nn.Module):\r\n    expansion = 1\r\n\r\n    def __init__(self, in_planes, planes, stride=1, option='A'):\r\n        super(BasicBlock, self).__init__()\r\n        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\r\n        self.bn1 = nn.BatchNorm2d(planes)\r\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)\r\n        self.bn2 = nn.BatchNorm2d(planes)\r\n\r\n        self.shortcut = nn.Sequential()\r\n        if stride != 1 or in_planes != planes:\r\n            if option == 'A':\r\n                \"\"\"\r\n                For CIFAR10 ResNet paper uses option A.\r\n                \"\"\"\r\n                self.shortcut = LambdaLayer(lambda x:\r\n                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), \"constant\", 0))\r\n            elif option == 'B':\r\n                self.shortcut = nn.Sequential(\r\n                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),\r\n                     nn.BatchNorm2d(self.expansion * planes)\r\n                )\r\n\r\n    def forward(self, x):\r\n        out = F.relu(self.bn1(self.conv1(x)))\r\n        out = self.bn2(self.conv2(out))\r\n        out += self.shortcut(x)\r\n        out = F.relu(out)\r\n        #print(out.size())\r\n        return out\r\n\r\n\r\n\r\nclass Residual(nn.Module):\r\n    def __init__(self, fn):\r\n        super().__init__()\r\n        self.fn = fn\r\n    def forward(self, x, **kwargs):\r\n        return self.fn(x, **kwargs) + x\r\n\r\nclass LayerNormalize(nn.Module):\r\n    def __init__(self, dim, fn):\r\n        super().__init__()\r\n        self.norm = nn.LayerNorm(dim)\r\n        self.fn = fn\r\n    def forward(self, x, **kwargs):\r\n        return self.fn(self.norm(x), **kwargs)\r\n\r\nclass MLP_Block(nn.Module):\r\n    def __init__(self, dim, hidden_dim, dropout = 0.1):\r\n        super().__init__()\r\n        self.nn1 = nn.Linear(dim, hidden_dim)\r\n        torch.nn.init.xavier_uniform_(self.nn1.weight)\r\n        torch.nn.init.normal_(self.nn1.bias, std = 1e-6)\r\n        self.af1 = nn.GELU()\r\n        self.do1 = nn.Dropout(dropout)\r\n        self.nn2 = nn.Linear(hidden_dim, dim)\r\n        torch.nn.init.xavier_uniform_(self.nn2.weight)\r\n        torch.nn.init.normal_(self.nn2.bias, std = 1e-6)\r\n        self.do2 = nn.Dropout(dropout)\r\n        \r\n    def forward(self, x):\r\n        x = self.nn1(x)\r\n        x = self.af1(x)\r\n        x = self.do1(x)\r\n        x = self.nn2(x)\r\n        x = self.do2(x)\r\n        \r\n        return x\r\n\r\nclass Attention(nn.Module):\r\n    def __init__(self, dim, heads = 8, dropout = 0.1):\r\n        super().__init__()\r\n        self.heads = heads\r\n        self.scale = dim ** -0.5  # 1/sqrt(dim)\r\n\r\n        self.to_qkv = nn.Linear(dim, dim * 3, bias = True) # Wq,Wk,Wv for each vector, thats why *3\r\n        torch.nn.init.xavier_uniform_(self.to_qkv.weight)\r\n        torch.nn.init.zeros_(self.to_qkv.bias)\r\n        \r\n        self.nn1 = nn.Linear(dim, dim)\r\n        torch.nn.init.xavier_uniform_(self.nn1.weight)\r\n        torch.nn.init.zeros_(self.nn1.bias)        \r\n        self.do1 = nn.Dropout(dropout)\r\n        \r\n\r\n    def forward(self, x, mask = None):\r\n        b, n, _, h = *x.shape, self.heads\r\n        qkv = self.to_qkv(x) #gets q = Q = Wq matmul x1, k = Wk mm x2, v = Wv mm x3\r\n        q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h) # split into multi head attentions\r\n\r\n        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale\r\n\r\n        if mask is not None:\r\n            mask = F.pad(mask.flatten(1), (1, 0), value = True)\r\n            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'\r\n            mask = mask[:, None, :] * mask[:, :, None]\r\n            dots.masked_fill_(~mask, float('-inf'))\r\n            del mask\r\n\r\n        attn = dots.softmax(dim=-1) #follow the softmax,q,d,v equation in the paper\r\n\r\n        out = torch.einsum('bhij,bhjd->bhid', attn, v) #product of v times whatever inside softmax\r\n        out = rearrange(out, 'b h n d -> b n (h d)') #concat heads into one matrix, ready for next encoder block\r\n        out =  self.nn1(out)\r\n        out = self.do1(out)\r\n        return out\r\n\r\nclass Transformer(nn.Module):\r\n    def __init__(self, dim, depth, heads, mlp_dim, dropout):\r\n        super().__init__()\r\n        self.layers = nn.ModuleList([])\r\n        for _ in range(depth):\r\n            self.layers.append(nn.ModuleList([\r\n                Residual(LayerNormalize(dim, Attention(dim, heads = heads, dropout = dropout))),\r\n                Residual(LayerNormalize(dim, MLP_Block(dim, mlp_dim, dropout = dropout)))\r\n            ]))\r\n    def forward(self, x, mask = None):\r\n        for attention, mlp in self.layers:\r\n            x = attention(x, mask = mask) # go to attention\r\n            x = mlp(x) #go to MLP_Block\r\n        return x\r\n     \r\n\r\nclass ViTResNet(nn.Module):\r\n    def __init__(self, block, num_blocks, num_classes=10, dim = 128, num_tokens = 8, mlp_dim = 256, heads = 8, depth = 6, emb_dropout = 0.1, dropout= 0.1):\r\n        super(ViTResNet, self).__init__()\r\n        self.in_planes = 16\r\n        self.L = num_tokens\r\n        self.cT = dim\r\n        \r\n        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)\r\n        self.bn1 = nn.BatchNorm2d(16)\r\n        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)\r\n        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)\r\n        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) #8x8 feature maps (64 in total)\r\n        self.apply(_weights_init)\r\n        \r\n        \r\n        # Tokenization\r\n        self.token_wA = nn.Parameter(torch.empty(BATCH_SIZE_TRAIN,self.L, 64),requires_grad = True) #Tokenization parameters\r\n        torch.nn.init.xavier_uniform_(self.token_wA)\r\n        self.token_wV = nn.Parameter(torch.empty(BATCH_SIZE_TRAIN,64,self.cT),requires_grad = True) #Tokenization parameters\r\n        torch.nn.init.xavier_uniform_(self.token_wV)        \r\n             \r\n        \r\n        self.pos_embedding = nn.Parameter(torch.empty(1, (num_tokens + 1), dim))\r\n        torch.nn.init.normal_(self.pos_embedding, std = .02) # initialized based on the paper\r\n\r\n        #self.patch_conv= nn.Conv2d(64,dim, self.patch_size, stride = self.patch_size) \r\n\r\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim)) #initialized based on the paper\r\n        self.dropout = nn.Dropout(emb_dropout)\r\n\r\n        self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)\r\n\r\n        self.to_cls_token = nn.Identity()\r\n\r\n        self.nn1 = nn.Linear(dim, num_classes)  # if finetuning, just use a linear layer without further hidden layers (paper)\r\n        torch.nn.init.xavier_uniform_(self.nn1.weight)\r\n        torch.nn.init.normal_(self.nn1.bias, std = 1e-6)\r\n\r\n\r\n    def _make_layer(self, block, planes, num_blocks, stride):\r\n        strides = [stride] + [1]*(num_blocks-1)\r\n        layers = []\r\n        for stride in strides:\r\n            layers.append(block(self.in_planes, planes, stride))\r\n            self.in_planes = planes * block.expansion\r\n\r\n        return nn.Sequential(*layers)\r\n    \r\n    \r\n        \r\n    def forward(self, img, mask = None):\r\n        x = F.relu(self.bn1(self.conv1(img)))\r\n        x = self.layer1(x)\r\n        x = self.layer2(x)  \r\n        x = self.layer3(x) \r\n        \r\n        x = rearrange(x, 'b c h w -> b (h w) c') # 64 vectors each with 64 points. These are the sequences or word vecotrs like in NLP\r\n\r\n        #Tokenization \r\n        wa = rearrange(self.token_wA, 'b h w -> b w h') #Transpose\r\n        A= torch.einsum('bij,bjk->bik', x, wa) \r\n        A = rearrange(A, 'b h w -> b w h') #Transpose\r\n        A = A.softmax(dim=-1)\r\n\r\n        VV= torch.einsum('bij,bjk->bik', x, self.token_wV)       \r\n        T = torch.einsum('bij,bjk->bik', A, VV)  \r\n        #print(T.size())\r\n\r\n        cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)\r\n        x = torch.cat((cls_tokens, T), dim=1)\r\n        x += self.pos_embedding\r\n        x = self.dropout(x)\r\n        x = self.transformer(x, mask) #main game\r\n        x = self.to_cls_token(x[:, 0])       \r\n        x = self.nn1(x)\r\n        \r\n        \r\n        return x\r\n\r\n\r\nBATCH_SIZE_TRAIN = 100\r\nBATCH_SIZE_TEST = 100\r\n\r\nDL_PATH = \"C:\\Pytorch\\Spyder\\CIFAR10_data\" # Use your own path\r\n# CIFAR10: 60000 32x32 color images in 10 classes, with 6000 images per class\r\ntransform = torchvision.transforms.Compose(\r\n     [torchvision.transforms.RandomHorizontalFlip(),\r\n     torchvision.transforms.RandomRotation(10, resample=PIL.Image.BILINEAR),\r\n     torchvision.transforms.RandomAffine(8, translate=(.15,.15)),\r\n     torchvision.transforms.ToTensor(),\r\n     torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])\r\n\r\n\r\ntrain_dataset = torchvision.datasets.CIFAR10(DL_PATH, train=True,\r\n                                        download=True, transform=transform)\r\n\r\ntest_dataset = torchvision.datasets.CIFAR10(DL_PATH, train=False,\r\n                                       download=True, transform=transform)\r\n\r\ntrain_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN,\r\n                                          shuffle=True)\r\n\r\ntest_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE_TEST,\r\n                                         shuffle=False)\r\n\r\ndef train(model, optimizer, data_loader, loss_history):\r\n    total_samples = len(data_loader.dataset)\r\n    model.train()\r\n\r\n    for i, (data, target) in enumerate(data_loader):\r\n        optimizer.zero_grad()\r\n        output = F.log_softmax(model(data), dim=1)\r\n        loss = F.nll_loss(output, target)\r\n        loss.backward()\r\n        optimizer.step()\r\n\r\n        if i % 100 == 0:\r\n            print('[' +  '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(total_samples) +\r\n                  ' (' + '{:3.0f}'.format(100 * i / len(data_loader)) + '%)]  Loss: ' +\r\n                  '{:6.4f}'.format(loss.item()))\r\n            loss_history.append(loss.item())\r\n            \r\ndef evaluate(model, data_loader, loss_history):\r\n    model.eval()\r\n    \r\n    total_samples = len(data_loader.dataset)\r\n    correct_samples = 0\r\n    total_loss = 0\r\n\r\n    with torch.no_grad():\r\n        for data, target in data_loader:\r\n            output = F.log_softmax(model(data), dim=1)\r\n            loss = F.nll_loss(output, target, reduction='sum')\r\n            _, pred = torch.max(output, dim=1)\r\n            \r\n            total_loss += loss.item()\r\n            correct_samples += pred.eq(target).sum()\r\n\r\n    avg_loss = total_loss / total_samples\r\n    loss_history.append(avg_loss)\r\n    print('\\nAverage test loss: ' + '{:.4f}'.format(avg_loss) +\r\n          '  Accuracy:' + '{:5}'.format(correct_samples) + '/' +\r\n          '{:5}'.format(total_samples) + ' (' +\r\n          '{:4.2f}'.format(100.0 * correct_samples / total_samples) + '%)\\n')\r\n\r\nN_EPOCHS = 150\r\n\r\n\r\nmodel = ViTResNet(BasicBlock, [3, 3, 3])\r\noptimizer = torch.optim.Adam(model.parameters(), lr=0.003)\r\n\r\n#optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate,momentum=.9,weight_decay=1e-4)\r\n#lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[35,48],gamma = 0.1)\r\n\r\ntrain_loss_history, test_loss_history = [], []\r\nfor epoch in range(1, N_EPOCHS + 1):\r\n    print('Epoch:', epoch)\r\n    start_time = time.time()\r\n    train(model, optimizer, train_loader, train_loss_history)\r\n    print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')\r\n    evaluate(model, test_loader, test_loss_history)\r\n\r\nprint('Execution time')\r\n\r\nPATH = \".\\ViTRes.pt\" # Use your own path\r\ntorch.save(model.state_dict(), PATH)\r\n\r\n\r\n# =============================================================================\r\n# model = ViT()\r\n# model.load_state_dict(torch.load(PATH))\r\n# model.eval()            \r\n# =============================================================================\r\n"
  }
]