Full Code of nakosung/VQ-VAE for AI

master e374bf9ddc8c cached
7 files
17.0 KB
4.3k tokens
34 symbols
1 requests
Download .txt
Repository: nakosung/VQ-VAE
Branch: master
Commit: e374bf9ddc8c
Files: 7
Total size: 17.0 KB

Directory structure:
gitextract_4u1gmzza/

├── README.md
├── data_loader.py
├── logger.py
├── main.py
├── model.py
├── setup.py
└── solver.py

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
# VQ-VAE(Neural Discrete Representation Learning)

https://arxiv.org/abs/1711.00937

## Requirements
[nsml](https://research.clova.ai/nsml-alpha)

## How to run
`nsml run -d CelebA_128 -v -i`

## Introduction

This is a repro of Vector Quantisation VAE from Deepmind. Authors had applied VQ-VAE for various tasks, but this repo is a slight modification of yunjey's VAE-GAN(CelebA dataset) to replace VAE with VQ-VAE.

![image](https://user-images.githubusercontent.com/2463571/32690439-0ec64640-c73a-11e7-9e06-a0a6ea23248d.png)

![image](https://user-images.githubusercontent.com/2463571/32690440-1738bff6-c73a-11e7-92a4-57dd8449e5a5.png)


================================================
FILE: data_loader.py
================================================
import os
from torch.utils import data
from torchvision import transforms
from PIL import Image

class ImageFolder(data.Dataset):
    """Custom Dataset compatible with prebuilt DataLoader."""
    def __init__(self, root, transform=None):
        """Initialize image paths and preprocessing module."""
        self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root)))
        self.transform = transform
        
    def __getitem__(self, index):
        """Read an image from a file and preprocesses it and returns."""
        image_path = self.image_paths[index]
        image = Image.open(image_path).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        return image
    
    def __len__(self):
        """Return the total number of image files."""
        return len(self.image_paths)


