Repository: AAnoosheh/ComboGAN
Branch: master
Commit: 27643b6fb26b
Files: 28
Total size: 68.3 KB
Directory structure:
gitextract_lw_z0n5w/
├── .gitignore
├── LICENSE
├── README.md
├── data/
│ ├── __init__.py
│ ├── base_dataset.py
│ ├── data_loader.py
│ ├── image_folder.py
│ └── unaligned_dataset.py
├── models/
│ ├── __init__.py
│ ├── base_model.py
│ ├── combogan_model.py
│ └── networks.py
├── options/
│ ├── __init__.py
│ ├── base_options.py
│ ├── test_options.py
│ └── train_options.py
├── scripts/
│ ├── continue_combogan.sh
│ ├── test_combogan.sh
│ └── train_combogan.sh
├── test.py
├── train.py
└── util/
├── __init__.py
├── get_data.py
├── html.py
├── image_pool.py
├── png.py
├── util.py
└── visualizer.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
datasets/
checkpoints/
results/
*.png
*/**/__pycache__
*/*.pyc
*/**/*.pyc
*/**/**/*.pyc
*/**/**/**/*.pyc
*/**/**/**/**/*.pyc
*/*.so*
*/**/*.so*
*/**/*.dylib*
*~
================================================
FILE: LICENSE
================================================
Copyright (c) 2017, Asha Anoosheh
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: README.md
================================================
# ComboGAN
This is our ongoing PyTorch implementation for ComboGAN.
Code was written by [Asha Anoosheh](https://github.com/aanoosheh) (built upon [CycleGAN](https://github.com/junyanz/CycleGAN))
#### [[ComboGAN Paper]](https://arxiv.org/pdf/1712.06909.pdf)
If you use this code for your research, please cite:
ComboGAN: Unrestrained Scalability for Image Domain Translation
[Asha Anoosheh](http://ashaanoosheh.com), [Eirikur Augustsson](https://relational.github.io/), [Radu Timofte](http://www.vision.ee.ethz.ch/~timofter/), [Luc van Gool](https://www.vision.ee.ethz.ch/en/members/get_member.cgi?id=1)
In Arxiv, 2017.
## Prerequisites
- Linux or macOS
- Python 3
- CPU or NVIDIA GPU + CUDA CuDNN
## Getting Started
### Installation
- Install PyTorch and dependencies from http://pytorch.org
- Install Torch vision from the source.
```bash
git clone https://github.com/pytorch/vision
cd vision
python setup.py install
```
- Install python libraries [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate).
```bash
pip install visdom
pip install dominate
```
- Clone this repo:
```bash
git clone https://github.com/AAnoosheh/ComboGAN.git
cd ComboGAN
```
### ComboGAN training
Our ready datasets can be downloaded using `./datasets/download_dataset.sh `.
A pretrained model for the 14-painters dataset can be found [HERE](https://www.dropbox.com/s/t8s6x0bu52d73s0/paint14_pretrained.zip?dl=0). Place under `./checkpoints/` and test using the instructions below, with args `--name paint14_pretrained --dataroot ./datasets/painters_14 --n_domains 14 --which_epoch 1150`.
Example running scripts can be found in the `scripts` directory.
- Train a model:
```
python train.py --name --dataroot ./datasets/ --n_domains --niter --niter_decay
```
Checkpoints will be saved by default to `./checkpoints//`
- Fine-tuning/Resume training:
```
python train.py --continue_train --which_epoch --name --dataroot ./datasets/ --n_domains --niter --niter_decay
```
- Test the model:
```
python test.py --phase test --name --dataroot ./datasets/ --n_domains --which_epoch --serial_test
```
The test results will be saved to a html file here: `./results///index.html`.
## Training/Testing Details
- Flags: see `options/train_options.py` for training-specific flags; see `options/test_options.py` for test-specific flags; and see `options/base_options.py` for all common flags.
- Dataset format: The desired data directory (provided by `--dataroot`) should contain subfolders of the form `train*/` and `test*/`, and they are loaded in alphabetical order. (Note that a folder named train10 would be loaded before train2, and thus all checkpoints and results would be ordered accordingly.)
- CPU/GPU (default `--gpu_ids 0`): set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. You need a large batch size (e.g. `--batchSize 32`) to benefit from multiple GPUs.
- Visualization: during training, the current results and loss plots can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will appear on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have `visdom` installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id 0`. Secondly, the intermediate results are also saved to `./checkpoints//web/index.html`. To avoid this, set the `--no_html` flag.
- Preprocessing: images can be resized and cropped in different ways using `--resize_or_crop` option. The default option `'resize_and_crop'` resizes the image to be of size `(opt.loadSize, opt.loadSize)` and does a random crop of size `(opt.fineSize, opt.fineSize)`. `'crop'` skips the resizing step and only performs random cropping. `'scale_width'` resizes the image to have width `opt.fineSize` while keeping the aspect ratio. `'scale_width_and_crop'` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`.
NOTE: one should **not** expect ComboGAN to work on just any combination of input and output datasets (e.g. `dogs<->houses`). We find it works better if two datasets share similar visual content. For example, `landscape painting<->landscape photographs` works much better than `portrait painting <-> landscape photographs`.
================================================
FILE: data/__init__.py
================================================
================================================
FILE: data/base_dataset.py
================================================
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
def get_transform(opt):
transform_list = []
if 'resize' in opt.resize_or_crop:
transform_list.append(transforms.Resize(opt.loadSize, Image.BICUBIC))
if opt.isTrain:
if 'crop' in opt.resize_or_crop:
transform_list.append(transforms.RandomCrop(opt.fineSize))
if not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())
transform_list += [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
================================================
FILE: data/data_loader.py
================================================
import torch.utils.data
from data.unaligned_dataset import UnalignedDataset
class DataLoader():
def name(self):
return 'DataLoader'
def __init__(self, opt):
self.opt = opt
self.dataset = UnalignedDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
num_workers=int(opt.nThreads))
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
for i, data in enumerate(self.dataloader):
if i >= self.opt.max_dataset_size:
break
yield data
================================================
FILE: data/image_folder.py
================================================
###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
###############################################################################
import torch.utils.data as data
from PIL import Image
import os
import os.path
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)
================================================
FILE: data/unaligned_dataset.py
================================================
import os.path, glob
import torchvision.transforms as transforms
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
import random
class UnalignedDataset(BaseDataset):
def __init__(self, opt):
super(UnalignedDataset, self).__init__()
self.opt = opt
self.transform = get_transform(opt)
datapath = os.path.join(opt.dataroot, opt.phase + '*')
self.dirs = sorted(glob.glob(datapath))
self.paths = [sorted(make_dataset(d)) for d in self.dirs]
self.sizes = [len(p) for p in self.paths]
def load_image(self, dom, idx):
path = self.paths[dom][idx]
img = Image.open(path).convert('RGB')
img = self.transform(img)
return img, path
def __getitem__(self, index):
if not self.opt.isTrain:
if self.opt.serial_test:
for d,s in enumerate(self.sizes):
if index < s:
DA = d; break
index -= s
index_A = index
else:
DA = index % len(self.dirs)
index_A = random.randint(0, self.sizes[DA] - 1)
else:
# Choose two of our domains to perform a pass on
DA, DB = random.sample(range(len(self.dirs)), 2)
index_A = random.randint(0, self.sizes[DA] - 1)
A_img, A_path = self.load_image(DA, index_A)
bundle = {'A': A_img, 'DA': DA, 'path': A_path}
if self.opt.isTrain:
index_B = random.randint(0, self.sizes[DB] - 1)
B_img, _ = self.load_image(DB, index_B)
bundle.update( {'B': B_img, 'DB': DB} )
return bundle
def __len__(self):
if self.opt.isTrain:
return max(self.sizes)
return sum(self.sizes)
def name(self):
return 'UnalignedDataset'
================================================
FILE: models/__init__.py
================================================
================================================
FILE: models/base_model.py
================================================
import os
import torch
class BaseModel():
def name(self):
return 'BaseModel'
def __init__(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
def set_input(self, input):
self.input = input
def forward(self):
pass
# used in test time, no backprop
def test(self):
pass
def get_image_paths(self):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
return self.input
def get_current_errors(self):
return {}
def save(self, label):
pass
# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch, gpu_ids):
save_filename = '%d_net_%s' % (epoch, network_label)
save_path = os.path.join(self.save_dir, save_filename)
network.save(save_path)
if gpu_ids and torch.cuda.is_available():
network.cuda(gpu_ids[0])
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch):
save_filename = '%d_net_%s' % (epoch, network_label)
save_path = os.path.join(self.save_dir, save_filename)
network.load(save_path)
def update_learning_rate():
pass
================================================
FILE: models/combogan_model.py
================================================
import numpy as np
import torch
from collections import OrderedDict
import util.util as util
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
class ComboGANModel(BaseModel):
def name(self):
return 'ComboGANModel'
def __init__(self, opt):
super(ComboGANModel, self).__init__(opt)
self.n_domains = opt.n_domains
self.DA, self.DB = None, None
self.real_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
self.real_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
# load/define networks
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
opt.netG_n_blocks, opt.netG_n_shared,
self.n_domains, opt.norm, opt.use_dropout, self.gpu_ids)
if self.isTrain:
blur_fn = lambda x : torch.nn.functional.conv2d(x, self.Tensor(util.gkern_2d()), groups=3, padding=2)
self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD_n_layers,
self.n_domains, blur_fn, opt.norm, self.gpu_ids)
if not self.isTrain or opt.continue_train:
which_epoch = opt.which_epoch
self.load_network(self.netG, 'G', which_epoch)
if self.isTrain:
self.load_network(self.netD, 'D', which_epoch)
if self.isTrain:
self.fake_pools = [ImagePool(opt.pool_size) for _ in range(self.n_domains)]
# define loss functions
self.L1 = torch.nn.SmoothL1Loss()
self.downsample = torch.nn.AvgPool2d(3, stride=2)
self.criterionCycle = self.L1
self.criterionIdt = lambda y,t : self.L1(self.downsample(y), self.downsample(t))
self.criterionLatent = lambda y,t : self.L1(y, t.detach())
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
# initialize optimizers
self.netG.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999))
self.netD.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999))
# initialize loss storage
self.loss_D, self.loss_G = [0]*self.n_domains, [0]*self.n_domains
self.loss_cycle = [0]*self.n_domains
# initialize loss multipliers
self.lambda_cyc, self.lambda_enc = opt.lambda_cycle, (0 * opt.lambda_latent)
self.lambda_idt, self.lambda_fwd = opt.lambda_identity, opt.lambda_forward
print('---------- Networks initialized -------------')
print(self.netG)
if self.isTrain:
print(self.netD)
print('-----------------------------------------------')
def set_input(self, input):
input_A = input['A']
self.real_A.resize_(input_A.size()).copy_(input_A)
self.DA = input['DA'][0]
if self.isTrain:
input_B = input['B']
self.real_B.resize_(input_B.size()).copy_(input_B)
self.DB = input['DB'][0]
self.image_paths = input['path']
def test(self):
with torch.no_grad():
self.visuals = [self.real_A]
self.labels = ['real_%d' % self.DA]
# cache encoding to not repeat it everytime
encoded = self.netG.encode(self.real_A, self.DA)
for d in range(self.n_domains):
if d == self.DA and not self.opt.autoencode:
continue
fake = self.netG.decode(encoded, d)
self.visuals.append( fake )
self.labels.append( 'fake_%d' % d )
if self.opt.reconstruct:
rec = self.netG.forward(fake, d, self.DA)
self.visuals.append( rec )
self.labels.append( 'rec_%d' % d )
def get_image_paths(self):
return self.image_paths
def backward_D_basic(self, real, fake, domain):
# Real
pred_real = self.netD.forward(real, domain)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = self.netD.forward(fake.detach(), domain)
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss
loss_D = (loss_D_real + loss_D_fake) * 0.5
# backward
loss_D.backward()
return loss_D
def backward_D(self):
#D_A
fake_B = self.fake_pools[self.DB].query(self.fake_B)
self.loss_D[self.DA] = self.backward_D_basic(self.real_B, fake_B, self.DB)
#D_B
fake_A = self.fake_pools[self.DA].query(self.fake_A)
self.loss_D[self.DB] = self.backward_D_basic(self.real_A, fake_A, self.DA)
def backward_G(self):
encoded_A = self.netG.encode(self.real_A, self.DA)
encoded_B = self.netG.encode(self.real_B, self.DB)
# Optional identity "autoencode" loss
if self.lambda_idt > 0:
# Same encoder and decoder should recreate image
idt_A = self.netG.decode(encoded_A, self.DA)
loss_idt_A = self.criterionIdt(idt_A, self.real_A)
idt_B = self.netG.decode(encoded_B, self.DB)
loss_idt_B = self.criterionIdt(idt_B, self.real_B)
else:
loss_idt_A, loss_idt_B = 0, 0
# GAN loss
# D_A(G_A(A))
self.fake_B = self.netG.decode(encoded_A, self.DB)
pred_fake = self.netD.forward(self.fake_B, self.DB)
self.loss_G[self.DA] = self.criterionGAN(pred_fake, True)
# D_B(G_B(B))
self.fake_A = self.netG.decode(encoded_B, self.DA)
pred_fake = self.netD.forward(self.fake_A, self.DA)
self.loss_G[self.DB] = self.criterionGAN(pred_fake, True)
# Forward cycle loss
rec_encoded_A = self.netG.encode(self.fake_B, self.DB)
self.rec_A = self.netG.decode(rec_encoded_A, self.DA)
self.loss_cycle[self.DA] = self.criterionCycle(self.rec_A, self.real_A)
# Backward cycle loss
rec_encoded_B = self.netG.encode(self.fake_A, self.DA)
self.rec_B = self.netG.decode(rec_encoded_B, self.DB)
self.loss_cycle[self.DB] = self.criterionCycle(self.rec_B, self.real_B)
# Optional cycle loss on encoding space
if self.lambda_enc > 0:
loss_enc_A = self.criterionLatent(rec_encoded_A, encoded_A)
loss_enc_B = self.criterionLatent(rec_encoded_B, encoded_B)
else:
loss_enc_A, loss_enc_B = 0, 0
# Optional loss on downsampled image before and after
if self.lambda_fwd > 0:
loss_fwd_A = self.criterionIdt(self.fake_B, self.real_A)
loss_fwd_B = self.criterionIdt(self.fake_A, self.real_B)
else:
loss_fwd_A, loss_fwd_B = 0, 0
# combined loss
loss_G = self.loss_G[self.DA] + self.loss_G[self.DB] + \
(self.loss_cycle[self.DA] + self.loss_cycle[self.DB]) * self.lambda_cyc + \
(loss_idt_A + loss_idt_B) * self.lambda_idt + \
(loss_enc_A + loss_enc_B) * self.lambda_enc + \
(loss_fwd_A + loss_fwd_B) * self.lambda_fwd
loss_G.backward()
def optimize_parameters(self):
# G_A and G_B
self.netG.zero_grads(self.DA, self.DB)
self.backward_G()
self.netG.step_grads(self.DA, self.DB)
# D_A and D_B
self.netD.zero_grads(self.DA, self.DB)
self.backward_D()
self.netD.step_grads(self.DA, self.DB)
def get_current_errors(self):
extract = lambda l: [(i if type(i) is int or type(i) is float else i.item()) for i in l]
D_losses, G_losses, cyc_losses = extract(self.loss_D), extract(self.loss_G), extract(self.loss_cycle)
return OrderedDict([('D', D_losses), ('G', G_losses), ('Cyc', cyc_losses)])
def get_current_visuals(self, testing=False):
if not testing:
self.visuals = [self.real_A, self.fake_B, self.rec_A, self.real_B, self.fake_A, self.rec_B]
self.labels = ['real_A', 'fake_B', 'rec_A', 'real_B', 'fake_A', 'rec_B']
images = [util.tensor2im(v.data) for v in self.visuals]
return OrderedDict(zip(self.labels, images))
def save(self, label):
self.save_network(self.netG, 'G', label, self.gpu_ids)
self.save_network(self.netD, 'D', label, self.gpu_ids)
def update_hyperparams(self, curr_iter):
if curr_iter > self.opt.niter:
decay_frac = (curr_iter - self.opt.niter) / self.opt.niter_decay
new_lr = self.opt.lr * (1 - decay_frac)
self.netG.update_lr(new_lr)
self.netD.update_lr(new_lr)
print('updated learning rate: %f' % new_lr)
if self.opt.lambda_latent > 0:
decay_frac = curr_iter / (self.opt.niter + self.opt.niter_decay)
self.lambda_enc = self.opt.lambda_latent * decay_frac
================================================
FILE: models/networks.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
import functools, itertools
import numpy as np
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0)
elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
return functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
return functools.partial(nn.InstanceNorm2d, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
def define_G(input_nc, output_nc, ngf, n_blocks, n_blocks_shared, n_domains, norm='batch', use_dropout=False, gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
n_blocks -= n_blocks_shared
n_blocks_enc = n_blocks // 2
n_blocks_dec = n_blocks - n_blocks_enc
dup_args = (ngf, norm_layer, use_dropout, gpu_ids, use_bias)
enc_args = (input_nc, n_blocks_enc) + dup_args
dec_args = (output_nc, n_blocks_dec) + dup_args
if n_blocks_shared > 0:
n_blocks_shdec = n_blocks_shared // 2
n_blocks_shenc = n_blocks_shared - n_blocks_shdec
shenc_args = (n_domains, n_blocks_shenc) + dup_args
shdec_args = (n_domains, n_blocks_shdec) + dup_args
plex_netG = G_Plexer(n_domains, ResnetGenEncoder, enc_args, ResnetGenDecoder, dec_args, ResnetGenShared, shenc_args, shdec_args)
else:
plex_netG = G_Plexer(n_domains, ResnetGenEncoder, enc_args, ResnetGenDecoder, dec_args)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
plex_netG.cuda(gpu_ids[0])
plex_netG.apply(weights_init)
return plex_netG
def define_D(input_nc, ndf, netD_n_layers, n_domains, blur_fn, norm='batch', gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
model_args = (input_nc, ndf, netD_n_layers, blur_fn, norm_layer, gpu_ids)
plex_netD = D_Plexer(n_domains, NLayerDiscriminator, model_args)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
plex_netD.cuda(gpu_ids[0])
plex_netD.apply(weights_init)
return plex_netD
##############################################################################
# Classes
##############################################################################
# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.Tensor = tensor
self.labels_real, self.labels_fake = None, None
self.preloss = nn.Sigmoid() if not use_lsgan else None
self.loss = nn.MSELoss() if use_lsgan else nn.BCELoss()
def get_target_tensor(self, inputs, is_real):
if self.labels_real is None or self.labels_real[0].numel() != inputs[0].numel():
self.labels_real = [ self.Tensor(input.size()).fill_(1.0) for input in inputs ]
self.labels_fake = [ self.Tensor(input.size()).fill_(0.0) for input in inputs ]
if is_real:
return self.labels_real
return self.labels_fake
def __call__(self, inputs, is_real):
labels = self.get_target_tensor(inputs, is_real)
if self.preloss is not None:
inputs = [self.preloss(input) for input in inputs]
losses = [self.loss(input, label) for input, label in zip(inputs, labels)]
multipliers = list(range(1, len(inputs)+1)); multipliers[-1] += 1
losses = [m*l for m,l in zip(multipliers, losses)]
return sum(losses) / (sum(multipliers) * len(losses))
# Defines the generator that consists of Resnet blocks between a few
# downsampling/upsampling operations.
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenEncoder(nn.Module):
def __init__(self, input_nc, n_blocks=4, ngf=64, norm_layer=nn.BatchNorm2d,
use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenEncoder, self).__init__()
self.gpu_ids = gpu_ids
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
bias=use_bias),
norm_layer(ngf),
nn.PReLU()]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.PReLU()]
mult = 2**n_downsampling
for _ in range(n_blocks):
model += [ResnetBlock(ngf * mult, norm_layer=norm_layer,
use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)]
self.model = nn.Sequential(*model)
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
return self.model(input)
class ResnetGenShared(nn.Module):
def __init__(self, n_domains, n_blocks=2, ngf=64, norm_layer=nn.BatchNorm2d,
use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenShared, self).__init__()
self.gpu_ids = gpu_ids
model = []
n_downsampling = 2
mult = 2**n_downsampling
for _ in range(n_blocks):
model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, n_domains=n_domains,
use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)]
self.model = SequentialContext(n_domains, *model)
def forward(self, input, domain):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, (input, domain), self.gpu_ids)
return self.model(input, domain)
class ResnetGenDecoder(nn.Module):
def __init__(self, output_nc, n_blocks=5, ngf=64, norm_layer=nn.BatchNorm2d,
use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenDecoder, self).__init__()
self.gpu_ids = gpu_ids
model = []
n_downsampling = 2
mult = 2**n_downsampling
for _ in range(n_blocks):
model += [ResnetBlock(ngf * mult, norm_layer=norm_layer,
use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=4, stride=2,
padding=1, output_padding=0,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.PReLU()]
model += [nn.ReflectionPad2d(3),
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
return self.model(input)
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, norm_layer, use_dropout, use_bias, padding_type='reflect', n_domains=0):
super(ResnetBlock, self).__init__()
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim + n_domains, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.PReLU()]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim + n_domains, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim)]
self.conv_block = SequentialContext(n_domains, *conv_block)
def forward(self, input):
if isinstance(input, tuple):
return input[0] + self.conv_block(*input)
return input + self.conv_block(input)
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, blur_fn=None, norm_layer=nn.BatchNorm2d, gpu_ids=[]):
super(NLayerDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
self.blur_fn = blur_fn
self.gray_fn = lambda x: (.299*x[:,0,:,:] + .587*x[:,1,:,:] + .114*x[:,2,:,:]).unsqueeze_(1)
self.model_gray = self.model(1, ndf, n_layers, norm_layer)
self.model_rgb = self.model(input_nc, ndf, n_layers, norm_layer)
def model(self, input_nc, ndf, n_layers, norm_layer):
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = int(np.ceil((kw-1)/2))
sequences = [[
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.PReLU()
]]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequences += [[
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult + 1,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult + 1),
nn.PReLU()
]]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequences += [[
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.PReLU(),
\
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
]]
return SequentialOutput(*sequences)
def forward(self, input):
luminance, blurred_rgb = self.gray_fn(input), self.blur_fn(input)
if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
outs1 = nn.parallel.data_parallel(self.model_gray, luminance, self.gpu_ids)
outs2 = nn.parallel.data_parallel(self.model_rgb, blurred_rgb, self.gpu_ids)
else:
outs1 = self.model_gray(luminance)
outs2 = self.model_rgb(blurred_rgb)
return [torch.cat([o1,o2], 1) for o1,o2 in zip(outs1, outs2)]
class Plexer(nn.Module):
def __init__(self):
super(Plexer, self).__init__()
def apply(self, func):
for net in self.networks:
net.apply(func)
def cuda(self, device_id):
for net in self.networks:
net.cuda(device_id)
def init_optimizers(self, opt, lr, betas):
self.optimizers = [opt(net.parameters(), lr=lr, betas=betas) \
for net in self.networks]
def zero_grads(self, dom_a, dom_b):
self.optimizers[dom_a].zero_grad()
self.optimizers[dom_b].zero_grad()
def step_grads(self, dom_a, dom_b):
self.optimizers[dom_a].step()
self.optimizers[dom_b].step()
def update_lr(self, new_lr):
for opt in self.optimizers:
for param_group in opt.param_groups:
param_group['lr'] = new_lr
def save(self, save_path):
for i, net in enumerate(self.networks):
filename = save_path + ('%d.pth' % i)
torch.save(net.cpu().state_dict(), filename)
def load(self, save_path):
for i, net in enumerate(self.networks):
filename = save_path + ('%d.pth' % i)
net.load_state_dict(torch.load(filename))
class G_Plexer(Plexer):
def __init__(self, n_domains, encoder, enc_args, decoder, dec_args,
block=None, shenc_args=None, shdec_args=None):
super(G_Plexer, self).__init__()
self.encoders = [encoder(*enc_args) for _ in range(n_domains)]
self.decoders = [decoder(*dec_args) for _ in range(n_domains)]
self.sharing = block is not None
if self.sharing:
self.shared_encoder = block(*shenc_args)
self.shared_decoder = block(*shdec_args)
self.encoders.append( self.shared_encoder )
self.decoders.append( self.shared_decoder )
self.networks = self.encoders + self.decoders
def init_optimizers(self, opt, lr, betas):
self.optimizers = []
for enc, dec in zip(self.encoders, self.decoders):
params = itertools.chain(enc.parameters(), dec.parameters())
self.optimizers.append( opt(params, lr=lr, betas=betas) )
def forward(self, input, in_domain, out_domain):
encoded = self.encode(input, in_domain)
return self.decode(encoded, out_domain)
def encode(self, input, domain):
output = self.encoders[domain].forward(input)
if self.sharing:
return self.shared_encoder.forward(output, domain)
return output
def decode(self, input, domain):
if self.sharing:
input = self.shared_decoder.forward(input, domain)
return self.decoders[domain].forward(input)
def zero_grads(self, dom_a, dom_b):
self.optimizers[dom_a].zero_grad()
if self.sharing:
self.optimizers[-1].zero_grad()
self.optimizers[dom_b].zero_grad()
def step_grads(self, dom_a, dom_b):
self.optimizers[dom_a].step()
if self.sharing:
self.optimizers[-1].step()
self.optimizers[dom_b].step()
def __repr__(self):
e, d = self.encoders[0], self.decoders[0]
e_params = sum([p.numel() for p in e.parameters()])
d_params = sum([p.numel() for p in d.parameters()])
return repr(e) +'\n'+ repr(d) +'\n'+ \
'Created %d Encoder-Decoder pairs' % len(self.encoders) +'\n'+ \
'Number of parameters per Encoder: %d' % e_params +'\n'+ \
'Number of parameters per Deocder: %d' % d_params
class D_Plexer(Plexer):
def __init__(self, n_domains, model, model_args):
super(D_Plexer, self).__init__()
self.networks = [model(*model_args) for _ in range(n_domains)]
def forward(self, input, domain):
discriminator = self.networks[domain]
return discriminator.forward(input)
def __repr__(self):
t = self.networks[0]
t_params = sum([p.numel() for p in t.parameters()])
return repr(t) +'\n'+ \
'Created %d Discriminators' % len(self.networks) +'\n'+ \
'Number of parameters per Discriminator: %d' % t_params
class SequentialContext(nn.Sequential):
def __init__(self, n_classes, *args):
super(SequentialContext, self).__init__(*args)
self.n_classes = n_classes
self.context_var = None
def prepare_context(self, input, domain):
if self.context_var is None or self.context_var.size()[-2:] != input.size()[-2:]:
tensor = torch.cuda.FloatTensor if isinstance(input.data, torch.cuda.FloatTensor) \
else torch.FloatTensor
self.context_var = tensor(*((1, self.n_classes) + input.size()[-2:]))
self.context_var.data.fill_(-1.0)
self.context_var.data[:,domain,:,:] = 1.0
return self.context_var
def forward(self, *input):
if self.n_classes < 2 or len(input) < 2:
return super(SequentialContext, self).forward(input[0])
x, domain = input
for module in self._modules.values():
if 'Conv' in module.__class__.__name__:
context_var = self.prepare_context(x, domain)
x = torch.cat([x, context_var], dim=1)
elif 'Block' in module.__class__.__name__:
x = (x,) + input[1:]
x = module(x)
return x
class SequentialOutput(nn.Sequential):
def __init__(self, *args):
args = [nn.Sequential(*arg) for arg in args]
super(SequentialOutput, self).__init__(*args)
def forward(self, input):
predictions = []
layers = self._modules.values()
for i, module in enumerate(layers):
output = module(input)
if i == 0:
input = output; continue
predictions.append( output[:,-1,:,:] )
if i != len(layers) - 1:
input = output[:,:-1,:,:]
return predictions
================================================
FILE: options/__init__.py
================================================
================================================
FILE: options/base_options.py
================================================
import argparse
import os
from util import util
import torch
class BaseOptions():
def __init__(self):
self.parser = argparse.ArgumentParser()
self.initialized = False
def initialize(self):
self.parser.add_argument('--name', required=True, type=str, help='name of the experiment. It decides where to store samples and models')
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
self.parser.add_argument('--dataroot', required=True, type=str, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
self.parser.add_argument('--n_domains', required=True, type=int, help='Number of domains to transfer among')
self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize|resize_and_crop|crop]')
self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size')
self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')
self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
self.parser.add_argument('--netG_n_blocks', type=int, default=9, help='number of residual blocks to use for netG')
self.parser.add_argument('--netG_n_shared', type=int, default=0, help='number of blocks to use for netG shared center module')
self.parser.add_argument('--netD_n_layers', type=int, default=4, help='number of layers to use for netD')
self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
self.parser.add_argument('--use_dropout', action='store_true', help='insert dropout for the generator')
self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display (set >1 to use visdom)')
self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
self.initialized = True
def parse(self):
if not self.initialized:
self.initialize()
self.opt = self.parser.parse_args()
self.opt.isTrain = self.isTrain # train or test
str_ids = self.opt.gpu_ids.split(',')
self.opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
self.opt.gpu_ids.append(id)
# set gpu ids
if len(self.opt.gpu_ids) > 0:
torch.cuda.set_device(self.opt.gpu_ids[0])
args = vars(self.opt)
print('------------ Options -------------')
for k, v in sorted(args.items()):
print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')
# save to the disk
expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write('------------ Options -------------\n')
for k, v in sorted(args.items()):
opt_file.write('%s: %s\n' % (str(k), str(v)))
opt_file.write('-------------- End ----------------\n')
return self.opt
================================================
FILE: options/test_options.py
================================================
from .base_options import BaseOptions
class TestOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.isTrain = False
self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
self.parser.add_argument('--which_epoch', required=True, type=int, help='which epoch to load for inference?')
self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc (determines name of folder to load from)')
self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run (if serial_test not enabled)')
self.parser.add_argument('--serial_test', action='store_true', help='read each image once from folders in sequential order')
self.parser.add_argument('--autoencode', action='store_true', help='translate images back into its own domain')
self.parser.add_argument('--reconstruct', action='store_true', help='do reconstructions of images during testing')
self.parser.add_argument('--show_matrix', action='store_true', help='visualize images in a matrix format as well')
================================================
FILE: options/train_options.py
================================================
from .base_options import BaseOptions
class TrainOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.isTrain = True
self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
self.parser.add_argument('--which_epoch', type=int, default=0, help='which epoch to load if continuing training')
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc (determines name of folder to load from)')
self.parser.add_argument('--niter', required=True, type=int, help='# of epochs at starting learning rate (try 50*n_domains)')
self.parser.add_argument('--niter_decay', required=True, type=int, help='# of epochs to linearly decay learning rate to zero (try 50*n_domains)')
self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for ADAM')
self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of ADAM')
self.parser.add_argument('--lambda_cycle', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
self.parser.add_argument('--lambda_identity', type=float, default=0.0, help='weight for identity "autoencode" mapping (A -> A)')
self.parser.add_argument('--lambda_latent', type=float, default=0.0, help='weight for latent-space loss (A -> z -> B -> z)')
self.parser.add_argument('--lambda_forward', type=float, default=0.0, help='weight for forward loss (A -> B; try 0.2)')
self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
self.parser.add_argument('--no_lsgan', action='store_true', help='use vanilla discriminator in place of least-squares one')
self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
================================================
FILE: scripts/continue_combogan.sh
================================================
python train.py \
--dataroot ./datasets/alps \
--name alps_combogan \
--continue_train \
--which_epoch 117 \
--n_domains 4 \
--niter 200 \
--niter_decay 200 \
--lambda_identity 0.0 \
--lambda_forward 0.0
================================================
FILE: scripts/test_combogan.sh
================================================
python test.py \
--phase test \
--dataroot ./datasets/alps \
--name alps_combogan \
--n_domains 4 \
--which_epoch 400 \
--show_matrix
================================================
FILE: scripts/train_combogan.sh
================================================
python train.py \
--dataroot ./datasets/alps \
--name alps_combogan \
--n_domains 4 \
--niter 200 \
--niter_decay 200 \
--lambda_identity 0.0 \
--lambda_forward 0.0
================================================
FILE: test.py
================================================
import time
import os
from options.test_options import TestOptions
from data.data_loader import DataLoader
from models.combogan_model import ComboGANModel
from util.visualizer import Visualizer
from util import html
opt = TestOptions().parse()
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
dataset = DataLoader(opt)
model = ComboGANModel(opt)
visualizer = Visualizer(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%d' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %d' % (opt.name, opt.phase, opt.which_epoch))
# store images for matrix visualization
vis_buffer = []
# test
for i, data in enumerate(dataset):
if not opt.serial_test and i >= opt.how_many:
break
model.set_input(data)
model.test()
visuals = model.get_current_visuals(testing=True)
img_path = model.get_image_paths()
print('process image... %s' % img_path)
visualizer.save_images(webpage, visuals, img_path)
if opt.show_matrix:
vis_buffer.append(visuals)
if (i+1) % opt.n_domains == 0:
save_path = os.path.join(web_dir, 'mat_%d.png' % (i//opt.n_domains))
visualizer.save_image_matrix(vis_buffer, save_path)
vis_buffer.clear()
webpage.save()
================================================
FILE: train.py
================================================
import time
from options.train_options import TrainOptions
from data.data_loader import DataLoader
from models.combogan_model import ComboGANModel
from util.visualizer import Visualizer
opt = TrainOptions().parse()
dataset = DataLoader(opt)
print('# training images = %d' % len(dataset))
model = ComboGANModel(opt)
visualizer = Visualizer(opt)
total_steps = 0
# Update initially if continuing
if opt.which_epoch > 0:
model.update_hyperparams(opt.which_epoch)
for epoch in range(opt.which_epoch + 1, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
epoch_iter = 0
for i, data in enumerate(dataset):
iter_start_time = time.time()
total_steps += opt.batchSize
epoch_iter += opt.batchSize
model.set_input(data)
model.optimize_parameters()
if total_steps % opt.display_freq == 0:
visualizer.display_current_results(model.get_current_visuals(), epoch)
if total_steps % opt.print_freq == 0:
errors = model.get_current_errors()
t = (time.time() - iter_start_time) / opt.batchSize
visualizer.print_current_errors(epoch, epoch_iter, errors, t)
if opt.display_id > 0:
visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
model.save(epoch)
print('End of epoch %d / %d \t Time Taken: %d sec' %
(epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
model.update_hyperparams(epoch)
================================================
FILE: util/__init__.py
================================================
================================================
FILE: util/get_data.py
================================================
from __future__ import print_function
import os
import tarfile
import requests
from warnings import warn
from zipfile import ZipFile
from bs4 import BeautifulSoup
from os.path import abspath, isdir, join, basename
class GetData(object):
"""
Download CycleGAN or Pix2Pix Data.
Args:
technique : str
One of: 'cyclegan' or 'pix2pix'.
verbose : bool
If True, print additional information.
Examples:
>>> from util.get_data import GetData
>>> gd = GetData(technique='cyclegan')
>>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
"""
def __init__(self, technique='cyclegan', verbose=True):
url_dict = {
'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets',
'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
}
self.url = url_dict.get(technique.lower())
self._verbose = verbose
def _print(self, text):
if self._verbose:
print(text)
@staticmethod
def _get_options(r):
soup = BeautifulSoup(r.text, 'lxml')
options = [h.text for h in soup.find_all('a', href=True)
if h.text.endswith(('.zip', 'tar.gz'))]
return options
def _present_options(self):
r = requests.get(self.url)
options = self._get_options(r)
print('Options:\n')
for i, o in enumerate(options):
print("{0}: {1}".format(i, o))
choice = input("\nPlease enter the number of the "
"dataset above you wish to download:")
return options[int(choice)]
def _download_data(self, dataset_url, save_path):
if not isdir(save_path):
os.makedirs(save_path)
base = basename(dataset_url)
temp_save_path = join(save_path, base)
with open(temp_save_path, "wb") as f:
r = requests.get(dataset_url)
f.write(r.content)
if base.endswith('.tar.gz'):
obj = tarfile.open(temp_save_path)
elif base.endswith('.zip'):
obj = ZipFile(temp_save_path, 'r')
else:
raise ValueError("Unknown File Type: {0}.".format(base))
self._print("Unpacking Data...")
obj.extractall(save_path)
obj.close()
os.remove(temp_save_path)
def get(self, save_path, dataset=None):
"""
Download a dataset.
Args:
save_path : str
A directory to save the data to.
dataset : str, optional
A specific dataset to download.
Note: this must include the file extension.
If None, options will be presented for you
to choose from.
Returns:
save_path_full : str
The absolute path to the downloaded data.
"""
if dataset is None:
selected_dataset = self._present_options()
else:
selected_dataset = dataset
save_path_full = join(save_path, selected_dataset.split('.')[0])
if isdir(save_path_full):
warn("\n'{0}' already exists. Voiding Download.".format(
save_path_full))
else:
self._print('Downloading Data...')
url = "{0}/{1}".format(self.url, selected_dataset)
self._download_data(url, save_path=save_path)
return abspath(save_path_full)
================================================
FILE: util/html.py
================================================
import dominate
from dominate.tags import *
import os
class HTML:
def __init__(self, web_dir, title, reflesh=0):
self.title = title
self.web_dir = web_dir
self.img_dir = os.path.join(self.web_dir, 'images')
if not os.path.exists(self.web_dir):
os.makedirs(self.web_dir)
if not os.path.exists(self.img_dir):
os.makedirs(self.img_dir)
# print(self.img_dir)
self.doc = dominate.document(title=title)
if reflesh > 0:
with self.doc.head:
meta(http_equiv="reflesh", content=str(reflesh))
def get_image_dir(self):
return self.img_dir
def add_header(self, str):
with self.doc:
h3(str)
def add_table(self, border=1):
self.t = table(border=border, style="table-layout: fixed;")
self.doc.add(self.t)
def add_images(self, ims, txts, links, width=400):
self.add_table()
with self.t:
with tr():
for im, txt, link in zip(ims, txts, links):
with td(style="word-wrap: break-word;", halign="center", valign="top"):
with p():
with a(href=os.path.join('images', link)):
img(style="width:%dpx" % width, src=os.path.join('images', im))
br()
p(txt)
def save(self):
html_file = '%s/index.html' % self.web_dir
f = open(html_file, 'wt')
f.write(self.doc.render())
f.close()
if __name__ == '__main__':
html = HTML('web/', 'test_html')
html.add_header('hello world')
ims = []
txts = []
links = []
for n in range(4):
ims.append('image_%d.png' % n)
txts.append('text_%d' % n)
links.append('image_%d.png' % n)
html.add_images(ims, txts, links)
html.save()
================================================
FILE: util/image_pool.py
================================================
import random
import numpy as np
import torch
from torch.autograd import Variable
class ImagePool():
def __init__(self, pool_size):
self.pool_size = pool_size
if self.pool_size > 0:
self.num_imgs = 0
self.images = []
def query(self, images):
if self.pool_size == 0:
return images
return_images = []
for image in images.data:
image = torch.unsqueeze(image, 0)
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5:
random_id = random.randint(0, self.pool_size-1)
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else:
return_images.append(image)
return_images = Variable(torch.cat(return_images, 0))
return return_images
================================================
FILE: util/png.py
================================================
import struct
import zlib
def encode(buf, width, height):
""" buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """
assert (width * height * 3 == len(buf))
bpp = 3
def raw_data():
# reverse the vertical line order and add null bytes at the start
row_bytes = width * bpp
for row_start in range((height - 1) * width * bpp, -1, -row_bytes):
yield b'\x00'
yield buf[row_start:row_start + row_bytes]
def chunk(tag, data):
return [
struct.pack("!I", len(data)),
tag,
data,
struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag)))
]
SIGNATURE = b'\x89PNG\r\n\x1a\n'
COLOR_TYPE_RGB = 2
COLOR_TYPE_RGBA = 6
bit_depth = 8
return b''.join(
[ SIGNATURE ] +
chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) +
chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) +
chunk(b'IEND', b'')
)
================================================
FILE: util/util.py
================================================
from __future__ import print_function
import torch
import numpy as np
from scipy.ndimage.filters import gaussian_filter
from PIL import Image
import inspect, re
import os
import collections
# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
return image_numpy.astype(imtype)
def gkern_2d(size=5, sigma=3):
# Create 2D gaussian kernel
dirac = np.zeros((size, size))
dirac[size//2, size//2] = 1
mask = gaussian_filter(dirac, sigma)
# Adjust dimensions for torch conv2d
return np.stack([np.expand_dims(mask, axis=0)] * 3)
def diagnose_network(net, name='network'):
mean = 0.0
count = 0
for param in net.parameters():
if param.grad is not None:
mean += torch.mean(torch.abs(param.grad.data))
count += 1
if count > 0:
mean = mean / count
print(name)
print(mean)
def save_image(image_numpy, image_path):
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path)
def info(object, spacing=10, collapse=1):
"""Print methods and doc strings.
Takes module, class, list, dictionary, or string."""
methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)]
processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s)
print( "\n".join(["%s %s" %
(method.ljust(spacing),
processFunc(str(getattr(object, method).__doc__)))
for method in methodList]) )
def varname(p):
for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
if m:
return m.group(1)
def print_numpy(x, val=True, shp=False):
x = x.astype(np.float64)
if shp:
print('shape,', x.shape)
if val:
x = x.flatten()
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
================================================
FILE: util/visualizer.py
================================================
import numpy as np
import os
import ntpath
import time
from . import util
from . import html
class Visualizer():
def __init__(self, opt):
# self.opt = opt
self.display_id = opt.display_id
self.use_html = opt.isTrain and not opt.no_html
self.win_size = opt.display_winsize
self.name = opt.name
if self.display_id > 0:
import visdom
self.vis = visdom.Visdom(port = opt.display_port)
self.display_single_pane_ncols = opt.display_single_pane_ncols
if self.use_html:
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
self.img_dir = os.path.join(self.web_dir, 'images')
print('create web directory %s...' % self.web_dir)
util.mkdirs([self.web_dir, self.img_dir])
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
# |visuals|: dictionary of images to display or save
def display_current_results(self, visuals, epoch):
if self.display_id > 0: # show images in the browser
if self.display_single_pane_ncols > 0:
h, w = next(iter(visuals.values())).shape[:2]
table_css = """""" % (w, h)
ncols = self.display_single_pane_ncols
title = self.name
label_html = ''
label_html_row = ''
nrows = int(np.ceil(len(visuals.items()) / ncols))
images = []
idx = 0
for label, image_numpy in visuals.items():
label_html_row += '%s | ' % label
images.append(image_numpy.transpose([2, 0, 1]))
idx += 1
if idx % ncols == 0:
label_html += '%s
' % label_html_row
label_html_row = ''
white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
while idx % ncols != 0:
images.append(white_image)
label_html_row += ' | '
idx += 1
if label_html_row != '':
label_html += '%s
' % label_html_row
# pane col = image row
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
padding=2, opts=dict(title=title + ' images'))
label_html = '' % label_html
self.vis.text(table_css + label_html, win = self.display_id + 2,
opts=dict(title=title + ' labels'))
else:
idx = 1
for label, image_numpy in visuals.items():
#image_numpy = np.flipud(image_numpy)
self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label),
win=self.display_id + idx)
idx += 1
if self.use_html: # save images to a html file
for label, image_numpy in visuals.items():
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
util.save_image(image_numpy, img_path)
# update website
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
for n in range(epoch, 0, -1):
webpage.add_header('epoch [%d]' % n)
ims = []
txts = []
links = []
for label, image_numpy in visuals.items():
img_path = 'epoch%.3d_%s.png' % (n, label)
ims.append(img_path)
txts.append(label)
links.append(img_path)
webpage.add_images(ims, txts, links, width=self.win_size)
webpage.save()
# errors: dictionary of error labels and values
def plot_current_errors(self, epoch, counter_ratio, opt, errors):
if not hasattr(self, 'plot_data'):
self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
self.plot_data['X'].append(epoch + counter_ratio)
self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
self.vis.line(
X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),
Y=np.array(self.plot_data['Y']),
opts={
'title': self.name + ' loss over time',
'legend': self.plot_data['legend'],
'xlabel': 'epoch',
'ylabel': 'loss'},
win=self.display_id)
# errors: same format as |errors| of plotCurrentErrors
def print_current_errors(self, epoch, i, errors, t):
message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
for k, v in errors.items():
v = ['%.3f' % iv for iv in v]
message += k + ': ' + ', '.join(v) + ' | '
print(message)
with open(self.log_name, "a") as log_file:
log_file.write('%s\n' % message)
# save image to the disk
def save_images(self, webpage, visuals, image_path):
image_dir = webpage.get_image_dir()
short_path = ntpath.basename(image_path[0])
name = os.path.splitext(short_path)[0]
webpage.add_header(name)
ims = []
txts = []
links = []
for label, image_numpy in visuals.items():
image_name = '%s_%s.png' % (name, label)
save_path = os.path.join(image_dir, image_name)
util.save_image(image_numpy, save_path)
ims.append(image_name)
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=self.win_size)
def save_image_matrix(self, visuals_list, save_path):
images_list = []
get_domain = lambda x: x.split('_')[-1]
for visuals in visuals_list:
pairs = list(visuals.items())
real_label, real_img = pairs[0]
real_dom = get_domain(real_label)
for label, img in pairs:
if 'fake' not in label:
continue
if get_domain(label) == real_dom:
images_list.append(real_img)
else:
images_list.append(img)
immat = self.stack_images(images_list)
util.save_image(immat, save_path)
# reshape a list of images into a square matrix of them
def stack_images(self, list_np_images):
n = int(np.ceil(np.sqrt(len(list_np_images))))
# add padding between images
for i, im in enumerate(list_np_images):
val = 255 if i%n == i//n else 0
r_pad = np.pad(im[:,:,0], (3,3), mode='constant', constant_values=0)
g_pad = np.pad(im[:,:,1], (3,3), mode='constant', constant_values=val)
b_pad = np.pad(im[:,:,2], (3,3), mode='constant', constant_values=0)
list_np_images[i] = np.stack([r_pad,g_pad,b_pad], axis=2)
data = np.array(list_np_images)
data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
return data