Repository: AAnoosheh/ComboGAN Branch: master Commit: 27643b6fb26b Files: 28 Total size: 68.3 KB Directory structure: gitextract_lw_z0n5w/ ├── .gitignore ├── LICENSE ├── README.md ├── data/ │ ├── __init__.py │ ├── base_dataset.py │ ├── data_loader.py │ ├── image_folder.py │ └── unaligned_dataset.py ├── models/ │ ├── __init__.py │ ├── base_model.py │ ├── combogan_model.py │ └── networks.py ├── options/ │ ├── __init__.py │ ├── base_options.py │ ├── test_options.py │ └── train_options.py ├── scripts/ │ ├── continue_combogan.sh │ ├── test_combogan.sh │ └── train_combogan.sh ├── test.py ├── train.py └── util/ ├── __init__.py ├── get_data.py ├── html.py ├── image_pool.py ├── png.py ├── util.py └── visualizer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ datasets/ checkpoints/ results/ *.png */**/__pycache__ */*.pyc */**/*.pyc */**/**/*.pyc */**/**/**/*.pyc */**/**/**/**/*.pyc */*.so* */**/*.so* */**/*.dylib* *~ ================================================ FILE: LICENSE ================================================ Copyright (c) 2017, Asha Anoosheh All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: README.md ================================================ # ComboGAN This is our ongoing PyTorch implementation for ComboGAN. Code was written by [Asha Anoosheh](https://github.com/aanoosheh) (built upon [CycleGAN](https://github.com/junyanz/CycleGAN)) #### [[ComboGAN Paper]](https://arxiv.org/pdf/1712.06909.pdf) If you use this code for your research, please cite: ComboGAN: Unrestrained Scalability for Image Domain Translation [Asha Anoosheh](http://ashaanoosheh.com), [Eirikur Augustsson](https://relational.github.io/), [Radu Timofte](http://www.vision.ee.ethz.ch/~timofter/), [Luc van Gool](https://www.vision.ee.ethz.ch/en/members/get_member.cgi?id=1) In Arxiv, 2017.



