Repository: voletiv/self-attention-GAN-pytorch
Branch: master
Commit: c44d3a02996a
Files: 10
Total size: 47.0 KB
Directory structure:
gitextract_cs9fl47f/
├── .gitignore
├── LICENSE
├── README.md
├── parameters.py
├── requirements.txt
├── sagan_models.py
├── test.py
├── train.py
├── trainer.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2019 Vikram Voleti
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# self-attention-GAN-pytorch
This is an almost exact replica in PyTorch of the Tensorflow version of [SAGAN](https://arxiv.org/abs/1805.08318) released by Google Brain [[repo](https://github.com/brain-research/self-attention-gan)] in August 2018.
Code structure is inspired from [this repo](https://github.com/heykeetae/Self-Attention-GAN), but follows the details of [Google Brain's repo](https://github.com/brain-research/self-attention-gan).
## Prerequisites
Check `requirements.txt`.
* [Python 3.5+](https://www.continuum.io/downloads)
* [PyTorch 0.4.1](http://pytorch.org/)
## Training
#### 1. Check `parameters.py` for all arguments and their default values
#### 2. Train on custom images in folder a/b/c:
```bash
$ python train.py --data_path 'a/b/c' --save_path 'o/p/q' --batch_size 64 --name sagan
```
(Warning: Works only on *128x128* images, input images are resized to that. Tweak the Generator & Discriminator first if you would like to use some other image size. And then use the `imsize` option:
```bash
$ python train.py --data_path 'a/b/c' --save_path 'o/p/q' --batch_size 64 --imsize 32 --name sagan
```
)
Model training will be recorded in a new folder inside `--save_path` with the name `<timestamp>_<name>_<basename of data_path>`.
By default, model weights are saved in a subfolder called `weights`, and train & validation samples during training in a subfolder called `samples` (can be changed in `parameters.py`).
## Testing/Evaluating
Check `test.py`.
## Self-Attention GAN
**[Han Zhang, Ian Goodfellow, Dimitris Metaxas and Augustus Odena, "Self-Attention Generative Adversarial Networks." arXiv preprint arXiv:1805.08318 (2018)](https://arxiv.org/abs/1805.08318).**
```
@article{Zhang2018SelfAttentionGA,
title={Self-Attention Generative Adversarial Networks},
author={Han Zhang and Ian J. Goodfellow and Dimitris N. Metaxas and Augustus Odena},
journal={CoRR},
year={2018},
volume={abs/1805.08318}
}
```
================================================
FILE: parameters.py
================================================
import argparse
import datetime
import os
def get_parameters():
parser = argparse.ArgumentParser()
# Images data path & Output path
parser.add_argument('--dataset', type=str, default='folder', choices=["cifar10", "fake", "folder", "hdf5", "imagenet", "lfw", "lsun"],
help="cifar10 | fake | folder | hdf5 | imagenet | lfw | lsun")
parser.add_argument('--data_path', type=str, default='', help='Path to root of image data (saved in dirs of classes)')
parser.add_argument('--save_path', type=str, default='./sagan_models')
# Training settings
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--batch_size_in_gpu', type=int, default=0,
help='0 => same as batch_size, else: if using multiple gpu iterations to make an effective batch, e.g. batch_size=32, batch_size_in_gpu=16 => optimizer.step() is run 2 iterations after running loss.backward()')
parser.add_argument('--total_step', type=int, default=200000, help='how many iterations')
parser.add_argument('--d_steps_per_iter', type=int, default=1, help='how many D updates per iteration')
parser.add_argument('--g_steps_per_iter', type=int, default=1, help='how many G updates per iteration')
parser.add_argument('--d_lr', type=float, default=0.0004)
parser.add_argument('--g_lr', type=float, default=0.0001)
parser.add_argument('--beta1', type=float, default=0.0)
parser.add_argument('--beta2', type=float, default=0.999)
# Model hyper-parameters
parser.add_argument('--adv_loss', type=str, default='hinge', choices=['hinge', 'dcgan', 'wgan_gp', 'gan'])
parser.add_argument('--z_dim', type=int, default=128)
parser.add_argument('--g_conv_dim', type=int, default=64)
parser.add_argument('--d_conv_dim', type=int, default=64)
parser.add_argument('--lambda_gp', type=float, default=10)
# Instance noise
# https://github.com/soumith/ganhacks/issues/14#issuecomment-312509518
# https://www.inference.vc/instance-noise-a-trick-for-stabilising-gan-training/
parser.add_argument('--inst_noise_sigma', type=float, default=0.0)
parser.add_argument('--inst_noise_sigma_iters', type=int, default=2000)
# Image transforms
parser.add_argument('--dont_shuffle', action='store_true')
parser.add_argument('--dont_drop_last', action='store_true', help="Whether not to drop the last batch in dataset if its size < batch_size")
parser.add_argument('--dont_resize', action='store_true', help="Whether not to resize images")
parser.add_argument('--imsize', type=int, default=128)
parser.add_argument('--centercrop', action='store_true', help="Whether to center crop images")
parser.add_argument('--centercrop_size', type=int, default=128)
parser.add_argument('--dont_normalize', action='store_true', help="Whether to normalize image values")
# Step sizes
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=10)
parser.add_argument('--model_save_step', type=float, default=50)
parser.add_argument('--save_n_images', type=int, default=0,
help='0 => same as batch_size_in_gpu')
parser.add_argument('--nrow', type=int, default=10)
parser.add_argument('--max_frames_per_gif', type=int, default=100)
# Pretrained model
parser.add_argument('--pretrained_model', type=str, default='')
parser.add_argument('--state_dict_or_model', type=str, default='', help="Specify whether .pth pretrained_model is a 'state_dict' or a complete 'model'")
# Misc
parser.add_argument('--manual_seed', type=int, default=29)
parser.add_argument('--disable_cuda', action='store_true', help='Disable CUDA')
parser.add_argument('--parallel', action='store_true', help="Run on multiple GPUs")
parser.add_argument('--num_workers', type=int, default=4)
# parser.add_argument('--use_tensorboard', action='store_true')
# Output paths
parser.add_argument('--model_weights_dir', type=str, default='weights')
parser.add_argument('--sample_images_dir', type=str, default='samples')
# Model name
parser.add_argument('--name', type=str, default='sagan')
args = parser.parse_args()
if args.batch_size_in_gpu == 0:
args.batch_size_in_gpu = args.batch_size
assert args.batch_size_in_gpu <= args.batch_size, "ERROR: please make sure batch_size >= batch_size_in_gpu!! Given batch_size: " + str(args.batch_size) + " ; batch_size_in_gpu: " + str(args.batch_size_in_gpu)
assert args.batch_size % args.batch_size_in_gpu == 0, "ERROR: please make sure batch_size_in_gpu divides batch_size!! Given batch_size: " + str(args.batch_size) + " ; batch_size_in_gpu: " + str(args.batch_size_in_gpu)
args.batch_size_effective = args.batch_size_in_gpu*(args.batch_size//args.batch_size_in_gpu)
print("Effective BATCH SIZE:", args.batch_size_effective)
if args.save_n_images == 0:
args.save_n_images = args.batch_size_in_gpu
assert args.save_n_images <= args.batch_size_in_gpu, "ERROR: please make save_n_images <= batch_size_in_gpu!! Given save_n_images: " + str(args.save_n_images) + " ; batch_size_in_gpu: " + str(args.batch_size_in_gpu)
# Corrections
args.shuffle = not args.dont_shuffle
args.drop_last = not args.dont_drop_last
args.resize = not args.dont_resize
args.normalize = not args.dont_normalize
args.dataloader_args = {'num_workers':args.num_workers}
args.name = '{0:%Y%m%d_%H%M%S}_{1}_{2}'.format(datetime.datetime.now(), args.name, os.path.basename(args.data_path))
args.save_path = os.path.join(args.save_path, args.name)
args.model_weights_path = os.path.join(args.save_path, args.model_weights_dir)
args.sample_images_path = os.path.join(args.save_path, args.sample_images_dir)
return args
================================================
FILE: requirements.txt
================================================
matplotlib==3.0.0
torchvision==0.2.1
torch==2.2.0
opencv_python==4.2.0.32
imageio==2.4.1
numpy==1.22.0
================================================
FILE: sagan_models.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch.nn.init import xavier_uniform_
def init_weights(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
xavier_uniform_(m.weight)
m.bias.data.fill_(0.)
def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))
def snlinear(in_features, out_features):
return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features))
def sn_embedding(num_embeddings, embedding_dim):
return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim))
class Self_Attn(nn.Module):
""" Self attention Layer"""
def __init__(self, in_channels):
super(Self_Attn, self).__init__()
self.in_channels = in_channels
self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0)
self.snconv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
self.softmax = nn.Softmax(dim=-1)
self.sigma = nn.Parameter(torch.zeros(1))
def forward(self, x):
"""
inputs :
x : input feature maps(B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
_, ch, h, w = x.size()
# Theta path
theta = self.snconv1x1_theta(x)
theta = theta.view(-1, ch//8, h*w)
# Phi path
phi = self.snconv1x1_phi(x)
phi = self.maxpool(phi)
phi = phi.view(-1, ch//8, h*w//4)
# Attn map
attn = torch.bmm(theta.permute(0, 2, 1), phi)
attn = self.softmax(attn)
# g path
g = self.snconv1x1_g(x)
g = self.maxpool(g)
g = g.view(-1, ch//2, h*w//4)
# Attn_g
attn_g = torch.bmm(g, attn.permute(0, 2, 1))
attn_g = attn_g.view(-1, ch//2, h, w)
attn_g = self.snconv1x1_attn(attn_g)
# Out
out = x + self.sigma*attn_g
return out
class ConditionalBatchNorm2d(nn.Module):
# https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775
def __init__(self, num_features, num_classes):
super().__init__()
self.num_features = num_features
self.bn = nn.BatchNorm2d(num_features, momentum=0.001, affine=False)
self.embed = nn.Embedding(num_classes, num_features * 2)
# self.embed.weight.data[:, :num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, :num_features].fill_(1.) # Initialize scale to 1
self.embed.weight.data[:, num_features:].zero_() # Initialize bias at 0
def forward(self, x, y):
out = self.bn(x)
gamma, beta = self.embed(y).chunk(2, 1)
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
return out
class GenBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_classes):
super(GenBlock, self).__init__()
self.cond_bn1 = ConditionalBatchNorm2d(in_channels, num_classes)
self.relu = nn.ReLU(inplace=True)
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
self.cond_bn2 = ConditionalBatchNorm2d(out_channels, num_classes)
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, labels):
x0 = x
x = self.cond_bn1(x, labels)
x = self.relu(x)
x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample
x = self.snconv2d1(x)
x = self.cond_bn2(x, labels)
x = self.relu(x)
x = self.snconv2d2(x)
x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample
x0 = self.snconv2d0(x0)
out = x + x0
return out
class Generator(nn.Module):
"""Generator."""
def __init__(self, z_dim, g_conv_dim, num_classes):
super(Generator, self).__init__()
self.z_dim = z_dim
self.g_conv_dim = g_conv_dim
self.snlinear0 = snlinear(in_features=z_dim, out_features=g_conv_dim*16*4*4)
self.block1 = GenBlock(g_conv_dim*16, g_conv_dim*16, num_classes)
self.block2 = GenBlock(g_conv_dim*16, g_conv_dim*8, num_classes)
self.block3 = GenBlock(g_conv_dim*8, g_conv_dim*4, num_classes)
self.self_attn = Self_Attn(g_conv_dim*4)
self.block4 = GenBlock(g_conv_dim*4, g_conv_dim*2, num_classes)
self.block5 = GenBlock(g_conv_dim*2, g_conv_dim, num_classes)
self.bn = nn.BatchNorm2d(g_conv_dim, eps=1e-5, momentum=0.0001, affine=True)
self.relu = nn.ReLU(inplace=True)
self.snconv2d1 = snconv2d(in_channels=g_conv_dim, out_channels=3, kernel_size=3, stride=1, padding=1)
self.tanh = nn.Tanh()
# Weight init
self.apply(init_weights)
def forward(self, z, labels):
# n x z_dim
act0 = self.snlinear0(z) # n x g_conv_dim*16*4*4
act0 = act0.view(-1, self.g_conv_dim*16, 4, 4) # n x g_conv_dim*16 x 4 x 4
act1 = self.block1(act0, labels) # n x g_conv_dim*16 x 8 x 8
act2 = self.block2(act1, labels) # n x g_conv_dim*8 x 16 x 16
act3 = self.block3(act2, labels) # n x g_conv_dim*4 x 32 x 32
act3 = self.self_attn(act3) # n x g_conv_dim*4 x 32 x 32
act4 = self.block4(act3, labels) # n x g_conv_dim*2 x 64 x 64
act5 = self.block5(act4, labels) # n x g_conv_dim x 128 x 128
act5 = self.bn(act5) # n x g_conv_dim x 128 x 128
act5 = self.relu(act5) # n x g_conv_dim x 128 x 128
act6 = self.snconv2d1(act5) # n x 3 x 128 x 128
act6 = self.tanh(act6) # n x 3 x 128 x 128
return act6
class DiscOptBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(DiscOptBlock, self).__init__()
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
self.downsample = nn.AvgPool2d(2)
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x0 = x
x = self.snconv2d1(x)
x = self.relu(x)
x = self.snconv2d2(x)
x = self.downsample(x)
x0 = self.downsample(x0)
x0 = self.snconv2d0(x0)
out = x + x0
return out
class DiscBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(DiscBlock, self).__init__()
self.relu = nn.ReLU(inplace=True)
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
self.downsample = nn.AvgPool2d(2)
self.ch_mismatch = False
if in_channels != out_channels:
self.ch_mismatch = True
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, downsample=True):
x0 = x
x = self.relu(x)
x = self.snconv2d1(x)
x = self.relu(x)
x = self.snconv2d2(x)
if downsample:
x = self.downsample(x)
if downsample or self.ch_mismatch:
x0 = self.snconv2d0(x0)
if downsample:
x0 = self.downsample(x0)
out = x + x0
return out
class Discriminator(nn.Module):
"""Discriminator."""
def __init__(self, d_conv_dim, num_classes):
super(Discriminator, self).__init__()
self.d_conv_dim = d_conv_dim
self.opt_block1 = DiscOptBlock(3, d_conv_dim)
self.block1 = DiscBlock(d_conv_dim, d_conv_dim*2)
self.self_attn = Self_Attn(d_conv_dim*2)
self.block2 = DiscBlock(d_conv_dim*2, d_conv_dim*4)
self.block3 = DiscBlock(d_conv_dim*4, d_conv_dim*8)
self.block4 = DiscBlock(d_conv_dim*8, d_conv_dim*16)
self.block5 = DiscBlock(d_conv_dim*16, d_conv_dim*16)
self.relu = nn.ReLU(inplace=True)
self.snlinear1 = snlinear(in_features=d_conv_dim*16, out_features=1)
self.sn_embedding1 = sn_embedding(num_classes, d_conv_dim*16)
# Weight init
self.apply(init_weights)
xavier_uniform_(self.sn_embedding1.weight)
def forward(self, x, labels):
# n x 3 x 128 x 128
h0 = self.opt_block1(x) # n x d_conv_dim x 64 x 64
h1 = self.block1(h0) # n x d_conv_dim*2 x 32 x 32
h1 = self.self_attn(h1) # n x d_conv_dim*2 x 32 x 32
h2 = self.block2(h1) # n x d_conv_dim*4 x 16 x 16
h3 = self.block3(h2) # n x d_conv_dim*8 x 8 x 8
h4 = self.block4(h3) # n x d_conv_dim*16 x 4 x 4
h5 = self.block5(h4, downsample=False) # n x d_conv_dim*16 x 4 x 4
h5 = self.relu(h5) # n x d_conv_dim*16 x 4 x 4
h6 = torch.sum(h5, dim=[2,3]) # n x d_conv_dim*16
output1 = torch.squeeze(self.snlinear1(h6)) # n
# Projection
h_labels = self.sn_embedding1(labels) # n x d_conv_dim*16
proj = torch.mul(h6, h_labels) # n x d_conv_dim*16
output2 = torch.sum(proj, dim=[1]) # n
# Out
output = output1 + output2 # n
return output
================================================
FILE: test.py
================================================
import sys
import utils
from parameters import *
from sagan_models import Generator, Discriminator
if __name__ == '__main__':
config = get_parameters()
config.command = 'python ' + ' '.join(sys.argv)
print(config)
utils.check_for_CUDA(config)
# Load pretrained model (if provided)
if config.pretrained_model != '':
utils.load_pretrained_model(config)
else:
assert config.num_of_classes, "Please provide number of classes! Eg. python3 test.py --num_of_classes 10"
config.G = Generator(config.z_dim, config.g_conv_dim, config.num_of_classes).to(config.device)
config.D = Discriminator(config.d_conv_dim, config.num_of_classes).to(config.device)
config.G.eval()
config.D.eval()
print(config.G, config.D)
================================================
FILE: train.py
================================================
import sys
import utils
from parameters import *
from trainer import Trainer
if __name__ == '__main__':
config = get_parameters()
config.command = 'python ' + ' '.join(sys.argv)
print(config)
trainer = Trainer(config)
trainer.train()
utils.save_ckpt(trainer, final=True)
================================================
FILE: trainer.py
================================================
import datetime
import numpy as np
import os
import random
import sys
import time
import torch
import torch.nn as nn
import torchvision.utils as vutils
from torch.backends import cudnn
import utils
from sagan_models import Generator, Discriminator
class Trainer(object):
def __init__(self, config):
# Config
self.config = config
self.start = 0 # Unless using pre-trained model
# Create directories if not exist
utils.make_folder(self.config.save_path)
utils.make_folder(self.config.model_weights_path)
utils.make_folder(self.config.sample_images_path)
# Copy files
utils.write_config_to_file(self.config, self.config.save_path)
utils.copy_scripts(self.config.save_path)
# Check for CUDA
utils.check_for_CUDA(self)
# Make dataloader
self.dataloader, self.num_of_classes = utils.make_dataloader(batch_size=self.config.batch_size_in_gpu,
dataset_type=self.config.dataset,
data_path=self.config.data_path,
shuffle=self.config.shuffle,
drop_last=self.config.drop_last,
dataloader_args=self.config.dataloader_args,
resize=self.config.resize,
imsize=self.config.imsize,
centercrop=self.config.centercrop,
centercrop_size=self.config.centercrop_size,
normalize=self.config.normalize,
)
# Data iterator
self.data_iter = iter(self.dataloader)
# Build G and D
self.build_models()
if self.config.adv_loss == 'dcgan':
self.criterion = nn.BCELoss()
def train(self):
# Seed
np.random.seed(self.config.manual_seed)
random.seed(self.config.manual_seed)
torch.manual_seed(self.config.manual_seed)
# For fast training
cudnn.benchmark = True
# For BatchNorm
self.G.train()
self.D.train()
# Fixed noise for sampling from G
fixed_noise = torch.randn(self.config.batch_size_in_gpu, self.config.z_dim, device=self.device)
if self.num_of_classes < self.config.batch_size_in_gpu:
fixed_labels = torch.from_numpy(np.tile(np.arange(self.num_of_classes), self.config.batch_size_in_gpu//self.num_of_classes + 1)[:self.config.batch_size_in_gpu]).to(self.device)
else:
fixed_labels = torch.from_numpy(np.arange(self.config.batch_size_in_gpu)).to(self.device)
# For gan loss
label = torch.full((self.config.batch_size_in_gpu,), 1, device=self.device)
ones = torch.full((self.config.batch_size_in_gpu,), 1, device=self.device)
# Losses file
log_file_name = os.path.join(self.config.save_path, 'log.txt')
log_file = open(log_file_name, "wt")
# Init
start_time = time.time()
G_losses = []
D_losses_real = []
D_losses_fake = []
D_losses = []
D_xs = []
D_Gz_trainDs = []
D_Gz_trainGs = []
# Instance noise - make random noise mean (0) and std for injecting
inst_noise_mean = torch.full((self.config.batch_size_in_gpu, 3, self.config.imsize, self.config.imsize), 0, device=self.device)
inst_noise_std = torch.full((self.config.batch_size_in_gpu, 3, self.config.imsize, self.config.imsize), self.config.inst_noise_sigma, device=self.device)
self.gpu_batches = self.config.batch_size//self.config.batch_size_in_gpu
# Start training
for self.step in range(self.start, self.config.total_step):
# Instance noise std is linearly annealed from self.inst_noise_sigma to 0 thru self.inst_noise_sigma_iters
inst_noise_sigma_curr = 0 if self.step > self.config.inst_noise_sigma_iters else (1 - self.step/self.config.inst_noise_sigma_iters)*self.config.inst_noise_sigma
inst_noise_std.fill_(inst_noise_sigma_curr)
# ================== TRAIN D ================== #
for _ in range(self.config.d_steps_per_iter):
# Zero grad
self.reset_grad()
# Accumulate losses for full batch_size
# while running GPU computations on only batch_size_in_gpu
for gpu_batch in range(self.gpu_batches):
# TRAIN with REAL
# Get real images & real labels
real_images, real_labels = self.get_real_samples()
# Get D output for real images & real labels
inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device)
d_out_real = self.D(real_images + inst_noise, real_labels)
# Compute D loss with real images & real labels
if self.config.adv_loss == 'hinge':
d_loss_real = torch.nn.ReLU()(ones - d_out_real).mean()
elif self.config.adv_loss == 'wgan_gp':
d_loss_real = -d_out_real.mean()
else:
label.fill_(1)
d_loss_real = self.criterion(d_out_real, label)
# Backward
d_loss_real /= self.gpu_batches
d_loss_real.backward()
# Delete loss, output
if self.step % self.config.log_step != 0 or gpu_batch < self.gpu_batches - 1:
del d_out_real, d_loss_real
# TRAIN with FAKE
# Create random noise
z = torch.randn(self.config.batch_size_in_gpu, self.config.z_dim, device=self.device)
# Generate fake images for same real labels
fake_images = self.G(z, real_labels)
# Get D output for fake images & same real labels
inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device)
d_out_fake = self.D(fake_images.detach() + inst_noise, real_labels)
# Compute D loss with fake images & real labels
if self.config.adv_loss == 'hinge':
d_loss_fake = torch.nn.ReLU()(ones + d_out_fake).mean()
elif self.config.adv_loss == 'dcgan':
label.fill_(0)
d_loss_fake = self.criterion(d_out_fake, label)
else:
d_loss_fake = d_out_fake.mean()
# If WGAN_GP, compute GP and add to D loss
if self.config.adv_loss == 'wgan_gp':
d_loss_gp = self.config.lambda_gp * self.compute_gradient_penalty(real_images, real_labels, fake_images.detach())
d_loss_fake += d_loss_gp
# Backward
d_loss_fake /= self.gpu_batches
d_loss_fake.backward()
# Delete loss, output
del fake_images
if self.step % self.config.log_step != 0 or gpu_batch < self.gpu_batches - 1:
del d_out_fake, d_loss_fake
# Optimize
self.D_optimizer.step()
# ================== TRAIN G ================== #
for _ in range(self.config.g_steps_per_iter):
# Zero grad
self.reset_grad()
# Accumulate losses for full batch_size
# while running GPU computations on only batch_size_in_gpu
for gpu_batch in range(self.gpu_batches):
# Get real images & real labels (only need real labels)
real_images, real_labels = self.get_real_samples()
# Create random noise
z = torch.randn(self.config.batch_size_in_gpu, self.config.z_dim).to(self.device)
# Generate fake images for same real labels
fake_images = self.G(z, real_labels)
# Get D output for fake images & same real labels
inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device)
g_out_fake = self.D(fake_images + inst_noise, real_labels)
# Compute G loss with fake images & real labels
if self.config.adv_loss == 'dcgan':
label.fill_(1)
g_loss = self.criterion(g_out_fake, label)
else:
g_loss = -g_out_fake.mean()
# Backward
g_loss /= self.gpu_batches
g_loss.backward()
# Delete loss, output
del fake_images
if self.step % self.config.log_step != 0 or gpu_batch < self.gpu_batches - 1:
del g_out_fake, g_loss
# Optimize
self.G_optimizer.step()
# Print out log info
if self.step % self.config.log_step == 0:
G_losses.append(g_loss.mean().item())
D_losses_real.append(d_loss_real.mean().item())
D_losses_fake.append(d_loss_fake.mean().item())
D_loss = D_losses_real[-1] + D_losses_fake[-1]
if self.config.adv_loss == 'wgan_gp':
D_loss += d_loss_gp.mean().item()
D_losses.append(D_loss)
D_xs.append(d_out_real.mean().item())
D_Gz_trainDs.append(d_out_fake.mean().item())
D_Gz_trainGs.append(g_out_fake.mean().item())
curr_time = time.time()
curr_time_str = datetime.datetime.fromtimestamp(curr_time).strftime('%Y-%m-%d %H:%M:%S')
elapsed = str(datetime.timedelta(seconds=(curr_time - start_time)))
log = ("[{}] : Elapsed [{}], Iter [{} / {}], G_loss: {:.4f}, D_loss: {:.4f}, D_loss_real: {:.4f}, D_loss_fake: {:.4f}, D(x): {:.4f}, D(G(z))_trainD: {:.4f}, D(G(z))_trainG: {:.4f}\n".
format(curr_time_str, elapsed, self.step, self.config.total_step,
G_losses[-1], D_losses[-1], D_losses_real[-1], D_losses_fake[-1],
D_xs[-1], D_Gz_trainDs[-1], D_Gz_trainGs[-1]))
print('\n' + log)
log_file.write(log)
log_file.flush()
utils.make_plots(G_losses, D_losses, D_losses_real, D_losses_fake, D_xs, D_Gz_trainDs, D_Gz_trainGs,
self.config.log_step, self.config.save_path)
# Delete loss, output
del d_out_real, d_loss_real, d_out_fake, d_loss_fake, g_out_fake, g_loss
# Sample images
if self.step % self.config.sample_step == 0:
print("Saving image samples..")
self.G.eval()
fake_images = self.G(fixed_noise, fixed_labels)
self.G.train()
sample_images = utils.denorm(fake_images.detach()[:self.config.save_n_images])
# Save batch images
vutils.save_image(sample_images, os.path.join(self.config.sample_images_path, 'fake_{:05d}.png'.format(self.step)), nrow=self.config.nrow)
# Save gif
utils.make_gif(sample_images[0].cpu().numpy().transpose(1, 2, 0)*255, self.step,
self.config.sample_images_path, self.config.name, max_frames_per_gif=self.config.max_frames_per_gif)
# Delete output
del fake_images
# Save model
if self.step % self.config.model_save_step == 0:
utils.save_ckpt(self)
def build_models(self):
self.G = Generator(self.config.z_dim, self.config.g_conv_dim, self.num_of_classes).to(self.device)
self.D = Discriminator(self.config.d_conv_dim, self.num_of_classes).to(self.device)
# Loss and optimizer
# self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
self.G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.config.g_lr, [self.config.beta1, self.config.beta2])
self.D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.config.d_lr, [self.config.beta1, self.config.beta2])
# Start with pretrained model (if it exists)
if self.config.pretrained_model != '':
utils.load_pretrained_model(self)
if 'cuda' in self.device.type and self.config.parallel and torch.cuda.device_count() > 1:
self.G = nn.DataParallel(self.G)
self.D = nn.DataParallel(self.D)
# print networks
print(self.G)
print(self.D)
def reset_grad(self):
self.G_optimizer.zero_grad()
self.D_optimizer.zero_grad()
def get_real_samples(self):
try:
real_images, real_labels = next(self.data_iter)
except:
self.data_iter = iter(self.dataloader)
real_images, real_labels = next(self.data_iter)
real_images, real_labels = real_images.to(self.device), real_labels.to(self.device)
return real_images, real_labels
def compute_gradient_penalty(self, real_images, real_labels, fake_images):
# Compute gradient penalty
alpha = torch.rand(real_images.size(0), 1, 1, 1).expand_as(real_images).to(device)
interpolated = torch.tensor(alpha * real_images + (1 - alpha) * fake_images, requires_grad=True)
out = self.D(interpolated, real_labels)
exp_grad = torch.ones(out.size()).to(device)
grad = torch.autograd.grad(outputs=out,
inputs=interpolated,
grad_outputs=exp_grad,
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
grad = grad.view(grad.size(0), -1)
grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)
return d_loss_gp
================================================
FILE: utils.py
================================================
import cv2
import glob
import imageio
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import os
import shutil
import torch
import torchvision.datasets as dset
from torchvision import transforms
def make_folder(path):
if not os.path.exists(path):
os.makedirs(path)
def denorm(x):
out = (x + 1) / 2
return out.clamp_(0, 1)
def write_config_to_file(config, save_path):
with open(os.path.join(save_path, 'config.txt'), 'w') as file:
for arg in vars(config):
file.write(str(arg) + ': ' + str(getattr(config, arg)) + '\n')
def copy_scripts(dst):
for file in glob.glob('*.py'):
shutil.copy(file, dst)
for d in glob.glob('*/'):
if '__' not in d and d[0] != '.':
shutil.copytree(d, os.path.join(dst, d))
def make_transform(resize=True, imsize=128, centercrop=False, centercrop_size=128,
totensor=True, normalize=True, norm_mean=(0.5, 0.5, 0.5), norm_std=(0.5, 0.5, 0.5)):
options = []
if resize:
options.append(transforms.Resize((imsize)))
if centercrop:
options.append(transforms.CenterCrop(centercrop_size))
if totensor:
options.append(transforms.ToTensor())
if normalize:
options.append(transforms.Normalize(norm_mean, norm_std))
transform = transforms.Compose(options)
return transform
def make_dataloader(batch_size, dataset_type, data_path, shuffle=True, drop_last=True, dataloader_args={},
resize=True, imsize=128, centercrop=False, centercrop_size=128, totensor=True,
normalize=True, norm_mean=(0.5, 0.5, 0.5), norm_std=(0.5, 0.5, 0.5)):
# Make transform
transform = make_transform(resize=resize, imsize=imsize,
centercrop=centercrop, centercrop_size=centercrop_size,
totensor=totensor, normalize=normalize, norm_mean=norm_mean, norm_std=norm_std)
# Make dataset
if dataset_type in ['folder', 'imagenet', 'lfw']:
# folder dataset
assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
dataset = dset.ImageFolder(root=data_path, transform=transform)
elif dataset_type == 'lsun':
assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
dataset = dset.LSUN(root=data_path, classes=['bedroom_train'], transform=transform)
elif dataset_type == 'cifar10':
if not os.path.exists(data_path):
print("data_path does not exist! Given: {}\nDownloading CIFAR10 dataset...".format(data_path))
dataset = dset.CIFAR10(root=data_path, download=True, transform=transform)
elif dataset_type == 'fake':
dataset = dset.FakeData(image_size=(3, centercrop_size, centercrop_size), transform=transforms.ToTensor())
assert dataset
num_of_classes = len(dataset.classes)
print("Data found! # of images =", len(dataset), ", # of classes =", num_of_classes, ", classes:", dataset.classes)
# Make dataloader from dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, **dataloader_args)
return dataloader, num_of_classes
def make_gif(image, iteration_number, save_path, model_name, max_frames_per_gif=100):
# Make gif
gif_frames = []
# Read old gif frames
try:
gif_frames_reader = imageio.get_reader(os.path.join(save_path, model_name + ".gif"))
for frame in gif_frames_reader:
gif_frames.append(frame[:, :, :3])
except:
pass
# Append new frame
im = cv2.putText(np.concatenate((np.zeros((32, image.shape[1], image.shape[2])), image), axis=0),
'iter %s' % str(iteration_number), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1, cv2.LINE_AA).astype('uint8')
gif_frames.append(im)
# If frames exceeds, save as different file
if len(gif_frames) > max_frames_per_gif:
print("Splitting the GIF...")
gif_frames_00 = gif_frames[:max_frames_per_gif]
num_of_gifs_already_saved = len(glob.glob(os.path.join(save_path, model_name + "_*.gif")))
print("Saving", os.path.join(save_path, model_name + "_%05d.gif" % (num_of_gifs_already_saved)))
imageio.mimsave(os.path.join(save_path, model_name + "_%05d.gif" % (num_of_gifs_already_saved)), gif_frames_00)
gif_frames = gif_frames[max_frames_per_gif:]
# Save gif
# print("Saving", os.path.join(save_path, model_name + ".gif"))
imageio.mimsave(os.path.join(save_path, model_name + ".gif"), gif_frames)
def make_plots(G_losses, D_losses, D_losses_real, D_losses_fake, D_xs, D_Gz_trainDs, D_Gz_trainGs, log_step, save_path, init_epoch=0):
iters = np.arange(len(D_losses))*log_step + init_epoch
fig = plt.figure(figsize=(20, 20))
plt.subplot(311)
plt.plot(iters, np.zeros(iters.shape), 'k--', alpha=0.5)
plt.plot(iters, G_losses, color='C0', label='G')
plt.legend()
plt.title("Generator loss")
plt.xlabel("Iterations")
plt.subplot(312)
plt.plot(iters, np.zeros(iters.shape), 'k--', alpha=0.5)
plt.plot(iters, D_losses_real, color='C1', alpha=0.7, label='D_real')
plt.plot(iters, D_losses_fake, color='C2', alpha=0.7, label='D_fake')
plt.plot(iters, D_losses, color='C0', alpha=0.7, label='D')
plt.legend()
plt.title("Discriminator loss")
plt.xlabel("Iterations")
plt.subplot(313)
plt.plot(iters, np.zeros(iters.shape), 'k--', alpha=0.5)
plt.plot(iters, np.ones(iters.shape), 'k--', alpha=0.5)
plt.plot(iters, D_xs, alpha=0.7, label='D(x)')
plt.plot(iters, D_Gz_trainDs, alpha=0.7, label='D(G(z))_trainD')
plt.plot(iters, D_Gz_trainGs, alpha=0.7, label='D(G(z))_trainG')
plt.legend()
plt.title("D(x), D(G(z))")
plt.xlabel("Iterations")
plt.savefig(os.path.join(save_path, "plots.png"))
plt.clf()
plt.close()
def save_ckpt(sagan_obj, model=False, final=False):
print("Saving ckpt...")
if final:
# Save final - both model and state_dict
torch.save({
'step': sagan_obj.step,
'G_state_dict': sagan_obj.G.module.state_dict() if hasattr(sagan_obj.G, "module") else sagan_obj.G.state_dict(), # "module" in case DataParallel is used
'G_optimizer_state_dict': sagan_obj.G_optimizer.state_dict(),
'D_state_dict': sagan_obj.D.module.state_dict() if hasattr(sagan_obj.D, "module") else sagan_obj.D.state_dict(), # "module" in case DataParallel is used,
'D_optimizer_state_dict': sagan_obj.D_optimizer.state_dict(),
}, os.path.join(sagan_obj.config.model_weights_path, '{}_final_state_dict_ckpt_{:07d}.pth'.format(sagan_obj.config.name, sagan_obj.step)))
torch.save({
'step': sagan_obj.step,
'G': sagan_obj.G.module if hasattr(sagan_obj.G, "module") else sagan_obj.G,
'G_optimizer': sagan_obj.G_optimizer,
'D': sagan_obj.D.module if hasattr(sagan_obj.D, "module") else sagan_obj.D,
'D_optimizer': sagan_obj.D_optimizer,
}, os.path.join(sagan_obj.config.model_weights_path, '{}_final_model_ckpt_{:07d}.pth'.format(sagan_obj.config.name, sagan_obj.step)))
elif model:
# Save full model (not state_dict)
torch.save({
'step': sagan_obj.step,
'G': sagan_obj.G.module if hasattr(sagan_obj.G, "module") else sagan_obj.G, # "module" in case DataParallel is used
'G_optimizer': sagan_obj.G_optimizer,
'D': sagan_obj.D.module if hasattr(sagan_obj.D, "module") else sagan_obj.D, # "module" in case DataParallel is used
'D_optimizer': sagan_obj.D_optimizer,
}, os.path.join(sagan_obj.config.model_weights_path, '{}_model_ckpt_{:07d}.pth'.format(sagan_obj.config.name, sagan_obj.step)))
else:
# Save state_dict
torch.save({
'step': sagan_obj.step,
'G_state_dict': sagan_obj.G.module.state_dict() if hasattr(sagan_obj.G, "module") else sagan_obj.G.state_dict(),
'G_optimizer_state_dict': sagan_obj.G_optimizer.state_dict(),
'D_state_dict': sagan_obj.D.module.state_dict() if hasattr(sagan_obj.D, "module") else sagan_obj.D.state_dict(),
'D_optimizer_state_dict': sagan_obj.D_optimizer.state_dict(),
}, os.path.join(sagan_obj.config.model_weights_path, 'ckpt_{:07d}.pth'.format(sagan_obj.step)))
def load_pretrained_model(sagan_obj):
print("Loading pretrained_model", sagan_obj.config.pretrained_model, "...")
# Check for path
assert os.path.exists(sagan_obj.config.pretrained_model), "Path of .pth pretrained_model doesn't exist! Given: " + sagan_obj.config.pretrained_model
checkpoint = torch.load(sagan_obj.config.pretrained_model)
# If we know it is a state_dict (instead of complete model)
if sagan_obj.config.state_dict_or_model == 'state_dict':
sagan_obj.start = checkpoint['step'] + 1
sagan_obj.G.load_state_dict(checkpoint['G_state_dict'])
sagan_obj.G_optimizer.load_state_dict(checkpoint['G_optimizer_state_dict'])
sagan_obj.D.load_state_dict(checkpoint['D_state_dict'])
sagan_obj.D_optimizer.load_state_dict(checkpoint['D_optimizer_state_dict'])
# Else, if we know it is a complete model (and not just state_dict)
elif sagan_obj.config.state_dict_or_model == 'model':
sagan_obj.start = checkpoint['step'] + 1
sagan_obj.G = torch.load(checkpoint['G']).to(sagan_obj.device)
sagan_obj.G_optimizer = torch.load(checkpoint['G_optimizer'])
sagan_obj.D = torch.load(checkpoint['D']).to(sagan_obj.device)
sagan_obj.D_optimizer = torch.load(checkpoint['D_optimizer'])
# Else try for complete model, then try for state_dict
else:
try:
sagan_obj.start = checkpoint['step'] + 1
sagan_obj.G.load_state_dict(checkpoint['G_state_dict'])
sagan_obj.G_optimizer.load_state_dict(checkpoint['G_optimizer_state_dict'])
sagan_obj.D.load_state_dict(checkpoint['D_state_dict'])
sagan_obj.D_optimizer.load_state_dict(checkpoint['D_optimizer_state_dict'])
except:
sagan_obj.start = checkpoint['step'] + 1
sagan_obj.G = torch.load(checkpoint['G']).to(sagan_obj.device)
sagan_obj.G_optimizer = torch.load(checkpoint['G_optimizer'])
sagan_obj.D = torch.load(checkpoint['D']).to(sagan_obj.device)
sagan_obj.D_optimizer = torch.load(checkpoint['D_optimizer'])
def check_for_CUDA(sagan_obj):
if not sagan_obj.config.disable_cuda and torch.cuda.is_available():
print("CUDA is available!")
sagan_obj.device = torch.device('cuda')
sagan_obj.config.dataloader_args['pin_memory'] = True
else:
print("Cuda is NOT available, running on CPU.")
sagan_obj.device = torch.device('cpu')
if torch.cuda.is_available() and sagan_obj.config.disable_cuda:
print("WARNING: You have a CUDA device, so you should probably run without --disable_cuda")
gitextract_cs9fl47f/ ├── .gitignore ├── LICENSE ├── README.md ├── parameters.py ├── requirements.txt ├── sagan_models.py ├── test.py ├── train.py ├── trainer.py └── utils.py
SYMBOL INDEX (44 symbols across 4 files)
FILE: parameters.py
function get_parameters (line 6) | def get_parameters():
FILE: sagan_models.py
function init_weights (line 10) | def init_weights(m):
function snconv2d (line 16) | def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0...
function snlinear (line 21) | def snlinear(in_features, out_features):
function sn_embedding (line 25) | def sn_embedding(num_embeddings, embedding_dim):
class Self_Attn (line 29) | class Self_Attn(nn.Module):
method __init__ (line 32) | def __init__(self, in_channels):
method forward (line 43) | def forward(self, x):
class ConditionalBatchNorm2d (line 75) | class ConditionalBatchNorm2d(nn.Module):
method __init__ (line 77) | def __init__(self, num_features, num_classes):
method forward (line 86) | def forward(self, x, y):
class GenBlock (line 93) | class GenBlock(nn.Module):
method __init__ (line 94) | def __init__(self, in_channels, out_channels, num_classes):
method forward (line 103) | def forward(self, x, labels):
class Generator (line 121) | class Generator(nn.Module):
method __init__ (line 124) | def __init__(self, z_dim, g_conv_dim, num_classes):
method forward (line 144) | def forward(self, z, labels):
class DiscOptBlock (line 161) | class DiscOptBlock(nn.Module):
method __init__ (line 162) | def __init__(self, in_channels, out_channels):
method forward (line 170) | def forward(self, x):
class DiscBlock (line 185) | class DiscBlock(nn.Module):
method __init__ (line 186) | def __init__(self, in_channels, out_channels):
method forward (line 197) | def forward(self, x, downsample=True):
class Discriminator (line 216) | class Discriminator(nn.Module):
method __init__ (line 219) | def __init__(self, d_conv_dim, num_classes):
method forward (line 237) | def forward(self, x, labels):
FILE: trainer.py
class Trainer (line 17) | class Trainer(object):
method __init__ (line 19) | def __init__(self, config):
method train (line 61) | def train(self):
method build_models (line 278) | def build_models(self):
method reset_grad (line 299) | def reset_grad(self):
method get_real_samples (line 303) | def get_real_samples(self):
method compute_gradient_penalty (line 313) | def compute_gradient_penalty(self, real_images, real_labels, fake_imag...
FILE: utils.py
function make_folder (line 16) | def make_folder(path):
function denorm (line 21) | def denorm(x):
function write_config_to_file (line 26) | def write_config_to_file(config, save_path):
function copy_scripts (line 32) | def copy_scripts(dst):
function make_transform (line 41) | def make_transform(resize=True, imsize=128, centercrop=False, centercrop...
function make_dataloader (line 56) | def make_dataloader(batch_size, dataset_type, data_path, shuffle=True, d...
function make_gif (line 85) | def make_gif(image, iteration_number, save_path, model_name, max_frames_...
function make_plots (line 117) | def make_plots(G_losses, D_losses, D_losses_real, D_losses_fake, D_xs, D...
function save_ckpt (line 148) | def save_ckpt(sagan_obj, model=False, final=False):
function load_pretrained_model (line 189) | def load_pretrained_model(sagan_obj):
function check_for_CUDA (line 224) | def check_for_CUDA(sagan_obj):
Condensed preview — 10 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (50K chars).
[
{
"path": ".gitignore",
"chars": 1203,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "LICENSE",
"chars": 1070,
"preview": "MIT License\n\nCopyright (c) 2019 Vikram Voleti\n\nPermission is hereby granted, free of charge, to any person obtaining a c"
},
{
"path": "README.md",
"chars": 1967,
"preview": "# self-attention-GAN-pytorch\n\nThis is an almost exact replica in PyTorch of the Tensorflow version of [SAGAN](https://ar"
},
{
"path": "parameters.py",
"chars": 5870,
"preview": "import argparse\nimport datetime\nimport os\n\n\ndef get_parameters():\n\n parser = argparse.ArgumentParser()\n\n # Images "
},
{
"path": "requirements.txt",
"chars": 103,
"preview": "matplotlib==3.0.0\ntorchvision==0.2.1\ntorch==2.2.0\nopencv_python==4.2.0.32\nimageio==2.4.1\nnumpy==1.22.0\n"
},
{
"path": "sagan_models.py",
"chars": 10566,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torch.nn.utils import spectr"
},
{
"path": "test.py",
"chars": 779,
"preview": "import sys\n\nimport utils\n\nfrom parameters import *\nfrom sagan_models import Generator, Discriminator\n\n\nif __name__ == '_"
},
{
"path": "train.py",
"chars": 299,
"preview": "import sys\n\nimport utils\n\nfrom parameters import *\nfrom trainer import Trainer\n\n\nif __name__ == '__main__':\n config ="
},
{
"path": "trainer.py",
"chars": 14882,
"preview": "import datetime\nimport numpy as np\nimport os\nimport random\nimport sys\nimport time\nimport torch\nimport torch.nn as nn\nimp"
},
{
"path": "utils.py",
"chars": 11368,
"preview": "import cv2\nimport glob\nimport imageio\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\nimport num"
}
]
About this extraction
This page contains the full source code of the voletiv/self-attention-GAN-pytorch GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 10 files (47.0 KB), approximately 12.0k tokens, and a symbol index with 44 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.