Repository: voletiv/self-attention-GAN-pytorch Branch: master Commit: c44d3a02996a Files: 10 Total size: 47.0 KB Directory structure: gitextract_cs9fl47f/ ├── .gitignore ├── LICENSE ├── README.md ├── parameters.py ├── requirements.txt ├── sagan_models.py ├── test.py ├── train.py ├── trainer.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2019 Vikram Voleti Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # self-attention-GAN-pytorch This is an almost exact replica in PyTorch of the Tensorflow version of [SAGAN](https://arxiv.org/abs/1805.08318) released by Google Brain [[repo](https://github.com/brain-research/self-attention-gan)] in August 2018. Code structure is inspired from [this repo](https://github.com/heykeetae/Self-Attention-GAN), but follows the details of [Google Brain's repo](https://github.com/brain-research/self-attention-gan). ## Prerequisites Check `requirements.txt`. * [Python 3.5+](https://www.continuum.io/downloads) * [PyTorch 0.4.1](http://pytorch.org/) ## Training #### 1. Check `parameters.py` for all arguments and their default values #### 2. Train on custom images in folder a/b/c: ```bash $ python train.py --data_path 'a/b/c' --save_path 'o/p/q' --batch_size 64 --name sagan ``` (Warning: Works only on *128x128* images, input images are resized to that. Tweak the Generator & Discriminator first if you would like to use some other image size. And then use the `imsize` option: ```bash $ python train.py --data_path 'a/b/c' --save_path 'o/p/q' --batch_size 64 --imsize 32 --name sagan ``` ) Model training will be recorded in a new folder inside `--save_path` with the name `__`. By default, model weights are saved in a subfolder called `weights`, and train & validation samples during training in a subfolder called `samples` (can be changed in `parameters.py`). ## Testing/Evaluating Check `test.py`. ## Self-Attention GAN **[Han Zhang, Ian Goodfellow, Dimitris Metaxas and Augustus Odena, "Self-Attention Generative Adversarial Networks." arXiv preprint arXiv:1805.08318 (2018)](https://arxiv.org/abs/1805.08318).** ``` @article{Zhang2018SelfAttentionGA, title={Self-Attention Generative Adversarial Networks}, author={Han Zhang and Ian J. Goodfellow and Dimitris N. Metaxas and Augustus Odena}, journal={CoRR}, year={2018}, volume={abs/1805.08318} } ``` ================================================ FILE: parameters.py ================================================ import argparse import datetime import os def get_parameters(): parser = argparse.ArgumentParser() # Images data path & Output path parser.add_argument('--dataset', type=str, default='folder', choices=["cifar10", "fake", "folder", "hdf5", "imagenet", "lfw", "lsun"], help="cifar10 | fake | folder | hdf5 | imagenet | lfw | lsun") parser.add_argument('--data_path', type=str, default='', help='Path to root of image data (saved in dirs of classes)') parser.add_argument('--save_path', type=str, default='./sagan_models') # Training settings parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--batch_size_in_gpu', type=int, default=0, help='0 => same as batch_size, else: if using multiple gpu iterations to make an effective batch, e.g. batch_size=32, batch_size_in_gpu=16 => optimizer.step() is run 2 iterations after running loss.backward()') parser.add_argument('--total_step', type=int, default=200000, help='how many iterations') parser.add_argument('--d_steps_per_iter', type=int, default=1, help='how many D updates per iteration') parser.add_argument('--g_steps_per_iter', type=int, default=1, help='how many G updates per iteration') parser.add_argument('--d_lr', type=float, default=0.0004) parser.add_argument('--g_lr', type=float, default=0.0001) parser.add_argument('--beta1', type=float, default=0.0) parser.add_argument('--beta2', type=float, default=0.999) # Model hyper-parameters parser.add_argument('--adv_loss', type=str, default='hinge', choices=['hinge', 'dcgan', 'wgan_gp', 'gan']) parser.add_argument('--z_dim', type=int, default=128) parser.add_argument('--g_conv_dim', type=int, default=64) parser.add_argument('--d_conv_dim', type=int, default=64) parser.add_argument('--lambda_gp', type=float, default=10) # Instance noise # https://github.com/soumith/ganhacks/issues/14#issuecomment-312509518 # https://www.inference.vc/instance-noise-a-trick-for-stabilising-gan-training/ parser.add_argument('--inst_noise_sigma', type=float, default=0.0) parser.add_argument('--inst_noise_sigma_iters', type=int, default=2000) # Image transforms parser.add_argument('--dont_shuffle', action='store_true') parser.add_argument('--dont_drop_last', action='store_true', help="Whether not to drop the last batch in dataset if its size < batch_size") parser.add_argument('--dont_resize', action='store_true', help="Whether not to resize images") parser.add_argument('--imsize', type=int, default=128) parser.add_argument('--centercrop', action='store_true', help="Whether to center crop images") parser.add_argument('--centercrop_size', type=int, default=128) parser.add_argument('--dont_normalize', action='store_true', help="Whether to normalize image values") # Step sizes parser.add_argument('--log_step', type=int, default=10) parser.add_argument('--sample_step', type=int, default=10) parser.add_argument('--model_save_step', type=float, default=50) parser.add_argument('--save_n_images', type=int, default=0, help='0 => same as batch_size_in_gpu') parser.add_argument('--nrow', type=int, default=10) parser.add_argument('--max_frames_per_gif', type=int, default=100) # Pretrained model parser.add_argument('--pretrained_model', type=str, default='') parser.add_argument('--state_dict_or_model', type=str, default='', help="Specify whether .pth pretrained_model is a 'state_dict' or a complete 'model'") # Misc parser.add_argument('--manual_seed', type=int, default=29) parser.add_argument('--disable_cuda', action='store_true', help='Disable CUDA') parser.add_argument('--parallel', action='store_true', help="Run on multiple GPUs") parser.add_argument('--num_workers', type=int, default=4) # parser.add_argument('--use_tensorboard', action='store_true') # Output paths parser.add_argument('--model_weights_dir', type=str, default='weights') parser.add_argument('--sample_images_dir', type=str, default='samples') # Model name parser.add_argument('--name', type=str, default='sagan') args = parser.parse_args() if args.batch_size_in_gpu == 0: args.batch_size_in_gpu = args.batch_size assert args.batch_size_in_gpu <= args.batch_size, "ERROR: please make sure batch_size >= batch_size_in_gpu!! Given batch_size: " + str(args.batch_size) + " ; batch_size_in_gpu: " + str(args.batch_size_in_gpu) assert args.batch_size % args.batch_size_in_gpu == 0, "ERROR: please make sure batch_size_in_gpu divides batch_size!! Given batch_size: " + str(args.batch_size) + " ; batch_size_in_gpu: " + str(args.batch_size_in_gpu) args.batch_size_effective = args.batch_size_in_gpu*(args.batch_size//args.batch_size_in_gpu) print("Effective BATCH SIZE:", args.batch_size_effective) if args.save_n_images == 0: args.save_n_images = args.batch_size_in_gpu assert args.save_n_images <= args.batch_size_in_gpu, "ERROR: please make save_n_images <= batch_size_in_gpu!! Given save_n_images: " + str(args.save_n_images) + " ; batch_size_in_gpu: " + str(args.batch_size_in_gpu) # Corrections args.shuffle = not args.dont_shuffle args.drop_last = not args.dont_drop_last args.resize = not args.dont_resize args.normalize = not args.dont_normalize args.dataloader_args = {'num_workers':args.num_workers} args.name = '{0:%Y%m%d_%H%M%S}_{1}_{2}'.format(datetime.datetime.now(), args.name, os.path.basename(args.data_path)) args.save_path = os.path.join(args.save_path, args.name) args.model_weights_path = os.path.join(args.save_path, args.model_weights_dir) args.sample_images_path = os.path.join(args.save_path, args.sample_images_dir) return args ================================================ FILE: requirements.txt ================================================ matplotlib==3.0.0 torchvision==0.2.1 torch==2.2.0 opencv_python==4.2.0.32 imageio==2.4.1 numpy==1.22.0 ================================================ FILE: sagan_models.py ================================================ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils import spectral_norm from torch.nn.init import xavier_uniform_ def init_weights(m): if type(m) == nn.Linear or type(m) == nn.Conv2d: xavier_uniform_(m.weight) m.bias.data.fill_(0.) def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)) def snlinear(in_features, out_features): return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features)) def sn_embedding(num_embeddings, embedding_dim): return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)) class Self_Attn(nn.Module): """ Self attention Layer""" def __init__(self, in_channels): super(Self_Attn, self).__init__() self.in_channels = in_channels self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0) self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0) self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0) self.snconv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0) self.maxpool = nn.MaxPool2d(2, stride=2, padding=0) self.softmax = nn.Softmax(dim=-1) self.sigma = nn.Parameter(torch.zeros(1)) def forward(self, x): """ inputs : x : input feature maps(B X C X W X H) returns : out : self attention value + input feature attention: B X N X N (N is Width*Height) """ _, ch, h, w = x.size() # Theta path theta = self.snconv1x1_theta(x) theta = theta.view(-1, ch//8, h*w) # Phi path phi = self.snconv1x1_phi(x) phi = self.maxpool(phi) phi = phi.view(-1, ch//8, h*w//4) # Attn map attn = torch.bmm(theta.permute(0, 2, 1), phi) attn = self.softmax(attn) # g path g = self.snconv1x1_g(x) g = self.maxpool(g) g = g.view(-1, ch//2, h*w//4) # Attn_g attn_g = torch.bmm(g, attn.permute(0, 2, 1)) attn_g = attn_g.view(-1, ch//2, h, w) attn_g = self.snconv1x1_attn(attn_g) # Out out = x + self.sigma*attn_g return out class ConditionalBatchNorm2d(nn.Module): # https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775 def __init__(self, num_features, num_classes): super().__init__() self.num_features = num_features self.bn = nn.BatchNorm2d(num_features, momentum=0.001, affine=False) self.embed = nn.Embedding(num_classes, num_features * 2) # self.embed.weight.data[:, :num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) self.embed.weight.data[:, :num_features].fill_(1.) # Initialize scale to 1 self.embed.weight.data[:, num_features:].zero_() # Initialize bias at 0 def forward(self, x, y): out = self.bn(x) gamma, beta = self.embed(y).chunk(2, 1) out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) return out class GenBlock(nn.Module): def __init__(self, in_channels, out_channels, num_classes): super(GenBlock, self).__init__() self.cond_bn1 = ConditionalBatchNorm2d(in_channels, num_classes) self.relu = nn.ReLU(inplace=True) self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1) self.cond_bn2 = ConditionalBatchNorm2d(out_channels, num_classes) self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1) self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x, labels): x0 = x x = self.cond_bn1(x, labels) x = self.relu(x) x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample x = self.snconv2d1(x) x = self.cond_bn2(x, labels) x = self.relu(x) x = self.snconv2d2(x) x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample x0 = self.snconv2d0(x0) out = x + x0 return out class Generator(nn.Module): """Generator.""" def __init__(self, z_dim, g_conv_dim, num_classes): super(Generator, self).__init__() self.z_dim = z_dim self.g_conv_dim = g_conv_dim self.snlinear0 = snlinear(in_features=z_dim, out_features=g_conv_dim*16*4*4) self.block1 = GenBlock(g_conv_dim*16, g_conv_dim*16, num_classes) self.block2 = GenBlock(g_conv_dim*16, g_conv_dim*8, num_classes) self.block3 = GenBlock(g_conv_dim*8, g_conv_dim*4, num_classes) self.self_attn = Self_Attn(g_conv_dim*4) self.block4 = GenBlock(g_conv_dim*4, g_conv_dim*2, num_classes) self.block5 = GenBlock(g_conv_dim*2, g_conv_dim, num_classes) self.bn = nn.BatchNorm2d(g_conv_dim, eps=1e-5, momentum=0.0001, affine=True) self.relu = nn.ReLU(inplace=True) self.snconv2d1 = snconv2d(in_channels=g_conv_dim, out_channels=3, kernel_size=3, stride=1, padding=1) self.tanh = nn.Tanh() # Weight init self.apply(init_weights) def forward(self, z, labels): # n x z_dim act0 = self.snlinear0(z) # n x g_conv_dim*16*4*4 act0 = act0.view(-1, self.g_conv_dim*16, 4, 4) # n x g_conv_dim*16 x 4 x 4 act1 = self.block1(act0, labels) # n x g_conv_dim*16 x 8 x 8 act2 = self.block2(act1, labels) # n x g_conv_dim*8 x 16 x 16 act3 = self.block3(act2, labels) # n x g_conv_dim*4 x 32 x 32 act3 = self.self_attn(act3) # n x g_conv_dim*4 x 32 x 32 act4 = self.block4(act3, labels) # n x g_conv_dim*2 x 64 x 64 act5 = self.block5(act4, labels) # n x g_conv_dim x 128 x 128 act5 = self.bn(act5) # n x g_conv_dim x 128 x 128 act5 = self.relu(act5) # n x g_conv_dim x 128 x 128 act6 = self.snconv2d1(act5) # n x 3 x 128 x 128 act6 = self.tanh(act6) # n x 3 x 128 x 128 return act6 class DiscOptBlock(nn.Module): def __init__(self, in_channels, out_channels): super(DiscOptBlock, self).__init__() self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU(inplace=True) self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1) self.downsample = nn.AvgPool2d(2) self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): x0 = x x = self.snconv2d1(x) x = self.relu(x) x = self.snconv2d2(x) x = self.downsample(x) x0 = self.downsample(x0) x0 = self.snconv2d0(x0) out = x + x0 return out class DiscBlock(nn.Module): def __init__(self, in_channels, out_channels): super(DiscBlock, self).__init__() self.relu = nn.ReLU(inplace=True) self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1) self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1) self.downsample = nn.AvgPool2d(2) self.ch_mismatch = False if in_channels != out_channels: self.ch_mismatch = True self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x, downsample=True): x0 = x x = self.relu(x) x = self.snconv2d1(x) x = self.relu(x) x = self.snconv2d2(x) if downsample: x = self.downsample(x) if downsample or self.ch_mismatch: x0 = self.snconv2d0(x0) if downsample: x0 = self.downsample(x0) out = x + x0 return out class Discriminator(nn.Module): """Discriminator.""" def __init__(self, d_conv_dim, num_classes): super(Discriminator, self).__init__() self.d_conv_dim = d_conv_dim self.opt_block1 = DiscOptBlock(3, d_conv_dim) self.block1 = DiscBlock(d_conv_dim, d_conv_dim*2) self.self_attn = Self_Attn(d_conv_dim*2) self.block2 = DiscBlock(d_conv_dim*2, d_conv_dim*4) self.block3 = DiscBlock(d_conv_dim*4, d_conv_dim*8) self.block4 = DiscBlock(d_conv_dim*8, d_conv_dim*16) self.block5 = DiscBlock(d_conv_dim*16, d_conv_dim*16) self.relu = nn.ReLU(inplace=True) self.snlinear1 = snlinear(in_features=d_conv_dim*16, out_features=1) self.sn_embedding1 = sn_embedding(num_classes, d_conv_dim*16) # Weight init self.apply(init_weights) xavier_uniform_(self.sn_embedding1.weight) def forward(self, x, labels): # n x 3 x 128 x 128 h0 = self.opt_block1(x) # n x d_conv_dim x 64 x 64 h1 = self.block1(h0) # n x d_conv_dim*2 x 32 x 32 h1 = self.self_attn(h1) # n x d_conv_dim*2 x 32 x 32 h2 = self.block2(h1) # n x d_conv_dim*4 x 16 x 16 h3 = self.block3(h2) # n x d_conv_dim*8 x 8 x 8 h4 = self.block4(h3) # n x d_conv_dim*16 x 4 x 4 h5 = self.block5(h4, downsample=False) # n x d_conv_dim*16 x 4 x 4 h5 = self.relu(h5) # n x d_conv_dim*16 x 4 x 4 h6 = torch.sum(h5, dim=[2,3]) # n x d_conv_dim*16 output1 = torch.squeeze(self.snlinear1(h6)) # n # Projection h_labels = self.sn_embedding1(labels) # n x d_conv_dim*16 proj = torch.mul(h6, h_labels) # n x d_conv_dim*16 output2 = torch.sum(proj, dim=[1]) # n # Out output = output1 + output2 # n return output ================================================ FILE: test.py ================================================ import sys import utils from parameters import * from sagan_models import Generator, Discriminator if __name__ == '__main__': config = get_parameters() config.command = 'python ' + ' '.join(sys.argv) print(config) utils.check_for_CUDA(config) # Load pretrained model (if provided) if config.pretrained_model != '': utils.load_pretrained_model(config) else: assert config.num_of_classes, "Please provide number of classes! Eg. python3 test.py --num_of_classes 10" config.G = Generator(config.z_dim, config.g_conv_dim, config.num_of_classes).to(config.device) config.D = Discriminator(config.d_conv_dim, config.num_of_classes).to(config.device) config.G.eval() config.D.eval() print(config.G, config.D) ================================================ FILE: train.py ================================================ import sys import utils from parameters import * from trainer import Trainer if __name__ == '__main__': config = get_parameters() config.command = 'python ' + ' '.join(sys.argv) print(config) trainer = Trainer(config) trainer.train() utils.save_ckpt(trainer, final=True) ================================================ FILE: trainer.py ================================================ import datetime import numpy as np import os import random import sys import time import torch import torch.nn as nn import torchvision.utils as vutils from torch.backends import cudnn import utils from sagan_models import Generator, Discriminator class Trainer(object): def __init__(self, config): # Config self.config = config self.start = 0 # Unless using pre-trained model # Create directories if not exist utils.make_folder(self.config.save_path) utils.make_folder(self.config.model_weights_path) utils.make_folder(self.config.sample_images_path) # Copy files utils.write_config_to_file(self.config, self.config.save_path) utils.copy_scripts(self.config.save_path) # Check for CUDA utils.check_for_CUDA(self) # Make dataloader self.dataloader, self.num_of_classes = utils.make_dataloader(batch_size=self.config.batch_size_in_gpu, dataset_type=self.config.dataset, data_path=self.config.data_path, shuffle=self.config.shuffle, drop_last=self.config.drop_last, dataloader_args=self.config.dataloader_args, resize=self.config.resize, imsize=self.config.imsize, centercrop=self.config.centercrop, centercrop_size=self.config.centercrop_size, normalize=self.config.normalize, ) # Data iterator self.data_iter = iter(self.dataloader) # Build G and D self.build_models() if self.config.adv_loss == 'dcgan': self.criterion = nn.BCELoss() def train(self): # Seed np.random.seed(self.config.manual_seed) random.seed(self.config.manual_seed) torch.manual_seed(self.config.manual_seed) # For fast training cudnn.benchmark = True # For BatchNorm self.G.train() self.D.train() # Fixed noise for sampling from G fixed_noise = torch.randn(self.config.batch_size_in_gpu, self.config.z_dim, device=self.device) if self.num_of_classes < self.config.batch_size_in_gpu: fixed_labels = torch.from_numpy(np.tile(np.arange(self.num_of_classes), self.config.batch_size_in_gpu//self.num_of_classes + 1)[:self.config.batch_size_in_gpu]).to(self.device) else: fixed_labels = torch.from_numpy(np.arange(self.config.batch_size_in_gpu)).to(self.device) # For gan loss label = torch.full((self.config.batch_size_in_gpu,), 1, device=self.device) ones = torch.full((self.config.batch_size_in_gpu,), 1, device=self.device) # Losses file log_file_name = os.path.join(self.config.save_path, 'log.txt') log_file = open(log_file_name, "wt") # Init start_time = time.time() G_losses = [] D_losses_real = [] D_losses_fake = [] D_losses = [] D_xs = [] D_Gz_trainDs = [] D_Gz_trainGs = [] # Instance noise - make random noise mean (0) and std for injecting inst_noise_mean = torch.full((self.config.batch_size_in_gpu, 3, self.config.imsize, self.config.imsize), 0, device=self.device) inst_noise_std = torch.full((self.config.batch_size_in_gpu, 3, self.config.imsize, self.config.imsize), self.config.inst_noise_sigma, device=self.device) self.gpu_batches = self.config.batch_size//self.config.batch_size_in_gpu # Start training for self.step in range(self.start, self.config.total_step): # Instance noise std is linearly annealed from self.inst_noise_sigma to 0 thru self.inst_noise_sigma_iters inst_noise_sigma_curr = 0 if self.step > self.config.inst_noise_sigma_iters else (1 - self.step/self.config.inst_noise_sigma_iters)*self.config.inst_noise_sigma inst_noise_std.fill_(inst_noise_sigma_curr) # ================== TRAIN D ================== # for _ in range(self.config.d_steps_per_iter): # Zero grad self.reset_grad() # Accumulate losses for full batch_size # while running GPU computations on only batch_size_in_gpu for gpu_batch in range(self.gpu_batches): # TRAIN with REAL # Get real images & real labels real_images, real_labels = self.get_real_samples() # Get D output for real images & real labels inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device) d_out_real = self.D(real_images + inst_noise, real_labels) # Compute D loss with real images & real labels if self.config.adv_loss == 'hinge': d_loss_real = torch.nn.ReLU()(ones - d_out_real).mean() elif self.config.adv_loss == 'wgan_gp': d_loss_real = -d_out_real.mean() else: label.fill_(1) d_loss_real = self.criterion(d_out_real, label) # Backward d_loss_real /= self.gpu_batches d_loss_real.backward() # Delete loss, output if self.step % self.config.log_step != 0 or gpu_batch < self.gpu_batches - 1: del d_out_real, d_loss_real # TRAIN with FAKE # Create random noise z = torch.randn(self.config.batch_size_in_gpu, self.config.z_dim, device=self.device) # Generate fake images for same real labels fake_images = self.G(z, real_labels) # Get D output for fake images & same real labels inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device) d_out_fake = self.D(fake_images.detach() + inst_noise, real_labels) # Compute D loss with fake images & real labels if self.config.adv_loss == 'hinge': d_loss_fake = torch.nn.ReLU()(ones + d_out_fake).mean() elif self.config.adv_loss == 'dcgan': label.fill_(0) d_loss_fake = self.criterion(d_out_fake, label) else: d_loss_fake = d_out_fake.mean() # If WGAN_GP, compute GP and add to D loss if self.config.adv_loss == 'wgan_gp': d_loss_gp = self.config.lambda_gp * self.compute_gradient_penalty(real_images, real_labels, fake_images.detach()) d_loss_fake += d_loss_gp # Backward d_loss_fake /= self.gpu_batches d_loss_fake.backward() # Delete loss, output del fake_images if self.step % self.config.log_step != 0 or gpu_batch < self.gpu_batches - 1: del d_out_fake, d_loss_fake # Optimize self.D_optimizer.step() # ================== TRAIN G ================== # for _ in range(self.config.g_steps_per_iter): # Zero grad self.reset_grad() # Accumulate losses for full batch_size # while running GPU computations on only batch_size_in_gpu for gpu_batch in range(self.gpu_batches): # Get real images & real labels (only need real labels) real_images, real_labels = self.get_real_samples() # Create random noise z = torch.randn(self.config.batch_size_in_gpu, self.config.z_dim).to(self.device) # Generate fake images for same real labels fake_images = self.G(z, real_labels) # Get D output for fake images & same real labels inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device) g_out_fake = self.D(fake_images + inst_noise, real_labels) # Compute G loss with fake images & real labels if self.config.adv_loss == 'dcgan': label.fill_(1) g_loss = self.criterion(g_out_fake, label) else: g_loss = -g_out_fake.mean() # Backward g_loss /= self.gpu_batches g_loss.backward() # Delete loss, output del fake_images if self.step % self.config.log_step != 0 or gpu_batch < self.gpu_batches - 1: del g_out_fake, g_loss # Optimize self.G_optimizer.step() # Print out log info if self.step % self.config.log_step == 0: G_losses.append(g_loss.mean().item()) D_losses_real.append(d_loss_real.mean().item()) D_losses_fake.append(d_loss_fake.mean().item()) D_loss = D_losses_real[-1] + D_losses_fake[-1] if self.config.adv_loss == 'wgan_gp': D_loss += d_loss_gp.mean().item() D_losses.append(D_loss) D_xs.append(d_out_real.mean().item()) D_Gz_trainDs.append(d_out_fake.mean().item()) D_Gz_trainGs.append(g_out_fake.mean().item()) curr_time = time.time() curr_time_str = datetime.datetime.fromtimestamp(curr_time).strftime('%Y-%m-%d %H:%M:%S') elapsed = str(datetime.timedelta(seconds=(curr_time - start_time))) log = ("[{}] : Elapsed [{}], Iter [{} / {}], G_loss: {:.4f}, D_loss: {:.4f}, D_loss_real: {:.4f}, D_loss_fake: {:.4f}, D(x): {:.4f}, D(G(z))_trainD: {:.4f}, D(G(z))_trainG: {:.4f}\n". format(curr_time_str, elapsed, self.step, self.config.total_step, G_losses[-1], D_losses[-1], D_losses_real[-1], D_losses_fake[-1], D_xs[-1], D_Gz_trainDs[-1], D_Gz_trainGs[-1])) print('\n' + log) log_file.write(log) log_file.flush() utils.make_plots(G_losses, D_losses, D_losses_real, D_losses_fake, D_xs, D_Gz_trainDs, D_Gz_trainGs, self.config.log_step, self.config.save_path) # Delete loss, output del d_out_real, d_loss_real, d_out_fake, d_loss_fake, g_out_fake, g_loss # Sample images if self.step % self.config.sample_step == 0: print("Saving image samples..") self.G.eval() fake_images = self.G(fixed_noise, fixed_labels) self.G.train() sample_images = utils.denorm(fake_images.detach()[:self.config.save_n_images]) # Save batch images vutils.save_image(sample_images, os.path.join(self.config.sample_images_path, 'fake_{:05d}.png'.format(self.step)), nrow=self.config.nrow) # Save gif utils.make_gif(sample_images[0].cpu().numpy().transpose(1, 2, 0)*255, self.step, self.config.sample_images_path, self.config.name, max_frames_per_gif=self.config.max_frames_per_gif) # Delete output del fake_images # Save model if self.step % self.config.model_save_step == 0: utils.save_ckpt(self) def build_models(self): self.G = Generator(self.config.z_dim, self.config.g_conv_dim, self.num_of_classes).to(self.device) self.D = Discriminator(self.config.d_conv_dim, self.num_of_classes).to(self.device) # Loss and optimizer # self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.config.g_lr, [self.config.beta1, self.config.beta2]) self.D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.config.d_lr, [self.config.beta1, self.config.beta2]) # Start with pretrained model (if it exists) if self.config.pretrained_model != '': utils.load_pretrained_model(self) if 'cuda' in self.device.type and self.config.parallel and torch.cuda.device_count() > 1: self.G = nn.DataParallel(self.G) self.D = nn.DataParallel(self.D) # print networks print(self.G) print(self.D) def reset_grad(self): self.G_optimizer.zero_grad() self.D_optimizer.zero_grad() def get_real_samples(self): try: real_images, real_labels = next(self.data_iter) except: self.data_iter = iter(self.dataloader) real_images, real_labels = next(self.data_iter) real_images, real_labels = real_images.to(self.device), real_labels.to(self.device) return real_images, real_labels def compute_gradient_penalty(self, real_images, real_labels, fake_images): # Compute gradient penalty alpha = torch.rand(real_images.size(0), 1, 1, 1).expand_as(real_images).to(device) interpolated = torch.tensor(alpha * real_images + (1 - alpha) * fake_images, requires_grad=True) out = self.D(interpolated, real_labels) exp_grad = torch.ones(out.size()).to(device) grad = torch.autograd.grad(outputs=out, inputs=interpolated, grad_outputs=exp_grad, retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1) ** 2) return d_loss_gp ================================================ FILE: utils.py ================================================ import cv2 import glob import imageio import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import os import shutil import torch import torchvision.datasets as dset from torchvision import transforms def make_folder(path): if not os.path.exists(path): os.makedirs(path) def denorm(x): out = (x + 1) / 2 return out.clamp_(0, 1) def write_config_to_file(config, save_path): with open(os.path.join(save_path, 'config.txt'), 'w') as file: for arg in vars(config): file.write(str(arg) + ': ' + str(getattr(config, arg)) + '\n') def copy_scripts(dst): for file in glob.glob('*.py'): shutil.copy(file, dst) for d in glob.glob('*/'): if '__' not in d and d[0] != '.': shutil.copytree(d, os.path.join(dst, d)) def make_transform(resize=True, imsize=128, centercrop=False, centercrop_size=128, totensor=True, normalize=True, norm_mean=(0.5, 0.5, 0.5), norm_std=(0.5, 0.5, 0.5)): options = [] if resize: options.append(transforms.Resize((imsize))) if centercrop: options.append(transforms.CenterCrop(centercrop_size)) if totensor: options.append(transforms.ToTensor()) if normalize: options.append(transforms.Normalize(norm_mean, norm_std)) transform = transforms.Compose(options) return transform def make_dataloader(batch_size, dataset_type, data_path, shuffle=True, drop_last=True, dataloader_args={}, resize=True, imsize=128, centercrop=False, centercrop_size=128, totensor=True, normalize=True, norm_mean=(0.5, 0.5, 0.5), norm_std=(0.5, 0.5, 0.5)): # Make transform transform = make_transform(resize=resize, imsize=imsize, centercrop=centercrop, centercrop_size=centercrop_size, totensor=totensor, normalize=normalize, norm_mean=norm_mean, norm_std=norm_std) # Make dataset if dataset_type in ['folder', 'imagenet', 'lfw']: # folder dataset assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path dataset = dset.ImageFolder(root=data_path, transform=transform) elif dataset_type == 'lsun': assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path dataset = dset.LSUN(root=data_path, classes=['bedroom_train'], transform=transform) elif dataset_type == 'cifar10': if not os.path.exists(data_path): print("data_path does not exist! Given: {}\nDownloading CIFAR10 dataset...".format(data_path)) dataset = dset.CIFAR10(root=data_path, download=True, transform=transform) elif dataset_type == 'fake': dataset = dset.FakeData(image_size=(3, centercrop_size, centercrop_size), transform=transforms.ToTensor()) assert dataset num_of_classes = len(dataset.classes) print("Data found! # of images =", len(dataset), ", # of classes =", num_of_classes, ", classes:", dataset.classes) # Make dataloader from dataset dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, **dataloader_args) return dataloader, num_of_classes def make_gif(image, iteration_number, save_path, model_name, max_frames_per_gif=100): # Make gif gif_frames = [] # Read old gif frames try: gif_frames_reader = imageio.get_reader(os.path.join(save_path, model_name + ".gif")) for frame in gif_frames_reader: gif_frames.append(frame[:, :, :3]) except: pass # Append new frame im = cv2.putText(np.concatenate((np.zeros((32, image.shape[1], image.shape[2])), image), axis=0), 'iter %s' % str(iteration_number), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1, cv2.LINE_AA).astype('uint8') gif_frames.append(im) # If frames exceeds, save as different file if len(gif_frames) > max_frames_per_gif: print("Splitting the GIF...") gif_frames_00 = gif_frames[:max_frames_per_gif] num_of_gifs_already_saved = len(glob.glob(os.path.join(save_path, model_name + "_*.gif"))) print("Saving", os.path.join(save_path, model_name + "_%05d.gif" % (num_of_gifs_already_saved))) imageio.mimsave(os.path.join(save_path, model_name + "_%05d.gif" % (num_of_gifs_already_saved)), gif_frames_00) gif_frames = gif_frames[max_frames_per_gif:] # Save gif # print("Saving", os.path.join(save_path, model_name + ".gif")) imageio.mimsave(os.path.join(save_path, model_name + ".gif"), gif_frames) def make_plots(G_losses, D_losses, D_losses_real, D_losses_fake, D_xs, D_Gz_trainDs, D_Gz_trainGs, log_step, save_path, init_epoch=0): iters = np.arange(len(D_losses))*log_step + init_epoch fig = plt.figure(figsize=(20, 20)) plt.subplot(311) plt.plot(iters, np.zeros(iters.shape), 'k--', alpha=0.5) plt.plot(iters, G_losses, color='C0', label='G') plt.legend() plt.title("Generator loss") plt.xlabel("Iterations") plt.subplot(312) plt.plot(iters, np.zeros(iters.shape), 'k--', alpha=0.5) plt.plot(iters, D_losses_real, color='C1', alpha=0.7, label='D_real') plt.plot(iters, D_losses_fake, color='C2', alpha=0.7, label='D_fake') plt.plot(iters, D_losses, color='C0', alpha=0.7, label='D') plt.legend() plt.title("Discriminator loss") plt.xlabel("Iterations") plt.subplot(313) plt.plot(iters, np.zeros(iters.shape), 'k--', alpha=0.5) plt.plot(iters, np.ones(iters.shape), 'k--', alpha=0.5) plt.plot(iters, D_xs, alpha=0.7, label='D(x)') plt.plot(iters, D_Gz_trainDs, alpha=0.7, label='D(G(z))_trainD') plt.plot(iters, D_Gz_trainGs, alpha=0.7, label='D(G(z))_trainG') plt.legend() plt.title("D(x), D(G(z))") plt.xlabel("Iterations") plt.savefig(os.path.join(save_path, "plots.png")) plt.clf() plt.close() def save_ckpt(sagan_obj, model=False, final=False): print("Saving ckpt...") if final: # Save final - both model and state_dict torch.save({ 'step': sagan_obj.step, 'G_state_dict': sagan_obj.G.module.state_dict() if hasattr(sagan_obj.G, "module") else sagan_obj.G.state_dict(), # "module" in case DataParallel is used 'G_optimizer_state_dict': sagan_obj.G_optimizer.state_dict(), 'D_state_dict': sagan_obj.D.module.state_dict() if hasattr(sagan_obj.D, "module") else sagan_obj.D.state_dict(), # "module" in case DataParallel is used, 'D_optimizer_state_dict': sagan_obj.D_optimizer.state_dict(), }, os.path.join(sagan_obj.config.model_weights_path, '{}_final_state_dict_ckpt_{:07d}.pth'.format(sagan_obj.config.name, sagan_obj.step))) torch.save({ 'step': sagan_obj.step, 'G': sagan_obj.G.module if hasattr(sagan_obj.G, "module") else sagan_obj.G, 'G_optimizer': sagan_obj.G_optimizer, 'D': sagan_obj.D.module if hasattr(sagan_obj.D, "module") else sagan_obj.D, 'D_optimizer': sagan_obj.D_optimizer, }, os.path.join(sagan_obj.config.model_weights_path, '{}_final_model_ckpt_{:07d}.pth'.format(sagan_obj.config.name, sagan_obj.step))) elif model: # Save full model (not state_dict) torch.save({ 'step': sagan_obj.step, 'G': sagan_obj.G.module if hasattr(sagan_obj.G, "module") else sagan_obj.G, # "module" in case DataParallel is used 'G_optimizer': sagan_obj.G_optimizer, 'D': sagan_obj.D.module if hasattr(sagan_obj.D, "module") else sagan_obj.D, # "module" in case DataParallel is used 'D_optimizer': sagan_obj.D_optimizer, }, os.path.join(sagan_obj.config.model_weights_path, '{}_model_ckpt_{:07d}.pth'.format(sagan_obj.config.name, sagan_obj.step))) else: # Save state_dict torch.save({ 'step': sagan_obj.step, 'G_state_dict': sagan_obj.G.module.state_dict() if hasattr(sagan_obj.G, "module") else sagan_obj.G.state_dict(), 'G_optimizer_state_dict': sagan_obj.G_optimizer.state_dict(), 'D_state_dict': sagan_obj.D.module.state_dict() if hasattr(sagan_obj.D, "module") else sagan_obj.D.state_dict(), 'D_optimizer_state_dict': sagan_obj.D_optimizer.state_dict(), }, os.path.join(sagan_obj.config.model_weights_path, 'ckpt_{:07d}.pth'.format(sagan_obj.step))) def load_pretrained_model(sagan_obj): print("Loading pretrained_model", sagan_obj.config.pretrained_model, "...") # Check for path assert os.path.exists(sagan_obj.config.pretrained_model), "Path of .pth pretrained_model doesn't exist! Given: " + sagan_obj.config.pretrained_model checkpoint = torch.load(sagan_obj.config.pretrained_model) # If we know it is a state_dict (instead of complete model) if sagan_obj.config.state_dict_or_model == 'state_dict': sagan_obj.start = checkpoint['step'] + 1 sagan_obj.G.load_state_dict(checkpoint['G_state_dict']) sagan_obj.G_optimizer.load_state_dict(checkpoint['G_optimizer_state_dict']) sagan_obj.D.load_state_dict(checkpoint['D_state_dict']) sagan_obj.D_optimizer.load_state_dict(checkpoint['D_optimizer_state_dict']) # Else, if we know it is a complete model (and not just state_dict) elif sagan_obj.config.state_dict_or_model == 'model': sagan_obj.start = checkpoint['step'] + 1 sagan_obj.G = torch.load(checkpoint['G']).to(sagan_obj.device) sagan_obj.G_optimizer = torch.load(checkpoint['G_optimizer']) sagan_obj.D = torch.load(checkpoint['D']).to(sagan_obj.device) sagan_obj.D_optimizer = torch.load(checkpoint['D_optimizer']) # Else try for complete model, then try for state_dict else: try: sagan_obj.start = checkpoint['step'] + 1 sagan_obj.G.load_state_dict(checkpoint['G_state_dict']) sagan_obj.G_optimizer.load_state_dict(checkpoint['G_optimizer_state_dict']) sagan_obj.D.load_state_dict(checkpoint['D_state_dict']) sagan_obj.D_optimizer.load_state_dict(checkpoint['D_optimizer_state_dict']) except: sagan_obj.start = checkpoint['step'] + 1 sagan_obj.G = torch.load(checkpoint['G']).to(sagan_obj.device) sagan_obj.G_optimizer = torch.load(checkpoint['G_optimizer']) sagan_obj.D = torch.load(checkpoint['D']).to(sagan_obj.device) sagan_obj.D_optimizer = torch.load(checkpoint['D_optimizer']) def check_for_CUDA(sagan_obj): if not sagan_obj.config.disable_cuda and torch.cuda.is_available(): print("CUDA is available!") sagan_obj.device = torch.device('cuda') sagan_obj.config.dataloader_args['pin_memory'] = True else: print("Cuda is NOT available, running on CPU.") sagan_obj.device = torch.device('cpu') if torch.cuda.is_available() and sagan_obj.config.disable_cuda: print("WARNING: You have a CUDA device, so you should probably run without --disable_cuda")