master c44d3a02996a cached
10 files
47.0 KB
12.0k tokens
44 symbols
1 requests
Download .txt
Repository: voletiv/self-attention-GAN-pytorch
Branch: master
Commit: c44d3a02996a
Files: 10
Total size: 47.0 KB

Directory structure:
gitextract_cs9fl47f/

├── .gitignore
├── LICENSE
├── README.md
├── parameters.py
├── requirements.txt
├── sagan_models.py
├── test.py
├── train.py
├── trainer.py
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2019 Vikram Voleti

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# self-attention-GAN-pytorch

This is an almost exact replica in PyTorch of the Tensorflow version of [SAGAN](https://arxiv.org/abs/1805.08318) released by Google Brain [[repo](https://github.com/brain-research/self-attention-gan)] in August 2018.

Code structure is inspired from [this repo](https://github.com/heykeetae/Self-Attention-GAN), but follows the details of [Google Brain's repo](https://github.com/brain-research/self-attention-gan).

## Prerequisites

Check `requirements.txt`.

* [Python 3.5+](https://www.continuum.io/downloads)
* [PyTorch 0.4.1](http://pytorch.org/)

## Training

#### 1. Check `parameters.py` for all arguments and their default values

#### 2. Train on custom images in folder a/b/c:
```bash
$ python train.py --data_path 'a/b/c' --save_path 'o/p/q' --batch_size 64 --name sagan
```

(Warning: Works only on *128x128* images, input images are resized to that. Tweak the Generator & Discriminator first if you would like to use some other image size. And then use the `imsize` option:
```bash
$ python train.py --data_path 'a/b/c' --save_path 'o/p/q' --batch_size 64 --imsize 32 --name sagan
```
)

Model training will be recorded in a new folder inside `--save_path` with the name `<timestamp>_<name>_<basename of data_path>`.

By default, model weights are saved in a subfolder called `weights`, and train & validation samples during training in a subfolder called `samples` (can be changed in `parameters.py`).

## Testing/Evaluating

Check `test.py`.

## Self-Attention GAN
**[Han Zhang, Ian Goodfellow, Dimitris Metaxas and Augustus Odena, "Self-Attention Generative Adversarial Networks." arXiv preprint arXiv:1805.08318 (2018)](https://arxiv.org/abs/1805.08318).**

```
@article{Zhang2018SelfAttentionGA,
    title={Self-Attention Generative Adversarial Networks},
    author={Han Zhang and Ian J. Goodfellow and Dimitris N. Metaxas and Augustus Odena},
    journal={CoRR},
    year={2018},
    volume={abs/1805.08318}
}
```


================================================
FILE: parameters.py
================================================
import argparse
import datetime
import os


def get_parameters():

    parser = argparse.ArgumentParser()

    # Images data path & Output path
    parser.add_argument('--dataset', type=str, default='folder', choices=["cifar10", "fake", "folder", "hdf5", "imagenet", "lfw", "lsun"],
                        help="cifar10 | fake | folder | hdf5 | imagenet | lfw | lsun")
    parser.add_argument('--data_path', type=str, default='', help='Path to root of image data (saved in dirs of classes)')
    parser.add_argument('--save_path', type=str, default='./sagan_models')

    # Training settings
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--batch_size_in_gpu', type=int, default=0,
                        help='0 => same as batch_size, else: if using multiple gpu iterations to make an effective batch, e.g. batch_size=32, batch_size_in_gpu=16 => optimizer.step() is run 2 iterations after running loss.backward()')
    parser.add_argument('--total_step', type=int, default=200000, help='how many iterations')
    parser.add_argument('--d_steps_per_iter', type=int, default=1, help='how many D updates per iteration')
    parser.add_argument('--g_steps_per_iter', type=int, default=1, help='how many G updates per iteration')
    parser.add_argument('--d_lr', type=float, default=0.0004)
    parser.add_argument('--g_lr', type=float, default=0.0001)
    parser.add_argument('--beta1', type=float, default=0.0)
    parser.add_argument('--beta2', type=float, default=0.999)

    # Model hyper-parameters
    parser.add_argument('--adv_loss', type=str, default='hinge', choices=['hinge', 'dcgan', 'wgan_gp', 'gan'])
    parser.add_argument('--z_dim', type=int, default=128)
    parser.add_argument('--g_conv_dim', type=int, default=64)
    parser.add_argument('--d_conv_dim', type=int, default=64)
    parser.add_argument('--lambda_gp', type=float, default=10)

    # Instance noise
    # https://github.com/soumith/ganhacks/issues/14#issuecomment-312509518
    # https://www.inference.vc/instance-noise-a-trick-for-stabilising-gan-training/
    parser.add_argument('--inst_noise_sigma', type=float, default=0.0)
    parser.add_argument('--inst_noise_sigma_iters', type=int, default=2000)

    # Image transforms
    parser.add_argument('--dont_shuffle', action='store_true')
    parser.add_argument('--dont_drop_last', action='store_true', help="Whether not to drop the last batch in dataset if its size < batch_size")
    parser.add_argument('--dont_resize', action='store_true', help="Whether not to resize images")
    parser.add_argument('--imsize', type=int, default=128)
    parser.add_argument('--centercrop', action='store_true', help="Whether to center crop images")
    parser.add_argument('--centercrop_size', type=int, default=128)
    parser.add_argument('--dont_normalize', action='store_true', help="Whether to normalize image values")

    # Step sizes
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--sample_step', type=int, default=10)
    parser.add_argument('--model_save_step', type=float, default=50)
    parser.add_argument('--save_n_images', type=int, default=0,
                        help='0 => same as batch_size_in_gpu')
    parser.add_argument('--nrow', type=int, default=10)
    parser.add_argument('--max_frames_per_gif', type=int, default=100)

    # Pretrained model
    parser.add_argument('--pretrained_model', type=str, default='')
    parser.add_argument('--state_dict_or_model', type=str, default='', help="Specify whether .pth pretrained_model is a 'state_dict' or a complete 'model'")

    # Misc
    parser.add_argument('--manual_seed', type=int, default=29)
    parser.add_argument('--disable_cuda', action='store_true', help='Disable CUDA')
    parser.add_argument('--parallel', action='store_true', help="Run on multiple GPUs")
    parser.add_argument('--num_workers', type=int, default=4)
    # parser.add_argument('--use_tensorboard', action='store_true')

    # Output paths
    parser.add_argument('--model_weights_dir', type=str, default='weights')
    parser.add_argument('--sample_images_dir', type=str, default='samples')

    # Model name
    parser.add_argument('--name', type=str, default='sagan')

    args = parser.parse_args()

    if args.batch_size_in_gpu == 0:
        args.batch_size_in_gpu = args.batch_size

    assert args.batch_size_in_gpu <= args.batch_size, "ERROR: please make sure batch_size >= batch_size_in_gpu!! Given batch_size: " + str(args.batch_size) + " ; batch_size_in_gpu: " + str(args.batch_size_in_gpu)
    assert args.batch_size % args.batch_size_in_gpu == 0, "ERROR: please make sure batch_size_in_gpu divides batch_size!! Given batch_size: " + str(args.batch_size) + " ; batch_size_in_gpu: " + str(args.batch_size_in_gpu)

    args.batch_size_effective = args.batch_size_in_gpu*(args.batch_size//args.batch_size_in_gpu)

    print("Effective BATCH SIZE:", args.batch_size_effective)

    if args.save_n_images == 0:
        args.save_n_images = args.batch_size_in_gpu

    assert args.save_n_images <= args.batch_size_in_gpu, "ERROR: please make save_n_images <= batch_size_in_gpu!! Given save_n_images: " + str(args.save_n_images) + " ; batch_size_in_gpu: " + str(args.batch_size_in_gpu)

    # Corrections
    args.shuffle = not args.dont_shuffle
    args.drop_last = not args.dont_drop_last
    args.resize = not args.dont_resize
    args.normalize = not args.dont_normalize

    args.dataloader_args = {'num_workers':args.num_workers}

    args.name = '{0:%Y%m%d_%H%M%S}_{1}_{2}'.format(datetime.datetime.now(), args.name, os.path.basename(args.data_path))

    args.save_path = os.path.join(args.save_path, args.name)
    args.model_weights_path = os.path.join(args.save_path, args.model_weights_dir)
    args.sample_images_path = os.path.join(args.save_path, args.sample_images_dir)

    return args


================================================
FILE: requirements.txt
================================================
matplotlib==3.0.0
torchvision==0.2.1
torch==2.2.0
opencv_python==4.2.0.32
imageio==2.4.1
numpy==1.22.0


================================================
FILE: sagan_models.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn.utils import spectral_norm
from torch.nn.init import xavier_uniform_


def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        xavier_uniform_(m.weight)
        m.bias.data.fill_(0.)


def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
    return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                   stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))


def snlinear(in_features, out_features):
    return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features))


def sn_embedding(num_embeddings, embedding_dim):
    return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim))


