Repository: nakosung/VQ-VAE Branch: master Commit: e374bf9ddc8c Files: 7 Total size: 17.0 KB Directory structure: gitextract_4u1gmzza/ ├── README.md ├── data_loader.py ├── logger.py ├── main.py ├── model.py ├── setup.py └── solver.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # VQ-VAE(Neural Discrete Representation Learning) https://arxiv.org/abs/1711.00937 ## Requirements [nsml](https://research.clova.ai/nsml-alpha) ## How to run `nsml run -d CelebA_128 -v -i` ## Introduction This is a repro of Vector Quantisation VAE from Deepmind. Authors had applied VQ-VAE for various tasks, but this repo is a slight modification of yunjey's VAE-GAN(CelebA dataset) to replace VAE with VQ-VAE. ![image](https://user-images.githubusercontent.com/2463571/32690439-0ec64640-c73a-11e7-9e06-a0a6ea23248d.png) ![image](https://user-images.githubusercontent.com/2463571/32690440-1738bff6-c73a-11e7-92a4-57dd8449e5a5.png) ================================================ FILE: data_loader.py ================================================ import os from torch.utils import data from torchvision import transforms from PIL import Image class ImageFolder(data.Dataset): """Custom Dataset compatible with prebuilt DataLoader.""" def __init__(self, root, transform=None): """Initialize image paths and preprocessing module.""" self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root))) self.transform = transform def __getitem__(self, index): """Read an image from a file and preprocesses it and returns.""" image_path = self.image_paths[index] image = Image.open(image_path).convert('RGB') if self.transform is not None: image = self.transform(image) return image def __len__(self): """Return the total number of image files.""" return len(self.image_paths) def get_loader(image_path, image_size, batch_size, num_workers=2): """Create and return Dataloader.""" transform = transforms.Compose([ transforms.Scale(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) dataset = ImageFolder(image_path, transform) data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) return data_loader ================================================ FILE: logger.py ================================================ # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 import numpy as np import scipy.misc import nsml import visdom class Logger(object): def __init__(self, log_dir): self.last = None self.viz = nsml.Visdom(visdom=visdom) def scalar_summary(self, tag, value, step, scope=None): if self.last and self.last['step'] != step: nsml.report(**self.last,scope=scope) self.last = None if self.last is None: self.last = {'step':step,'iter':step,'epoch':1} self.last[tag] = value def images_summary(self, tag, images, step): """Log a list of images.""" self.viz.images( images, opts=dict(title='%s/%d' % (tag, step), caption='%s/%d' % (tag, step)), ) def histo_summary(self, tag, values, step, bins=1000): pass ================================================ FILE: main.py ================================================ import argparse import base64 from io import BytesIO import os import numpy as np from PIL import Image import torch from torch.backends import cudnn from data_loader import get_loader from solver import Solver import nsml INFER_STR = ['inference', 'infer'] OUTPUT_FORMAT = 'png' def str2bool(v): return v.lower() in ('true') def main(config, scope): # Create directories if not exist. if not os.path.exists(config.log_path): os.makedirs(config.log_path) if not os.path.exists(config.model_save_path): os.makedirs(config.model_save_path) if not os.path.exists(config.sample_path): os.makedirs(config.sample_path) if config.mode == 'sample': config.batch_size = config.sample_size # Data loader data_loader = get_loader(config.image_path, config.image_size, config.batch_size, config.num_workers) # Solver solver = Solver(data_loader, config) def load(filename, *args): solver.load(filename) def save(filename, *args): solver.save(filename) def infer(input): result = solver.infer(input) # convert tensor to dataurl data_url_list = [''] * input for idx, sample in enumerate(result): numpy_array = np.uint8(sample.cpu().numpy()*255) image = Image.fromarray(np.transpose(numpy_array, axes=(1, 2, 0)), 'RGB') temp_out = BytesIO() image.save(temp_out, format=OUTPUT_FORMAT) byte_data = temp_out.getvalue() data_url_list[idx] = u'data:image/{format};base64,{data}'.\ format(format=OUTPUT_FORMAT, data=base64.b64encode(byte_data).decode('ascii')) return data_url_list def evaluate(test_data, output): pass def decode(input): return input nsml.bind(save, load, infer, evaluate, decode) if config.pause: nsml.paused(scope=scope) if config.mode == 'train': solver.train() elif config.mode == 'sample': solver.sample() if __name__ == '__main__': parser = argparse.ArgumentParser() # Model hyper-parameters parser.add_argument('--image_size', type=int, default=64) parser.add_argument('--z_dim', type=int, default=256) parser.add_argument('--k_dim', type=int, default=256) parser.add_argument('--code_dim', type=int, default=16) parser.add_argument('--g_conv_dim', type=int, default=64) parser.add_argument('--d_conv_dim', type=int, default=64) # Training settings parser.add_argument('--total_step', type=int, default=200000) parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--num_workers', type=int, default=2) parser.add_argument('--lr', type=float, default=0.0002) parser.add_argument('--beta1', type=float, default=0.5) parser.add_argument('--beta2', type=float, default=0.999) parser.add_argument('--trained_model', type=int, default=None) parser.add_argument('--vq_beta', type=float, default=0.25) # Test settings parser.add_argument('--step_for_sampling', type=int, default=200000) parser.add_argument('--sample_size', type=int, default=32) # Misc parser.add_argument('--mode', type=str, default='train', choices=['train', 'sample', *INFER_STR]) parser.add_argument("--iteration", type=int, default=0) parser.add_argument('--use_tensorboard', type=str2bool, default=True) parser.add_argument('--log_path', type=str, default='./celebA/logs') parser.add_argument('--model_save_path', type=str, default='./celebA/models') parser.add_argument('--sample_path', type=str, default='./celebA/samples') parser.add_argument('--image_path', type=str, default=nsml.DATASET_PATH) parser.add_argument('--log_step', type=int, default=10) parser.add_argument('--sample_step', type=int, default=200) parser.add_argument('--model_save_step', type=int, default=1000) # nsml setting parser.add_argument('--pause', type=int, default=0) config = parser.parse_args() print(config) main(config,scope=locals()) ================================================ FILE: model.py ================================================ import torch import torch.nn as nn import numpy as np import math from torch.autograd import Variable class Generator(nn.Module): """Generator. Vector Quantised Variational Auto-Encoder.""" def __init__(self, image_size=64, z_dim=256, conv_dim=64, code_dim=16, k_dim=256): super(Generator, self).__init__() self.k_dim = k_dim self.z_dim = z_dim self.code_dim = code_dim self.dict = nn.Embedding(k_dim, z_dim) # Encoder (increasing #filter linearly) layers = [] layers.append(nn.Conv2d(3, conv_dim, kernel_size=3, padding=1)) layers.append(nn.BatchNorm2d(conv_dim)) layers.append(nn.ReLU()) repeat_num = int(math.log2(image_size / code_dim)) curr_dim = conv_dim for i in range(repeat_num): layers.append(nn.Conv2d(curr_dim, conv_dim * (i+2), kernel_size=4, stride=2, padding=1)) layers.append(nn.BatchNorm2d(conv_dim * (i+2))) layers.append(nn.ReLU()) curr_dim = conv_dim * (i+2) # Now we have (code_dim,code_dim,curr_dim) layers.append(nn.Conv2d(curr_dim, z_dim, kernel_size=1)) # (code_dim,code_dim,z_dim) self.encoder = nn.Sequential(*layers) # Decoder (320 - 256 - 192 - 128 - 64) layers = [] layers.append(nn.ConvTranspose2d(z_dim, curr_dim, kernel_size=1)) layers.append(nn.BatchNorm2d(curr_dim)) layers.append(nn.ReLU()) for i in reversed(range(repeat_num)): layers.append(nn.ConvTranspose2d(curr_dim , conv_dim * (i+1), kernel_size=4, stride=2, padding=1)) layers.append(nn.BatchNorm2d(conv_dim * (i+1))) layers.append(nn.ReLU()) curr_dim = conv_dim * (i+1) layers.append(nn.Conv2d(curr_dim, 3, kernel_size=3, padding=1)) self.decoder = nn.Sequential(*layers) self.init_weights() def init_weights(self): initrange = 1.0 / self.k_dim self.dict.weight.data.uniform_(-initrange, initrange) def forward(self, x): h = self.encoder(x) # (?, z_dim*2, 1, 1) sz = h.size() # BCWH -> BWHC org_h = h h = h.permute(0,2,3,1) h = h.contiguous() Z = h.view(-1,self.z_dim) W = self.dict.weight def L2_dist(a,b): return ((a - b) ** 2) # Sample nearest embedding j = L2_dist(Z[:,None],W[None,:]).sum(2).min(1)[1] W_j = W[j] # Stop gradients Z_sg = Z.detach() W_j_sg = W_j.detach() # BWHC -> BCWH h = W_j.view(sz[0],sz[2],sz[3],sz[1]) h = h.permute(0,3,1,2) def hook(grad): nonlocal org_h self.saved_grad = grad self.saved_h = org_h return grad h.register_hook(hook) # losses return self.decoder(h), L2_dist(Z,W_j_sg).sum(1).mean(), L2_dist(Z_sg,W_j).sum(1).mean() # back propagation for encoder def bwd(self): self.saved_h.backward(self.saved_grad) def decode(self, z): z = z.view(z.size(0), z.size(1), 1, 1) return self.decoder(z) ================================================ FILE: setup.py ================================================ #nsml: pytorch/pytorch from distutils.core import setup setup( name='nsml example 07 VAE GAN', version='1.0', description='ns-ml', install_requires =[ 'visdom', 'pillow' ] ) ================================================ FILE: solver.py ================================================ import torch import torch.nn as nn import os from torch.autograd import Variable from torchvision.utils import save_image from model import Generator import nsml class Solver(object): def __init__(self, data_loader, config): # Data loader새 self.data_loader = data_loader # Model hyper-parameters self.image_size = config.image_size self.z_dim = config.z_dim self.k_dim = config.k_dim self.g_conv_dim = config.g_conv_dim self.code_dim = config.code_dim # Training settings self.total_step = config.total_step self.lr = config.lr self.beta1 = config.beta1 self.beta2 = config.beta2 self.trained_model = config.trained_model self.use_tensorboard = config.use_tensorboard self.vq_beta = config.vq_beta # Path and step size self.log_path = config.log_path self.sample_path = config.sample_path self.model_save_path = config.model_save_path self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step # Test setting self.step_for_sampling = config.step_for_sampling self.sample_size = config.sample_size self.build_model() if self.use_tensorboard: self.build_tensorboard() # Start with trained model if self.trained_model: self.load_trained_model() def build_model(self): # model and optimizer self.G = Generator(self.image_size, self.z_dim, self.g_conv_dim, k_dim=self.k_dim, code_dim=self.code_dim) self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.lr, [self.beta1, self.beta2]) if torch.cuda.is_available(): self.G.cuda() def load_trained_model(self): self.load(os.path.join( self.model_save_path, '{}_G.pth'.format(self.trained_model))) print('loaded trained models (step: {})..!'.format(self.trained_model)) def load(self,filename): S = torch.load(filename) self.G.load_state_dict(S['G']) def build_tensorboard(self): from logger import Logger self.logger = Logger(self.log_path) def update_lr(self, lr): for param_group in self.g_optimizer.param_groups: param_group['lr'] = lr def reset_grad(self): self.g_optimizer.zero_grad() def to_var(self, x): if torch.cuda.is_available(): x = x.cuda() return Variable(x) def to_np(self, x): return x.data.cpu().numpy() def denorm(self, x): out = (x + 1) / 2 return out.clamp_(0, 1) def detach(self, x): return Variable(x.data) def train(self): # Reconst loss reconst_loss = nn.L1Loss() # Data iter data_iter = iter(self.data_loader) iter_per_epoch = len(self.data_loader) # Fixed inputs for sampling fixed_x = next(data_iter) save_image(self.denorm(fixed_x), os.path.join(self.sample_path, 'fixed_x.png')) fixed_x = self.to_var(fixed_x) # Start with trained model if self.trained_model: start = self.trained_model + 1 else: start = 0 for step in range(start, self.total_step): # Schedule learning rate alpha = (step - start) / (self.total_step - start) lr = self.lr * (1/100 ** alpha) self.update_lr(lr) # Reset data_iter for each epoch if (step+1) % iter_per_epoch == 0: data_iter = iter(self.data_loader) x = self.to_var(next(data_iter)) # ================== Train G ================== # # Train with real images (VQ-VAE) out, loss_e1, loss_e2 = self.G(x) loss_rec = reconst_loss(out, x) loss = loss_rec + loss_e1 + self.vq_beta * loss_e2 self.reset_grad() # For decoder loss.backward(retain_graph=True) # For encoder self.G.bwd() self.g_optimizer.step() # Print out log info if (step+1) % self.log_step == 0: print("[{}/{}] loss: {:.4f}, loss_e1: {:.4f}". \ format(step+1, self.total_step, loss_rec.data[0], loss_e1.data[0])) if self.use_tensorboard: info = { 'loss/loss_rec': loss_rec.data[0], 'loss/loss_ee': loss_e1.data[0], 'misc/lr': lr } for tag, value in info.items(): self.logger.scalar_summary(tag, value, step+1, scope=locals()) # Sample images if (step+1) % self.sample_step == 0: reconst, _, _ = self.G(fixed_x) def np(tensor): return tensor.cpu().numpy() self.logger.images_summary('recons',np(self.denorm(reconst.data)),step+1) # Save check points if (step+1) % self.model_save_step == 0: nsml.save(step) def save(self, filename): G = self.G.state_dict() torch.save({'G':G}, filename) def infer(self, sample_size): # Data iter data_iter = iter(self.data_loader) # Inputs for sampling z = self.to_var(torch.randn(sample_size, self.z_dim)) # Load trained params self.G.eval() # Sampling fake = self.G.decode(z) return self.denorm(fake.data) def sample(self): # Data iter data_iter = iter(self.data_loader) # Inputs for sampling x = next(data_iter) z = self.to_var(torch.randn(self.sample_size, self.z_dim)) save_image(self.denorm(x), 'real.png') x = self.to_var(x) # Load trained params self.G.eval() S = torch.load(os.path.join( self.model_save_path, '{}_G.pth'.format(self.step_for_sampling))) self.G.load_state_dict(S['G']) # Sampling reconst, _, _ = self.G(x) fake = self.G.decode(z) save_image(self.denorm(reconst.data), 'reconst.png')