## Prerequisites - Linux or macOS - Python 3 - CPU or NVIDIA GPU + CUDA CuDNN ## Getting Started ### Installation - Install PyTorch and dependencies from http://pytorch.org - Install Torch vision from the source. ```bash git clone https://github.com/pytorch/vision cd vision python setup.py install ``` - Install python libraries [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate). ```bash pip install visdom pip install dominate ``` - Clone this repo: ```bash git clone https://github.com/AAnoosheh/ComboGAN.git cd ComboGAN ``` ### ComboGAN training Our ready datasets can be downloaded using `./datasets/download_dataset.sh `. A pretrained model for the 14-painters dataset can be found [HERE](https://www.dropbox.com/s/t8s6x0bu52d73s0/paint14_pretrained.zip?dl=0). Place under `./checkpoints/` and test using the instructions below, with args `--name paint14_pretrained --dataroot ./datasets/painters_14 --n_domains 14 --which_epoch 1150`. Example running scripts can be found in the `scripts` directory. - Train a model: ``` python train.py --name --dataroot ./datasets/ --n_domains --niter --niter_decay ``` Checkpoints will be saved by default to `./checkpoints//` - Fine-tuning/Resume training: ``` python train.py --continue_train --which_epoch --name --dataroot ./datasets/ --n_domains --niter --niter_decay ``` - Test the model: ``` python test.py --phase test --name --dataroot ./datasets/ --n_domains --which_epoch --serial_test ``` The test results will be saved to a html file here: `./results///index.html`. ## Training/Testing Details - Flags: see `options/train_options.py` for training-specific flags; see `options/test_options.py` for test-specific flags; and see `options/base_options.py` for all common flags. - Dataset format: The desired data directory (provided by `--dataroot`) should contain subfolders of the form `train*/` and `test*/`, and they are loaded in alphabetical order. (Note that a folder named train10 would be loaded before train2, and thus all checkpoints and results would be ordered accordingly.) - CPU/GPU (default `--gpu_ids 0`): set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. You need a large batch size (e.g. `--batchSize 32`) to benefit from multiple GPUs. - Visualization: during training, the current results and loss plots can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will appear on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have `visdom` installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id 0`. Secondly, the intermediate results are also saved to `./checkpoints//web/index.html`. To avoid this, set the `--no_html` flag. - Preprocessing: images can be resized and cropped in different ways using `--resize_or_crop` option. The default option `'resize_and_crop'` resizes the image to be of size `(opt.loadSize, opt.loadSize)` and does a random crop of size `(opt.fineSize, opt.fineSize)`. `'crop'` skips the resizing step and only performs random cropping. `'scale_width'` resizes the image to have width `opt.fineSize` while keeping the aspect ratio. `'scale_width_and_crop'` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`. NOTE: one should **not** expect ComboGAN to work on just any combination of input and output datasets (e.g. `dogs<->houses`). We find it works better if two datasets share similar visual content. For example, `landscape painting<->landscape photographs` works much better than `portrait painting <-> landscape photographs`. ================================================ FILE: data/__init__.py ================================================ ================================================ FILE: data/base_dataset.py ================================================ import torch.utils.data as data from PIL import Image import torchvision.transforms as transforms class BaseDataset(data.Dataset): def __init__(self): super(BaseDataset, self).__init__() def name(self): return 'BaseDataset' def initialize(self, opt): pass def get_transform(opt): transform_list = [] if 'resize' in opt.resize_or_crop: transform_list.append(transforms.Resize(opt.loadSize, Image.BICUBIC)) if opt.isTrain: if 'crop' in opt.resize_or_crop: transform_list.append(transforms.RandomCrop(opt.fineSize)) if not opt.no_flip: transform_list.append(transforms.RandomHorizontalFlip()) transform_list += [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) ================================================ FILE: data/data_loader.py ================================================ import torch.utils.data from data.unaligned_dataset import UnalignedDataset class DataLoader(): def name(self): return 'DataLoader' def __init__(self, opt): self.opt = opt self.dataset = UnalignedDataset(opt) self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=opt.batchSize, num_workers=int(opt.nThreads)) def __len__(self): return min(len(self.dataset), self.opt.max_dataset_size) def __iter__(self): for i, data in enumerate(self.dataloader): if i >= self.opt.max_dataset_size: break yield data ================================================ FILE: data/image_folder.py ================================================ ############################################################################### # Code from # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py # Modified the original code so that it also loads images from the current # directory as well as the subdirectories ############################################################################### import torch.utils.data as data from PIL import Image import os import os.path IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', ] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) def make_dataset(dir): images = [] assert os.path.isdir(dir), '%s is not a valid directory' % dir for root, _, fnames in sorted(os.walk(dir)): for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname) images.append(path) return images def default_loader(path): return Image.open(path).convert('RGB') class ImageFolder(data.Dataset): def __init__(self, root, transform=None, return_paths=False, loader=default_loader): imgs = make_dataset(root) if len(imgs) == 0: raise(RuntimeError("Found 0 images in: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) self.root = root self.imgs = imgs self.transform = transform self.return_paths = return_paths self.loader = loader def __getitem__(self, index): path = self.imgs[index] img = self.loader(path) if self.transform is not None: img = self.transform(img) if self.return_paths: return img, path else: return img def __len__(self): return len(self.imgs) ================================================ FILE: data/unaligned_dataset.py ================================================ import os.path, glob import torchvision.transforms as transforms from data.base_dataset import BaseDataset, get_transform from data.image_folder import make_dataset from PIL import Image import random class UnalignedDataset(BaseDataset): def __init__(self, opt): super(UnalignedDataset, self).__init__() self.opt = opt self.transform = get_transform(opt) datapath = os.path.join(opt.dataroot, opt.phase + '*') self.dirs = sorted(glob.glob(datapath)) self.paths = [sorted(make_dataset(d)) for d in self.dirs] self.sizes = [len(p) for p in self.paths] def load_image(self, dom, idx): path = self.paths[dom][idx] img = Image.open(path).convert('RGB') img = self.transform(img) return img, path def __getitem__(self, index): if not self.opt.isTrain: if self.opt.serial_test: for d,s in enumerate(self.sizes): if index < s: DA = d; break index -= s index_A = index else: DA = index % len(self.dirs) index_A = random.randint(0, self.sizes[DA] - 1) else: # Choose two of our domains to perform a pass on DA, DB = random.sample(range(len(self.dirs)), 2) index_A = random.randint(0, self.sizes[DA] - 1) A_img, A_path = self.load_image(DA, index_A) bundle = {'A': A_img, 'DA': DA, 'path': A_path} if self.opt.isTrain: index_B = random.randint(0, self.sizes[DB] - 1) B_img, _ = self.load_image(DB, index_B) bundle.update( {'B': B_img, 'DB': DB} ) return bundle def __len__(self): if self.opt.isTrain: return max(self.sizes) return sum(self.sizes) def name(self): return 'UnalignedDataset' ================================================ FILE: models/__init__.py ================================================ ================================================ FILE: models/base_model.py ================================================ import os import torch class BaseModel(): def name(self): return 'BaseModel' def __init__(self, opt): self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) def set_input(self, input): self.input = input def forward(self): pass # used in test time, no backprop def test(self): pass def get_image_paths(self): pass def optimize_parameters(self): pass def get_current_visuals(self): return self.input def get_current_errors(self): return {} def save(self, label): pass # helper saving function that can be used by subclasses def save_network(self, network, network_label, epoch, gpu_ids): save_filename = '%d_net_%s' % (epoch, network_label) save_path = os.path.join(self.save_dir, save_filename) network.save(save_path) if gpu_ids and torch.cuda.is_available(): network.cuda(gpu_ids[0]) # helper loading function that can be used by subclasses def load_network(self, network, network_label, epoch): save_filename = '%d_net_%s' % (epoch, network_label) save_path = os.path.join(self.save_dir, save_filename) network.load(save_path) def update_learning_rate(): pass ================================================ FILE: models/combogan_model.py ================================================ import numpy as np import torch from collections import OrderedDict import util.util as util from util.image_pool import ImagePool from .base_model import BaseModel from . import networks class ComboGANModel(BaseModel): def name(self): return 'ComboGANModel' def __init__(self, opt): super(ComboGANModel, self).__init__(opt) self.n_domains = opt.n_domains self.DA, self.DB = None, None self.real_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.real_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG_n_blocks, opt.netG_n_shared, self.n_domains, opt.norm, opt.use_dropout, self.gpu_ids) if self.isTrain: blur_fn = lambda x : torch.nn.functional.conv2d(x, self.Tensor(util.gkern_2d()), groups=3, padding=2) self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD_n_layers, self.n_domains, blur_fn, opt.norm, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG, 'G', which_epoch) if self.isTrain: self.load_network(self.netD, 'D', which_epoch) if self.isTrain: self.fake_pools = [ImagePool(opt.pool_size) for _ in range(self.n_domains)] # define loss functions self.L1 = torch.nn.SmoothL1Loss() self.downsample = torch.nn.AvgPool2d(3, stride=2) self.criterionCycle = self.L1 self.criterionIdt = lambda y,t : self.L1(self.downsample(y), self.downsample(t)) self.criterionLatent = lambda y,t : self.L1(y, t.detach()) self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) # initialize optimizers self.netG.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999)) self.netD.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999)) # initialize loss storage self.loss_D, self.loss_G = [0]*self.n_domains, [0]*self.n_domains self.loss_cycle = [0]*self.n_domains # initialize loss multipliers self.lambda_cyc, self.lambda_enc = opt.lambda_cycle, (0 * opt.lambda_latent) self.lambda_idt, self.lambda_fwd = opt.lambda_identity, opt.lambda_forward print('---------- Networks initialized -------------') print(self.netG) if self.isTrain: print(self.netD) print('-----------------------------------------------') def set_input(self, input): input_A = input['A'] self.real_A.resize_(input_A.size()).copy_(input_A) self.DA = input['DA'][0] if self.isTrain: input_B = input['B'] self.real_B.resize_(input_B.size()).copy_(input_B) self.DB = input['DB'][0] self.image_paths = input['path'] def test(self): with torch.no_grad(): self.visuals = [self.real_A] self.labels = ['real_%d' % self.DA] # cache encoding to not repeat it everytime encoded = self.netG.encode(self.real_A, self.DA) for d in range(self.n_domains): if d == self.DA and not self.opt.autoencode: continue fake = self.netG.decode(encoded, d) self.visuals.append( fake ) self.labels.append( 'fake_%d' % d ) if self.opt.reconstruct: rec = self.netG.forward(fake, d, self.DA) self.visuals.append( rec ) self.labels.append( 'rec_%d' % d ) def get_image_paths(self): return self.image_paths def backward_D_basic(self, real, fake, domain): # Real pred_real = self.netD.forward(real, domain) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = self.netD.forward(fake.detach(), domain) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D(self): #D_A fake_B = self.fake_pools[self.DB].query(self.fake_B) self.loss_D[self.DA] = self.backward_D_basic(self.real_B, fake_B, self.DB) #D_B fake_A = self.fake_pools[self.DA].query(self.fake_A) self.loss_D[self.DB] = self.backward_D_basic(self.real_A, fake_A, self.DA) def backward_G(self): encoded_A = self.netG.encode(self.real_A, self.DA) encoded_B = self.netG.encode(self.real_B, self.DB) # Optional identity "autoencode" loss if self.lambda_idt > 0: # Same encoder and decoder should recreate image idt_A = self.netG.decode(encoded_A, self.DA) loss_idt_A = self.criterionIdt(idt_A, self.real_A) idt_B = self.netG.decode(encoded_B, self.DB) loss_idt_B = self.criterionIdt(idt_B, self.real_B) else: loss_idt_A, loss_idt_B = 0, 0 # GAN loss # D_A(G_A(A)) self.fake_B = self.netG.decode(encoded_A, self.DB) pred_fake = self.netD.forward(self.fake_B, self.DB) self.loss_G[self.DA] = self.criterionGAN(pred_fake, True) # D_B(G_B(B)) self.fake_A = self.netG.decode(encoded_B, self.DA) pred_fake = self.netD.forward(self.fake_A, self.DA) self.loss_G[self.DB] = self.criterionGAN(pred_fake, True) # Forward cycle loss rec_encoded_A = self.netG.encode(self.fake_B, self.DB) self.rec_A = self.netG.decode(rec_encoded_A, self.DA) self.loss_cycle[self.DA] = self.criterionCycle(self.rec_A, self.real_A) # Backward cycle loss rec_encoded_B = self.netG.encode(self.fake_A, self.DA) self.rec_B = self.netG.decode(rec_encoded_B, self.DB) self.loss_cycle[self.DB] = self.criterionCycle(self.rec_B, self.real_B) # Optional cycle loss on encoding space if self.lambda_enc > 0: loss_enc_A = self.criterionLatent(rec_encoded_A, encoded_A) loss_enc_B = self.criterionLatent(rec_encoded_B, encoded_B) else: loss_enc_A, loss_enc_B = 0, 0 # Optional loss on downsampled image before and after if self.lambda_fwd > 0: loss_fwd_A = self.criterionIdt(self.fake_B, self.real_A) loss_fwd_B = self.criterionIdt(self.fake_A, self.real_B) else: loss_fwd_A, loss_fwd_B = 0, 0 # combined loss loss_G = self.loss_G[self.DA] + self.loss_G[self.DB] + \ (self.loss_cycle[self.DA] + self.loss_cycle[self.DB]) * self.lambda_cyc + \ (loss_idt_A + loss_idt_B) * self.lambda_idt + \ (loss_enc_A + loss_enc_B) * self.lambda_enc + \ (loss_fwd_A + loss_fwd_B) * self.lambda_fwd loss_G.backward() def optimize_parameters(self): # G_A and G_B self.netG.zero_grads(self.DA, self.DB) self.backward_G() self.netG.step_grads(self.DA, self.DB) # D_A and D_B self.netD.zero_grads(self.DA, self.DB) self.backward_D() self.netD.step_grads(self.DA, self.DB) def get_current_errors(self): extract = lambda l: [(i if type(i) is int or type(i) is float else i.item()) for i in l] D_losses, G_losses, cyc_losses = extract(self.loss_D), extract(self.loss_G), extract(self.loss_cycle) return OrderedDict([('D', D_losses), ('G', G_losses), ('Cyc', cyc_losses)]) def get_current_visuals(self, testing=False): if not testing: self.visuals = [self.real_A, self.fake_B, self.rec_A, self.real_B, self.fake_A, self.rec_B] self.labels = ['real_A', 'fake_B', 'rec_A', 'real_B', 'fake_A', 'rec_B'] images = [util.tensor2im(v.data) for v in self.visuals] return OrderedDict(zip(self.labels, images)) def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids) def update_hyperparams(self, curr_iter): if curr_iter > self.opt.niter: decay_frac = (curr_iter - self.opt.niter) / self.opt.niter_decay new_lr = self.opt.lr * (1 - decay_frac) self.netG.update_lr(new_lr) self.netD.update_lr(new_lr) print('updated learning rate: %f' % new_lr) if self.opt.lambda_latent > 0: decay_frac = curr_iter / (self.opt.niter + self.opt.niter_decay) self.lambda_enc = self.opt.lambda_latent * decay_frac ================================================ FILE: models/networks.py ================================================ import torch import torch.nn as nn from torch.nn import init import functools, itertools import numpy as np def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) if hasattr(m.bias, 'data'): m.bias.data.fill_(0) elif classname.find('BatchNorm2d') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) def get_norm_layer(norm_type='instance'): if norm_type == 'batch': return functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == 'instance': return functools.partial(nn.InstanceNorm2d, affine=False) else: raise NotImplementedError('normalization layer [%s] is not found' % norm_type) def define_G(input_nc, output_nc, ngf, n_blocks, n_blocks_shared, n_domains, norm='batch', use_dropout=False, gpu_ids=[]): norm_layer = get_norm_layer(norm_type=norm) if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d n_blocks -= n_blocks_shared n_blocks_enc = n_blocks // 2 n_blocks_dec = n_blocks - n_blocks_enc dup_args = (ngf, norm_layer, use_dropout, gpu_ids, use_bias) enc_args = (input_nc, n_blocks_enc) + dup_args dec_args = (output_nc, n_blocks_dec) + dup_args if n_blocks_shared > 0: n_blocks_shdec = n_blocks_shared // 2 n_blocks_shenc = n_blocks_shared - n_blocks_shdec shenc_args = (n_domains, n_blocks_shenc) + dup_args shdec_args = (n_domains, n_blocks_shdec) + dup_args plex_netG = G_Plexer(n_domains, ResnetGenEncoder, enc_args, ResnetGenDecoder, dec_args, ResnetGenShared, shenc_args, shdec_args) else: plex_netG = G_Plexer(n_domains, ResnetGenEncoder, enc_args, ResnetGenDecoder, dec_args) if len(gpu_ids) > 0: assert(torch.cuda.is_available()) plex_netG.cuda(gpu_ids[0]) plex_netG.apply(weights_init) return plex_netG def define_D(input_nc, ndf, netD_n_layers, n_domains, blur_fn, norm='batch', gpu_ids=[]): norm_layer = get_norm_layer(norm_type=norm) model_args = (input_nc, ndf, netD_n_layers, blur_fn, norm_layer, gpu_ids) plex_netD = D_Plexer(n_domains, NLayerDiscriminator, model_args) if len(gpu_ids) > 0: assert(torch.cuda.is_available()) plex_netD.cuda(gpu_ids[0]) plex_netD.apply(weights_init) return plex_netD ############################################################################## # Classes ############################################################################## # Defines the GAN loss which uses either LSGAN or the regular GAN. # When LSGAN is used, it is basically same as MSELoss, # but it abstracts away the need to create the target label tensor # that has the same size as the input class GANLoss(nn.Module): def __init__(self, use_lsgan=True, tensor=torch.FloatTensor): super(GANLoss, self).__init__() self.Tensor = tensor self.labels_real, self.labels_fake = None, None self.preloss = nn.Sigmoid() if not use_lsgan else None self.loss = nn.MSELoss() if use_lsgan else nn.BCELoss() def get_target_tensor(self, inputs, is_real): if self.labels_real is None or self.labels_real[0].numel() != inputs[0].numel(): self.labels_real = [ self.Tensor(input.size()).fill_(1.0) for input in inputs ] self.labels_fake = [ self.Tensor(input.size()).fill_(0.0) for input in inputs ] if is_real: return self.labels_real return self.labels_fake def __call__(self, inputs, is_real): labels = self.get_target_tensor(inputs, is_real) if self.preloss is not None: inputs = [self.preloss(input) for input in inputs] losses = [self.loss(input, label) for input, label in zip(inputs, labels)] multipliers = list(range(1, len(inputs)+1)); multipliers[-1] += 1 losses = [m*l for m,l in zip(multipliers, losses)] return sum(losses) / (sum(multipliers) * len(losses)) # Defines the generator that consists of Resnet blocks between a few # downsampling/upsampling operations. # Code and idea originally from Justin Johnson's architecture. # https://github.com/jcjohnson/fast-neural-style/ class ResnetGenEncoder(nn.Module): def __init__(self, input_nc, n_blocks=4, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'): assert(n_blocks >= 0) super(ResnetGenEncoder, self).__init__() self.gpu_ids = gpu_ids model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.PReLU()] n_downsampling = 2 for i in range(n_downsampling): mult = 2**i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), norm_layer(ngf * mult * 2), nn.PReLU()] mult = 2**n_downsampling for _ in range(n_blocks): model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)] self.model = nn.Sequential(*model) def forward(self, input): if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) return self.model(input) class ResnetGenShared(nn.Module): def __init__(self, n_domains, n_blocks=2, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'): assert(n_blocks >= 0) super(ResnetGenShared, self).__init__() self.gpu_ids = gpu_ids model = [] n_downsampling = 2 mult = 2**n_downsampling for _ in range(n_blocks): model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, n_domains=n_domains, use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)] self.model = SequentialContext(n_domains, *model) def forward(self, input, domain): if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, (input, domain), self.gpu_ids) return self.model(input, domain) class ResnetGenDecoder(nn.Module): def __init__(self, output_nc, n_blocks=5, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'): assert(n_blocks >= 0) super(ResnetGenDecoder, self).__init__() self.gpu_ids = gpu_ids model = [] n_downsampling = 2 mult = 2**n_downsampling for _ in range(n_blocks): model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)] for i in range(n_downsampling): mult = 2**(n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=4, stride=2, padding=1, output_padding=0, bias=use_bias), norm_layer(int(ngf * mult / 2)), nn.PReLU()] model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, input): if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) return self.model(input) # Define a resnet block class ResnetBlock(nn.Module): def __init__(self, dim, norm_layer, use_dropout, use_bias, padding_type='reflect', n_domains=0): super(ResnetBlock, self).__init__() conv_block = [] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim + n_domains, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.PReLU()] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim + n_domains, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] self.conv_block = SequentialContext(n_domains, *conv_block) def forward(self, input): if isinstance(input, tuple): return input[0] + self.conv_block(*input) return input + self.conv_block(input) # Defines the PatchGAN discriminator with the specified arguments. class NLayerDiscriminator(nn.Module): def __init__(self, input_nc, ndf=64, n_layers=3, blur_fn=None, norm_layer=nn.BatchNorm2d, gpu_ids=[]): super(NLayerDiscriminator, self).__init__() self.gpu_ids = gpu_ids self.blur_fn = blur_fn self.gray_fn = lambda x: (.299*x[:,0,:,:] + .587*x[:,1,:,:] + .114*x[:,2,:,:]).unsqueeze_(1) self.model_gray = self.model(1, ndf, n_layers, norm_layer) self.model_rgb = self.model(input_nc, ndf, n_layers, norm_layer) def model(self, input_nc, ndf, n_layers, norm_layer): if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d kw = 4 padw = int(np.ceil((kw-1)/2)) sequences = [[ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.PReLU() ]] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): nf_mult_prev = nf_mult nf_mult = min(2**n, 8) sequences += [[ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult + 1, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult + 1), nn.PReLU() ]] nf_mult_prev = nf_mult nf_mult = min(2**n_layers, 8) sequences += [[ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.PReLU(), \ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) ]] return SequentialOutput(*sequences) def forward(self, input): luminance, blurred_rgb = self.gray_fn(input), self.blur_fn(input) if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): outs1 = nn.parallel.data_parallel(self.model_gray, luminance, self.gpu_ids) outs2 = nn.parallel.data_parallel(self.model_rgb, blurred_rgb, self.gpu_ids) else: outs1 = self.model_gray(luminance) outs2 = self.model_rgb(blurred_rgb) return [torch.cat([o1,o2], 1) for o1,o2 in zip(outs1, outs2)] class Plexer(nn.Module): def __init__(self): super(Plexer, self).__init__() def apply(self, func): for net in self.networks: net.apply(func) def cuda(self, device_id): for net in self.networks: net.cuda(device_id) def init_optimizers(self, opt, lr, betas): self.optimizers = [opt(net.parameters(), lr=lr, betas=betas) \ for net in self.networks] def zero_grads(self, dom_a, dom_b): self.optimizers[dom_a].zero_grad() self.optimizers[dom_b].zero_grad() def step_grads(self, dom_a, dom_b): self.optimizers[dom_a].step() self.optimizers[dom_b].step() def update_lr(self, new_lr): for opt in self.optimizers: for param_group in opt.param_groups: param_group['lr'] = new_lr def save(self, save_path): for i, net in enumerate(self.networks): filename = save_path + ('%d.pth' % i) torch.save(net.cpu().state_dict(), filename) def load(self, save_path): for i, net in enumerate(self.networks): filename = save_path + ('%d.pth' % i) net.load_state_dict(torch.load(filename)) class G_Plexer(Plexer): def __init__(self, n_domains, encoder, enc_args, decoder, dec_args, block=None, shenc_args=None, shdec_args=None): super(G_Plexer, self).__init__() self.encoders = [encoder(*enc_args) for _ in range(n_domains)] self.decoders = [decoder(*dec_args) for _ in range(n_domains)] self.sharing = block is not None if self.sharing: self.shared_encoder = block(*shenc_args) self.shared_decoder = block(*shdec_args) self.encoders.append( self.shared_encoder ) self.decoders.append( self.shared_decoder ) self.networks = self.encoders + self.decoders def init_optimizers(self, opt, lr, betas): self.optimizers = [] for enc, dec in zip(self.encoders, self.decoders): params = itertools.chain(enc.parameters(), dec.parameters()) self.optimizers.append( opt(params, lr=lr, betas=betas) ) def forward(self, input, in_domain, out_domain): encoded = self.encode(input, in_domain) return self.decode(encoded, out_domain) def encode(self, input, domain): output = self.encoders[domain].forward(input) if self.sharing: return self.shared_encoder.forward(output, domain) return output def decode(self, input, domain): if self.sharing: input = self.shared_decoder.forward(input, domain) return self.decoders[domain].forward(input) def zero_grads(self, dom_a, dom_b): self.optimizers[dom_a].zero_grad() if self.sharing: self.optimizers[-1].zero_grad() self.optimizers[dom_b].zero_grad() def step_grads(self, dom_a, dom_b): self.optimizers[dom_a].step() if self.sharing: self.optimizers[-1].step() self.optimizers[dom_b].step() def __repr__(self): e, d = self.encoders[0], self.decoders[0] e_params = sum([p.numel() for p in e.parameters()]) d_params = sum([p.numel() for p in d.parameters()]) return repr(e) +'\n'+ repr(d) +'\n'+ \ 'Created %d Encoder-Decoder pairs' % len(self.encoders) +'\n'+ \ 'Number of parameters per Encoder: %d' % e_params +'\n'+ \ 'Number of parameters per Deocder: %d' % d_params class D_Plexer(Plexer): def __init__(self, n_domains, model, model_args): super(D_Plexer, self).__init__() self.networks = [model(*model_args) for _ in range(n_domains)] def forward(self, input, domain): discriminator = self.networks[domain] return discriminator.forward(input) def __repr__(self): t = self.networks[0] t_params = sum([p.numel() for p in t.parameters()]) return repr(t) +'\n'+ \ 'Created %d Discriminators' % len(self.networks) +'\n'+ \ 'Number of parameters per Discriminator: %d' % t_params class SequentialContext(nn.Sequential): def __init__(self, n_classes, *args): super(SequentialContext, self).__init__(*args) self.n_classes = n_classes self.context_var = None def prepare_context(self, input, domain): if self.context_var is None or self.context_var.size()[-2:] != input.size()[-2:]: tensor = torch.cuda.FloatTensor if isinstance(input.data, torch.cuda.FloatTensor) \ else torch.FloatTensor self.context_var = tensor(*((1, self.n_classes) + input.size()[-2:])) self.context_var.data.fill_(-1.0) self.context_var.data[:,domain,:,:] = 1.0 return self.context_var def forward(self, *input): if self.n_classes < 2 or len(input) < 2: return super(SequentialContext, self).forward(input[0]) x, domain = input for module in self._modules.values(): if 'Conv' in module.__class__.__name__: context_var = self.prepare_context(x, domain) x = torch.cat([x, context_var], dim=1) elif 'Block' in module.__class__.__name__: x = (x,) + input[1:] x = module(x) return x class SequentialOutput(nn.Sequential): def __init__(self, *args): args = [nn.Sequential(*arg) for arg in args] super(SequentialOutput, self).__init__(*args) def forward(self, input): predictions = [] layers = self._modules.values() for i, module in enumerate(layers): output = module(input) if i == 0: input = output; continue predictions.append( output[:,-1,:,:] ) if i != len(layers) - 1: input = output[:,:-1,:,:] return predictions ================================================ FILE: options/__init__.py ================================================ ================================================ FILE: options/base_options.py ================================================ import argparse import os from util import util import torch class BaseOptions(): def __init__(self): self.parser = argparse.ArgumentParser() self.initialized = False def initialize(self): self.parser.add_argument('--name', required=True, type=str, help='name of the experiment. It decides where to store samples and models') self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') self.parser.add_argument('--dataroot', required=True, type=str, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') self.parser.add_argument('--n_domains', required=True, type=int, help='Number of domains to transfer among') self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize|resize_and_crop|crop]') self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size') self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size') self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') self.parser.add_argument('--netG_n_blocks', type=int, default=9, help='number of residual blocks to use for netG') self.parser.add_argument('--netG_n_shared', type=int, default=0, help='number of blocks to use for netG shared center module') self.parser.add_argument('--netD_n_layers', type=int, default=4, help='number of layers to use for netD') self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') self.parser.add_argument('--use_dropout', action='store_true', help='insert dropout for the generator') self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display (set >1 to use visdom)') self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') self.initialized = True def parse(self): if not self.initialized: self.initialize() self.opt = self.parser.parse_args() self.opt.isTrain = self.isTrain # train or test str_ids = self.opt.gpu_ids.split(',') self.opt.gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >= 0: self.opt.gpu_ids.append(id) # set gpu ids if len(self.opt.gpu_ids) > 0: torch.cuda.set_device(self.opt.gpu_ids[0]) args = vars(self.opt) print('------------ Options -------------') for k, v in sorted(args.items()): print('%s: %s' % (str(k), str(v))) print('-------------- End ----------------') # save to the disk expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) util.mkdirs(expr_dir) file_name = os.path.join(expr_dir, 'opt.txt') with open(file_name, 'wt') as opt_file: opt_file.write('------------ Options -------------\n') for k, v in sorted(args.items()): opt_file.write('%s: %s\n' % (str(k), str(v))) opt_file.write('-------------- End ----------------\n') return self.opt ================================================ FILE: options/test_options.py ================================================ from .base_options import BaseOptions class TestOptions(BaseOptions): def initialize(self): BaseOptions.initialize(self) self.isTrain = False self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') self.parser.add_argument('--which_epoch', required=True, type=int, help='which epoch to load for inference?') self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc (determines name of folder to load from)') self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run (if serial_test not enabled)') self.parser.add_argument('--serial_test', action='store_true', help='read each image once from folders in sequential order') self.parser.add_argument('--autoencode', action='store_true', help='translate images back into its own domain') self.parser.add_argument('--reconstruct', action='store_true', help='do reconstructions of images during testing') self.parser.add_argument('--show_matrix', action='store_true', help='visualize images in a matrix format as well') ================================================ FILE: options/train_options.py ================================================ from .base_options import BaseOptions class TrainOptions(BaseOptions): def initialize(self): BaseOptions.initialize(self) self.isTrain = True self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') self.parser.add_argument('--which_epoch', type=int, default=0, help='which epoch to load if continuing training') self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc (determines name of folder to load from)') self.parser.add_argument('--niter', required=True, type=int, help='# of epochs at starting learning rate (try 50*n_domains)') self.parser.add_argument('--niter_decay', required=True, type=int, help='# of epochs to linearly decay learning rate to zero (try 50*n_domains)') self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for ADAM') self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of ADAM') self.parser.add_argument('--lambda_cycle', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') self.parser.add_argument('--lambda_identity', type=float, default=0.0, help='weight for identity "autoencode" mapping (A -> A)') self.parser.add_argument('--lambda_latent', type=float, default=0.0, help='weight for latent-space loss (A -> z -> B -> z)') self.parser.add_argument('--lambda_forward', type=float, default=0.0, help='weight for forward loss (A -> B; try 0.2)') self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') self.parser.add_argument('--no_lsgan', action='store_true', help='use vanilla discriminator in place of least-squares one') self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') ================================================ FILE: scripts/continue_combogan.sh ================================================ python train.py \ --dataroot ./datasets/alps \ --name alps_combogan \ --continue_train \ --which_epoch 117 \ --n_domains 4 \ --niter 200 \ --niter_decay 200 \ --lambda_identity 0.0 \ --lambda_forward 0.0 ================================================ FILE: scripts/test_combogan.sh ================================================ python test.py \ --phase test \ --dataroot ./datasets/alps \ --name alps_combogan \ --n_domains 4 \ --which_epoch 400 \ --show_matrix ================================================ FILE: scripts/train_combogan.sh ================================================ python train.py \ --dataroot ./datasets/alps \ --name alps_combogan \ --n_domains 4 \ --niter 200 \ --niter_decay 200 \ --lambda_identity 0.0 \ --lambda_forward 0.0 ================================================ FILE: test.py ================================================ import time import os from options.test_options import TestOptions from data.data_loader import DataLoader from models.combogan_model import ComboGANModel from util.visualizer import Visualizer from util import html opt = TestOptions().parse() opt.nThreads = 1 # test code only supports nThreads = 1 opt.batchSize = 1 # test code only supports batchSize = 1 dataset = DataLoader(opt) model = ComboGANModel(opt) visualizer = Visualizer(opt) # create website web_dir = os.path.join(opt.results_dir, opt.name, '%s_%d' % (opt.phase, opt.which_epoch)) webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %d' % (opt.name, opt.phase, opt.which_epoch)) # store images for matrix visualization vis_buffer = [] # test for i, data in enumerate(dataset): if not opt.serial_test and i >= opt.how_many: break model.set_input(data) model.test() visuals = model.get_current_visuals(testing=True) img_path = model.get_image_paths() print('process image... %s' % img_path) visualizer.save_images(webpage, visuals, img_path) if opt.show_matrix: vis_buffer.append(visuals) if (i+1) % opt.n_domains == 0: save_path = os.path.join(web_dir, 'mat_%d.png' % (i//opt.n_domains)) visualizer.save_image_matrix(vis_buffer, save_path) vis_buffer.clear() webpage.save() ================================================ FILE: train.py ================================================ import time from options.train_options import TrainOptions from data.data_loader import DataLoader from models.combogan_model import ComboGANModel from util.visualizer import Visualizer opt = TrainOptions().parse() dataset = DataLoader(opt) print('# training images = %d' % len(dataset)) model = ComboGANModel(opt) visualizer = Visualizer(opt) total_steps = 0 # Update initially if continuing if opt.which_epoch > 0: model.update_hyperparams(opt.which_epoch) for epoch in range(opt.which_epoch + 1, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() epoch_iter = 0 for i, data in enumerate(dataset): iter_start_time = time.time() total_steps += opt.batchSize epoch_iter += opt.batchSize model.set_input(data) model.optimize_parameters() if total_steps % opt.display_freq == 0: visualizer.display_current_results(model.get_current_visuals(), epoch) if total_steps % opt.print_freq == 0: errors = model.get_current_errors() t = (time.time() - iter_start_time) / opt.batchSize visualizer.print_current_errors(epoch, epoch_iter, errors, t) if opt.display_id > 0: visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) if epoch % opt.save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) model.save(epoch) print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) model.update_hyperparams(epoch) ================================================ FILE: util/__init__.py ================================================ ================================================ FILE: util/get_data.py ================================================ from __future__ import print_function import os import tarfile import requests from warnings import warn from zipfile import ZipFile from bs4 import BeautifulSoup from os.path import abspath, isdir, join, basename class GetData(object): """ Download CycleGAN or Pix2Pix Data. Args: technique : str One of: 'cyclegan' or 'pix2pix'. verbose : bool If True, print additional information. Examples: >>> from util.get_data import GetData >>> gd = GetData(technique='cyclegan') >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. """ def __init__(self, technique='cyclegan', verbose=True): url_dict = { 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets', 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' } self.url = url_dict.get(technique.lower()) self._verbose = verbose def _print(self, text): if self._verbose: print(text) @staticmethod def _get_options(r): soup = BeautifulSoup(r.text, 'lxml') options = [h.text for h in soup.find_all('a', href=True) if h.text.endswith(('.zip', 'tar.gz'))] return options def _present_options(self): r = requests.get(self.url) options = self._get_options(r) print('Options:\n') for i, o in enumerate(options): print("{0}: {1}".format(i, o)) choice = input("\nPlease enter the number of the " "dataset above you wish to download:") return options[int(choice)] def _download_data(self, dataset_url, save_path): if not isdir(save_path): os.makedirs(save_path) base = basename(dataset_url) temp_save_path = join(save_path, base) with open(temp_save_path, "wb") as f: r = requests.get(dataset_url) f.write(r.content) if base.endswith('.tar.gz'): obj = tarfile.open(temp_save_path) elif base.endswith('.zip'): obj = ZipFile(temp_save_path, 'r') else: raise ValueError("Unknown File Type: {0}.".format(base)) self._print("Unpacking Data...") obj.extractall(save_path) obj.close() os.remove(temp_save_path) def get(self, save_path, dataset=None): """ Download a dataset. Args: save_path : str A directory to save the data to. dataset : str, optional A specific dataset to download. Note: this must include the file extension. If None, options will be presented for you to choose from. Returns: save_path_full : str The absolute path to the downloaded data. """ if dataset is None: selected_dataset = self._present_options() else: selected_dataset = dataset save_path_full = join(save_path, selected_dataset.split('.')[0]) if isdir(save_path_full): warn("\n'{0}' already exists. Voiding Download.".format( save_path_full)) else: self._print('Downloading Data...') url = "{0}/{1}".format(self.url, selected_dataset) self._download_data(url, save_path=save_path) return abspath(save_path_full) ================================================ FILE: util/html.py ================================================ import dominate from dominate.tags import * import os class HTML: def __init__(self, web_dir, title, reflesh=0): self.title = title self.web_dir = web_dir self.img_dir = os.path.join(self.web_dir, 'images') if not os.path.exists(self.web_dir): os.makedirs(self.web_dir) if not os.path.exists(self.img_dir): os.makedirs(self.img_dir) # print(self.img_dir) self.doc = dominate.document(title=title) if reflesh > 0: with self.doc.head: meta(http_equiv="reflesh", content=str(reflesh)) def get_image_dir(self): return self.img_dir def add_header(self, str): with self.doc: h3(str) def add_table(self, border=1): self.t = table(border=border, style="table-layout: fixed;") self.doc.add(self.t) def add_images(self, ims, txts, links, width=400): self.add_table() with self.t: with tr(): for im, txt, link in zip(ims, txts, links): with td(style="word-wrap: break-word;", halign="center", valign="top"): with p(): with a(href=os.path.join('images', link)): img(style="width:%dpx" % width, src=os.path.join('images', im)) br() p(txt) def save(self): html_file = '%s/index.html' % self.web_dir f = open(html_file, 'wt') f.write(self.doc.render()) f.close() if __name__ == '__main__': html = HTML('web/', 'test_html') html.add_header('hello world') ims = [] txts = [] links = [] for n in range(4): ims.append('image_%d.png' % n) txts.append('text_%d' % n) links.append('image_%d.png' % n) html.add_images(ims, txts, links) html.save() ================================================ FILE: util/image_pool.py ================================================ import random import numpy as np import torch from torch.autograd import Variable class ImagePool(): def __init__(self, pool_size): self.pool_size = pool_size if self.pool_size > 0: self.num_imgs = 0 self.images = [] def query(self, images): if self.pool_size == 0: return images return_images = [] for image in images.data: image = torch.unsqueeze(image, 0) if self.num_imgs < self.pool_size: self.num_imgs = self.num_imgs + 1 self.images.append(image) return_images.append(image) else: p = random.uniform(0, 1) if p > 0.5: random_id = random.randint(0, self.pool_size-1) tmp = self.images[random_id].clone() self.images[random_id] = image return_images.append(tmp) else: return_images.append(image) return_images = Variable(torch.cat(return_images, 0)) return return_images ================================================ FILE: util/png.py ================================================ import struct import zlib def encode(buf, width, height): """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """ assert (width * height * 3 == len(buf)) bpp = 3 def raw_data(): # reverse the vertical line order and add null bytes at the start row_bytes = width * bpp for row_start in range((height - 1) * width * bpp, -1, -row_bytes): yield b'\x00' yield buf[row_start:row_start + row_bytes] def chunk(tag, data): return [ struct.pack("!I", len(data)), tag, data, struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag))) ] SIGNATURE = b'\x89PNG\r\n\x1a\n' COLOR_TYPE_RGB = 2 COLOR_TYPE_RGBA = 6 bit_depth = 8 return b''.join( [ SIGNATURE ] + chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) + chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) + chunk(b'IEND', b'') ) ================================================ FILE: util/util.py ================================================ from __future__ import print_function import torch import numpy as np from scipy.ndimage.filters import gaussian_filter from PIL import Image import inspect, re import os import collections # Converts a Tensor into a Numpy array # |imtype|: the desired type of the converted numpy array def tensor2im(image_tensor, imtype=np.uint8): image_numpy = image_tensor[0].cpu().float().numpy() image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 return image_numpy.astype(imtype) def gkern_2d(size=5, sigma=3): # Create 2D gaussian kernel dirac = np.zeros((size, size)) dirac[size//2, size//2] = 1 mask = gaussian_filter(dirac, sigma) # Adjust dimensions for torch conv2d return np.stack([np.expand_dims(mask, axis=0)] * 3) def diagnose_network(net, name='network'): mean = 0.0 count = 0 for param in net.parameters(): if param.grad is not None: mean += torch.mean(torch.abs(param.grad.data)) count += 1 if count > 0: mean = mean / count print(name) print(mean) def save_image(image_numpy, image_path): image_pil = Image.fromarray(image_numpy) image_pil.save(image_path) def info(object, spacing=10, collapse=1): """Print methods and doc strings. Takes module, class, list, dictionary, or string.""" methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)] processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s) print( "\n".join(["%s %s" % (method.ljust(spacing), processFunc(str(getattr(object, method).__doc__))) for method in methodList]) ) def varname(p): for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]: m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line) if m: return m.group(1) def print_numpy(x, val=True, shp=False): x = x.astype(np.float64) if shp: print('shape,', x.shape) if val: x = x.flatten() print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) def mkdirs(paths): if isinstance(paths, list) and not isinstance(paths, str): for path in paths: mkdir(path) else: mkdir(paths) def mkdir(path): if not os.path.exists(path): os.makedirs(path) ================================================ FILE: util/visualizer.py ================================================ import numpy as np import os import ntpath import time from . import util from . import html class Visualizer(): def __init__(self, opt): # self.opt = opt self.display_id = opt.display_id self.use_html = opt.isTrain and not opt.no_html self.win_size = opt.display_winsize self.name = opt.name if self.display_id > 0: import visdom self.vis = visdom.Visdom(port = opt.display_port) self.display_single_pane_ncols = opt.display_single_pane_ncols if self.use_html: self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') self.img_dir = os.path.join(self.web_dir, 'images') print('create web directory %s...' % self.web_dir) util.mkdirs([self.web_dir, self.img_dir]) self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') with open(self.log_name, "a") as log_file: now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) # |visuals|: dictionary of images to display or save def display_current_results(self, visuals, epoch): if self.display_id > 0: # show images in the browser if self.display_single_pane_ncols > 0: h, w = next(iter(visuals.values())).shape[:2] table_css = """""" % (w, h) ncols = self.display_single_pane_ncols title = self.name label_html = '' label_html_row = '' nrows = int(np.ceil(len(visuals.items()) / ncols)) images = [] idx = 0 for label, image_numpy in visuals.items(): label_html_row += '%s' % label images.append(image_numpy.transpose([2, 0, 1])) idx += 1 if idx % ncols == 0: label_html += '%s' % label_html_row label_html_row = '' white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 while idx % ncols != 0: images.append(white_image) label_html_row += '' idx += 1 if label_html_row != '': label_html += '%s' % label_html_row # pane col = image row self.vis.images(images, nrow=ncols, win=self.display_id + 1, padding=2, opts=dict(title=title + ' images')) label_html = '%s
' % label_html self.vis.text(table_css + label_html, win = self.display_id + 2, opts=dict(title=title + ' labels')) else: idx = 1 for label, image_numpy in visuals.items(): #image_numpy = np.flipud(image_numpy) self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label), win=self.display_id + idx) idx += 1 if self.use_html: # save images to a html file for label, image_numpy in visuals.items(): img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) util.save_image(image_numpy, img_path) # update website webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) for n in range(epoch, 0, -1): webpage.add_header('epoch [%d]' % n) ims = [] txts = [] links = [] for label, image_numpy in visuals.items(): img_path = 'epoch%.3d_%s.png' % (n, label) ims.append(img_path) txts.append(label) links.append(img_path) webpage.add_images(ims, txts, links, width=self.win_size) webpage.save() # errors: dictionary of error labels and values def plot_current_errors(self, epoch, counter_ratio, opt, errors): if not hasattr(self, 'plot_data'): self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} self.plot_data['X'].append(epoch + counter_ratio) self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) self.vis.line( X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), Y=np.array(self.plot_data['Y']), opts={ 'title': self.name + ' loss over time', 'legend': self.plot_data['legend'], 'xlabel': 'epoch', 'ylabel': 'loss'}, win=self.display_id) # errors: same format as |errors| of plotCurrentErrors def print_current_errors(self, epoch, i, errors, t): message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) for k, v in errors.items(): v = ['%.3f' % iv for iv in v] message += k + ': ' + ', '.join(v) + ' | ' print(message) with open(self.log_name, "a") as log_file: log_file.write('%s\n' % message) # save image to the disk def save_images(self, webpage, visuals, image_path): image_dir = webpage.get_image_dir() short_path = ntpath.basename(image_path[0]) name = os.path.splitext(short_path)[0] webpage.add_header(name) ims = [] txts = [] links = [] for label, image_numpy in visuals.items(): image_name = '%s_%s.png' % (name, label) save_path = os.path.join(image_dir, image_name) util.save_image(image_numpy, save_path) ims.append(image_name) txts.append(label) links.append(image_name) webpage.add_images(ims, txts, links, width=self.win_size) def save_image_matrix(self, visuals_list, save_path): images_list = [] get_domain = lambda x: x.split('_')[-1] for visuals in visuals_list: pairs = list(visuals.items()) real_label, real_img = pairs[0] real_dom = get_domain(real_label) for label, img in pairs: if 'fake' not in label: continue if get_domain(label) == real_dom: images_list.append(real_img) else: images_list.append(img) immat = self.stack_images(images_list) util.save_image(immat, save_path) # reshape a list of images into a square matrix of them def stack_images(self, list_np_images): n = int(np.ceil(np.sqrt(len(list_np_images)))) # add padding between images for i, im in enumerate(list_np_images): val = 255 if i%n == i//n else 0 r_pad = np.pad(im[:,:,0], (3,3), mode='constant', constant_values=0) g_pad = np.pad(im[:,:,1], (3,3), mode='constant', constant_values=val) b_pad = np.pad(im[:,:,2], (3,3), mode='constant', constant_values=0) list_np_images[i] = np.stack([r_pad,g_pad,b_pad], axis=2) data = np.array(list_np_images) data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1))) data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:]) return data