Repository: AAnoosheh/ComboGAN
Branch: master
Commit: 27643b6fb26b
Files: 28
Total size: 68.3 KB
Directory structure:
gitextract_lw_z0n5w/
├── .gitignore
├── LICENSE
├── README.md
├── data/
│ ├── __init__.py
│ ├── base_dataset.py
│ ├── data_loader.py
│ ├── image_folder.py
│ └── unaligned_dataset.py
├── models/
│ ├── __init__.py
│ ├── base_model.py
│ ├── combogan_model.py
│ └── networks.py
├── options/
│ ├── __init__.py
│ ├── base_options.py
│ ├── test_options.py
│ └── train_options.py
├── scripts/
│ ├── continue_combogan.sh
│ ├── test_combogan.sh
│ └── train_combogan.sh
├── test.py
├── train.py
└── util/
├── __init__.py
├── get_data.py
├── html.py
├── image_pool.py
├── png.py
├── util.py
└── visualizer.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
datasets/
checkpoints/
results/
*.png
*/**/__pycache__
*/*.pyc
*/**/*.pyc
*/**/**/*.pyc
*/**/**/**/*.pyc
*/**/**/**/**/*.pyc
*/*.so*
*/**/*.so*
*/**/*.dylib*
*~
================================================
FILE: LICENSE
================================================
Copyright (c) 2017, Asha Anoosheh
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: README.md
================================================
# ComboGAN
This is our ongoing PyTorch implementation for ComboGAN.
Code was written by [Asha Anoosheh](https://github.com/aanoosheh) (built upon [CycleGAN](https://github.com/junyanz/CycleGAN))
#### [[ComboGAN Paper]](https://arxiv.org/pdf/1712.06909.pdf)
<img src="img/Inference.png" width=420/>
If you use this code for your research, please cite:
ComboGAN: Unrestrained Scalability for Image Domain Translation
[Asha Anoosheh](http://ashaanoosheh.com), [Eirikur Augustsson](https://relational.github.io/), [Radu Timofte](http://www.vision.ee.ethz.ch/~timofter/), [Luc van Gool](https://www.vision.ee.ethz.ch/en/members/get_member.cgi?id=1)
In Arxiv, 2017.
<br><br>
<img src='img/Paintings.png' align="center" width=900>
<br><br>
## Prerequisites
- Linux or macOS
- Python 3
- CPU or NVIDIA GPU + CUDA CuDNN
## Getting Started
### Installation
- Install PyTorch and dependencies from http://pytorch.org
- Install Torch vision from the source.
```bash
git clone https://github.com/pytorch/vision
cd vision
python setup.py install
```
- Install python libraries [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate).
```bash
pip install visdom
pip install dominate
```
- Clone this repo:
```bash
git clone https://github.com/AAnoosheh/ComboGAN.git
cd ComboGAN
```
### ComboGAN training
Our ready datasets can be downloaded using `./datasets/download_dataset.sh <dataset_name>`.
A pretrained model for the 14-painters dataset can be found [HERE](https://www.dropbox.com/s/t8s6x0bu52d73s0/paint14_pretrained.zip?dl=0). Place under `./checkpoints/` and test using the instructions below, with args `--name paint14_pretrained --dataroot ./datasets/painters_14 --n_domains 14 --which_epoch 1150`.
Example running scripts can be found in the `scripts` directory.
- Train a model:
```
python train.py --name <experiment_name> --dataroot ./datasets/<your_dataset> --n_domains <N> --niter <num_epochs_constant_LR> --niter_decay <num_epochs_decaying_LR>
```
Checkpoints will be saved by default to `./checkpoints/<experiment_name>/`
- Fine-tuning/Resume training:
```
python train.py --continue_train --which_epoch <checkpoint_number_to_load> --name <experiment_name> --dataroot ./datasets/<your_dataset> --n_domains <N> --niter <num_epochs_constant_LR> --niter_decay <num_epochs_decaying_LR>
```
- Test the model:
```
python test.py --phase test --name <experiment_name> --dataroot ./datasets/<your_dataset> --n_domains <N> --which_epoch <checkpoint_number_to_load> --serial_test
```
The test results will be saved to a html file here: `./results/<experiment_name>/<epoch_number>/index.html`.
## Training/Testing Details
- Flags: see `options/train_options.py` for training-specific flags; see `options/test_options.py` for test-specific flags; and see `options/base_options.py` for all common flags.
- Dataset format: The desired data directory (provided by `--dataroot`) should contain subfolders of the form `train*/` and `test*/`, and they are loaded in alphabetical order. (Note that a folder named train10 would be loaded before train2, and thus all checkpoints and results would be ordered accordingly.)
- CPU/GPU (default `--gpu_ids 0`): set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. You need a large batch size (e.g. `--batchSize 32`) to benefit from multiple GPUs.
- Visualization: during training, the current results and loss plots can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will appear on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have `visdom` installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id 0`. Secondly, the intermediate results are also saved to `./checkpoints/<experiment_name>/web/index.html`. To avoid this, set the `--no_html` flag.
- Preprocessing: images can be resized and cropped in different ways using `--resize_or_crop` option. The default option `'resize_and_crop'` resizes the image to be of size `(opt.loadSize, opt.loadSize)` and does a random crop of size `(opt.fineSize, opt.fineSize)`. `'crop'` skips the resizing step and only performs random cropping. `'scale_width'` resizes the image to have width `opt.fineSize` while keeping the aspect ratio. `'scale_width_and_crop'` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`.
NOTE: one should **not** expect ComboGAN to work on just any combination of input and output datasets (e.g. `dogs<->houses`). We find it works better if two datasets share similar visual content. For example, `landscape painting<->landscape photographs` works much better than `portrait painting <-> landscape photographs`.
================================================
FILE: data/__init__.py
================================================
================================================
FILE: data/base_dataset.py
================================================
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
def get_transform(opt):
transform_list = []
if 'resize' in opt.resize_or_crop:
transform_list.append(transforms.Resize(opt.loadSize, Image.BICUBIC))
if opt.isTrain:
if 'crop' in opt.resize_or_crop:
transform_list.append(transforms.RandomCrop(opt.fineSize))
if not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())
transform_list += [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
================================================
FILE: data/data_loader.py
================================================
import torch.utils.data
from data.unaligned_dataset import UnalignedDataset
class DataLoader():
def name(self):
return 'DataLoader'
def __init__(self, opt):
self.opt = opt
self.dataset = UnalignedDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
num_workers=int(opt.nThreads))
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
for i, data in enumerate(self.dataloader):
if i >= self.opt.max_dataset_size:
break
yield data
================================================
FILE: data/image_folder.py
================================================
###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
###############################################################################
import torch.utils.data as data
from PIL import Image
import os
import os.path
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)
================================================
FILE: data/unaligned_dataset.py
================================================
import os.path, glob
import torchvision.transforms as transforms
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
import random
class UnalignedDataset(BaseDataset):
def __init__(self, opt):
super(UnalignedDataset, self).__init__()
self.opt = opt
self.transform = get_transform(opt)
datapath = os.path.join(opt.dataroot, opt.phase + '*')
self.dirs = sorted(glob.glob(datapath))
self.paths = [sorted(make_dataset(d)) for d in self.dirs]
self.sizes = [len(p) for p in self.paths]
def load_image(self, dom, idx):
path = self.paths[dom][idx]
img = Image.open(path).convert('RGB')
img = self.transform(img)
return img, path
def __getitem__(self, index):
if not self.opt.isTrain:
if self.opt.serial_test:
for d,s in enumerate(self.sizes):
if index < s:
DA = d; break
index -= s
index_A = index
else:
DA = index % len(self.dirs)
index_A = random.randint(0, self.sizes[DA] - 1)
else:
# Choose two of our domains to perform a pass on
DA, DB = random.sample(range(len(self.dirs)), 2)
index_A = random.randint(0, self.sizes[DA] - 1)
A_img, A_path = self.load_image(DA, index_A)
bundle = {'A': A_img, 'DA': DA, 'path': A_path}
if self.opt.isTrain:
index_B = random.randint(0, self.sizes[DB] - 1)
B_img, _ = self.load_image(DB, index_B)
bundle.update( {'B': B_img, 'DB': DB} )
return bundle
def __len__(self):
if self.opt.isTrain:
return max(self.sizes)
return sum(self.sizes)
def name(self):
return 'UnalignedDataset'
================================================
FILE: models/__init__.py
================================================
================================================
FILE: models/base_model.py
================================================
import os
import torch
class BaseModel():
def name(self):
return 'BaseModel'
def __init__(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
def set_input(self, input):
self.input = input
def forward(self):
pass
# used in test time, no backprop
def test(self):
pass
def get_image_paths(self):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
return self.input
def get_current_errors(self):
return {}
def save(self, label):
pass
# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch, gpu_ids):
save_filename = '%d_net_%s' % (epoch, network_label)
save_path = os.path.join(self.save_dir, save_filename)
network.save(save_path)
if gpu_ids and torch.cuda.is_available():
network.cuda(gpu_ids[0])
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch):
save_filename = '%d_net_%s' % (epoch, network_label)
save_path = os.path.join(self.save_dir, save_filename)
network.load(save_path)
def update_learning_rate():
pass
================================================
FILE: models/combogan_model.py
================================================
import numpy as np
import torch
from collections import OrderedDict
import util.util as util
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
class ComboGANModel(BaseModel):
def name(self):
return 'ComboGANModel'
def __init__(self, opt):
super(ComboGANModel, self).__init__(opt)
self.n_domains = opt.n_domains
self.DA, self.DB = None, None
self.real_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
self.real_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
# load/define networks
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
opt.netG_n_blocks, opt.netG_n_shared,
self.n_domains, opt.norm, opt.use_dropout, self.gpu_ids)
if self.isTrain:
blur_fn = lambda x : torch.nn.functional.conv2d(x, self.Tensor(util.gkern_2d()), groups=3, padding=2)
self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD_n_layers,
self.n_domains, blur_fn, opt.norm, self.gpu_ids)
if not self.isTrain or opt.continue_train:
which_epoch = opt.which_epoch
self.load_network(self.netG, 'G', which_epoch)
if self.isTrain:
self.load_network(self.netD, 'D', which_epoch)
if self.isTrain:
self.fake_pools = [ImagePool(opt.pool_size) for _ in range(self.n_domains)]
# define loss functions
self.L1 = torch.nn.SmoothL1Loss()
self.downsample = torch.nn.AvgPool2d(3, stride=2)
self.criterionCycle = self.L1
self.criterionIdt = lambda y,t : self.L1(self.downsample(y), self.downsample(t))
self.criterionLatent = lambda y,t : self.L1(y, t.detach())
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
# initialize optimizers
self.netG.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999))
self.netD.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999))
# initialize loss storage
self.loss_D, self.loss_G = [0]*self.n_domains, [0]*self.n_domains
self.loss_cycle = [0]*self.n_domains
# initialize loss multipliers
self.lambda_cyc, self.lambda_enc = opt.lambda_cycle, (0 * opt.lambda_latent)
self.lambda_idt, self.lambda_fwd = opt.lambda_identity, opt.lambda_forward
print('---------- Networks initialized -------------')
print(self.netG)
if self.isTrain:
print(self.netD)
print('-----------------------------------------------')
def set_input(self, input):
input_A = input['A']
self.real_A.resize_(input_A.size()).copy_(input_A)
self.DA = input['DA'][0]
if self.isTrain:
input_B = input['B']
self.real_B.resize_(input_B.size()).copy_(input_B)
self.DB = input['DB'][0]
self.image_paths = input['path']
def test(self):
with torch.no_grad():
self.visuals = [self.real_A]
self.labels = ['real_%d' % self.DA]
# cache encoding to not repeat it everytime
encoded = self.netG.encode(self.real_A, self.DA)
for d in range(self.n_domains):
if d == self.DA and not self.opt.autoencode:
continue
fake = self.netG.decode(encoded, d)
self.visuals.append( fake )
self.labels.append( 'fake_%d' % d )
if self.opt.reconstruct:
rec = self.netG.forward(fake, d, self.DA)
self.visuals.append( rec )
self.labels.append( 'rec_%d' % d )
def get_image_paths(self):
return self.image_paths
def backward_D_basic(self, real, fake, domain):
# Real
pred_real = self.netD.forward(real, domain)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = self.netD.forward(fake.detach(), domain)
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss
loss_D = (loss_D_real + loss_D_fake) * 0.5
# backward
loss_D.backward()
return loss_D
def backward_D(self):
#D_A
fake_B = self.fake_pools[self.DB].query(self.fake_B)
self.loss_D[self.DA] = self.backward_D_basic(self.real_B, fake_B, self.DB)
#D_B
fake_A = self.fake_pools[self.DA].query(self.fake_A)
self.loss_D[self.DB] = self.backward_D_basic(self.real_A, fake_A, self.DA)
def backward_G(self):
encoded_A = self.netG.encode(self.real_A, self.DA)
encoded_B = self.netG.encode(self.real_B, self.DB)
# Optional identity "autoencode" loss
if self.lambda_idt > 0:
# Same encoder and decoder should recreate image
idt_A = self.netG.decode(encoded_A, self.DA)
loss_idt_A = self.criterionIdt(idt_A, self.real_A)
idt_B = self.netG.decode(encoded_B, self.DB)
loss_idt_B = self.criterionIdt(idt_B, self.real_B)
else:
loss_idt_A, loss_idt_B = 0, 0
# GAN loss
# D_A(G_A(A))
self.fake_B = self.netG.decode(encoded_A, self.DB)
pred_fake = self.netD.forward(self.fake_B, self.DB)
self.loss_G[self.DA] = self.criterionGAN(pred_fake, True)
# D_B(G_B(B))
self.fake_A = self.netG.decode(encoded_B, self.DA)
pred_fake = self.netD.forward(self.fake_A, self.DA)
self.loss_G[self.DB] = self.criterionGAN(pred_fake, True)
# Forward cycle loss
rec_encoded_A = self.netG.encode(self.fake_B, self.DB)
self.rec_A = self.netG.decode(rec_encoded_A, self.DA)
self.loss_cycle[self.DA] = self.criterionCycle(self.rec_A, self.real_A)
# Backward cycle loss
rec_encoded_B = self.netG.encode(self.fake_A, self.DA)
self.rec_B = self.netG.decode(rec_encoded_B, self.DB)
self.loss_cycle[self.DB] = self.criterionCycle(self.rec_B, self.real_B)
# Optional cycle loss on encoding space
if self.lambda_enc > 0:
loss_enc_A = self.criterionLatent(rec_encoded_A, encoded_A)
loss_enc_B = self.criterionLatent(rec_encoded_B, encoded_B)
else:
loss_enc_A, loss_enc_B = 0, 0
# Optional loss on downsampled image before and after
if self.lambda_fwd > 0:
loss_fwd_A = self.criterionIdt(self.fake_B, self.real_A)
loss_fwd_B = self.criterionIdt(self.fake_A, self.real_B)
else:
loss_fwd_A, loss_fwd_B = 0, 0
# combined loss
loss_G = self.loss_G[self.DA] + self.loss_G[self.DB] + \
(self.loss_cycle[self.DA] + self.loss_cycle[self.DB]) * self.lambda_cyc + \
(loss_idt_A + loss_idt_B) * self.lambda_idt + \
(loss_enc_A + loss_enc_B) * self.lambda_enc + \
(loss_fwd_A + loss_fwd_B) * self.lambda_fwd
loss_G.backward()
def optimize_parameters(self):
# G_A and G_B
self.netG.zero_grads(self.DA, self.DB)
self.backward_G()
self.netG.step_grads(self.DA, self.DB)
# D_A and D_B
self.netD.zero_grads(self.DA, self.DB)
self.backward_D()
self.netD.step_grads(self.DA, self.DB)
def get_current_errors(self):
extract = lambda l: [(i if type(i) is int or type(i) is float else i.item()) for i in l]
D_losses, G_losses, cyc_losses = extract(self.loss_D), extract(self.loss_G), extract(self.loss_cycle)
return OrderedDict([('D', D_losses), ('G', G_losses), ('Cyc', cyc_losses)])
def get_current_visuals(self, testing=False):
if not testing:
self.visuals = [self.real_A, self.fake_B, self.rec_A, self.real_B, self.fake_A, self.rec_B]
self.labels = ['real_A', 'fake_B', 'rec_A', 'real_B', 'fake_A', 'rec_B']
images = [util.tensor2im(v.data) for v in self.visuals]
return OrderedDict(zip(self.labels, images))
def save(self, label):
self.save_network(self.netG, 'G', label, self.gpu_ids)
self.save_network(self.netD, 'D', label, self.gpu_ids)
def update_hyperparams(self, curr_iter):
if curr_iter > self.opt.niter:
decay_frac = (curr_iter - self.opt.niter) / self.opt.niter_decay
new_lr = self.opt.lr * (1 - decay_frac)
self.netG.update_lr(new_lr)
self.netD.update_lr(new_lr)
print('updated learning rate: %f' % new_lr)
if self.opt.lambda_latent > 0:
decay_frac = curr_iter / (self.opt.niter + self.opt.niter_decay)
self.lambda_enc = self.opt.lambda_latent * decay_frac
================================================
FILE: models/networks.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
import functools, itertools
import numpy as np
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0)
elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
return functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
return functools.partial(nn.InstanceNorm2d, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
def define_G(input_nc, output_nc, ngf, n_blocks, n_blocks_shared, n_domains, norm='batch', use_dropout=False, gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
n_blocks -= n_blocks_shared
n_blocks_enc = n_blocks // 2
n_blocks_dec = n_blocks - n_blocks_enc
dup_args = (ngf, norm_layer, use_dropout, gpu_ids, use_bias)
enc_args = (input_nc, n_blocks_enc) + dup_args
dec_args = (output_nc, n_blocks_dec) + dup_args
if n_blocks_shared > 0:
n_blocks_shdec = n_blocks_shared // 2
n_blocks_shenc = n_blocks_shared - n_blocks_shdec
shenc_args = (n_domains, n_blocks_shenc) + dup_args
shdec_args = (n_domains, n_blocks_shdec) + dup_args
plex_netG = G_Plexer(n_domains, ResnetGenEncoder, enc_args, ResnetGenDecoder, dec_args, ResnetGenShared, shenc_args, shdec_args)
else:
plex_netG = G_Plexer(n_domains, ResnetGenEncoder, enc_args, ResnetGenDecoder, dec_args)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
plex_netG.cuda(gpu_ids[0])
plex_netG.apply(weights_init)
return plex_netG
def define_D(input_nc, ndf, netD_n_layers, n_domains, blur_fn, norm='batch', gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
model_args = (input_nc, ndf, netD_n_layers, blur_fn, norm_layer, gpu_ids)
plex_netD = D_Plexer(n_domains, NLayerDiscriminator, model_args)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
plex_netD.cuda(gpu_ids[0])
plex_netD.apply(weights_init)
return plex_netD
##############################################################################
# Classes
##############################################################################
# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.Tensor = tensor
self.labels_real, self.labels_fake = None, None
self.preloss = nn.Sigmoid() if not use_lsgan else None
self.loss = nn.MSELoss() if use_lsgan else nn.BCELoss()
def get_target_tensor(self, inputs, is_real):
if self.labels_real is None or self.labels_real[0].numel() != inputs[0].numel():
self.labels_real = [ self.Tensor(input.size()).fill_(1.0) for input in inputs ]
self.labels_fake = [ self.Tensor(input.size()).fill_(0.0) for input in inputs ]
if is_real:
return self.labels_real
return self.labels_fake
def __call__(self, inputs, is_real):
labels = self.get_target_tensor(inputs, is_real)
if self.preloss is not None:
inputs = [self.preloss(input) for input in inputs]
losses = [self.loss(input, label) for input, label in zip(inputs, labels)]
multipliers = list(range(1, len(inputs)+1)); multipliers[-1] += 1
losses = [m*l for m,l in zip(multipliers, losses)]
return sum(losses) / (sum(multipliers) * len(losses))
# Defines the generator that consists of Resnet blocks between a few
# downsampling/upsampling operations.
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenEncoder(nn.Module):
def __init__(self, input_nc, n_blocks=4, ngf=64, norm_layer=nn.BatchNorm2d,
use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenEncoder, self).__init__()
self.gpu_ids = gpu_ids
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
bias=use_bias),
norm_layer(ngf),
nn.PReLU()]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.PReLU()]
mult = 2**n_downsampling
for _ in range(n_blocks):
model += [ResnetBlock(ngf * mult, norm_layer=norm_layer,
use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)]
self.model = nn.Sequential(*model)
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
return self.model(input)
class ResnetGenShared(nn.Module):
def __init__(self, n_domains, n_blocks=2, ngf=64, norm_layer=nn.BatchNorm2d,
use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenShared, self).__init__()
self.gpu_ids = gpu_ids
model = []
n_downsampling = 2
mult = 2**n_downsampling
for _ in range(n_blocks):
model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, n_domains=n_domains,
use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)]
self.model = SequentialContext(n_domains, *model)
def forward(self, input, domain):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, (input, domain), self.gpu_ids)
return self.model(input, domain)
class ResnetGenDecoder(nn.Module):
def __init__(self, output_nc, n_blocks=5, ngf=64, norm_layer=nn.BatchNorm2d,
use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenDecoder, self).__init__()
self.gpu_ids = gpu_ids
model = []
n_downsampling = 2
mult = 2**n_downsampling
for _ in range(n_blocks):
model += [ResnetBlock(ngf * mult, norm_layer=norm_layer,
use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=4, stride=2,
padding=1, output_padding=0,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.PReLU()]
model += [nn.ReflectionPad2d(3),
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
return self.model(input)
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, norm_layer, use_dropout, use_bias, padding_type='reflect', n_domains=0):
super(ResnetBlock, self).__init__()
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim + n_domains, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.PReLU()]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim + n_domains, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim)]
self.conv_block = SequentialContext(n_domains, *conv_block)
def forward(self, input):
if isinstance(input, tuple):
return input[0] + self.conv_block(*input)
return input + self.conv_block(input)
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, blur_fn=None, norm_layer=nn.BatchNorm2d, gpu_ids=[]):
super(NLayerDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
self.blur_fn = blur_fn
self.gray_fn = lambda x: (.299*x[:,0,:,:] + .587*x[:,1,:,:] + .114*x[:,2,:,:]).unsqueeze_(1)
self.model_gray = self.model(1, ndf, n_layers, norm_layer)
self.model_rgb = self.model(input_nc, ndf, n_layers, norm_layer)
def model(self, input_nc, ndf, n_layers, norm_layer):
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = int(np.ceil((kw-1)/2))
sequences = [[
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.PReLU()
]]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequences += [[
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult + 1,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult + 1),
nn.PReLU()
]]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequences += [[
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.PReLU(),
\
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
]]
return SequentialOutput(*sequences)
def forward(self, input):
luminance, blurred_rgb = self.gray_fn(input), self.blur_fn(input)
if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
outs1 = nn.parallel.data_parallel(self.model_gray, luminance, self.gpu_ids)
outs2 = nn.parallel.data_parallel(self.model_rgb, blurred_rgb, self.gpu_ids)
else:
outs1 = self.model_gray(luminance)
outs2 = self.model_rgb(blurred_rgb)
return [torch.cat([o1,o2], 1) for o1,o2 in zip(outs1, outs2)]
class Plexer(nn.Module):
def __init__(self):
super(Plexer, self).__init__()
def apply(self, func):
for net in self.networks:
net.apply(func)
def cuda(self, device_id):
for net in self.networks:
net.cuda(device_id)
def init_optimizers(self, opt, lr, betas):
self.optimizers = [opt(net.parameters(), lr=lr, betas=betas) \
for net in self.networks]
def zero_grads(self, dom_a, dom_b):
self.optimizers[dom_a].zero_grad()
self.optimizers[dom_b].zero_grad()
def step_grads(self, dom_a, dom_b):
self.optimizers[dom_a].step()
self.optimizers[dom_b].step()
def update_lr(self, new_lr):
for opt in self.optimizers:
for param_group in opt.param_groups:
param_group['lr'] = new_lr
def save(self, save_path):
for i, net in enumerate(self.networks):
filename = save_path + ('%d.pth' % i)
torch.save(net.cpu().state_dict(), filename)
def load(self, save_path):
for i, net in enumerate(self.networks):
filename = save_path + ('%d.pth' % i)
net.load_state_dict(torch.load(filename))
class G_Plexer(Plexer):
def __init__(self, n_domains, encoder, enc_args, decoder, dec_args,
block=None, shenc_args=None, shdec_args=None):
super(G_Plexer, self).__init__()
self.encoders = [encoder(*enc_args) for _ in range(n_domains)]
self.decoders = [decoder(*dec_args) for _ in range(n_domains)]
self.sharing = block is not None
if self.sharing:
self.shared_encoder = block(*shenc_args)
self.shared_decoder = block(*shdec_args)
self.encoders.append( self.shared_encoder )
self.decoders.append( self.shared_decoder )
self.networks = self.encoders + self.decoders
def init_optimizers(self, opt, lr, betas):
self.optimizers = []
for enc, dec in zip(self.encoders, self.decoders):
params = itertools.chain(enc.parameters(), dec.parameters())
self.optimizers.append( opt(params, lr=lr, betas=betas) )
def forward(self, input, in_domain, out_domain):
encoded = self.encode(input, in_domain)
return self.decode(encoded, out_domain)
def encode(self, input, domain):
output = self.encoders[domain].forward(input)
if self.sharing:
return self.shared_encoder.forward(output, domain)
return output
def decode(self, input, domain):
if self.sharing:
input = self.shared_decoder.forward(input, domain)
return self.decoders[domain].forward(input)
def zero_grads(self, dom_a, dom_b):
self.optimizers[dom_a].zero_grad()
if self.sharing:
self.optimizers[-1].zero_grad()
self.optimizers[dom_b].zero_grad()
def step_grads(self, dom_a, dom_b):
self.optimizers[dom_a].step()
if self.sharing:
self.optimizers[-1].step()
self.optimizers[dom_b].step()
def __repr__(self):
e, d = self.encoders[0], self.decoders[0]
e_params = sum([p.numel() for p in e.parameters()])
d_params = sum([p.numel() for p in d.parameters()])
return repr(e) +'\n'+ repr(d) +'\n'+ \
'Created %d Encoder-Decoder pairs' % len(self.encoders) +'\n'+ \
'Number of parameters per Encoder: %d' % e_params +'\n'+ \
'Number of parameters per Deocder: %d' % d_params
class D_Plexer(Plexer):
def __init__(self, n_domains, model, model_args):
super(D_Plexer, self).__init__()
self.networks = [model(*model_args) for _ in range(n_domains)]
def forward(self, input, domain):
discriminator = self.networks[domain]
return discriminator.forward(input)
def __repr__(self):
t = self.networks[0]
t_params = sum([p.numel() for p in t.parameters()])
return repr(t) +'\n'+ \
'Created %d Discriminators' % len(self.networks) +'\n'+ \
'Number of parameters per Discriminator: %d' % t_params
class SequentialContext(nn.Sequential):
def __init__(self, n_classes, *args):
super(SequentialContext, self).__init__(*args)
self.n_classes = n_classes
self.context_var = None
def prepare_context(self, input, domain):
if self.context_var is None or self.context_var.size()[-2:] != input.size()[-2:]:
tensor = torch.cuda.FloatTensor if isinstance(input.data, torch.cuda.FloatTensor) \
else torch.FloatTensor
self.context_var = tensor(*((1, self.n_classes) + input.size()[-2:]))
self.context_var.data.fill_(-1.0)
self.context_var.data[:,domain,:,:] = 1.0
return self.context_var
def forward(self, *input):
if self.n_classes < 2 or len(input) < 2:
return super(SequentialContext, self).forward(input[0])
x, domain = input
for module in self._modules.values():
if 'Conv' in module.__class__.__name__:
context_var = self.prepare_context(x, domain)
x = torch.cat([x, context_var], dim=1)
elif 'Block' in module.__class__.__name__:
x = (x,) + input[1:]
x = module(x)
return x
class SequentialOutput(nn.Sequential):
def __init__(self, *args):
args = [nn.Sequential(*arg) for arg in args]
super(SequentialOutput, self).__init__(*args)
def forward(self, input):
predictions = []
layers = self._modules.values()
for i, module in enumerate(layers):
output = module(input)
if i == 0:
input = output; continue
predictions.append( output[:,-1,:,:] )
if i != len(layers) - 1:
input = output[:,:-1,:,:]
return predictions
================================================
FILE: options/__init__.py
================================================
================================================
FILE: options/base_options.py
================================================
import argparse
import os
from util import util
import torch
class BaseOptions():
def __init__(self):
self.parser = argparse.ArgumentParser()
self.initialized = False
def initialize(self):
self.parser.add_argument('--name', required=True, type=str, help='name of the experiment. It decides where to store samples and models')
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
self.parser.add_argument('--dataroot', required=True, type=str, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
self.parser.add_argument('--n_domains', required=True, type=int, help='Number of domains to transfer among')
self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize|resize_and_crop|crop]')
self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size')
self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')
self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
self.parser.add_argument('--netG_n_blocks', type=int, default=9, help='number of residual blocks to use for netG')
self.parser.add_argument('--netG_n_shared', type=int, default=0, help='number of blocks to use for netG shared center module')
self.parser.add_argument('--netD_n_layers', type=int, default=4, help='number of layers to use for netD')
self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
self.parser.add_argument('--use_dropout', action='store_true', help='insert dropout for the generator')
self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display (set >1 to use visdom)')
self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
self.initialized = True
def parse(self):
if not self.initialized:
self.initialize()
self.opt = self.parser.parse_args()
self.opt.isTrain = self.isTrain # train or test
str_ids = self.opt.gpu_ids.split(',')
self.opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
self.opt.gpu_ids.append(id)
# set gpu ids
if len(self.opt.gpu_ids) > 0:
torch.cuda.set_device(self.opt.gpu_ids[0])
args = vars(self.opt)
print('------------ Options -------------')
for k, v in sorted(args.items()):
print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')
# save to the disk
expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write('------------ Options -------------\n')
for k, v in sorted(args.items()):
opt_file.write('%s: %s\n' % (str(k), str(v)))
opt_file.write('-------------- End ----------------\n')
return self.opt
================================================
FILE: options/test_options.py
================================================
from .base_options import BaseOptions
class TestOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.isTrain = False
self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
self.parser.add_argument('--which_epoch', required=True, type=int, help='which epoch to load for inference?')
self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc (determines name of folder to load from)')
self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run (if serial_test not enabled)')
self.parser.add_argument('--serial_test', action='store_true', help='read each image once from folders in sequential order')
self.parser.add_argument('--autoencode', action='store_true', help='translate images back into its own domain')
self.parser.add_argument('--reconstruct', action='store_true', help='do reconstructions of images during testing')
self.parser.add_argument('--show_matrix', action='store_true', help='visualize images in a matrix format as well')
================================================
FILE: options/train_options.py
================================================
from .base_options import BaseOptions
class TrainOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.isTrain = True
self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
self.parser.add_argument('--which_epoch', type=int, default=0, help='which epoch to load if continuing training')
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc (determines name of folder to load from)')
self.parser.add_argument('--niter', required=True, type=int, help='# of epochs at starting learning rate (try 50*n_domains)')
self.parser.add_argument('--niter_decay', required=True, type=int, help='# of epochs to linearly decay learning rate to zero (try 50*n_domains)')
self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for ADAM')
self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of ADAM')
self.parser.add_argument('--lambda_cycle', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
self.parser.add_argument('--lambda_identity', type=float, default=0.0, help='weight for identity "autoencode" mapping (A -> A)')
self.parser.add_argument('--lambda_latent', type=float, default=0.0, help='weight for latent-space loss (A -> z -> B -> z)')
self.parser.add_argument('--lambda_forward', type=float, default=0.0, help='weight for forward loss (A -> B; try 0.2)')
self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
self.parser.add_argument('--no_lsgan', action='store_true', help='use vanilla discriminator in place of least-squares one')
self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
================================================
FILE: scripts/continue_combogan.sh
================================================
python train.py \
--dataroot ./datasets/alps \
--name alps_combogan \
--continue_train \
--which_epoch 117 \
--n_domains 4 \
--niter 200 \
--niter_decay 200 \
--lambda_identity 0.0 \
--lambda_forward 0.0
================================================
FILE: scripts/test_combogan.sh
================================================
python test.py \
--phase test \
--dataroot ./datasets/alps \
--name alps_combogan \
--n_domains 4 \
--which_epoch 400 \
--show_matrix
================================================
FILE: scripts/train_combogan.sh
================================================
python train.py \
--dataroot ./datasets/alps \
--name alps_combogan \
--n_domains 4 \
--niter 200 \
--niter_decay 200 \
--lambda_identity 0.0 \
--lambda_forward 0.0
================================================
FILE: test.py
================================================
import time
import os
from options.test_options import TestOptions
from data.data_loader import DataLoader
from models.combogan_model import ComboGANModel
from util.visualizer import Visualizer
from util import html
opt = TestOptions().parse()
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
dataset = DataLoader(opt)
model = ComboGANModel(opt)
visualizer = Visualizer(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%d' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %d' % (opt.name, opt.phase, opt.which_epoch))
# store images for matrix visualization
vis_buffer = []
# test
for i, data in enumerate(dataset):
if not opt.serial_test and i >= opt.how_many:
break
model.set_input(data)
model.test()
visuals = model.get_current_visuals(testing=True)
img_path = model.get_image_paths()
print('process image... %s' % img_path)
visualizer.save_images(webpage, visuals, img_path)
if opt.show_matrix:
vis_buffer.append(visuals)
if (i+1) % opt.n_domains == 0:
save_path = os.path.join(web_dir, 'mat_%d.png' % (i//opt.n_domains))
visualizer.save_image_matrix(vis_buffer, save_path)
vis_buffer.clear()
webpage.save()
================================================
FILE: train.py
================================================
import time
from options.train_options import TrainOptions
from data.data_loader import DataLoader
from models.combogan_model import ComboGANModel
from util.visualizer import Visualizer
opt = TrainOptions().parse()
dataset = DataLoader(opt)
print('# training images = %d' % len(dataset))
model = ComboGANModel(opt)
visualizer = Visualizer(opt)
total_steps = 0
# Update initially if continuing
if opt.which_epoch > 0:
model.update_hyperparams(opt.which_epoch)
for epoch in range(opt.which_epoch + 1, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
epoch_iter = 0
for i, data in enumerate(dataset):
iter_start_time = time.time()
total_steps += opt.batchSize
epoch_iter += opt.batchSize
model.set_input(data)
model.optimize_parameters()
if total_steps % opt.display_freq == 0:
visualizer.display_current_results(model.get_current_visuals(), epoch)
if total_steps % opt.print_freq == 0:
errors = model.get_current_errors()
t = (time.time() - iter_start_time) / opt.batchSize
visualizer.print_current_errors(epoch, epoch_iter, errors, t)
if opt.display_id > 0:
visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
model.save(epoch)
print('End of epoch %d / %d \t Time Taken: %d sec' %
(epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
model.update_hyperparams(epoch)
================================================
FILE: util/__init__.py
================================================
================================================
FILE: util/get_data.py
================================================
from __future__ import print_function
import os
import tarfile
import requests
from warnings import warn
from zipfile import ZipFile
from bs4 import BeautifulSoup
from os.path import abspath, isdir, join, basename
class GetData(object):
"""
Download CycleGAN or Pix2Pix Data.
Args:
technique : str
One of: 'cyclegan' or 'pix2pix'.
verbose : bool
If True, print additional information.
Examples:
>>> from util.get_data import GetData
>>> gd = GetData(technique='cyclegan')
>>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
"""
def __init__(self, technique='cyclegan', verbose=True):
url_dict = {
'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets',
'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
}
self.url = url_dict.get(technique.lower())
self._verbose = verbose
def _print(self, text):
if self._verbose:
print(text)
@staticmethod
def _get_options(r):
soup = BeautifulSoup(r.text, 'lxml')
options = [h.text for h in soup.find_all('a', href=True)
if h.text.endswith(('.zip', 'tar.gz'))]
return options
def _present_options(self):
r = requests.get(self.url)
options = self._get_options(r)
print('Options:\n')
for i, o in enumerate(options):
print("{0}: {1}".format(i, o))
choice = input("\nPlease enter the number of the "
"dataset above you wish to download:")
return options[int(choice)]
def _download_data(self, dataset_url, save_path):
if not isdir(save_path):
os.makedirs(save_path)
base = basename(dataset_url)
temp_save_path = join(save_path, base)
with open(temp_save_path, "wb") as f:
r = requests.get(dataset_url)
f.write(r.content)
if base.endswith('.tar.gz'):
obj = tarfile.open(temp_save_path)
elif base.endswith('.zip'):
obj = ZipFile(temp_save_path, 'r')
else:
raise ValueError("Unknown File Type: {0}.".format(base))
self._print("Unpacking Data...")
obj.extractall(save_path)
obj.close()
os.remove(temp_save_path)
def get(self, save_path, dataset=None):
"""
Download a dataset.
Args:
save_path : str
A directory to save the data to.
dataset : str, optional
A specific dataset to download.
Note: this must include the file extension.
If None, options will be presented for you
to choose from.
Returns:
save_path_full : str
The absolute path to the downloaded data.
"""
if dataset is None:
selected_dataset = self._present_options()
else:
selected_dataset = dataset
save_path_full = join(save_path, selected_dataset.split('.')[0])
if isdir(save_path_full):
warn("\n'{0}' already exists. Voiding Download.".format(
save_path_full))
else:
self._print('Downloading Data...')
url = "{0}/{1}".format(self.url, selected_dataset)
self._download_data(url, save_path=save_path)
return abspath(save_path_full)
================================================
FILE: util/html.py
================================================
import dominate
from dominate.tags import *
import os
class HTML:
def __init__(self, web_dir, title, reflesh=0):
self.title = title
self.web_dir = web_dir
self.img_dir = os.path.join(self.web_dir, 'images')
if not os.path.exists(self.web_dir):
os.makedirs(self.web_dir)
if not os.path.exists(self.img_dir):
os.makedirs(self.img_dir)
# print(self.img_dir)
self.doc = dominate.document(title=title)
if reflesh > 0:
with self.doc.head:
meta(http_equiv="reflesh", content=str(reflesh))
def get_image_dir(self):
return self.img_dir
def add_header(self, str):
with self.doc:
h3(str)
def add_table(self, border=1):
self.t = table(border=border, style="table-layout: fixed;")
self.doc.add(self.t)
def add_images(self, ims, txts, links, width=400):
self.add_table()
with self.t:
with tr():
for im, txt, link in zip(ims, txts, links):
with td(style="word-wrap: break-word;", halign="center", valign="top"):
with p():
with a(href=os.path.join('images', link)):
img(style="width:%dpx" % width, src=os.path.join('images', im))
br()
p(txt)
def save(self):
html_file = '%s/index.html' % self.web_dir
f = open(html_file, 'wt')
f.write(self.doc.render())
f.close()
if __name__ == '__main__':
html = HTML('web/', 'test_html')
html.add_header('hello world')
ims = []
txts = []
links = []
for n in range(4):
ims.append('image_%d.png' % n)
txts.append('text_%d' % n)
links.append('image_%d.png' % n)
html.add_images(ims, txts, links)
html.save()
================================================
FILE: util/image_pool.py
================================================
import random
import numpy as np
import torch
from torch.autograd import Variable
class ImagePool():
def __init__(self, pool_size):
self.pool_size = pool_size
if self.pool_size > 0:
self.num_imgs = 0
self.images = []
def query(self, images):
if self.pool_size == 0:
return images
return_images = []
for image in images.data:
image = torch.unsqueeze(image, 0)
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5:
random_id = random.randint(0, self.pool_size-1)
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else:
return_images.append(image)
return_images = Variable(torch.cat(return_images, 0))
return return_images
================================================
FILE: util/png.py
================================================
import struct
import zlib
def encode(buf, width, height):
""" buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """
assert (width * height * 3 == len(buf))
bpp = 3
def raw_data():
# reverse the vertical line order and add null bytes at the start
row_bytes = width * bpp
for row_start in range((height - 1) * width * bpp, -1, -row_bytes):
yield b'\x00'
yield buf[row_start:row_start + row_bytes]
def chunk(tag, data):
return [
struct.pack("!I", len(data)),
tag,
data,
struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag)))
]
SIGNATURE = b'\x89PNG\r\n\x1a\n'
COLOR_TYPE_RGB = 2
COLOR_TYPE_RGBA = 6
bit_depth = 8
return b''.join(
[ SIGNATURE ] +
chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) +
chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) +
chunk(b'IEND', b'')
)
================================================
FILE: util/util.py
================================================
from __future__ import print_function
import torch
import numpy as np
from scipy.ndimage.filters import gaussian_filter
from PIL import Image
import inspect, re
import os
import collections
# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
return image_numpy.astype(imtype)
def gkern_2d(size=5, sigma=3):
# Create 2D gaussian kernel
dirac = np.zeros((size, size))
dirac[size//2, size//2] = 1
mask = gaussian_filter(dirac, sigma)
# Adjust dimensions for torch conv2d
return np.stack([np.expand_dims(mask, axis=0)] * 3)
def diagnose_network(net, name='network'):
mean = 0.0
count = 0
for param in net.parameters():
if param.grad is not None:
mean += torch.mean(torch.abs(param.grad.data))
count += 1
if count > 0:
mean = mean / count
print(name)
print(mean)
def save_image(image_numpy, image_path):
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path)
def info(object, spacing=10, collapse=1):
"""Print methods and doc strings.
Takes module, class, list, dictionary, or string."""
methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)]
processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s)
print( "\n".join(["%s %s" %
(method.ljust(spacing),
processFunc(str(getattr(object, method).__doc__)))
for method in methodList]) )
def varname(p):
for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
if m:
return m.group(1)
def print_numpy(x, val=True, shp=False):
x = x.astype(np.float64)
if shp:
print('shape,', x.shape)
if val:
x = x.flatten()
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
================================================
FILE: util/visualizer.py
================================================
import numpy as np
import os
import ntpath
import time
from . import util
from . import html
class Visualizer():
def __init__(self, opt):
# self.opt = opt
self.display_id = opt.display_id
self.use_html = opt.isTrain and not opt.no_html
self.win_size = opt.display_winsize
self.name = opt.name
if self.display_id > 0:
import visdom
self.vis = visdom.Visdom(port = opt.display_port)
self.display_single_pane_ncols = opt.display_single_pane_ncols
if self.use_html:
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
self.img_dir = os.path.join(self.web_dir, 'images')
print('create web directory %s...' % self.web_dir)
util.mkdirs([self.web_dir, self.img_dir])
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
# |visuals|: dictionary of images to display or save
def display_current_results(self, visuals, epoch):
if self.display_id > 0: # show images in the browser
if self.display_single_pane_ncols > 0:
h, w = next(iter(visuals.values())).shape[:2]
table_css = """<style>
table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center}
table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black}
</style>""" % (w, h)
ncols = self.display_single_pane_ncols
title = self.name
label_html = ''
label_html_row = ''
nrows = int(np.ceil(len(visuals.items()) / ncols))
images = []
idx = 0
for label, image_numpy in visuals.items():
label_html_row += '<td>%s</td>' % label
images.append(image_numpy.transpose([2, 0, 1]))
idx += 1
if idx % ncols == 0:
label_html += '<tr>%s</tr>' % label_html_row
label_html_row = ''
white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
while idx % ncols != 0:
images.append(white_image)
label_html_row += '<td></td>'
idx += 1
if label_html_row != '':
label_html += '<tr>%s</tr>' % label_html_row
# pane col = image row
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
padding=2, opts=dict(title=title + ' images'))
label_html = '<table>%s</table>' % label_html
self.vis.text(table_css + label_html, win = self.display_id + 2,
opts=dict(title=title + ' labels'))
else:
idx = 1
for label, image_numpy in visuals.items():
#image_numpy = np.flipud(image_numpy)
self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label),
win=self.display_id + idx)
idx += 1
if self.use_html: # save images to a html file
for label, image_numpy in visuals.items():
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
util.save_image(image_numpy, img_path)
# update website
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
for n in range(epoch, 0, -1):
webpage.add_header('epoch [%d]' % n)
ims = []
txts = []
links = []
for label, image_numpy in visuals.items():
img_path = 'epoch%.3d_%s.png' % (n, label)
ims.append(img_path)
txts.append(label)
links.append(img_path)
webpage.add_images(ims, txts, links, width=self.win_size)
webpage.save()
# errors: dictionary of error labels and values
def plot_current_errors(self, epoch, counter_ratio, opt, errors):
if not hasattr(self, 'plot_data'):
self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
self.plot_data['X'].append(epoch + counter_ratio)
self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
self.vis.line(
X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),
Y=np.array(self.plot_data['Y']),
opts={
'title': self.name + ' loss over time',
'legend': self.plot_data['legend'],
'xlabel': 'epoch',
'ylabel': 'loss'},
win=self.display_id)
# errors: same format as |errors| of plotCurrentErrors
def print_current_errors(self, epoch, i, errors, t):
message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
for k, v in errors.items():
v = ['%.3f' % iv for iv in v]
message += k + ': ' + ', '.join(v) + ' | '
print(message)
with open(self.log_name, "a") as log_file:
log_file.write('%s\n' % message)
# save image to the disk
def save_images(self, webpage, visuals, image_path):
image_dir = webpage.get_image_dir()
short_path = ntpath.basename(image_path[0])
name = os.path.splitext(short_path)[0]
webpage.add_header(name)
ims = []
txts = []
links = []
for label, image_numpy in visuals.items():
image_name = '%s_%s.png' % (name, label)
save_path = os.path.join(image_dir, image_name)
util.save_image(image_numpy, save_path)
ims.append(image_name)
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=self.win_size)
def save_image_matrix(self, visuals_list, save_path):
images_list = []
get_domain = lambda x: x.split('_')[-1]
for visuals in visuals_list:
pairs = list(visuals.items())
real_label, real_img = pairs[0]
real_dom = get_domain(real_label)
for label, img in pairs:
if 'fake' not in label:
continue
if get_domain(label) == real_dom:
images_list.append(real_img)
else:
images_list.append(img)
immat = self.stack_images(images_list)
util.save_image(immat, save_path)
# reshape a list of images into a square matrix of them
def stack_images(self, list_np_images):
n = int(np.ceil(np.sqrt(len(list_np_images))))
# add padding between images
for i, im in enumerate(list_np_images):
val = 255 if i%n == i//n else 0
r_pad = np.pad(im[:,:,0], (3,3), mode='constant', constant_values=0)
g_pad = np.pad(im[:,:,1], (3,3), mode='constant', constant_values=val)
b_pad = np.pad(im[:,:,2], (3,3), mode='constant', constant_values=0)
list_np_images[i] = np.stack([r_pad,g_pad,b_pad], axis=2)
data = np.array(list_np_images)
data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
return data
gitextract_lw_z0n5w/
├── .gitignore
├── LICENSE
├── README.md
├── data/
│ ├── __init__.py
│ ├── base_dataset.py
│ ├── data_loader.py
│ ├── image_folder.py
│ └── unaligned_dataset.py
├── models/
│ ├── __init__.py
│ ├── base_model.py
│ ├── combogan_model.py
│ └── networks.py
├── options/
│ ├── __init__.py
│ ├── base_options.py
│ ├── test_options.py
│ └── train_options.py
├── scripts/
│ ├── continue_combogan.sh
│ ├── test_combogan.sh
│ └── train_combogan.sh
├── test.py
├── train.py
└── util/
├── __init__.py
├── get_data.py
├── html.py
├── image_pool.py
├── png.py
├── util.py
└── visualizer.py
SYMBOL INDEX (148 symbols across 16 files)
FILE: data/base_dataset.py
class BaseDataset (line 5) | class BaseDataset(data.Dataset):
method __init__ (line 6) | def __init__(self):
method name (line 9) | def name(self):
method initialize (line 12) | def initialize(self, opt):
function get_transform (line 15) | def get_transform(opt):
FILE: data/data_loader.py
class DataLoader (line 5) | class DataLoader():
method name (line 6) | def name(self):
method __init__ (line 9) | def __init__(self, opt):
method __len__ (line 17) | def __len__(self):
method __iter__ (line 20) | def __iter__(self):
FILE: data/image_folder.py
function is_image_file (line 20) | def is_image_file(filename):
function make_dataset (line 24) | def make_dataset(dir):
function default_loader (line 37) | def default_loader(path):
class ImageFolder (line 41) | class ImageFolder(data.Dataset):
method __init__ (line 43) | def __init__(self, root, transform=None, return_paths=False,
method __getitem__ (line 57) | def __getitem__(self, index):
method __len__ (line 67) | def __len__(self):
FILE: data/unaligned_dataset.py
class UnalignedDataset (line 8) | class UnalignedDataset(BaseDataset):
method __init__ (line 9) | def __init__(self, opt):
method load_image (line 20) | def load_image(self, dom, idx):
method __getitem__ (line 26) | def __getitem__(self, index):
method __len__ (line 52) | def __len__(self):
method name (line 57) | def name(self):
FILE: models/base_model.py
class BaseModel (line 5) | class BaseModel():
method name (line 6) | def name(self):
method __init__ (line 9) | def __init__(self, opt):
method set_input (line 16) | def set_input(self, input):
method forward (line 19) | def forward(self):
method test (line 23) | def test(self):
method get_image_paths (line 26) | def get_image_paths(self):
method optimize_parameters (line 29) | def optimize_parameters(self):
method get_current_visuals (line 32) | def get_current_visuals(self):
method get_current_errors (line 35) | def get_current_errors(self):
method save (line 38) | def save(self, label):
method save_network (line 42) | def save_network(self, network, network_label, epoch, gpu_ids):
method load_network (line 50) | def load_network(self, network, network_label, epoch):
method update_learning_rate (line 55) | def update_learning_rate():
FILE: models/combogan_model.py
class ComboGANModel (line 10) | class ComboGANModel(BaseModel):
method name (line 11) | def name(self):
method __init__ (line 14) | def __init__(self, opt):
method set_input (line 63) | def set_input(self, input):
method test (line 73) | def test(self):
method get_image_paths (line 91) | def get_image_paths(self):
method backward_D_basic (line 94) | def backward_D_basic(self, real, fake, domain):
method backward_D (line 107) | def backward_D(self):
method backward_G (line 115) | def backward_G(self):
method optimize_parameters (line 169) | def optimize_parameters(self):
method get_current_errors (line 179) | def get_current_errors(self):
method get_current_visuals (line 184) | def get_current_visuals(self, testing=False):
method save (line 191) | def save(self, label):
method update_hyperparams (line 195) | def update_hyperparams(self, curr_iter):
FILE: models/networks.py
function weights_init (line 10) | def weights_init(m):
function get_norm_layer (line 21) | def get_norm_layer(norm_type='instance'):
function define_G (line 30) | def define_G(input_nc, output_nc, ngf, n_blocks, n_blocks_shared, n_doma...
function define_D (line 62) | def define_D(input_nc, ndf, netD_n_layers, n_domains, blur_fn, norm='bat...
class GANLoss (line 85) | class GANLoss(nn.Module):
method __init__ (line 86) | def __init__(self, use_lsgan=True, tensor=torch.FloatTensor):
method get_target_tensor (line 93) | def get_target_tensor(self, inputs, is_real):
method __call__ (line 101) | def __call__(self, inputs, is_real):
class ResnetGenEncoder (line 115) | class ResnetGenEncoder(nn.Module):
method __init__ (line 116) | def __init__(self, input_nc, n_blocks=4, ngf=64, norm_layer=nn.BatchNo...
method forward (line 143) | def forward(self, input):
class ResnetGenShared (line 148) | class ResnetGenShared(nn.Module):
method __init__ (line 149) | def __init__(self, n_domains, n_blocks=2, ngf=64, norm_layer=nn.BatchN...
method forward (line 165) | def forward(self, input, domain):
class ResnetGenDecoder (line 170) | class ResnetGenDecoder(nn.Module):
method __init__ (line 171) | def __init__(self, output_nc, n_blocks=5, ngf=64, norm_layer=nn.BatchN...
method forward (line 200) | def forward(self, input):
class ResnetBlock (line 207) | class ResnetBlock(nn.Module):
method __init__ (line 208) | def __init__(self, dim, norm_layer, use_dropout, use_bias, padding_typ...
method forward (line 242) | def forward(self, input):
class NLayerDiscriminator (line 249) | class NLayerDiscriminator(nn.Module):
method __init__ (line 250) | def __init__(self, input_nc, ndf=64, n_layers=3, blur_fn=None, norm_la...
method model (line 259) | def model(self, input_nc, ndf, n_layers, norm_layer):
method forward (line 297) | def forward(self, input):
class Plexer (line 309) | class Plexer(nn.Module):
method __init__ (line 310) | def __init__(self):
method apply (line 313) | def apply(self, func):
method cuda (line 317) | def cuda(self, device_id):
method init_optimizers (line 321) | def init_optimizers(self, opt, lr, betas):
method zero_grads (line 325) | def zero_grads(self, dom_a, dom_b):
method step_grads (line 329) | def step_grads(self, dom_a, dom_b):
method update_lr (line 333) | def update_lr(self, new_lr):
method save (line 338) | def save(self, save_path):
method load (line 343) | def load(self, save_path):
class G_Plexer (line 348) | class G_Plexer(Plexer):
method __init__ (line 349) | def __init__(self, n_domains, encoder, enc_args, decoder, dec_args,
method init_optimizers (line 363) | def init_optimizers(self, opt, lr, betas):
method forward (line 369) | def forward(self, input, in_domain, out_domain):
method encode (line 373) | def encode(self, input, domain):
method decode (line 379) | def decode(self, input, domain):
method zero_grads (line 384) | def zero_grads(self, dom_a, dom_b):
method step_grads (line 390) | def step_grads(self, dom_a, dom_b):
method __repr__ (line 396) | def __repr__(self):
class D_Plexer (line 405) | class D_Plexer(Plexer):
method __init__ (line 406) | def __init__(self, n_domains, model, model_args):
method forward (line 410) | def forward(self, input, domain):
method __repr__ (line 414) | def __repr__(self):
class SequentialContext (line 422) | class SequentialContext(nn.Sequential):
method __init__ (line 423) | def __init__(self, n_classes, *args):
method prepare_context (line 428) | def prepare_context(self, input, domain):
method forward (line 438) | def forward(self, *input):
class SequentialOutput (line 452) | class SequentialOutput(nn.Sequential):
method __init__ (line 453) | def __init__(self, *args):
method forward (line 457) | def forward(self, input):
FILE: options/base_options.py
class BaseOptions (line 6) | class BaseOptions():
method __init__ (line 7) | def __init__(self):
method initialize (line 11) | def initialize(self):
method parse (line 48) | def parse(self):
FILE: options/test_options.py
class TestOptions (line 4) | class TestOptions(BaseOptions):
method initialize (line 5) | def initialize(self):
FILE: options/train_options.py
class TrainOptions (line 4) | class TrainOptions(BaseOptions):
method initialize (line 5) | def initialize(self):
FILE: util/get_data.py
class GetData (line 11) | class GetData(object):
method __init__ (line 29) | def __init__(self, technique='cyclegan', verbose=True):
method _print (line 37) | def _print(self, text):
method _get_options (line 42) | def _get_options(r):
method _present_options (line 48) | def _present_options(self):
method _download_data (line 58) | def _download_data(self, dataset_url, save_path):
method get (line 81) | def get(self, save_path, dataset=None):
FILE: util/html.py
class HTML (line 6) | class HTML:
method __init__ (line 7) | def __init__(self, web_dir, title, reflesh=0):
method get_image_dir (line 22) | def get_image_dir(self):
method add_header (line 25) | def add_header(self, str):
method add_table (line 29) | def add_table(self, border=1):
method add_images (line 33) | def add_images(self, ims, txts, links, width=400):
method save (line 45) | def save(self):
FILE: util/image_pool.py
class ImagePool (line 5) | class ImagePool():
method __init__ (line 6) | def __init__(self, pool_size):
method query (line 12) | def query(self, images):
FILE: util/png.py
function encode (line 4) | def encode(buf, width, height):
FILE: util/util.py
function tensor2im (line 12) | def tensor2im(image_tensor, imtype=np.uint8):
function gkern_2d (line 17) | def gkern_2d(size=5, sigma=3):
function diagnose_network (line 26) | def diagnose_network(net, name='network'):
function save_image (line 39) | def save_image(image_numpy, image_path):
function info (line 43) | def info(object, spacing=10, collapse=1):
function varname (line 53) | def varname(p):
function print_numpy (line 59) | def print_numpy(x, val=True, shp=False):
function mkdirs (line 69) | def mkdirs(paths):
function mkdir (line 77) | def mkdir(path):
FILE: util/visualizer.py
class Visualizer (line 8) | class Visualizer():
method __init__ (line 9) | def __init__(self, opt):
method display_current_results (line 31) | def display_current_results(self, visuals, epoch):
method plot_current_errors (line 95) | def plot_current_errors(self, epoch, counter_ratio, opt, errors):
method print_current_errors (line 111) | def print_current_errors(self, epoch, i, errors, t):
method save_images (line 122) | def save_images(self, webpage, visuals, image_path):
method save_image_matrix (line 142) | def save_image_matrix(self, visuals_list, save_path):
method stack_images (line 163) | def stack_images(self, list_np_images):
Condensed preview — 28 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (73K chars).
[
{
"path": ".gitignore",
"chars": 161,
"preview": "datasets/\ncheckpoints/\nresults/\n*.png\n*/**/__pycache__\n*/*.pyc\n*/**/*.pyc\n*/**/**/*.pyc\n*/**/**/**/*.pyc\n*/**/**/**/**/*"
},
{
"path": "LICENSE",
"chars": 1296,
"preview": "Copyright (c) 2017, Asha Anoosheh\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or witho"
},
{
"path": "README.md",
"chars": 5075,
"preview": "\n# ComboGAN\n\nThis is our ongoing PyTorch implementation for ComboGAN.\nCode was written by [Asha Anoosheh](https://github"
},
{
"path": "data/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "data/base_dataset.py",
"chars": 907,
"preview": "import torch.utils.data as data\nfrom PIL import Image\nimport torchvision.transforms as transforms\n\nclass BaseDataset(dat"
},
{
"path": "data/data_loader.py",
"chars": 664,
"preview": "import torch.utils.data\nfrom data.unaligned_dataset import UnalignedDataset\n\n\nclass DataLoader():\n def name(self):\n "
},
{
"path": "data/image_folder.py",
"chars": 1946,
"preview": "###############################################################################\n# Code from\n# https://github.com/pytorch"
},
{
"path": "data/unaligned_dataset.py",
"chars": 1904,
"preview": "import os.path, glob\nimport torchvision.transforms as transforms\nfrom data.base_dataset import BaseDataset, get_transfor"
},
{
"path": "models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "models/base_model.py",
"chars": 1474,
"preview": "import os\nimport torch\n\n\nclass BaseModel():\n def name(self):\n return 'BaseModel'\n\n def __init__(self, opt):"
},
{
"path": "models/combogan_model.py",
"chars": 8933,
"preview": "import numpy as np\nimport torch\nfrom collections import OrderedDict\nimport util.util as util\nfrom util.image_pool import"
},
{
"path": "models/networks.py",
"chars": 17956,
"preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport functools, itertools\nimport numpy as np\n\n\n\n\ndef weig"
},
{
"path": "options/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "options/base_options.py",
"chars": 4663,
"preview": "import argparse\nimport os\nfrom util import util\nimport torch\n\nclass BaseOptions():\n def __init__(self):\n self."
},
{
"path": "options/test_options.py",
"chars": 1285,
"preview": "from .base_options import BaseOptions\n\n\nclass TestOptions(BaseOptions):\n def initialize(self):\n BaseOptions.in"
},
{
"path": "options/train_options.py",
"chars": 2399,
"preview": "from .base_options import BaseOptions\n\n\nclass TrainOptions(BaseOptions):\n def initialize(self):\n BaseOptions.i"
},
{
"path": "scripts/continue_combogan.sh",
"chars": 249,
"preview": "python train.py \\\n --dataroot ./datasets/alps \\\n --name alps_combogan \\\n --continue_train \\\n --which_epo"
},
{
"path": "scripts/test_combogan.sh",
"chars": 164,
"preview": "python test.py \\\n --phase test \\\n --dataroot ./datasets/alps \\\n --name alps_combogan \\\n --n_domains 4 \\"
},
{
"path": "scripts/train_combogan.sh",
"chars": 200,
"preview": "python train.py \\\n --dataroot ./datasets/alps \\\n --name alps_combogan \\\n --n_domains 4 \\\n --niter 200 \\"
},
{
"path": "test.py",
"chars": 1356,
"preview": "import time\nimport os\nfrom options.test_options import TestOptions\nfrom data.data_loader import DataLoader\nfrom models.c"
},
{
"path": "train.py",
"chars": 1643,
"preview": "import time\nfrom options.train_options import TrainOptions\nfrom data.data_loader import DataLoader\nfrom models.combogan_"
},
{
"path": "util/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "util/get_data.py",
"chars": 3511,
"preview": "from __future__ import print_function\nimport os\nimport tarfile\nimport requests\nfrom warnings import warn\nfrom zipfile im"
},
{
"path": "util/html.py",
"chars": 1912,
"preview": "import dominate\nfrom dominate.tags import *\nimport os\n\n\nclass HTML:\n def __init__(self, web_dir, title, reflesh=0):\n "
},
{
"path": "util/image_pool.py",
"chars": 1109,
"preview": "import random\nimport numpy as np\nimport torch\nfrom torch.autograd import Variable\nclass ImagePool():\n def __init__(se"
},
{
"path": "util/png.py",
"chars": 978,
"preview": "import struct\nimport zlib\n\ndef encode(buf, width, height):\n \"\"\" buf: must be bytes or a bytearray in py3, a regular str"
},
{
"path": "util/util.py",
"chars": 2477,
"preview": "from __future__ import print_function\nimport torch\nimport numpy as np\nfrom scipy.ndimage.filters import gaussian_filter\n"
},
{
"path": "util/visualizer.py",
"chars": 7695,
"preview": "import numpy as np\nimport os\nimport ntpath\nimport time\nfrom . import util\nfrom . import html\n\nclass Visualizer():\n de"
}
]
About this extraction
This page contains the full source code of the AAnoosheh/ComboGAN GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 28 files (68.3 KB), approximately 17.6k tokens, and a symbol index with 148 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.