[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2018 jindongwang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Pytorch-CapsuleNet\n\nA flexible and easy-to-follow Pytorch implementation of Hinton's Capsule Network.\n\nThere are already many repos containing the code for CapsNet. However, most of them are too tight to customize. And as we all know, Hinton's original paper is only tested on *MNIST* datasets. We clearly want to do more.\n\nThis repo is designed to hold other datasets and configurations. And the most important thing is, we want to make the code **flexible**. Then, we can *tailor* the network according to our needs.\n\nCurrently, the code supports both **MNIST and CIFAR-10** datasets.\n\n## Requirements\n\n- Python 3.x\n- Pytorch 0.3.0 or above\n- Numpy\n- tqdm (to make display better, of course you can replace it with 'print')\n\n## Run\n\nJust run `Python test_capsnet.py` in your terminal. That's all. If you want to change the dataset (MNIST or CIFAR-10), you can easily set the `dataset` variable.\n\nIt is better to run the code on a server with GPUs. Capsule network demands good computing devices. For instance, on my device (Nvidia K80), it will take about 5 minutes for one epoch of the MNIST datasets (batch size = 100).\n\n## More details\n\nThere are 3 `.py` files:\n- `capsnet.py`: the main class for capsule network\n- `data_loader.py`: the class to hold many classes\n- `test_capsnet.py`: the training and testing code\n\nThe results on your device may look like the following picture:\n\n![](https://raw.githubusercontent.com/jindongwang/Pytorch-CapsuleNet/master/result.jpg)\n\n## Acknowledgements\n\n- [Capsule-Network-Tutorial](https://github.com/higgsfield/Capsule-Network-Tutorial)\n- The original paper of Capsule Net by Geoffrey Hinton: [Dynamic routing between capsules](http://papers.nips.cc/paper/6975-dynamic-routing-between-capsules)\n"
  },
  {
    "path": "capsnet.py",
    "content": "import torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nfrom torch.autograd import Variable\r\n\r\nUSE_CUDA = True if torch.cuda.is_available() else False\r\n\r\n\r\nclass ConvLayer(nn.Module):\r\n    def __init__(self, in_channels=1, out_channels=256, kernel_size=9):\r\n        super(ConvLayer, self).__init__()\r\n\r\n        self.conv = nn.Conv2d(in_channels=in_channels,\r\n                              out_channels=out_channels,\r\n                              kernel_size=kernel_size,\r\n                              stride=1\r\n                              )\r\n\r\n    def forward(self, x):\r\n        return F.relu(self.conv(x))\r\n\r\n\r\nclass PrimaryCaps(nn.Module):\r\n    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9, num_routes=32 * 6 * 6):\r\n        super(PrimaryCaps, self).__init__()\r\n        self.num_routes = num_routes\r\n        self.capsules = nn.ModuleList([\r\n            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0)\r\n            for _ in range(num_capsules)])\r\n\r\n    def forward(self, x):\r\n        u = [capsule(x) for capsule in self.capsules]\r\n        u = torch.stack(u, dim=1)\r\n        u = u.view(x.size(0), self.num_routes, -1)\r\n        return self.squash(u)\r\n\r\n    def squash(self, input_tensor):\r\n        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)\r\n        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))\r\n        return output_tensor\r\n\r\n\r\nclass DigitCaps(nn.Module):\r\n    def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):\r\n        super(DigitCaps, self).__init__()\r\n\r\n        self.in_channels = in_channels\r\n        self.num_routes = num_routes\r\n        self.num_capsules = num_capsules\r\n\r\n        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))\r\n\r\n    def forward(self, x):\r\n        batch_size = x.size(0)\r\n        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)\r\n\r\n        W = torch.cat([self.W] * batch_size, dim=0)\r\n        u_hat = torch.matmul(W, x)\r\n\r\n        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))\r\n        if USE_CUDA:\r\n            b_ij = b_ij.cuda()\r\n\r\n        num_iterations = 3\r\n        for iteration in range(num_iterations):\r\n            c_ij = F.softmax(b_ij, dim=1)\r\n            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)\r\n\r\n            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)\r\n            v_j = self.squash(s_j)\r\n\r\n            if iteration < num_iterations - 1:\r\n                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))\r\n                b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)\r\n\r\n        return v_j.squeeze(1)\r\n\r\n    def squash(self, input_tensor):\r\n        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)\r\n        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))\r\n        return output_tensor\r\n\r\n\r\nclass Decoder(nn.Module):\r\n    def __init__(self, input_width=28, input_height=28, input_channel=1):\r\n        super(Decoder, self).__init__()\r\n        self.input_width = input_width\r\n        self.input_height = input_height\r\n        self.input_channel = input_channel\r\n        self.reconstraction_layers = nn.Sequential(\r\n            nn.Linear(16 * 10, 512),\r\n            nn.ReLU(inplace=True),\r\n            nn.Linear(512, 1024),\r\n            nn.ReLU(inplace=True),\r\n            nn.Linear(1024, self.input_height * self.input_width * self.input_channel),\r\n            nn.Sigmoid()\r\n        )\r\n\r\n    def forward(self, x, data):\r\n        classes = torch.sqrt((x ** 2).sum(2))\r\n        classes = F.softmax(classes, dim=0)\r\n\r\n        _, max_length_indices = classes.max(dim=1)\r\n        masked = Variable(torch.sparse.torch.eye(10))\r\n        if USE_CUDA:\r\n            masked = masked.cuda()\r\n        masked = masked.index_select(dim=0, index=Variable(max_length_indices.squeeze(1).data))\r\n        t = (x * masked[:, :, None, None]).view(x.size(0), -1)\r\n        reconstructions = self.reconstraction_layers(t)\r\n        reconstructions = reconstructions.view(-1, self.input_channel, self.input_width, self.input_height)\r\n        return reconstructions, masked\r\n\r\n\r\nclass CapsNet(nn.Module):\r\n    def __init__(self, config=None):\r\n        super(CapsNet, self).__init__()\r\n        if config:\r\n            self.conv_layer = ConvLayer(config.cnn_in_channels, config.cnn_out_channels, config.cnn_kernel_size)\r\n            self.primary_capsules = PrimaryCaps(config.pc_num_capsules, config.pc_in_channels, config.pc_out_channels,\r\n                                                config.pc_kernel_size, config.pc_num_routes)\r\n            self.digit_capsules = DigitCaps(config.dc_num_capsules, config.dc_num_routes, config.dc_in_channels,\r\n                                            config.dc_out_channels)\r\n            self.decoder = Decoder(config.input_width, config.input_height, config.cnn_in_channels)\r\n        else:\r\n            self.conv_layer = ConvLayer()\r\n            self.primary_capsules = PrimaryCaps()\r\n            self.digit_capsules = DigitCaps()\r\n            self.decoder = Decoder()\r\n\r\n        self.mse_loss = nn.MSELoss()\r\n\r\n    def forward(self, data):\r\n        output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))\r\n        reconstructions, masked = self.decoder(output, data)\r\n        return output, reconstructions, masked\r\n\r\n    def loss(self, data, x, target, reconstructions):\r\n        return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)\r\n\r\n    def margin_loss(self, x, labels, size_average=True):\r\n        batch_size = x.size(0)\r\n\r\n        v_c = torch.sqrt((x ** 2).sum(dim=2, keepdim=True))\r\n\r\n        left = F.relu(0.9 - v_c).view(batch_size, -1)\r\n        right = F.relu(v_c - 0.1).view(batch_size, -1)\r\n\r\n        loss = labels * left + 0.5 * (1.0 - labels) * right\r\n        loss = loss.sum(dim=1).mean()\r\n\r\n        return loss\r\n\r\n    def reconstruction_loss(self, data, reconstructions):\r\n        loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))\r\n        return loss * 0.0005\r\n"
  },
  {
    "path": "data_loader.py",
    "content": "import torch\r\nfrom torchvision import datasets, transforms\r\n\r\n\r\nclass Dataset:\r\n    def __init__(self, dataset, _batch_size):\r\n        super(Dataset, self).__init__()\r\n        if dataset == 'mnist':\r\n            dataset_transform = transforms.Compose([\r\n                transforms.ToTensor(),\r\n                transforms.Normalize((0.1307,), (0.3081,))\r\n            ])\r\n\r\n            train_dataset = datasets.MNIST('/data/mnist', train=True, download=True,\r\n                                           transform=dataset_transform)\r\n            test_dataset = datasets.MNIST('/data/mnist', train=False, download=True,\r\n                                          transform=dataset_transform)\r\n\r\n            self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=_batch_size, shuffle=True)\r\n            self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=_batch_size, shuffle=False)\r\n\r\n        elif dataset == 'cifar10':\r\n            data_transform = transforms.Compose([\r\n                transforms.ToTensor(),\r\n                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\r\n            ])\r\n            train_dataset = datasets.CIFAR10(\r\n                '/data/cifar', train=True, download=True, transform=data_transform)\r\n            test_dataset = datasets.CIFAR10(\r\n                '/data/cifar', train=False, download=True, transform=data_transform)\r\n\r\n            self.train_loader = torch.utils.data.DataLoader(\r\n                train_dataset, batch_size=_batch_size, shuffle=True)\r\n\r\n            self.test_loader = torch.utils.data.DataLoader(\r\n                test_dataset, batch_size=_batch_size, shuffle=False)\r\n        elif dataset == 'office-caltech':\r\n            pass\r\n        elif dataset == 'office31':\r\n            pass\r\n"
  },
  {
    "path": "test_capsnet.py",
    "content": "import numpy as np\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nfrom torch.autograd import Variable\r\nfrom torchvision import datasets, transforms\r\nfrom capsnet import CapsNet\r\nfrom data_loader import Dataset\r\nfrom tqdm import tqdm\r\n\r\nUSE_CUDA = True if torch.cuda.is_available() else False\r\nBATCH_SIZE = 100\r\nN_EPOCHS = 30\r\nLEARNING_RATE = 0.01\r\nMOMENTUM = 0.9\r\n\r\n'''\r\nConfig class to determine the parameters for capsule net\r\n'''\r\n\r\n\r\nclass Config:\r\n    def __init__(self, dataset='mnist'):\r\n        if dataset == 'mnist':\r\n            # CNN (cnn)\r\n            self.cnn_in_channels = 1\r\n            self.cnn_out_channels = 256\r\n            self.cnn_kernel_size = 9\r\n\r\n            # Primary Capsule (pc)\r\n            self.pc_num_capsules = 8\r\n            self.pc_in_channels = 256\r\n            self.pc_out_channels = 32\r\n            self.pc_kernel_size = 9\r\n            self.pc_num_routes = 32 * 6 * 6\r\n\r\n            # Digit Capsule (dc)\r\n            self.dc_num_capsules = 10\r\n            self.dc_num_routes = 32 * 6 * 6\r\n            self.dc_in_channels = 8\r\n            self.dc_out_channels = 16\r\n\r\n            # Decoder\r\n            self.input_width = 28\r\n            self.input_height = 28\r\n\r\n        elif dataset == 'cifar10':\r\n            # CNN (cnn)\r\n            self.cnn_in_channels = 3\r\n            self.cnn_out_channels = 256\r\n            self.cnn_kernel_size = 9\r\n\r\n            # Primary Capsule (pc)\r\n            self.pc_num_capsules = 8\r\n            self.pc_in_channels = 256\r\n            self.pc_out_channels = 32\r\n            self.pc_kernel_size = 9\r\n            self.pc_num_routes = 32 * 8 * 8\r\n\r\n            # Digit Capsule (dc)\r\n            self.dc_num_capsules = 10\r\n            self.dc_num_routes = 32 * 8 * 8\r\n            self.dc_in_channels = 8\r\n            self.dc_out_channels = 16\r\n\r\n            # Decoder\r\n            self.input_width = 32\r\n            self.input_height = 32\r\n\r\n        elif dataset == 'your own dataset':\r\n            pass\r\n\r\n\r\ndef train(model, optimizer, train_loader, epoch):\r\n    capsule_net = model\r\n    capsule_net.train()\r\n    n_batch = len(list(enumerate(train_loader)))\r\n    total_loss = 0\r\n    for batch_id, (data, target) in enumerate(tqdm(train_loader)):\r\n\r\n        target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)\r\n        data, target = Variable(data), Variable(target)\r\n\r\n        if USE_CUDA:\r\n            data, target = data.cuda(), target.cuda()\r\n\r\n        optimizer.zero_grad()\r\n        output, reconstructions, masked = capsule_net(data)\r\n        loss = capsule_net.loss(data, output, target, reconstructions)\r\n        loss.backward()\r\n        optimizer.step()\r\n        correct = sum(np.argmax(masked.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1))\r\n        train_loss = loss.item()\r\n        total_loss += train_loss\r\n        if batch_id % 100 == 0:\r\n            tqdm.write(\"Epoch: [{}/{}], Batch: [{}/{}], train accuracy: {:.6f}, loss: {:.6f}\".format(\r\n                epoch,\r\n                N_EPOCHS,\r\n                batch_id + 1,\r\n                n_batch,\r\n                correct / float(BATCH_SIZE),\r\n                train_loss / float(BATCH_SIZE)\r\n                ))\r\n    tqdm.write('Epoch: [{}/{}], train loss: {:.6f}'.format(epoch,N_EPOCHS,total_loss / len(train_loader.dataset)))\r\n\r\n\r\ndef test(capsule_net, test_loader, epoch):\r\n    capsule_net.eval()\r\n    test_loss = 0\r\n    correct = 0\r\n    for batch_id, (data, target) in enumerate(test_loader):\r\n\r\n        target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)\r\n        data, target = Variable(data), Variable(target)\r\n\r\n        if USE_CUDA:\r\n            data, target = data.cuda(), target.cuda()\r\n\r\n        output, reconstructions, masked = capsule_net(data)\r\n        loss = capsule_net.loss(data, output, target, reconstructions)\r\n\r\n        test_loss += loss.item()\r\n        correct += sum(np.argmax(masked.data.cpu().numpy(), 1) ==\r\n                       np.argmax(target.data.cpu().numpy(), 1))\r\n\r\n    tqdm.write(\r\n        \"Epoch: [{}/{}], test accuracy: {:.6f}, loss: {:.6f}\".format(epoch, N_EPOCHS, correct / len(test_loader.dataset),\r\n                                                                  test_loss / len(test_loader)))\r\n\r\n\r\nif __name__ == '__main__':\r\n    torch.manual_seed(1)\r\n    dataset = 'cifar10'\r\n    # dataset = 'mnist'\r\n    config = Config(dataset)\r\n    mnist = Dataset(dataset, BATCH_SIZE)\r\n\r\n    capsule_net = CapsNet(config)\r\n    capsule_net = torch.nn.DataParallel(capsule_net)\r\n    if USE_CUDA:\r\n        capsule_net = capsule_net.cuda()\r\n    capsule_net = capsule_net.module\r\n\r\n    optimizer = torch.optim.Adam(capsule_net.parameters())\r\n\r\n    for e in range(1, N_EPOCHS + 1):\r\n        train(capsule_net, optimizer, mnist.train_loader, e)\r\n        test(capsule_net, mnist.test_loader, e)\r\n"
  }
]