def get_loader(image_path, image_size, batch_size, num_workers=2):
    """Create and return Dataloader."""
    
    transform = transforms.Compose([
                    transforms.Scale(image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    dataset = ImageFolder(image_path, transform)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers)
    return data_loader

================================================
FILE: logger.py
================================================
# Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
import numpy as np
import scipy.misc 
import nsml
import visdom

class Logger(object):    
    def __init__(self, log_dir):
        self.last = None
        self.viz = nsml.Visdom(visdom=visdom)

    def scalar_summary(self, tag, value, step, scope=None):
        if self.last and self.last['step'] != step:
            nsml.report(**self.last,scope=scope)
            self.last = None
        if self.last is None:
            self.last = {'step':step,'iter':step,'epoch':1}
        self.last[tag] = value            

    def images_summary(self, tag, images, step):
        """Log a list of images."""
        self.viz.images(
            images,            
            opts=dict(title='%s/%d' % (tag, step), caption='%s/%d' % (tag, step)),
        )            
        
    def histo_summary(self, tag, values, step, bins=1000):
        pass

================================================
FILE: main.py
================================================
import argparse
import base64
from io import BytesIO
import os
import numpy as np
from PIL import Image
import torch
from torch.backends import cudnn
from data_loader import get_loader
from solver import Solver
import nsml

INFER_STR = ['inference', 'infer']
OUTPUT_FORMAT = 'png'


def str2bool(v):
    return v.lower() in ('true')


def main(config, scope):

    # Create directories if not exist.
    if not os.path.exists(config.log_path):
        os.makedirs(config.log_path)
    if not os.path.exists(config.model_save_path):
        os.makedirs(config.model_save_path)
    if not os.path.exists(config.sample_path):
        os.makedirs(config.sample_path)

    if config.mode == 'sample':
        config.batch_size = config.sample_size

    # Data loader
    data_loader = get_loader(config.image_path, config.image_size,
                              config.batch_size, config.num_workers)
    # Solver
    solver = Solver(data_loader, config)

    def load(filename, *args):
        solver.load(filename)

    def save(filename, *args):
        solver.save(filename)

    def infer(input):
        result = solver.infer(input)
        # convert tensor to dataurl
        data_url_list = [''] * input
        for idx, sample in enumerate(result):
            numpy_array = np.uint8(sample.cpu().numpy()*255)
            image = Image.fromarray(np.transpose(numpy_array, axes=(1, 2, 0)), 'RGB')
            temp_out = BytesIO()
            image.save(temp_out, format=OUTPUT_FORMAT)
            byte_data = temp_out.getvalue()
            data_url_list[idx] = u'data:image/{format};base64,{data}'.\
                format(format=OUTPUT_FORMAT,
                       data=base64.b64encode(byte_data).decode('ascii'))
        return data_url_list

    def evaluate(test_data, output):
        pass

    def decode(input):
        return input

    nsml.bind(save, load, infer, evaluate, decode)

    if config.pause:
        nsml.paused(scope=scope)
    
    if config.mode == 'train':
        solver.train()
    elif config.mode == 'sample':
        solver.sample()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    
    # Model hyper-parameters
    parser.add_argument('--image_size', type=int, default=64)
    parser.add_argument('--z_dim', type=int, default=256)
    parser.add_argument('--k_dim', type=int, default=256)
    parser.add_argument('--code_dim', type=int, default=16)
    parser.add_argument('--g_conv_dim', type=int, default=64)
    parser.add_argument('--d_conv_dim', type=int, default=64)
    
    # Training settings
    parser.add_argument('--total_step', type=int, default=200000)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--num_workers', type=int, default=2)
    parser.add_argument('--lr', type=float, default=0.0002)
    parser.add_argument('--beta1', type=float, default=0.5)
    parser.add_argument('--beta2', type=float, default=0.999)
    parser.add_argument('--trained_model', type=int, default=None)

    parser.add_argument('--vq_beta', type=float, default=0.25)
    
    # Test settings
    parser.add_argument('--step_for_sampling', type=int, default=200000)
    parser.add_argument('--sample_size', type=int, default=32)
    
    # Misc
    parser.add_argument('--mode', type=str, default='train', choices=['train', 'sample', *INFER_STR])
    parser.add_argument("--iteration", type=int, default=0)
    parser.add_argument('--use_tensorboard', type=str2bool, default=True)
    
    parser.add_argument('--log_path', type=str, default='./celebA/logs')
    parser.add_argument('--model_save_path', type=str, default='./celebA/models')
    parser.add_argument('--sample_path', type=str, default='./celebA/samples')
    parser.add_argument('--image_path', type=str, default=nsml.DATASET_PATH)
    
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--sample_step', type=int, default=200)
    parser.add_argument('--model_save_step', type=int, default=1000)

    # nsml setting
    parser.add_argument('--pause', type=int, default=0)
    
    config = parser.parse_args()
    print(config)
    main(config,scope=locals())


================================================
FILE: model.py
================================================
import torch 
import torch.nn as nn
import numpy as np
import math
from torch.autograd import Variable

class Generator(nn.Module):
    """Generator. Vector Quantised Variational Auto-Encoder."""
    def __init__(self, image_size=64, z_dim=256, conv_dim=64, code_dim=16, k_dim=256):
        super(Generator, self).__init__()

        self.k_dim = k_dim
        self.z_dim = z_dim
        self.code_dim = code_dim
        
        self.dict = nn.Embedding(k_dim, z_dim)
        
        # Encoder (increasing #filter linearly)
        layers = []
        layers.append(nn.Conv2d(3, conv_dim, kernel_size=3, padding=1))
        layers.append(nn.BatchNorm2d(conv_dim))
        layers.append(nn.ReLU())
        
        repeat_num = int(math.log2(image_size / code_dim))
        curr_dim = conv_dim
        for i in range(repeat_num):
            layers.append(nn.Conv2d(curr_dim, conv_dim * (i+2), kernel_size=4, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(conv_dim * (i+2)))
            layers.append(nn.ReLU())
            curr_dim = conv_dim * (i+2)

        # Now we have (code_dim,code_dim,curr_dim)
        layers.append(nn.Conv2d(curr_dim, z_dim, kernel_size=1))

        # (code_dim,code_dim,z_dim)
        self.encoder = nn.Sequential(*layers)
        
        # Decoder (320 - 256 - 192 - 128 - 64)
        layers = []
        layers.append(nn.ConvTranspose2d(z_dim, curr_dim, kernel_size=1))
        layers.append(nn.BatchNorm2d(curr_dim))
        layers.append(nn.ReLU())
                
        for i in reversed(range(repeat_num)):
            layers.append(nn.ConvTranspose2d(curr_dim , conv_dim * (i+1), kernel_size=4, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(conv_dim * (i+1)))
            layers.append(nn.ReLU())
            curr_dim = conv_dim * (i+1)
        
        layers.append(nn.Conv2d(curr_dim, 3, kernel_size=3, padding=1))
        self.decoder = nn.Sequential(*layers)

        self.init_weights()        
        
    def init_weights(self):
        initrange = 1.0 / self.k_dim
        self.dict.weight.data.uniform_(-initrange, initrange)        
            
    def forward(self, x):
        h = self.encoder(x)                             # (?, z_dim*2, 1, 1)

        sz = h.size()
        
        # BCWH -> BWHC
        org_h = h
        h = h.permute(0,2,3,1)
        h = h.contiguous()
        Z = h.view(-1,self.z_dim)

        W = self.dict.weight

        def L2_dist(a,b):
            return ((a - b) ** 2)
        
        # Sample nearest embedding
        j = L2_dist(Z[:,None],W[None,:]).sum(2).min(1)[1]
        W_j = W[j]

        # Stop gradients
        Z_sg = Z.detach()
        W_j_sg = W_j.detach()

        # BWHC -> BCWH
        h = W_j.view(sz[0],sz[2],sz[3],sz[1])
        h = h.permute(0,3,1,2)

        def hook(grad):
            nonlocal org_h
            self.saved_grad = grad
            self.saved_h = org_h
            return grad

        h.register_hook(hook)
        
        # losses
        return self.decoder(h), L2_dist(Z,W_j_sg).sum(1).mean(), L2_dist(Z_sg,W_j).sum(1).mean()

    # back propagation for encoder
    def bwd(self):
        self.saved_h.backward(self.saved_grad)
        
    def decode(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)
        return self.decoder(z)
    


================================================
FILE: setup.py
================================================
#nsml: pytorch/pytorch
from distutils.core import setup
setup(
    name='nsml example 07 VAE GAN',
    version='1.0',
    description='ns-ml',
    install_requires =[
        'visdom',
        'pillow'
    ]
)

================================================
FILE: solver.py
================================================
import torch
import torch.nn as nn
import os
from torch.autograd import Variable
from torchvision.utils import save_image
from model import Generator
import nsml

class Solver(object):
    
    def __init__(self, data_loader, config):
        # Data loader새
        self.data_loader = data_loader
        
        # Model hyper-parameters
        self.image_size = config.image_size
        self.z_dim = config.z_dim
        self.k_dim = config.k_dim
        self.g_conv_dim = config.g_conv_dim
        self.code_dim = config.code_dim
        
        # Training settings
        self.total_step = config.total_step
        self.lr = config.lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.trained_model = config.trained_model
        self.use_tensorboard = config.use_tensorboard

        self.vq_beta = config.vq_beta
        
        # Path and step size
        self.log_path = config.log_path
        self.sample_path = config.sample_path
        self.model_save_path = config.model_save_path
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        
        # Test setting
        self.step_for_sampling = config.step_for_sampling
        self.sample_size = config.sample_size
        
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()
        
        # Start with trained model
        if self.trained_model:
            self.load_trained_model()
        
    def build_model(self):
        # model and optimizer
        self.G = Generator(self.image_size, self.z_dim, self.g_conv_dim, k_dim=self.k_dim, code_dim=self.code_dim)
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.lr, [self.beta1, self.beta2])
        
        if torch.cuda.is_available():
            self.G.cuda()
    
    def load_trained_model(self):
        self.load(os.path.join(
            self.model_save_path, '{}_G.pth'.format(self.trained_model)))
        print('loaded trained models (step: {})..!'.format(self.trained_model))

    def load(self,filename):
        S = torch.load(filename)        
        self.G.load_state_dict(S['G'])
         
    def build_tensorboard(self):
        from logger import Logger
        self.logger = Logger(self.log_path)
        
    def update_lr(self, lr):
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = lr
    
    def reset_grad(self):
        self.g_optimizer.zero_grad()
        
    def to_var(self, x):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x)
    
    def to_np(self, x):
        return x.data.cpu().numpy()
    
    def denorm(self, x):
        out = (x + 1) / 2
        return out.clamp_(0, 1)
        
    def detach(self, x):
        return Variable(x.data)
    
    def train(self):
        # Reconst loss
        reconst_loss = nn.L1Loss()
        
        # Data iter
        data_iter = iter(self.data_loader)
        iter_per_epoch = len(self.data_loader)
        
        # Fixed inputs for sampling
        fixed_x = next(data_iter)
        save_image(self.denorm(fixed_x), os.path.join(self.sample_path, 'fixed_x.png'))
        fixed_x = self.to_var(fixed_x)
        
        # Start with trained model
        if self.trained_model:
            start = self.trained_model + 1
        else:
            start = 0
            
        for step in range(start, self.total_step):
            # Schedule learning rate
            alpha = (step - start) / (self.total_step - start)
            lr = self.lr * (1/100 ** alpha)

            self.update_lr(lr)
            
            # Reset data_iter for each epoch
            if (step+1) % iter_per_epoch == 0:
                data_iter = iter(self.data_loader)
               
            x = self.to_var(next(data_iter))
                        
            # ================== Train G ================== #
            # Train with real images (VQ-VAE)
            out, loss_e1, loss_e2 = self.G(x)
            loss_rec = reconst_loss(out, x)
            
            loss = loss_rec + loss_e1 + self.vq_beta * loss_e2
            self.reset_grad()

            # For decoder
            loss.backward(retain_graph=True)

            # For encoder
            self.G.bwd()

            self.g_optimizer.step()
            
            # Print out log info
            if (step+1) % self.log_step == 0:
                print("[{}/{}] loss: {:.4f}, loss_e1: {:.4f}". \
                      format(step+1, self.total_step, loss_rec.data[0], loss_e1.data[0]))
                      
                if self.use_tensorboard:
                    info = {
                        'loss/loss_rec': loss_rec.data[0],
                        'loss/loss_ee': loss_e1.data[0],
                        'misc/lr': lr
                    }
                    
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, step+1, scope=locals())
                
            # Sample images
            if (step+1) % self.sample_step == 0:
                reconst, _, _ = self.G(fixed_x)

                def np(tensor):
                    return tensor.cpu().numpy()

                self.logger.images_summary('recons',np(self.denorm(reconst.data)),step+1)

            # Save check points
            if (step+1) % self.model_save_step == 0:
                nsml.save(step)
    
    def save(self, filename):
        G = self.G.state_dict()
        torch.save({'G':G}, filename)

    def infer(self, sample_size):
        # Data iter
        data_iter = iter(self.data_loader)
        
        # Inputs for sampling
        z = self.to_var(torch.randn(sample_size, self.z_dim)) 
        
        # Load trained params
        self.G.eval()  
        
        # Sampling
        fake = self.G.decode(z)

        return self.denorm(fake.data)
        
    def sample(self):
        # Data iter
        data_iter = iter(self.data_loader)
        
        # Inputs for sampling
        x = next(data_iter)
        z = self.to_var(torch.randn(self.sample_size, self.z_dim)) 
        
        save_image(self.denorm(x), 'real.png')
        x = self.to_var(x)
        
        # Load trained params
        self.G.eval()  
        S = torch.load(os.path.join(
            self.model_save_path, '{}_G.pth'.format(self.step_for_sampling)))
        self.G.load_state_dict(S['G'])
        
        # Sampling
        reconst, _, _ = self.G(x)
        fake = self.G.decode(z)
        save_image(self.denorm(reconst.data), 'reconst.png')
Download .txt
gitextract_4u1gmzza/

├── README.md
├── data_loader.py
├── logger.py
├── main.py
├── model.py
├── setup.py
└── solver.py
Download .txt
SYMBOL INDEX (34 symbols across 5 files)

FILE: data_loader.py
  class ImageFolder (line 6) | class ImageFolder(data.Dataset):
    method __init__ (line 8) | def __init__(self, root, transform=None):
    method __getitem__ (line 13) | def __getitem__(self, index):
    method __len__ (line 21) | def __len__(self):
  function get_loader (line 26) | def get_loader(image_path, image_size, batch_size, num_workers=2):

FILE: logger.py
  class Logger (line 7) | class Logger(object):
    method __init__ (line 8) | def __init__(self, log_dir):
    method scalar_summary (line 12) | def scalar_summary(self, tag, value, step, scope=None):
    method images_summary (line 20) | def images_summary(self, tag, images, step):
    method histo_summary (line 27) | def histo_summary(self, tag, values, step, bins=1000):

FILE: main.py
  function str2bool (line 17) | def str2bool(v):
  function main (line 21) | def main(config, scope):

FILE: model.py
  class Generator (line 7) | class Generator(nn.Module):
    method __init__ (line 9) | def __init__(self, image_size=64, z_dim=256, conv_dim=64, code_dim=16,...
    method init_weights (line 55) | def init_weights(self):
    method forward (line 59) | def forward(self, x):
    method bwd (line 99) | def bwd(self):
    method decode (line 102) | def decode(self, z):

FILE: solver.py
  class Solver (line 9) | class Solver(object):
    method __init__ (line 11) | def __init__(self, data_loader, config):
    method build_model (line 52) | def build_model(self):
    method load_trained_model (line 60) | def load_trained_model(self):
    method load (line 65) | def load(self,filename):
    method build_tensorboard (line 69) | def build_tensorboard(self):
    method update_lr (line 73) | def update_lr(self, lr):
    method reset_grad (line 77) | def reset_grad(self):
    method to_var (line 80) | def to_var(self, x):
    method to_np (line 85) | def to_np(self, x):
    method denorm (line 88) | def denorm(self, x):
    method detach (line 92) | def detach(self, x):
    method train (line 95) | def train(self):
    method save (line 171) | def save(self, filename):
    method infer (line 175) | def infer(self, sample_size):
    method sample (line 190) | def sample(self):
Condensed preview — 7 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (18K chars).
[
  {
    "path": "README.md",
    "chars": 639,
    "preview": "# VQ-VAE(Neural Discrete Representation Learning)\n\nhttps://arxiv.org/abs/1711.00937\n\n## Requirements\n[nsml](https://rese"
  },
  {
    "path": "data_loader.py",
    "chars": 1472,
    "preview": "import os\nfrom torch.utils import data\nfrom torchvision import transforms\nfrom PIL import Image\n\nclass ImageFolder(data."
  },
  {
    "path": "logger.py",
    "chars": 934,
    "preview": "# Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514\nimport numpy as np\nimport scipy.m"
  },
  {
    "path": "main.py",
    "chars": 4160,
    "preview": "import argparse\nimport base64\nfrom io import BytesIO\nimport os\nimport numpy as np\nfrom PIL import Image\nimport torch\nfro"
  },
  {
    "path": "model.py",
    "chars": 3312,
    "preview": "import torch \nimport torch.nn as nn\nimport numpy as np\nimport math\nfrom torch.autograd import Variable\n\nclass Generator("
  },
  {
    "path": "setup.py",
    "chars": 209,
    "preview": "#nsml: pytorch/pytorch\nfrom distutils.core import setup\nsetup(\n    name='nsml example 07 VAE GAN',\n    version='1.0',\n  "
  },
  {
    "path": "solver.py",
    "chars": 6648,
    "preview": "import torch\nimport torch.nn as nn\nimport os\nfrom torch.autograd import Variable\nfrom torchvision.utils import save_imag"
  }
]

About this extraction

This page contains the full source code of the nakosung/VQ-VAE GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 7 files (17.0 KB), approximately 4.3k tokens, and a symbol index with 34 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!