class Self_Attn(nn.Module):
    """ Self attention Layer"""

    def __init__(self, in_channels):
        super(Self_Attn, self).__init__()
        self.in_channels = in_channels
        self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
        self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
        self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0)
        self.snconv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
        self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
        self.softmax  = nn.Softmax(dim=-1)
        self.sigma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        """
            inputs :
                x : input feature maps(B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        _, ch, h, w = x.size()
        # Theta path
        theta = self.snconv1x1_theta(x)
        theta = theta.view(-1, ch//8, h*w)
        # Phi path
        phi = self.snconv1x1_phi(x)
        phi = self.maxpool(phi)
        phi = phi.view(-1, ch//8, h*w//4)
        # Attn map
        attn = torch.bmm(theta.permute(0, 2, 1), phi)
        attn = self.softmax(attn)
        # g path
        g = self.snconv1x1_g(x)
        g = self.maxpool(g)
        g = g.view(-1, ch//2, h*w//4)
        # Attn_g
        attn_g = torch.bmm(g, attn.permute(0, 2, 1))
        attn_g = attn_g.view(-1, ch//2, h, w)
        attn_g = self.snconv1x1_attn(attn_g)
        # Out
        out = x + self.sigma*attn_g
        return out


class ConditionalBatchNorm2d(nn.Module):
    # https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.num_features = num_features
        self.bn = nn.BatchNorm2d(num_features, momentum=0.001, affine=False)
        self.embed = nn.Embedding(num_classes, num_features * 2)
        # self.embed.weight.data[:, :num_features].normal_(1, 0.02)  # Initialise scale at N(1, 0.02)
        self.embed.weight.data[:, :num_features].fill_(1.)  # Initialize scale to 1
        self.embed.weight.data[:, num_features:].zero_()  # Initialize bias at 0

    def forward(self, x, y):
        out = self.bn(x)
        gamma, beta = self.embed(y).chunk(2, 1)
        out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
        return out


class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_classes):
        super(GenBlock, self).__init__()
        self.cond_bn1 = ConditionalBatchNorm2d(in_channels, num_classes)
        self.relu = nn.ReLU(inplace=True)
        self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.cond_bn2 = ConditionalBatchNorm2d(out_channels, num_classes)
        self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x, labels):
        x0 = x

        x = self.cond_bn1(x, labels)
        x = self.relu(x)
        x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample
        x = self.snconv2d1(x)
        x = self.cond_bn2(x, labels)
        x = self.relu(x)
        x = self.snconv2d2(x)

        x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample
        x0 = self.snconv2d0(x0)

        out = x + x0
        return out


class Generator(nn.Module):
    """Generator."""

    def __init__(self, z_dim, g_conv_dim, num_classes):
        super(Generator, self).__init__()

        self.z_dim = z_dim
        self.g_conv_dim = g_conv_dim
        self.snlinear0 = snlinear(in_features=z_dim, out_features=g_conv_dim*16*4*4)
        self.block1 = GenBlock(g_conv_dim*16, g_conv_dim*16, num_classes)
        self.block2 = GenBlock(g_conv_dim*16, g_conv_dim*8, num_classes)
        self.block3 = GenBlock(g_conv_dim*8, g_conv_dim*4, num_classes)
        self.self_attn = Self_Attn(g_conv_dim*4)
        self.block4 = GenBlock(g_conv_dim*4, g_conv_dim*2, num_classes)
        self.block5 = GenBlock(g_conv_dim*2, g_conv_dim, num_classes)
        self.bn = nn.BatchNorm2d(g_conv_dim, eps=1e-5, momentum=0.0001, affine=True)
        self.relu = nn.ReLU(inplace=True)
        self.snconv2d1 = snconv2d(in_channels=g_conv_dim, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.tanh = nn.Tanh()

        # Weight init
        self.apply(init_weights)

    def forward(self, z, labels):
        # n x z_dim
        act0 = self.snlinear0(z)            # n x g_conv_dim*16*4*4
        act0 = act0.view(-1, self.g_conv_dim*16, 4, 4) # n x g_conv_dim*16 x 4 x 4
        act1 = self.block1(act0, labels)    # n x g_conv_dim*16 x 8 x 8
        act2 = self.block2(act1, labels)    # n x g_conv_dim*8 x 16 x 16
        act3 = self.block3(act2, labels)    # n x g_conv_dim*4 x 32 x 32
        act3 = self.self_attn(act3)         # n x g_conv_dim*4 x 32 x 32
        act4 = self.block4(act3, labels)    # n x g_conv_dim*2 x 64 x 64
        act5 = self.block5(act4, labels)    # n x g_conv_dim  x 128 x 128
        act5 = self.bn(act5)                # n x g_conv_dim  x 128 x 128
        act5 = self.relu(act5)              # n x g_conv_dim  x 128 x 128
        act6 = self.snconv2d1(act5)         # n x 3 x 128 x 128
        act6 = self.tanh(act6)              # n x 3 x 128 x 128
        return act6


class DiscOptBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DiscOptBlock, self).__init__()
        self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.downsample = nn.AvgPool2d(2)
        self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x0 = x

        x = self.snconv2d1(x)
        x = self.relu(x)
        x = self.snconv2d2(x)
        x = self.downsample(x)

        x0 = self.downsample(x0)
        x0 = self.snconv2d0(x0)

        out = x + x0
        return out


class DiscBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DiscBlock, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.downsample = nn.AvgPool2d(2)
        self.ch_mismatch = False
        if in_channels != out_channels:
            self.ch_mismatch = True
        self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x, downsample=True):
        x0 = x

        x = self.relu(x)
        x = self.snconv2d1(x)
        x = self.relu(x)
        x = self.snconv2d2(x)
        if downsample:
            x = self.downsample(x)

        if downsample or self.ch_mismatch:
            x0 = self.snconv2d0(x0)
            if downsample:
                x0 = self.downsample(x0)

        out = x + x0
        return out


class Discriminator(nn.Module):
    """Discriminator."""

    def __init__(self, d_conv_dim, num_classes):
        super(Discriminator, self).__init__()
        self.d_conv_dim = d_conv_dim
        self.opt_block1 = DiscOptBlock(3, d_conv_dim)
        self.block1 = DiscBlock(d_conv_dim, d_conv_dim*2)
        self.self_attn = Self_Attn(d_conv_dim*2)
        self.block2 = DiscBlock(d_conv_dim*2, d_conv_dim*4)
        self.block3 = DiscBlock(d_conv_dim*4, d_conv_dim*8)
        self.block4 = DiscBlock(d_conv_dim*8, d_conv_dim*16)
        self.block5 = DiscBlock(d_conv_dim*16, d_conv_dim*16)
        self.relu = nn.ReLU(inplace=True)
        self.snlinear1 = snlinear(in_features=d_conv_dim*16, out_features=1)
        self.sn_embedding1 = sn_embedding(num_classes, d_conv_dim*16)

        # Weight init
        self.apply(init_weights)
        xavier_uniform_(self.sn_embedding1.weight)

    def forward(self, x, labels):
        # n x 3 x 128 x 128
        h0 = self.opt_block1(x) # n x d_conv_dim   x 64 x 64
        h1 = self.block1(h0)    # n x d_conv_dim*2 x 32 x 32
        h1 = self.self_attn(h1) # n x d_conv_dim*2 x 32 x 32
        h2 = self.block2(h1)    # n x d_conv_dim*4 x 16 x 16
        h3 = self.block3(h2)    # n x d_conv_dim*8 x  8 x  8
        h4 = self.block4(h3)    # n x d_conv_dim*16 x 4 x  4
        h5 = self.block5(h4, downsample=False)  # n x d_conv_dim*16 x 4 x 4
        h5 = self.relu(h5)              # n x d_conv_dim*16 x 4 x 4
        h6 = torch.sum(h5, dim=[2,3])   # n x d_conv_dim*16
        output1 = torch.squeeze(self.snlinear1(h6)) # n
        # Projection
        h_labels = self.sn_embedding1(labels)   # n x d_conv_dim*16
        proj = torch.mul(h6, h_labels)          # n x d_conv_dim*16
        output2 = torch.sum(proj, dim=[1])      # n
        # Out
        output = output1 + output2              # n
        return output


================================================
FILE: test.py
================================================
import sys

import utils

from parameters import *
from sagan_models import Generator, Discriminator


if __name__ == '__main__':
    config = get_parameters()
    config.command = 'python ' + ' '.join(sys.argv)
    print(config)
    utils.check_for_CUDA(config)

    # Load pretrained model (if provided)
    if config.pretrained_model != '':
        utils.load_pretrained_model(config)
    else:
        assert config.num_of_classes, "Please provide number of classes! Eg. python3 test.py --num_of_classes 10"
        config.G = Generator(config.z_dim, config.g_conv_dim, config.num_of_classes).to(config.device)
        config.D = Discriminator(config.d_conv_dim, config.num_of_classes).to(config.device)

    config.G.eval()
    config.D.eval()
    print(config.G, config.D)


================================================
FILE: train.py
================================================
import sys

import utils

from parameters import *
from trainer import Trainer


if __name__ == '__main__':
    config = get_parameters()
    config.command = 'python ' + ' '.join(sys.argv)
    print(config)
    trainer = Trainer(config)
    trainer.train()
    utils.save_ckpt(trainer, final=True)


================================================
FILE: trainer.py
================================================
import datetime
import numpy as np
import os
import random
import sys
import time
import torch
import torch.nn as nn
import torchvision.utils as vutils

from torch.backends import cudnn

import utils
from sagan_models import Generator, Discriminator


class Trainer(object):

    def __init__(self, config):

        # Config
        self.config = config

        self.start = 0 # Unless using pre-trained model

        # Create directories if not exist
        utils.make_folder(self.config.save_path)
        utils.make_folder(self.config.model_weights_path)
        utils.make_folder(self.config.sample_images_path)

        # Copy files
        utils.write_config_to_file(self.config, self.config.save_path)
        utils.copy_scripts(self.config.save_path)

        # Check for CUDA
        utils.check_for_CUDA(self)

        # Make dataloader
        self.dataloader, self.num_of_classes = utils.make_dataloader(batch_size=self.config.batch_size_in_gpu,
                                                                     dataset_type=self.config.dataset,
                                                                     data_path=self.config.data_path,
                                                                     shuffle=self.config.shuffle,
                                                                     drop_last=self.config.drop_last,
                                                                     dataloader_args=self.config.dataloader_args,
                                                                     resize=self.config.resize,
                                                                     imsize=self.config.imsize,
                                                                     centercrop=self.config.centercrop,
                                                                     centercrop_size=self.config.centercrop_size,
                                                                     normalize=self.config.normalize,
                                                                     )

        # Data iterator
        self.data_iter = iter(self.dataloader)

        # Build G and D
        self.build_models()

        if self.config.adv_loss == 'dcgan':
            self.criterion = nn.BCELoss()

    def train(self):

        # Seed
        np.random.seed(self.config.manual_seed)
        random.seed(self.config.manual_seed)
        torch.manual_seed(self.config.manual_seed)

        # For fast training
        cudnn.benchmark = True

        # For BatchNorm
        self.G.train()
        self.D.train()

        # Fixed noise for sampling from G
        fixed_noise = torch.randn(self.config.batch_size_in_gpu, self.config.z_dim, device=self.device)
        if self.num_of_classes < self.config.batch_size_in_gpu:
            fixed_labels = torch.from_numpy(np.tile(np.arange(self.num_of_classes), self.config.batch_size_in_gpu//self.num_of_classes + 1)[:self.config.batch_size_in_gpu]).to(self.device)
        else:
            fixed_labels = torch.from_numpy(np.arange(self.config.batch_size_in_gpu)).to(self.device)

        # For gan loss
        label = torch.full((self.config.batch_size_in_gpu,), 1, device=self.device)
        ones = torch.full((self.config.batch_size_in_gpu,), 1, device=self.device)

        # Losses file
        log_file_name = os.path.join(self.config.save_path, 'log.txt')
        log_file = open(log_file_name, "wt")

        # Init
        start_time = time.time()
        G_losses = []
        D_losses_real = []
        D_losses_fake = []
        D_losses = []
        D_xs = []
        D_Gz_trainDs = []
        D_Gz_trainGs = []

        # Instance noise - make random noise mean (0) and std for injecting
        inst_noise_mean = torch.full((self.config.batch_size_in_gpu, 3, self.config.imsize, self.config.imsize), 0, device=self.device)
        inst_noise_std = torch.full((self.config.batch_size_in_gpu, 3, self.config.imsize, self.config.imsize), self.config.inst_noise_sigma, device=self.device)

        self.gpu_batches = self.config.batch_size//self.config.batch_size_in_gpu

        # Start training
        for self.step in range(self.start, self.config.total_step):

            # Instance noise std is linearly annealed from self.inst_noise_sigma to 0 thru self.inst_noise_sigma_iters
            inst_noise_sigma_curr = 0 if self.step > self.config.inst_noise_sigma_iters else (1 - self.step/self.config.inst_noise_sigma_iters)*self.config.inst_noise_sigma
            inst_noise_std.fill_(inst_noise_sigma_curr)

            # ================== TRAIN D ================== #

            for _ in range(self.config.d_steps_per_iter):

                # Zero grad
                self.reset_grad()

                # Accumulate losses for full batch_size
                # while running GPU computations on only batch_size_in_gpu
                for gpu_batch in range(self.gpu_batches):

                    # TRAIN with REAL

                    # Get real images & real labels
                    real_images, real_labels = self.get_real_samples()

                    # Get D output for real images & real labels
                    inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device)
                    d_out_real = self.D(real_images + inst_noise, real_labels)

                    # Compute D loss with real images & real labels
                    if self.config.adv_loss == 'hinge':
                        d_loss_real = torch.nn.ReLU()(ones - d_out_real).mean()
                    elif self.config.adv_loss == 'wgan_gp':
                        d_loss_real = -d_out_real.mean()
                    else:
                        label.fill_(1)
                        d_loss_real = self.criterion(d_out_real, label)

                    # Backward
                    d_loss_real /= self.gpu_batches
                    d_loss_real.backward()

                    # Delete loss, output
                    if self.step % self.config.log_step != 0 or gpu_batch < self.gpu_batches - 1:
                        del d_out_real, d_loss_real

                    # TRAIN with FAKE

                    # Create random noise
                    z = torch.randn(self.config.batch_size_in_gpu, self.config.z_dim, device=self.device)

                    # Generate fake images for same real labels
                    fake_images = self.G(z, real_labels)

                    # Get D output for fake images & same real labels
                    inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device)
                    d_out_fake = self.D(fake_images.detach() + inst_noise, real_labels)

                    # Compute D loss with fake images & real labels
                    if self.config.adv_loss == 'hinge':
                        d_loss_fake = torch.nn.ReLU()(ones + d_out_fake).mean()
                    elif self.config.adv_loss == 'dcgan':
                        label.fill_(0)
                        d_loss_fake = self.criterion(d_out_fake, label)
                    else:
                        d_loss_fake = d_out_fake.mean()

                    # If WGAN_GP, compute GP and add to D loss
                    if self.config.adv_loss == 'wgan_gp':
                        d_loss_gp = self.config.lambda_gp * self.compute_gradient_penalty(real_images, real_labels, fake_images.detach())
                        d_loss_fake += d_loss_gp

                    # Backward
                    d_loss_fake /= self.gpu_batches
                    d_loss_fake.backward()

                    # Delete loss, output
                    del fake_images
                    if self.step % self.config.log_step != 0 or gpu_batch < self.gpu_batches - 1:
                        del d_out_fake, d_loss_fake

                # Optimize
                self.D_optimizer.step()

            # ================== TRAIN G ================== #

            for _ in range(self.config.g_steps_per_iter):

                # Zero grad
                self.reset_grad()

                # Accumulate losses for full batch_size
                # while running GPU computations on only batch_size_in_gpu
                for gpu_batch in range(self.gpu_batches):

                    # Get real images & real labels (only need real labels)
                    real_images, real_labels = self.get_real_samples()

                    # Create random noise
                    z = torch.randn(self.config.batch_size_in_gpu, self.config.z_dim).to(self.device)

                    # Generate fake images for same real labels
                    fake_images = self.G(z, real_labels)

                    # Get D output for fake images & same real labels
                    inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device)
                    g_out_fake = self.D(fake_images + inst_noise, real_labels)

                    # Compute G loss with fake images & real labels
                    if self.config.adv_loss == 'dcgan':
                        label.fill_(1)
                        g_loss = self.criterion(g_out_fake, label)
                    else:
                        g_loss = -g_out_fake.mean()

                    # Backward
                    g_loss /= self.gpu_batches
                    g_loss.backward()

                    # Delete loss, output
                    del fake_images
                    if self.step % self.config.log_step != 0 or gpu_batch < self.gpu_batches - 1:
                        del g_out_fake, g_loss

                # Optimize
                self.G_optimizer.step()

            # Print out log info
            if self.step % self.config.log_step == 0:
                G_losses.append(g_loss.mean().item())
                D_losses_real.append(d_loss_real.mean().item())
                D_losses_fake.append(d_loss_fake.mean().item())
                D_loss = D_losses_real[-1] + D_losses_fake[-1]
                if self.config.adv_loss == 'wgan_gp':
                    D_loss += d_loss_gp.mean().item()
                D_losses.append(D_loss)
                D_xs.append(d_out_real.mean().item())
                D_Gz_trainDs.append(d_out_fake.mean().item())
                D_Gz_trainGs.append(g_out_fake.mean().item())
                curr_time = time.time()
                curr_time_str = datetime.datetime.fromtimestamp(curr_time).strftime('%Y-%m-%d %H:%M:%S')
                elapsed = str(datetime.timedelta(seconds=(curr_time - start_time)))
                log = ("[{}] : Elapsed [{}], Iter [{} / {}], G_loss: {:.4f}, D_loss: {:.4f}, D_loss_real: {:.4f}, D_loss_fake: {:.4f}, D(x): {:.4f}, D(G(z))_trainD: {:.4f}, D(G(z))_trainG: {:.4f}\n".
                       format(curr_time_str, elapsed, self.step, self.config.total_step,
                              G_losses[-1], D_losses[-1], D_losses_real[-1], D_losses_fake[-1],
                              D_xs[-1], D_Gz_trainDs[-1], D_Gz_trainGs[-1]))
                print('\n' + log)
                log_file.write(log)
                log_file.flush()
                utils.make_plots(G_losses, D_losses, D_losses_real, D_losses_fake, D_xs, D_Gz_trainDs, D_Gz_trainGs,
                                 self.config.log_step, self.config.save_path)

                # Delete loss, output
                del d_out_real, d_loss_real, d_out_fake, d_loss_fake, g_out_fake, g_loss

            # Sample images
            if self.step % self.config.sample_step == 0:
                print("Saving image samples..")
                self.G.eval()
                fake_images = self.G(fixed_noise, fixed_labels)
                self.G.train()
                sample_images = utils.denorm(fake_images.detach()[:self.config.save_n_images])
                # Save batch images
                vutils.save_image(sample_images, os.path.join(self.config.sample_images_path, 'fake_{:05d}.png'.format(self.step)), nrow=self.config.nrow)
                # Save gif
                utils.make_gif(sample_images[0].cpu().numpy().transpose(1, 2, 0)*255, self.step,
                               self.config.sample_images_path, self.config.name, max_frames_per_gif=self.config.max_frames_per_gif)
                # Delete output
                del fake_images

            # Save model
            if self.step % self.config.model_save_step == 0:
                utils.save_ckpt(self)

    def build_models(self):
        self.G = Generator(self.config.z_dim, self.config.g_conv_dim, self.num_of_classes).to(self.device)
        self.D = Discriminator(self.config.d_conv_dim, self.num_of_classes).to(self.device)

        # Loss and optimizer
        # self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.config.g_lr, [self.config.beta1, self.config.beta2])
        self.D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.config.d_lr, [self.config.beta1, self.config.beta2])

        # Start with pretrained model (if it exists)
        if self.config.pretrained_model != '':
            utils.load_pretrained_model(self)

        if 'cuda' in self.device.type and self.config.parallel and torch.cuda.device_count() > 1:
            self.G = nn.DataParallel(self.G)
            self.D = nn.DataParallel(self.D)

        # print networks
        print(self.G)
        print(self.D)

    def reset_grad(self):
        self.G_optimizer.zero_grad()
        self.D_optimizer.zero_grad()

    def get_real_samples(self):
        try:
            real_images, real_labels = next(self.data_iter)
        except:
            self.data_iter = iter(self.dataloader)
            real_images, real_labels = next(self.data_iter)

        real_images, real_labels = real_images.to(self.device), real_labels.to(self.device)
        return real_images, real_labels

    def compute_gradient_penalty(self, real_images, real_labels, fake_images):
        # Compute gradient penalty
        alpha = torch.rand(real_images.size(0), 1, 1, 1).expand_as(real_images).to(device)
        interpolated = torch.tensor(alpha * real_images + (1 - alpha) * fake_images, requires_grad=True)
        out = self.D(interpolated, real_labels)
        exp_grad = torch.ones(out.size()).to(device)
        grad = torch.autograd.grad(outputs=out,
                                   inputs=interpolated,
                                   grad_outputs=exp_grad,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]
        grad = grad.view(grad.size(0), -1)
        grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
        d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)
        return d_loss_gp


================================================
FILE: utils.py
================================================
import cv2
import glob
import imageio
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import os
import shutil
import torch
import torchvision.datasets as dset

from torchvision import transforms


def make_folder(path):
    if not os.path.exists(path):
        os.makedirs(path)


def denorm(x):
    out = (x + 1) / 2
    return out.clamp_(0, 1)


def write_config_to_file(config, save_path):
    with open(os.path.join(save_path, 'config.txt'), 'w') as file:
        for arg in vars(config):
            file.write(str(arg) + ': ' + str(getattr(config, arg)) + '\n')


def copy_scripts(dst):
    for file in glob.glob('*.py'):
        shutil.copy(file, dst)

    for d in glob.glob('*/'):
        if '__' not in d and d[0] != '.':
            shutil.copytree(d, os.path.join(dst, d))


def make_transform(resize=True, imsize=128, centercrop=False, centercrop_size=128,
                   totensor=True, normalize=True, norm_mean=(0.5, 0.5, 0.5), norm_std=(0.5, 0.5, 0.5)):
        options = []
        if resize:
            options.append(transforms.Resize((imsize)))
        if centercrop:
            options.append(transforms.CenterCrop(centercrop_size))
        if totensor:
            options.append(transforms.ToTensor())
        if normalize:
            options.append(transforms.Normalize(norm_mean, norm_std))
        transform = transforms.Compose(options)
        return transform


def make_dataloader(batch_size, dataset_type, data_path, shuffle=True, drop_last=True, dataloader_args={},
                    resize=True, imsize=128, centercrop=False, centercrop_size=128, totensor=True,
                    normalize=True, norm_mean=(0.5, 0.5, 0.5), norm_std=(0.5, 0.5, 0.5)):
    # Make transform
    transform = make_transform(resize=resize, imsize=imsize,
                               centercrop=centercrop, centercrop_size=centercrop_size,
                               totensor=totensor, normalize=normalize, norm_mean=norm_mean, norm_std=norm_std)
    # Make dataset
    if dataset_type in ['folder', 'imagenet', 'lfw']:
        # folder dataset
        assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
        dataset = dset.ImageFolder(root=data_path, transform=transform)
    elif dataset_type == 'lsun':
        assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
        dataset = dset.LSUN(root=data_path, classes=['bedroom_train'], transform=transform)
    elif dataset_type == 'cifar10':
        if not os.path.exists(data_path):
            print("data_path does not exist! Given: {}\nDownloading CIFAR10 dataset...".format(data_path))
        dataset = dset.CIFAR10(root=data_path, download=True, transform=transform)
    elif dataset_type == 'fake':
        dataset = dset.FakeData(image_size=(3, centercrop_size, centercrop_size), transform=transforms.ToTensor())
    assert dataset
    num_of_classes = len(dataset.classes)
    print("Data found!  # of images =", len(dataset), ", # of classes =", num_of_classes, ", classes:", dataset.classes)
    # Make dataloader from dataset
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, **dataloader_args)
    return dataloader, num_of_classes


def make_gif(image, iteration_number, save_path, model_name, max_frames_per_gif=100):

    # Make gif
    gif_frames = []

    # Read old gif frames
    try:
        gif_frames_reader = imageio.get_reader(os.path.join(save_path, model_name + ".gif"))
        for frame in gif_frames_reader:
            gif_frames.append(frame[:, :, :3])
    except:
        pass

    # Append new frame
    im = cv2.putText(np.concatenate((np.zeros((32, image.shape[1], image.shape[2])), image), axis=0),
                     'iter %s' % str(iteration_number), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1, cv2.LINE_AA).astype('uint8')
    gif_frames.append(im)

    # If frames exceeds, save as different file
    if len(gif_frames) > max_frames_per_gif:
        print("Splitting the GIF...")
        gif_frames_00 = gif_frames[:max_frames_per_gif]
        num_of_gifs_already_saved = len(glob.glob(os.path.join(save_path, model_name + "_*.gif")))
        print("Saving", os.path.join(save_path, model_name + "_%05d.gif" % (num_of_gifs_already_saved)))
        imageio.mimsave(os.path.join(save_path, model_name + "_%05d.gif" % (num_of_gifs_already_saved)), gif_frames_00)
        gif_frames = gif_frames[max_frames_per_gif:]

    # Save gif
    # print("Saving", os.path.join(save_path, model_name + ".gif"))
    imageio.mimsave(os.path.join(save_path, model_name + ".gif"), gif_frames)


def make_plots(G_losses, D_losses, D_losses_real, D_losses_fake, D_xs, D_Gz_trainDs, D_Gz_trainGs, log_step, save_path, init_epoch=0):
    iters = np.arange(len(D_losses))*log_step + init_epoch
    fig = plt.figure(figsize=(20, 20))
    plt.subplot(311)
    plt.plot(iters, np.zeros(iters.shape), 'k--', alpha=0.5)
    plt.plot(iters, G_losses, color='C0', label='G')
    plt.legend()
    plt.title("Generator loss")
    plt.xlabel("Iterations")
    plt.subplot(312)
    plt.plot(iters, np.zeros(iters.shape), 'k--', alpha=0.5)
    plt.plot(iters, D_losses_real, color='C1', alpha=0.7, label='D_real')
    plt.plot(iters, D_losses_fake, color='C2', alpha=0.7, label='D_fake')
    plt.plot(iters, D_losses, color='C0', alpha=0.7, label='D')
    plt.legend()
    plt.title("Discriminator loss")
    plt.xlabel("Iterations")
    plt.subplot(313)
    plt.plot(iters, np.zeros(iters.shape), 'k--', alpha=0.5)
    plt.plot(iters, np.ones(iters.shape), 'k--', alpha=0.5)
    plt.plot(iters, D_xs, alpha=0.7, label='D(x)')
    plt.plot(iters, D_Gz_trainDs, alpha=0.7, label='D(G(z))_trainD')
    plt.plot(iters, D_Gz_trainGs, alpha=0.7, label='D(G(z))_trainG')
    plt.legend()
    plt.title("D(x), D(G(z))")
    plt.xlabel("Iterations")
    plt.savefig(os.path.join(save_path, "plots.png"))
    plt.clf()
    plt.close()


def save_ckpt(sagan_obj, model=False, final=False):
    print("Saving ckpt...")

    if final:
        # Save final - both model and state_dict
        torch.save({
                    'step': sagan_obj.step,
                    'G_state_dict': sagan_obj.G.module.state_dict() if hasattr(sagan_obj.G, "module") else sagan_obj.G.state_dict(),    # "module" in case DataParallel is used
                    'G_optimizer_state_dict': sagan_obj.G_optimizer.state_dict(),
                    'D_state_dict': sagan_obj.D.module.state_dict() if hasattr(sagan_obj.D, "module") else sagan_obj.D.state_dict(),    # "module" in case DataParallel is used,
                    'D_optimizer_state_dict': sagan_obj.D_optimizer.state_dict(),
                    }, os.path.join(sagan_obj.config.model_weights_path, '{}_final_state_dict_ckpt_{:07d}.pth'.format(sagan_obj.config.name, sagan_obj.step)))
        torch.save({
                    'step': sagan_obj.step,
                    'G': sagan_obj.G.module if hasattr(sagan_obj.G, "module") else sagan_obj.G,
                    'G_optimizer': sagan_obj.G_optimizer,
                    'D': sagan_obj.D.module if hasattr(sagan_obj.D, "module") else sagan_obj.D,
                    'D_optimizer': sagan_obj.D_optimizer,
                    }, os.path.join(sagan_obj.config.model_weights_path, '{}_final_model_ckpt_{:07d}.pth'.format(sagan_obj.config.name, sagan_obj.step)))

    elif model:
        # Save full model (not state_dict)
        torch.save({
                    'step': sagan_obj.step,
                    'G': sagan_obj.G.module if hasattr(sagan_obj.G, "module") else sagan_obj.G,     # "module" in case DataParallel is used
                    'G_optimizer': sagan_obj.G_optimizer,
                    'D': sagan_obj.D.module if hasattr(sagan_obj.D, "module") else sagan_obj.D,     # "module" in case DataParallel is used
                    'D_optimizer': sagan_obj.D_optimizer,
                    }, os.path.join(sagan_obj.config.model_weights_path, '{}_model_ckpt_{:07d}.pth'.format(sagan_obj.config.name, sagan_obj.step)))

    else:
        # Save state_dict
        torch.save({
                    'step': sagan_obj.step,
                    'G_state_dict': sagan_obj.G.module.state_dict() if hasattr(sagan_obj.G, "module") else sagan_obj.G.state_dict(),
                    'G_optimizer_state_dict': sagan_obj.G_optimizer.state_dict(),
                    'D_state_dict': sagan_obj.D.module.state_dict() if hasattr(sagan_obj.D, "module") else sagan_obj.D.state_dict(),
                    'D_optimizer_state_dict': sagan_obj.D_optimizer.state_dict(),
                    }, os.path.join(sagan_obj.config.model_weights_path, 'ckpt_{:07d}.pth'.format(sagan_obj.step)))


def load_pretrained_model(sagan_obj):
    print("Loading pretrained_model", sagan_obj.config.pretrained_model, "...")
    # Check for path
    assert os.path.exists(sagan_obj.config.pretrained_model), "Path of .pth pretrained_model doesn't exist! Given: " + sagan_obj.config.pretrained_model
    checkpoint = torch.load(sagan_obj.config.pretrained_model)
    # If we know it is a state_dict (instead of complete model)
    if sagan_obj.config.state_dict_or_model == 'state_dict':
        sagan_obj.start = checkpoint['step'] + 1
        sagan_obj.G.load_state_dict(checkpoint['G_state_dict'])
        sagan_obj.G_optimizer.load_state_dict(checkpoint['G_optimizer_state_dict'])
        sagan_obj.D.load_state_dict(checkpoint['D_state_dict'])
        sagan_obj.D_optimizer.load_state_dict(checkpoint['D_optimizer_state_dict'])
    # Else, if we know it is a complete model (and not just state_dict)
    elif sagan_obj.config.state_dict_or_model == 'model':
        sagan_obj.start = checkpoint['step'] + 1
        sagan_obj.G = torch.load(checkpoint['G']).to(sagan_obj.device)
        sagan_obj.G_optimizer = torch.load(checkpoint['G_optimizer'])
        sagan_obj.D = torch.load(checkpoint['D']).to(sagan_obj.device)
        sagan_obj.D_optimizer = torch.load(checkpoint['D_optimizer'])
    # Else try for complete model, then try for state_dict
    else:
        try:
            sagan_obj.start = checkpoint['step'] + 1
            sagan_obj.G.load_state_dict(checkpoint['G_state_dict'])
            sagan_obj.G_optimizer.load_state_dict(checkpoint['G_optimizer_state_dict'])
            sagan_obj.D.load_state_dict(checkpoint['D_state_dict'])
            sagan_obj.D_optimizer.load_state_dict(checkpoint['D_optimizer_state_dict'])
        except:
            sagan_obj.start = checkpoint['step'] + 1
            sagan_obj.G = torch.load(checkpoint['G']).to(sagan_obj.device)
            sagan_obj.G_optimizer = torch.load(checkpoint['G_optimizer'])
            sagan_obj.D = torch.load(checkpoint['D']).to(sagan_obj.device)
            sagan_obj.D_optimizer = torch.load(checkpoint['D_optimizer'])


def check_for_CUDA(sagan_obj):
    if not sagan_obj.config.disable_cuda and torch.cuda.is_available():
        print("CUDA is available!")
        sagan_obj.device = torch.device('cuda')
        sagan_obj.config.dataloader_args['pin_memory'] = True
    else:
        print("Cuda is NOT available, running on CPU.")
        sagan_obj.device = torch.device('cpu')

    if torch.cuda.is_available() and sagan_obj.config.disable_cuda:
        print("WARNING: You have a CUDA device, so you should probably run without --disable_cuda")
Download .txt
gitextract_cs9fl47f/

├── .gitignore
├── LICENSE
├── README.md
├── parameters.py
├── requirements.txt
├── sagan_models.py
├── test.py
├── train.py
├── trainer.py
└── utils.py
Download .txt
SYMBOL INDEX (44 symbols across 4 files)

FILE: parameters.py
  function get_parameters (line 6) | def get_parameters():

FILE: sagan_models.py
  function init_weights (line 10) | def init_weights(m):
  function snconv2d (line 16) | def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0...
  function snlinear (line 21) | def snlinear(in_features, out_features):
  function sn_embedding (line 25) | def sn_embedding(num_embeddings, embedding_dim):
  class Self_Attn (line 29) | class Self_Attn(nn.Module):
    method __init__ (line 32) | def __init__(self, in_channels):
    method forward (line 43) | def forward(self, x):
  class ConditionalBatchNorm2d (line 75) | class ConditionalBatchNorm2d(nn.Module):
    method __init__ (line 77) | def __init__(self, num_features, num_classes):
    method forward (line 86) | def forward(self, x, y):
  class GenBlock (line 93) | class GenBlock(nn.Module):
    method __init__ (line 94) | def __init__(self, in_channels, out_channels, num_classes):
    method forward (line 103) | def forward(self, x, labels):
  class Generator (line 121) | class Generator(nn.Module):
    method __init__ (line 124) | def __init__(self, z_dim, g_conv_dim, num_classes):
    method forward (line 144) | def forward(self, z, labels):
  class DiscOptBlock (line 161) | class DiscOptBlock(nn.Module):
    method __init__ (line 162) | def __init__(self, in_channels, out_channels):
    method forward (line 170) | def forward(self, x):
  class DiscBlock (line 185) | class DiscBlock(nn.Module):
    method __init__ (line 186) | def __init__(self, in_channels, out_channels):
    method forward (line 197) | def forward(self, x, downsample=True):
  class Discriminator (line 216) | class Discriminator(nn.Module):
    method __init__ (line 219) | def __init__(self, d_conv_dim, num_classes):
    method forward (line 237) | def forward(self, x, labels):

FILE: trainer.py
  class Trainer (line 17) | class Trainer(object):
    method __init__ (line 19) | def __init__(self, config):
    method train (line 61) | def train(self):
    method build_models (line 278) | def build_models(self):
    method reset_grad (line 299) | def reset_grad(self):
    method get_real_samples (line 303) | def get_real_samples(self):
    method compute_gradient_penalty (line 313) | def compute_gradient_penalty(self, real_images, real_labels, fake_imag...

FILE: utils.py
  function make_folder (line 16) | def make_folder(path):
  function denorm (line 21) | def denorm(x):
  function write_config_to_file (line 26) | def write_config_to_file(config, save_path):
  function copy_scripts (line 32) | def copy_scripts(dst):
  function make_transform (line 41) | def make_transform(resize=True, imsize=128, centercrop=False, centercrop...
  function make_dataloader (line 56) | def make_dataloader(batch_size, dataset_type, data_path, shuffle=True, d...
  function make_gif (line 85) | def make_gif(image, iteration_number, save_path, model_name, max_frames_...
  function make_plots (line 117) | def make_plots(G_losses, D_losses, D_losses_real, D_losses_fake, D_xs, D...
  function save_ckpt (line 148) | def save_ckpt(sagan_obj, model=False, final=False):
  function load_pretrained_model (line 189) | def load_pretrained_model(sagan_obj):
  function check_for_CUDA (line 224) | def check_for_CUDA(sagan_obj):
Condensed preview — 10 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (50K chars).
[
  {
    "path": ".gitignore",
    "chars": 1203,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 1070,
    "preview": "MIT License\n\nCopyright (c) 2019 Vikram Voleti\n\nPermission is hereby granted, free of charge, to any person obtaining a c"
  },
  {
    "path": "README.md",
    "chars": 1967,
    "preview": "# self-attention-GAN-pytorch\n\nThis is an almost exact replica in PyTorch of the Tensorflow version of [SAGAN](https://ar"
  },
  {
    "path": "parameters.py",
    "chars": 5870,
    "preview": "import argparse\nimport datetime\nimport os\n\n\ndef get_parameters():\n\n    parser = argparse.ArgumentParser()\n\n    # Images "
  },
  {
    "path": "requirements.txt",
    "chars": 103,
    "preview": "matplotlib==3.0.0\ntorchvision==0.2.1\ntorch==2.2.0\nopencv_python==4.2.0.32\nimageio==2.4.1\nnumpy==1.22.0\n"
  },
  {
    "path": "sagan_models.py",
    "chars": 10566,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torch.nn.utils import spectr"
  },
  {
    "path": "test.py",
    "chars": 779,
    "preview": "import sys\n\nimport utils\n\nfrom parameters import *\nfrom sagan_models import Generator, Discriminator\n\n\nif __name__ == '_"
  },
  {
    "path": "train.py",
    "chars": 299,
    "preview": "import sys\n\nimport utils\n\nfrom parameters import *\nfrom trainer import Trainer\n\n\nif __name__ == '__main__':\n    config ="
  },
  {
    "path": "trainer.py",
    "chars": 14882,
    "preview": "import datetime\nimport numpy as np\nimport os\nimport random\nimport sys\nimport time\nimport torch\nimport torch.nn as nn\nimp"
  },
  {
    "path": "utils.py",
    "chars": 11368,
    "preview": "import cv2\nimport glob\nimport imageio\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\nimport num"
  }
]

About this extraction

This page contains the full source code of the voletiv/self-attention-GAN-pytorch GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 10 files (47.0 KB), approximately 12.0k tokens, and a symbol index with 44 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!