Repository: nakosung/VQ-VAE
Branch: master
Commit: e374bf9ddc8c
Files: 7
Total size: 17.0 KB
Directory structure:
gitextract_4u1gmzza/
├── README.md
├── data_loader.py
├── logger.py
├── main.py
├── model.py
├── setup.py
└── solver.py
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
# VQ-VAE(Neural Discrete Representation Learning)
https://arxiv.org/abs/1711.00937
## Requirements
[nsml](https://research.clova.ai/nsml-alpha)
## How to run
`nsml run -d CelebA_128 -v -i`
## Introduction
This is a repro of Vector Quantisation VAE from Deepmind. Authors had applied VQ-VAE for various tasks, but this repo is a slight modification of yunjey's VAE-GAN(CelebA dataset) to replace VAE with VQ-VAE.


================================================
FILE: data_loader.py
================================================
import os
from torch.utils import data
from torchvision import transforms
from PIL import Image
class ImageFolder(data.Dataset):
"""Custom Dataset compatible with prebuilt DataLoader."""
def __init__(self, root, transform=None):
"""Initialize image paths and preprocessing module."""
self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root)))
self.transform = transform
def __getitem__(self, index):
"""Read an image from a file and preprocesses it and returns."""
image_path = self.image_paths[index]
image = Image.open(image_path).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image
def __len__(self):
"""Return the total number of image files."""
return len(self.image_paths)
def get_loader(image_path, image_size, batch_size, num_workers=2):
"""Create and return Dataloader."""
transform = transforms.Compose([
transforms.Scale(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = ImageFolder(image_path, transform)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
return data_loader
================================================
FILE: logger.py
================================================
# Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
import numpy as np
import scipy.misc
import nsml
import visdom
class Logger(object):
def __init__(self, log_dir):
self.last = None
self.viz = nsml.Visdom(visdom=visdom)
def scalar_summary(self, tag, value, step, scope=None):
if self.last and self.last['step'] != step:
nsml.report(**self.last,scope=scope)
self.last = None
if self.last is None:
self.last = {'step':step,'iter':step,'epoch':1}
self.last[tag] = value
def images_summary(self, tag, images, step):
"""Log a list of images."""
self.viz.images(
images,
opts=dict(title='%s/%d' % (tag, step), caption='%s/%d' % (tag, step)),
)
def histo_summary(self, tag, values, step, bins=1000):
pass
================================================
FILE: main.py
================================================
import argparse
import base64
from io import BytesIO
import os
import numpy as np
from PIL import Image
import torch
from torch.backends import cudnn
from data_loader import get_loader
from solver import Solver
import nsml
INFER_STR = ['inference', 'infer']
OUTPUT_FORMAT = 'png'
def str2bool(v):
return v.lower() in ('true')
def main(config, scope):
# Create directories if not exist.
if not os.path.exists(config.log_path):
os.makedirs(config.log_path)
if not os.path.exists(config.model_save_path):
os.makedirs(config.model_save_path)
if not os.path.exists(config.sample_path):
os.makedirs(config.sample_path)
if config.mode == 'sample':
config.batch_size = config.sample_size
# Data loader
data_loader = get_loader(config.image_path, config.image_size,
config.batch_size, config.num_workers)
# Solver
solver = Solver(data_loader, config)
def load(filename, *args):
solver.load(filename)
def save(filename, *args):
solver.save(filename)
def infer(input):
result = solver.infer(input)
# convert tensor to dataurl
data_url_list = [''] * input
for idx, sample in enumerate(result):
numpy_array = np.uint8(sample.cpu().numpy()*255)
image = Image.fromarray(np.transpose(numpy_array, axes=(1, 2, 0)), 'RGB')
temp_out = BytesIO()
image.save(temp_out, format=OUTPUT_FORMAT)
byte_data = temp_out.getvalue()
data_url_list[idx] = u'data:image/{format};base64,{data}'.\
format(format=OUTPUT_FORMAT,
data=base64.b64encode(byte_data).decode('ascii'))
return data_url_list
def evaluate(test_data, output):
pass
def decode(input):
return input
nsml.bind(save, load, infer, evaluate, decode)
if config.pause:
nsml.paused(scope=scope)
if config.mode == 'train':
solver.train()
elif config.mode == 'sample':
solver.sample()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Model hyper-parameters
parser.add_argument('--image_size', type=int, default=64)
parser.add_argument('--z_dim', type=int, default=256)
parser.add_argument('--k_dim', type=int, default=256)
parser.add_argument('--code_dim', type=int, default=16)
parser.add_argument('--g_conv_dim', type=int, default=64)
parser.add_argument('--d_conv_dim', type=int, default=64)
# Training settings
parser.add_argument('--total_step', type=int, default=200000)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--lr', type=float, default=0.0002)
parser.add_argument('--beta1', type=float, default=0.5)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--trained_model', type=int, default=None)
parser.add_argument('--vq_beta', type=float, default=0.25)
# Test settings
parser.add_argument('--step_for_sampling', type=int, default=200000)
parser.add_argument('--sample_size', type=int, default=32)
# Misc
parser.add_argument('--mode', type=str, default='train', choices=['train', 'sample', *INFER_STR])
parser.add_argument("--iteration", type=int, default=0)
parser.add_argument('--use_tensorboard', type=str2bool, default=True)
parser.add_argument('--log_path', type=str, default='./celebA/logs')
parser.add_argument('--model_save_path', type=str, default='./celebA/models')
parser.add_argument('--sample_path', type=str, default='./celebA/samples')
parser.add_argument('--image_path', type=str, default=nsml.DATASET_PATH)
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=200)
parser.add_argument('--model_save_step', type=int, default=1000)
# nsml setting
parser.add_argument('--pause', type=int, default=0)
config = parser.parse_args()
print(config)
main(config,scope=locals())
================================================
FILE: model.py
================================================
import torch
import torch.nn as nn
import numpy as np
import math
from torch.autograd import Variable
class Generator(nn.Module):
"""Generator. Vector Quantised Variational Auto-Encoder."""
def __init__(self, image_size=64, z_dim=256, conv_dim=64, code_dim=16, k_dim=256):
super(Generator, self).__init__()
self.k_dim = k_dim
self.z_dim = z_dim
self.code_dim = code_dim
self.dict = nn.Embedding(k_dim, z_dim)
# Encoder (increasing #filter linearly)
layers = []
layers.append(nn.Conv2d(3, conv_dim, kernel_size=3, padding=1))
layers.append(nn.BatchNorm2d(conv_dim))
layers.append(nn.ReLU())
repeat_num = int(math.log2(image_size / code_dim))
curr_dim = conv_dim
for i in range(repeat_num):
layers.append(nn.Conv2d(curr_dim, conv_dim * (i+2), kernel_size=4, stride=2, padding=1))
layers.append(nn.BatchNorm2d(conv_dim * (i+2)))
layers.append(nn.ReLU())
curr_dim = conv_dim * (i+2)
# Now we have (code_dim,code_dim,curr_dim)
layers.append(nn.Conv2d(curr_dim, z_dim, kernel_size=1))
# (code_dim,code_dim,z_dim)
self.encoder = nn.Sequential(*layers)
# Decoder (320 - 256 - 192 - 128 - 64)
layers = []
layers.append(nn.ConvTranspose2d(z_dim, curr_dim, kernel_size=1))
layers.append(nn.BatchNorm2d(curr_dim))
layers.append(nn.ReLU())
for i in reversed(range(repeat_num)):
layers.append(nn.ConvTranspose2d(curr_dim , conv_dim * (i+1), kernel_size=4, stride=2, padding=1))
layers.append(nn.BatchNorm2d(conv_dim * (i+1)))
layers.append(nn.ReLU())
curr_dim = conv_dim * (i+1)
layers.append(nn.Conv2d(curr_dim, 3, kernel_size=3, padding=1))
self.decoder = nn.Sequential(*layers)
self.init_weights()
def init_weights(self):
initrange = 1.0 / self.k_dim
self.dict.weight.data.uniform_(-initrange, initrange)
def forward(self, x):
h = self.encoder(x) # (?, z_dim*2, 1, 1)
sz = h.size()
# BCWH -> BWHC
org_h = h
h = h.permute(0,2,3,1)
h = h.contiguous()
Z = h.view(-1,self.z_dim)
W = self.dict.weight
def L2_dist(a,b):
return ((a - b) ** 2)
# Sample nearest embedding
j = L2_dist(Z[:,None],W[None,:]).sum(2).min(1)[1]
W_j = W[j]
# Stop gradients
Z_sg = Z.detach()
W_j_sg = W_j.detach()
# BWHC -> BCWH
h = W_j.view(sz[0],sz[2],sz[3],sz[1])
h = h.permute(0,3,1,2)
def hook(grad):
nonlocal org_h
self.saved_grad = grad
self.saved_h = org_h
return grad
h.register_hook(hook)
# losses
return self.decoder(h), L2_dist(Z,W_j_sg).sum(1).mean(), L2_dist(Z_sg,W_j).sum(1).mean()
# back propagation for encoder
def bwd(self):
self.saved_h.backward(self.saved_grad)
def decode(self, z):
z = z.view(z.size(0), z.size(1), 1, 1)
return self.decoder(z)
================================================
FILE: setup.py
================================================
#nsml: pytorch/pytorch
from distutils.core import setup
setup(
name='nsml example 07 VAE GAN',
version='1.0',
description='ns-ml',
install_requires =[
'visdom',
'pillow'
]
)
================================================
FILE: solver.py
================================================
import torch
import torch.nn as nn
import os
from torch.autograd import Variable
from torchvision.utils import save_image
from model import Generator
import nsml
class Solver(object):
def __init__(self, data_loader, config):
# Data loader새
self.data_loader = data_loader
# Model hyper-parameters
self.image_size = config.image_size
self.z_dim = config.z_dim
self.k_dim = config.k_dim
self.g_conv_dim = config.g_conv_dim
self.code_dim = config.code_dim
# Training settings
self.total_step = config.total_step
self.lr = config.lr
self.beta1 = config.beta1
self.beta2 = config.beta2
self.trained_model = config.trained_model
self.use_tensorboard = config.use_tensorboard
self.vq_beta = config.vq_beta
# Path and step size
self.log_path = config.log_path
self.sample_path = config.sample_path
self.model_save_path = config.model_save_path
self.log_step = config.log_step
self.sample_step = config.sample_step
self.model_save_step = config.model_save_step
# Test setting
self.step_for_sampling = config.step_for_sampling
self.sample_size = config.sample_size
self.build_model()
if self.use_tensorboard:
self.build_tensorboard()
# Start with trained model
if self.trained_model:
self.load_trained_model()
def build_model(self):
# model and optimizer
self.G = Generator(self.image_size, self.z_dim, self.g_conv_dim, k_dim=self.k_dim, code_dim=self.code_dim)
self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.lr, [self.beta1, self.beta2])
if torch.cuda.is_available():
self.G.cuda()
def load_trained_model(self):
self.load(os.path.join(
self.model_save_path, '{}_G.pth'.format(self.trained_model)))
print('loaded trained models (step: {})..!'.format(self.trained_model))
def load(self,filename):
S = torch.load(filename)
self.G.load_state_dict(S['G'])
def build_tensorboard(self):
from logger import Logger
self.logger = Logger(self.log_path)
def update_lr(self, lr):
for param_group in self.g_optimizer.param_groups:
param_group['lr'] = lr
def reset_grad(self):
self.g_optimizer.zero_grad()
def to_var(self, x):
if torch.cuda.is_available():
x = x.cuda()
return Variable(x)
def to_np(self, x):
return x.data.cpu().numpy()
def denorm(self, x):
out = (x + 1) / 2
return out.clamp_(0, 1)
def detach(self, x):
return Variable(x.data)
def train(self):
# Reconst loss
reconst_loss = nn.L1Loss()
# Data iter
data_iter = iter(self.data_loader)
iter_per_epoch = len(self.data_loader)
# Fixed inputs for sampling
fixed_x = next(data_iter)
save_image(self.denorm(fixed_x), os.path.join(self.sample_path, 'fixed_x.png'))
fixed_x = self.to_var(fixed_x)
# Start with trained model
if self.trained_model:
start = self.trained_model + 1
else:
start = 0
for step in range(start, self.total_step):
# Schedule learning rate
alpha = (step - start) / (self.total_step - start)
lr = self.lr * (1/100 ** alpha)
self.update_lr(lr)
# Reset data_iter for each epoch
if (step+1) % iter_per_epoch == 0:
data_iter = iter(self.data_loader)
x = self.to_var(next(data_iter))
# ================== Train G ================== #
# Train with real images (VQ-VAE)
out, loss_e1, loss_e2 = self.G(x)
loss_rec = reconst_loss(out, x)
loss = loss_rec + loss_e1 + self.vq_beta * loss_e2
self.reset_grad()
# For decoder
loss.backward(retain_graph=True)
# For encoder
self.G.bwd()
self.g_optimizer.step()
# Print out log info
if (step+1) % self.log_step == 0:
print("[{}/{}] loss: {:.4f}, loss_e1: {:.4f}". \
format(step+1, self.total_step, loss_rec.data[0], loss_e1.data[0]))
if self.use_tensorboard:
info = {
'loss/loss_rec': loss_rec.data[0],
'loss/loss_ee': loss_e1.data[0],
'misc/lr': lr
}
for tag, value in info.items():
self.logger.scalar_summary(tag, value, step+1, scope=locals())
# Sample images
if (step+1) % self.sample_step == 0:
reconst, _, _ = self.G(fixed_x)
def np(tensor):
return tensor.cpu().numpy()
self.logger.images_summary('recons',np(self.denorm(reconst.data)),step+1)
# Save check points
if (step+1) % self.model_save_step == 0:
nsml.save(step)
def save(self, filename):
G = self.G.state_dict()
torch.save({'G':G}, filename)
def infer(self, sample_size):
# Data iter
data_iter = iter(self.data_loader)
# Inputs for sampling
z = self.to_var(torch.randn(sample_size, self.z_dim))
# Load trained params
self.G.eval()
# Sampling
fake = self.G.decode(z)
return self.denorm(fake.data)
def sample(self):
# Data iter
data_iter = iter(self.data_loader)
# Inputs for sampling
x = next(data_iter)
z = self.to_var(torch.randn(self.sample_size, self.z_dim))
save_image(self.denorm(x), 'real.png')
x = self.to_var(x)
# Load trained params
self.G.eval()
S = torch.load(os.path.join(
self.model_save_path, '{}_G.pth'.format(self.step_for_sampling)))
self.G.load_state_dict(S['G'])
# Sampling
reconst, _, _ = self.G(x)
fake = self.G.decode(z)
save_image(self.denorm(reconst.data), 'reconst.png')
gitextract_4u1gmzza/ ├── README.md ├── data_loader.py ├── logger.py ├── main.py ├── model.py ├── setup.py └── solver.py
SYMBOL INDEX (34 symbols across 5 files)
FILE: data_loader.py
class ImageFolder (line 6) | class ImageFolder(data.Dataset):
method __init__ (line 8) | def __init__(self, root, transform=None):
method __getitem__ (line 13) | def __getitem__(self, index):
method __len__ (line 21) | def __len__(self):
function get_loader (line 26) | def get_loader(image_path, image_size, batch_size, num_workers=2):
FILE: logger.py
class Logger (line 7) | class Logger(object):
method __init__ (line 8) | def __init__(self, log_dir):
method scalar_summary (line 12) | def scalar_summary(self, tag, value, step, scope=None):
method images_summary (line 20) | def images_summary(self, tag, images, step):
method histo_summary (line 27) | def histo_summary(self, tag, values, step, bins=1000):
FILE: main.py
function str2bool (line 17) | def str2bool(v):
function main (line 21) | def main(config, scope):
FILE: model.py
class Generator (line 7) | class Generator(nn.Module):
method __init__ (line 9) | def __init__(self, image_size=64, z_dim=256, conv_dim=64, code_dim=16,...
method init_weights (line 55) | def init_weights(self):
method forward (line 59) | def forward(self, x):
method bwd (line 99) | def bwd(self):
method decode (line 102) | def decode(self, z):
FILE: solver.py
class Solver (line 9) | class Solver(object):
method __init__ (line 11) | def __init__(self, data_loader, config):
method build_model (line 52) | def build_model(self):
method load_trained_model (line 60) | def load_trained_model(self):
method load (line 65) | def load(self,filename):
method build_tensorboard (line 69) | def build_tensorboard(self):
method update_lr (line 73) | def update_lr(self, lr):
method reset_grad (line 77) | def reset_grad(self):
method to_var (line 80) | def to_var(self, x):
method to_np (line 85) | def to_np(self, x):
method denorm (line 88) | def denorm(self, x):
method detach (line 92) | def detach(self, x):
method train (line 95) | def train(self):
method save (line 171) | def save(self, filename):
method infer (line 175) | def infer(self, sample_size):
method sample (line 190) | def sample(self):
Condensed preview — 7 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (18K chars).
[
{
"path": "README.md",
"chars": 639,
"preview": "# VQ-VAE(Neural Discrete Representation Learning)\n\nhttps://arxiv.org/abs/1711.00937\n\n## Requirements\n[nsml](https://rese"
},
{
"path": "data_loader.py",
"chars": 1472,
"preview": "import os\nfrom torch.utils import data\nfrom torchvision import transforms\nfrom PIL import Image\n\nclass ImageFolder(data."
},
{
"path": "logger.py",
"chars": 934,
"preview": "# Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514\nimport numpy as np\nimport scipy.m"
},
{
"path": "main.py",
"chars": 4160,
"preview": "import argparse\nimport base64\nfrom io import BytesIO\nimport os\nimport numpy as np\nfrom PIL import Image\nimport torch\nfro"
},
{
"path": "model.py",
"chars": 3312,
"preview": "import torch \nimport torch.nn as nn\nimport numpy as np\nimport math\nfrom torch.autograd import Variable\n\nclass Generator("
},
{
"path": "setup.py",
"chars": 209,
"preview": "#nsml: pytorch/pytorch\nfrom distutils.core import setup\nsetup(\n name='nsml example 07 VAE GAN',\n version='1.0',\n "
},
{
"path": "solver.py",
"chars": 6648,
"preview": "import torch\nimport torch.nn as nn\nimport os\nfrom torch.autograd import Variable\nfrom torchvision.utils import save_imag"
}
]
About this extraction
This page contains the full source code of the nakosung/VQ-VAE GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 7 files (17.0 KB), approximately 4.3k tokens, and a symbol index with 34 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.