[
  {
    "path": "README.md",
    "content": "# VQ-VAE(Neural Discrete Representation Learning)\n\nhttps://arxiv.org/abs/1711.00937\n\n## Requirements\n[nsml](https://research.clova.ai/nsml-alpha)\n\n## How to run\n`nsml run -d CelebA_128 -v -i`\n\n## Introduction\n\nThis is a repro of Vector Quantisation VAE from Deepmind. Authors had applied VQ-VAE for various tasks, but this repo is a slight modification of yunjey's VAE-GAN(CelebA dataset) to replace VAE with VQ-VAE.\n\n![image](https://user-images.githubusercontent.com/2463571/32690439-0ec64640-c73a-11e7-9e06-a0a6ea23248d.png)\n\n![image](https://user-images.githubusercontent.com/2463571/32690440-1738bff6-c73a-11e7-92a4-57dd8449e5a5.png)\n"
  },
  {
    "path": "data_loader.py",
    "content": "import os\nfrom torch.utils import data\nfrom torchvision import transforms\nfrom PIL import Image\n\nclass ImageFolder(data.Dataset):\n    \"\"\"Custom Dataset compatible with prebuilt DataLoader.\"\"\"\n    def __init__(self, root, transform=None):\n        \"\"\"Initialize image paths and preprocessing module.\"\"\"\n        self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root)))\n        self.transform = transform\n        \n    def __getitem__(self, index):\n        \"\"\"Read an image from a file and preprocesses it and returns.\"\"\"\n        image_path = self.image_paths[index]\n        image = Image.open(image_path).convert('RGB')\n        if self.transform is not None:\n            image = self.transform(image)\n        return image\n    \n    def __len__(self):\n        \"\"\"Return the total number of image files.\"\"\"\n        return len(self.image_paths)\n\n\ndef get_loader(image_path, image_size, batch_size, num_workers=2):\n    \"\"\"Create and return Dataloader.\"\"\"\n    \n    transform = transforms.Compose([\n                    transforms.Scale(image_size),\n                    transforms.ToTensor(),\n                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n    \n    dataset = ImageFolder(image_path, transform)\n    data_loader = data.DataLoader(dataset=dataset,\n                                  batch_size=batch_size,\n                                  shuffle=True,\n                                  num_workers=num_workers)\n    return data_loader"
  },
  {
    "path": "logger.py",
    "content": "# Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514\nimport numpy as np\nimport scipy.misc \nimport nsml\nimport visdom\n\nclass Logger(object):    \n    def __init__(self, log_dir):\n        self.last = None\n        self.viz = nsml.Visdom(visdom=visdom)\n\n    def scalar_summary(self, tag, value, step, scope=None):\n        if self.last and self.last['step'] != step:\n            nsml.report(**self.last,scope=scope)\n            self.last = None\n        if self.last is None:\n            self.last = {'step':step,'iter':step,'epoch':1}\n        self.last[tag] = value            \n\n    def images_summary(self, tag, images, step):\n        \"\"\"Log a list of images.\"\"\"\n        self.viz.images(\n            images,            \n            opts=dict(title='%s/%d' % (tag, step), caption='%s/%d' % (tag, step)),\n        )            \n        \n    def histo_summary(self, tag, values, step, bins=1000):\n        pass"
  },
  {
    "path": "main.py",
    "content": "import argparse\nimport base64\nfrom io import BytesIO\nimport os\nimport numpy as np\nfrom PIL import Image\nimport torch\nfrom torch.backends import cudnn\nfrom data_loader import get_loader\nfrom solver import Solver\nimport nsml\n\nINFER_STR = ['inference', 'infer']\nOUTPUT_FORMAT = 'png'\n\n\ndef str2bool(v):\n    return v.lower() in ('true')\n\n\ndef main(config, scope):\n\n    # Create directories if not exist.\n    if not os.path.exists(config.log_path):\n        os.makedirs(config.log_path)\n    if not os.path.exists(config.model_save_path):\n        os.makedirs(config.model_save_path)\n    if not os.path.exists(config.sample_path):\n        os.makedirs(config.sample_path)\n\n    if config.mode == 'sample':\n        config.batch_size = config.sample_size\n\n    # Data loader\n    data_loader = get_loader(config.image_path, config.image_size,\n                              config.batch_size, config.num_workers)\n    # Solver\n    solver = Solver(data_loader, config)\n\n    def load(filename, *args):\n        solver.load(filename)\n\n    def save(filename, *args):\n        solver.save(filename)\n\n    def infer(input):\n        result = solver.infer(input)\n        # convert tensor to dataurl\n        data_url_list = [''] * input\n        for idx, sample in enumerate(result):\n            numpy_array = np.uint8(sample.cpu().numpy()*255)\n            image = Image.fromarray(np.transpose(numpy_array, axes=(1, 2, 0)), 'RGB')\n            temp_out = BytesIO()\n            image.save(temp_out, format=OUTPUT_FORMAT)\n            byte_data = temp_out.getvalue()\n            data_url_list[idx] = u'data:image/{format};base64,{data}'.\\\n                format(format=OUTPUT_FORMAT,\n                       data=base64.b64encode(byte_data).decode('ascii'))\n        return data_url_list\n\n    def evaluate(test_data, output):\n        pass\n\n    def decode(input):\n        return input\n\n    nsml.bind(save, load, infer, evaluate, decode)\n\n    if config.pause:\n        nsml.paused(scope=scope)\n    \n    if config.mode == 'train':\n        solver.train()\n    elif config.mode == 'sample':\n        solver.sample()\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    \n    # Model hyper-parameters\n    parser.add_argument('--image_size', type=int, default=64)\n    parser.add_argument('--z_dim', type=int, default=256)\n    parser.add_argument('--k_dim', type=int, default=256)\n    parser.add_argument('--code_dim', type=int, default=16)\n    parser.add_argument('--g_conv_dim', type=int, default=64)\n    parser.add_argument('--d_conv_dim', type=int, default=64)\n    \n    # Training settings\n    parser.add_argument('--total_step', type=int, default=200000)\n    parser.add_argument('--batch_size', type=int, default=16)\n    parser.add_argument('--num_workers', type=int, default=2)\n    parser.add_argument('--lr', type=float, default=0.0002)\n    parser.add_argument('--beta1', type=float, default=0.5)\n    parser.add_argument('--beta2', type=float, default=0.999)\n    parser.add_argument('--trained_model', type=int, default=None)\n\n    parser.add_argument('--vq_beta', type=float, default=0.25)\n    \n    # Test settings\n    parser.add_argument('--step_for_sampling', type=int, default=200000)\n    parser.add_argument('--sample_size', type=int, default=32)\n    \n    # Misc\n    parser.add_argument('--mode', type=str, default='train', choices=['train', 'sample', *INFER_STR])\n    parser.add_argument(\"--iteration\", type=int, default=0)\n    parser.add_argument('--use_tensorboard', type=str2bool, default=True)\n    \n    parser.add_argument('--log_path', type=str, default='./celebA/logs')\n    parser.add_argument('--model_save_path', type=str, default='./celebA/models')\n    parser.add_argument('--sample_path', type=str, default='./celebA/samples')\n    parser.add_argument('--image_path', type=str, default=nsml.DATASET_PATH)\n    \n    parser.add_argument('--log_step', type=int, default=10)\n    parser.add_argument('--sample_step', type=int, default=200)\n    parser.add_argument('--model_save_step', type=int, default=1000)\n\n    # nsml setting\n    parser.add_argument('--pause', type=int, default=0)\n    \n    config = parser.parse_args()\n    print(config)\n    main(config,scope=locals())\n"
  },
  {
    "path": "model.py",
    "content": "import torch \nimport torch.nn as nn\nimport numpy as np\nimport math\nfrom torch.autograd import Variable\n\nclass Generator(nn.Module):\n    \"\"\"Generator. Vector Quantised Variational Auto-Encoder.\"\"\"\n    def __init__(self, image_size=64, z_dim=256, conv_dim=64, code_dim=16, k_dim=256):\n        super(Generator, self).__init__()\n\n        self.k_dim = k_dim\n        self.z_dim = z_dim\n        self.code_dim = code_dim\n        \n        self.dict = nn.Embedding(k_dim, z_dim)\n        \n        # Encoder (increasing #filter linearly)\n        layers = []\n        layers.append(nn.Conv2d(3, conv_dim, kernel_size=3, padding=1))\n        layers.append(nn.BatchNorm2d(conv_dim))\n        layers.append(nn.ReLU())\n        \n        repeat_num = int(math.log2(image_size / code_dim))\n        curr_dim = conv_dim\n        for i in range(repeat_num):\n            layers.append(nn.Conv2d(curr_dim, conv_dim * (i+2), kernel_size=4, stride=2, padding=1))\n            layers.append(nn.BatchNorm2d(conv_dim * (i+2)))\n            layers.append(nn.ReLU())\n            curr_dim = conv_dim * (i+2)\n\n        # Now we have (code_dim,code_dim,curr_dim)\n        layers.append(nn.Conv2d(curr_dim, z_dim, kernel_size=1))\n\n        # (code_dim,code_dim,z_dim)\n        self.encoder = nn.Sequential(*layers)\n        \n        # Decoder (320 - 256 - 192 - 128 - 64)\n        layers = []\n        layers.append(nn.ConvTranspose2d(z_dim, curr_dim, kernel_size=1))\n        layers.append(nn.BatchNorm2d(curr_dim))\n        layers.append(nn.ReLU())\n                \n        for i in reversed(range(repeat_num)):\n            layers.append(nn.ConvTranspose2d(curr_dim , conv_dim * (i+1), kernel_size=4, stride=2, padding=1))\n            layers.append(nn.BatchNorm2d(conv_dim * (i+1)))\n            layers.append(nn.ReLU())\n            curr_dim = conv_dim * (i+1)\n        \n        layers.append(nn.Conv2d(curr_dim, 3, kernel_size=3, padding=1))\n        self.decoder = nn.Sequential(*layers)\n\n        self.init_weights()        \n        \n    def init_weights(self):\n        initrange = 1.0 / self.k_dim\n        self.dict.weight.data.uniform_(-initrange, initrange)        \n            \n    def forward(self, x):\n        h = self.encoder(x)                             # (?, z_dim*2, 1, 1)\n\n        sz = h.size()\n        \n        # BCWH -> BWHC\n        org_h = h\n        h = h.permute(0,2,3,1)\n        h = h.contiguous()\n        Z = h.view(-1,self.z_dim)\n\n        W = self.dict.weight\n\n        def L2_dist(a,b):\n            return ((a - b) ** 2)\n        \n        # Sample nearest embedding\n        j = L2_dist(Z[:,None],W[None,:]).sum(2).min(1)[1]\n        W_j = W[j]\n\n        # Stop gradients\n        Z_sg = Z.detach()\n        W_j_sg = W_j.detach()\n\n        # BWHC -> BCWH\n        h = W_j.view(sz[0],sz[2],sz[3],sz[1])\n        h = h.permute(0,3,1,2)\n\n        def hook(grad):\n            nonlocal org_h\n            self.saved_grad = grad\n            self.saved_h = org_h\n            return grad\n\n        h.register_hook(hook)\n        \n        # losses\n        return self.decoder(h), L2_dist(Z,W_j_sg).sum(1).mean(), L2_dist(Z_sg,W_j).sum(1).mean()\n\n    # back propagation for encoder\n    def bwd(self):\n        self.saved_h.backward(self.saved_grad)\n        \n    def decode(self, z):\n        z = z.view(z.size(0), z.size(1), 1, 1)\n        return self.decoder(z)\n    \n"
  },
  {
    "path": "setup.py",
    "content": "#nsml: pytorch/pytorch\nfrom distutils.core import setup\nsetup(\n    name='nsml example 07 VAE GAN',\n    version='1.0',\n    description='ns-ml',\n    install_requires =[\n        'visdom',\n        'pillow'\n    ]\n)"
  },
  {
    "path": "solver.py",
    "content": "import torch\nimport torch.nn as nn\nimport os\nfrom torch.autograd import Variable\nfrom torchvision.utils import save_image\nfrom model import Generator\nimport nsml\n\nclass Solver(object):\n    \n    def __init__(self, data_loader, config):\n        # Data loader새\n        self.data_loader = data_loader\n        \n        # Model hyper-parameters\n        self.image_size = config.image_size\n        self.z_dim = config.z_dim\n        self.k_dim = config.k_dim\n        self.g_conv_dim = config.g_conv_dim\n        self.code_dim = config.code_dim\n        \n        # Training settings\n        self.total_step = config.total_step\n        self.lr = config.lr\n        self.beta1 = config.beta1\n        self.beta2 = config.beta2\n        self.trained_model = config.trained_model\n        self.use_tensorboard = config.use_tensorboard\n\n        self.vq_beta = config.vq_beta\n        \n        # Path and step size\n        self.log_path = config.log_path\n        self.sample_path = config.sample_path\n        self.model_save_path = config.model_save_path\n        self.log_step = config.log_step\n        self.sample_step = config.sample_step\n        self.model_save_step = config.model_save_step\n        \n        # Test setting\n        self.step_for_sampling = config.step_for_sampling\n        self.sample_size = config.sample_size\n        \n        self.build_model()\n        if self.use_tensorboard:\n            self.build_tensorboard()\n        \n        # Start with trained model\n        if self.trained_model:\n            self.load_trained_model()\n        \n    def build_model(self):\n        # model and optimizer\n        self.G = Generator(self.image_size, self.z_dim, self.g_conv_dim, k_dim=self.k_dim, code_dim=self.code_dim)\n        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.lr, [self.beta1, self.beta2])\n        \n        if torch.cuda.is_available():\n            self.G.cuda()\n    \n    def load_trained_model(self):\n        self.load(os.path.join(\n            self.model_save_path, '{}_G.pth'.format(self.trained_model)))\n        print('loaded trained models (step: {})..!'.format(self.trained_model))\n\n    def load(self,filename):\n        S = torch.load(filename)        \n        self.G.load_state_dict(S['G'])\n         \n    def build_tensorboard(self):\n        from logger import Logger\n        self.logger = Logger(self.log_path)\n        \n    def update_lr(self, lr):\n        for param_group in self.g_optimizer.param_groups:\n            param_group['lr'] = lr\n    \n    def reset_grad(self):\n        self.g_optimizer.zero_grad()\n        \n    def to_var(self, x):\n        if torch.cuda.is_available():\n            x = x.cuda()\n        return Variable(x)\n    \n    def to_np(self, x):\n        return x.data.cpu().numpy()\n    \n    def denorm(self, x):\n        out = (x + 1) / 2\n        return out.clamp_(0, 1)\n        \n    def detach(self, x):\n        return Variable(x.data)\n    \n    def train(self):\n        # Reconst loss\n        reconst_loss = nn.L1Loss()\n        \n        # Data iter\n        data_iter = iter(self.data_loader)\n        iter_per_epoch = len(self.data_loader)\n        \n        # Fixed inputs for sampling\n        fixed_x = next(data_iter)\n        save_image(self.denorm(fixed_x), os.path.join(self.sample_path, 'fixed_x.png'))\n        fixed_x = self.to_var(fixed_x)\n        \n        # Start with trained model\n        if self.trained_model:\n            start = self.trained_model + 1\n        else:\n            start = 0\n            \n        for step in range(start, self.total_step):\n            # Schedule learning rate\n            alpha = (step - start) / (self.total_step - start)\n            lr = self.lr * (1/100 ** alpha)\n\n            self.update_lr(lr)\n            \n            # Reset data_iter for each epoch\n            if (step+1) % iter_per_epoch == 0:\n                data_iter = iter(self.data_loader)\n               \n            x = self.to_var(next(data_iter))\n                        \n            # ================== Train G ================== #\n            # Train with real images (VQ-VAE)\n            out, loss_e1, loss_e2 = self.G(x)\n            loss_rec = reconst_loss(out, x)\n            \n            loss = loss_rec + loss_e1 + self.vq_beta * loss_e2\n            self.reset_grad()\n\n            # For decoder\n            loss.backward(retain_graph=True)\n\n            # For encoder\n            self.G.bwd()\n\n            self.g_optimizer.step()\n            \n            # Print out log info\n            if (step+1) % self.log_step == 0:\n                print(\"[{}/{}] loss: {:.4f}, loss_e1: {:.4f}\". \\\n                      format(step+1, self.total_step, loss_rec.data[0], loss_e1.data[0]))\n                      \n                if self.use_tensorboard:\n                    info = {\n                        'loss/loss_rec': loss_rec.data[0],\n                        'loss/loss_ee': loss_e1.data[0],\n                        'misc/lr': lr\n                    }\n                    \n                    for tag, value in info.items():\n                        self.logger.scalar_summary(tag, value, step+1, scope=locals())\n                \n            # Sample images\n            if (step+1) % self.sample_step == 0:\n                reconst, _, _ = self.G(fixed_x)\n\n                def np(tensor):\n                    return tensor.cpu().numpy()\n\n                self.logger.images_summary('recons',np(self.denorm(reconst.data)),step+1)\n\n            # Save check points\n            if (step+1) % self.model_save_step == 0:\n                nsml.save(step)\n    \n    def save(self, filename):\n        G = self.G.state_dict()\n        torch.save({'G':G}, filename)\n\n    def infer(self, sample_size):\n        # Data iter\n        data_iter = iter(self.data_loader)\n        \n        # Inputs for sampling\n        z = self.to_var(torch.randn(sample_size, self.z_dim)) \n        \n        # Load trained params\n        self.G.eval()  \n        \n        # Sampling\n        fake = self.G.decode(z)\n\n        return self.denorm(fake.data)\n        \n    def sample(self):\n        # Data iter\n        data_iter = iter(self.data_loader)\n        \n        # Inputs for sampling\n        x = next(data_iter)\n        z = self.to_var(torch.randn(self.sample_size, self.z_dim)) \n        \n        save_image(self.denorm(x), 'real.png')\n        x = self.to_var(x)\n        \n        # Load trained params\n        self.G.eval()  \n        S = torch.load(os.path.join(\n            self.model_save_path, '{}_G.pth'.format(self.step_for_sampling)))\n        self.G.load_state_dict(S['G'])\n        \n        # Sampling\n        reconst, _, _ = self.G(x)\n        fake = self.G.decode(z)\n        save_image(self.denorm(reconst.data), 'reconst.png')"
  }
]