Repository: AlamiMejjati/Unsupervised-Attention-guided-Image-to-Image-Translation
Branch: master
Commit: a691e3a94b7a
Files: 25
Total size: 63.5 KB
Directory structure:
gitextract__oabw19d/
├── LICENSE
├── README.md
├── Trained_models/
│ └── README.md
├── __init__.py
├── configs/
│ ├── exp_01.json
│ ├── exp_01_test.json
│ ├── exp_02.json
│ ├── exp_02_test.json
│ ├── exp_04.json
│ ├── exp_04_test.json
│ ├── exp_05.json
│ └── exp_05_test.json
├── create_cyclegan_dataset.py
├── cyclegan_datasets.py
├── data_loader.py
├── download_datasets.sh
├── layers.py
├── losses.py
├── main.py
├── model.py
└── test/
├── __init__.py
├── evaluate_losses.py
├── evaluate_networks.py
├── test_losses.py
└── test_model.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2017 Harry Yang
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Unsupervised Attention-guided Image-to-Image Translation
This repository contains the TensorFlow code for our NeurIPS 2018 paper [“Unsupervised Attention-guided Image-to-Image Translation”](https://arxiv.org/pdf/1806.02311.pdf). This code is based on the TensorFlow implementation of CycleGAN provided by [Harry Yang](https://github.com/leehomyc/cyclegan-1). You may need to train several times as the quality of the results are sensitive to the initialization.
By leveraging attention, our architecture (shown in the figure bellow) only maps relevant areas of the image, and by doing so, further enhances the quality of image to image translation.
Our model architecture is defined as depicted below, please refer to the paper for more details:
## Mapping results
### Our learned attention maps
The figure bellow displays automatically learned attention maps on various translation datasets:
### Horse-to-Zebra image translation results:
#### Horse-to-Zebra:
Top row in the figure below are input images and bottom row are the mappings produced by our algorithm.
#### Zebra-to-Horse:
Top row in the figure below are input images and bottom row are the mappings produced by our algorithm.
### Apple-to-Orange image translation results:
#### Apple-to-Orange:
Top row in the figure below are input images and bottom row are the mappings produced by our algorithm.
#### Orange-to-Apple:
Top row in the figure below are input images and bottom row are the mappings produced by our algorithm.
### Getting Started with the code
### Prepare dataset
* You can either download one of the defaults CycleGAN datasets or use your own dataset.
* Download a CycleGAN dataset (e.g. horse2zebra, apple2orange):
```bash
bash ./download_datasets.sh horse2zebra
```
* Use your own dataset: put images from each domain at folder_a and folder_b respectively.
* Create the csv file as input to the data loader.
* Edit the [```cyclegan_datasets.py```](cyclegan_datasets.py) file. For example, if you have a horse2zebra_train dataset which contains 1067 horse images and 1334 zebra images (both in JPG format), you can just edit the [```cyclegan_datasets.py```](cyclegan_datasets.py) as following:
```python
DATASET_TO_SIZES = {
'horse2zebra_train': 1334
}
PATH_TO_CSV = {
'horse2zebra_train': './AGGAN/input/horse2zebra/horse2zebra_train.csv'
}
DATASET_TO_IMAGETYPE = {
'horse2zebra_train': '.jpg'
}
```
* Run create_cyclegan_dataset.py:
```bash
python -m create_cyclegan_dataset --image_path_a='./input/horse2zebra/trainB' --image_path_b='./input/horse2zebra/trainA' --dataset_name="horse2zebra_train" --do_shuffle=0
```
### Training
* Create the configuration file. The configuration file contains basic information for training/testing. An example of the configuration file could be found at [```configs/exp_01.json```](configs/exp_01.json).
* Start training:
```bash
python main.py --to_train=1 --log_dir=./output/AGGAN/exp_01 --config_filename=./configs/exp_01.json
```
* Check the intermediate results:
* Tensorboard
```bash
tensorboard --port=6006 --logdir=./output/AGGAN/exp_01/#timestamp#
```
* Check the html visualization at ./output/AGGAN/exp_01/#timestamp#/epoch_#id#.html.
### Restoring from the previous checkpoint
```bash
python main.py --to_train=2 --log_dir=./output/AGGAN/exp_01 --config_filename=./configs/exp_01.json --checkpoint_dir=./output/AGGAN/exp_01/#timestamp#
```
### Testing
* Create the testing dataset:
* Edit the cyclegan_datasets.py file the same way as training.
* Create the csv file as the input to the data loader:
```bash
python -m create_cyclegan_dataset --image_path_a='./input/horse2zebra/testB' --image_path_b='./input/horse2zebra/testA' --dataset_name="horse2zebra_test" --do_shuffle=0
```
* Run testing:
```bash
python main.py --to_train=0 --log_dir=./output/AGGAN/exp_01 --config_filename=./configs/exp_01_test.json --checkpoint_dir=./output/AGGAN/exp_01/#old_timestamp#
```
* Trained models:
Our trained models can be downloaded from https://drive.google.com/open?id=1YEQMJK41KQj_-HfKFneSI12nWpTajgzT
================================================
FILE: Trained_models/README.md
================================================
#### Trained models.
Our trained models can be downloaded from https://drive.google.com/open?id=1YEQMJK41KQj_-HfKFneSI12nWpTajgzT
When using the trained parameters of horse to zebra image translation. Note that in our case the Source (input_a) is zebra and the target (input_b) is zebra and not the opposite.
================================================
FILE: __init__.py
================================================
================================================
FILE: configs/exp_01.json
================================================
{
"description": "The official PyTorch version of CycleGAN.",
"pool_size": 50,
"base_lr":0.0001,
"max_step": 100,
"network_version": "pytorch",
"dataset_name": "horse2zebra_train",
"do_flipping": 1,
"_LAMBDA_A": 10,
"_LAMBDA_B": 10
}
================================================
FILE: configs/exp_01_test.json
================================================
{
"description": "Testing with trained model.",
"network_version": "pytorch",
"dataset_name": "horse2zebra_test",
"do_flipping": 0
}
================================================
FILE: configs/exp_02.json
================================================
{
"description": "The official PyTorch version of CycleGAN.",
"pool_size": 50,
"base_lr":0.0001,
"max_step": 100,
"network_version": "pytorch",
"dataset_name": "apple2orange_train",
"do_flipping": 1,
"_LAMBDA_A": 10,
"_LAMBDA_B": 10
}
================================================
FILE: configs/exp_02_test.json
================================================
{
"description": "Testing with trained model.",
"network_version": "pytorch",
"dataset_name": "apple2orange_test",
"do_flipping": 0
}
================================================
FILE: configs/exp_04.json
================================================
{
"description": "The official PyTorch version of CycleGAN.",
"pool_size": 50,
"base_lr":0.0001,
"max_step": 100,
"network_version": "pytorch",
"dataset_name": "lion2tiger_train",
"do_flipping": 1,
"_LAMBDA_A": 10,
"_LAMBDA_B": 10
}
================================================
FILE: configs/exp_04_test.json
================================================
{
"description": "Testing with trained model.",
"network_version": "pytorch",
"dataset_name": "lion2tiger_test",
"do_flipping": 0
}
================================================
FILE: configs/exp_05.json
================================================
{
"description": "The official PyTorch version of CycleGAN.",
"pool_size": 50,
"base_lr":0.0002,
"max_step": 200,
"network_version": "pytorch",
"dataset_name": "summer2winter_yosemite_train",
"do_flipping": 1,
"_LAMBDA_A": 10,
"_LAMBDA_B": 10
}
================================================
FILE: configs/exp_05_test.json
================================================
{
"description": "Testing with trained model.",
"network_version": "pytorch",
"dataset_name": "summer2winter_yosemite_test",
"do_flipping": 0
}
================================================
FILE: create_cyclegan_dataset.py
================================================
"""Create datasets for training and testing."""
import csv
import os
import random
import click
import cyclegan_datasets
def create_list(foldername, fulldir=True, suffix=".jpg"):
"""
:param foldername: The full path of the folder.
:param fulldir: Whether to return the full path or not.
:param suffix: Filter by suffix.
:return: The list of filenames in the folder with given suffix.
"""
file_list_tmp = os.listdir(foldername)
file_list = []
if fulldir:
for item in file_list_tmp:
if item.endswith(suffix):
file_list.append(os.path.join(foldername, item))
else:
for item in file_list_tmp:
if item.endswith(suffix):
file_list.append(item)
return file_list
@click.command()
@click.option('--image_path_a',
type=click.STRING,
default='./input/horse2zebra/trainA',
help='The path to the images from domain_a.')
@click.option('--image_path_b',
type=click.STRING,
default='./input/horse2zebra/trainB',
help='The path to the images from domain_b.')
@click.option('--dataset_name',
type=click.STRING,
default='horse2zebra_train',
help='The name of the dataset in cyclegan_dataset.')
@click.option('--do_shuffle',
type=click.BOOL,
default=False,
help='Whether to shuffle images when creating the dataset.')
def create_dataset(image_path_a, image_path_b,
dataset_name, do_shuffle):
list_a = create_list(image_path_a, True,
cyclegan_datasets.DATASET_TO_IMAGETYPE[dataset_name])
list_b = create_list(image_path_b, True,
cyclegan_datasets.DATASET_TO_IMAGETYPE[dataset_name])
output_path = cyclegan_datasets.PATH_TO_CSV[dataset_name]
num_rows = cyclegan_datasets.DATASET_TO_SIZES[dataset_name]
all_data_tuples = []
for i in range(num_rows):
all_data_tuples.append((
list_a[i % len(list_a)],
list_b[i % len(list_b)]
))
if do_shuffle is True:
random.shuffle(all_data_tuples)
with open(output_path, 'w') as csv_file:
csv_writer = csv.writer(csv_file)
for data_tuple in enumerate(all_data_tuples):
csv_writer.writerow(list(data_tuple[1]))
if __name__ == '__main__':
create_dataset()
================================================
FILE: cyclegan_datasets.py
================================================
"""Contains the standard train/test splits for the cyclegan data."""
"""The size of each dataset. Usually it is the maximum number of images from
each domain."""
DATASET_TO_SIZES = {
'horse2zebra_train': 1334,
'horse2zebra_test': 140,
'apple2orange_train': 1019,
'apple2orange_test': 266,
'lion2tiger_train': 916,
'lion2tiger_test': 103,
'summer2winter_yosemite_train': 1231,
'summer2winter_yosemite_test': 309,
}
"""The image types of each dataset. Currently only supports .jpg or .png"""
DATASET_TO_IMAGETYPE = {
'horse2zebra_train': '.jpg',
'horse2zebra_test': '.jpg',
'apple2orange_train': '.jpg',
'apple2orange_test': '.jpg',
'lion2tiger_train': '.jpg',
'lion2tiger_test': '.jpg',
'summer2winter_yosemite_train': '.jpg',
'summer2winter_yosemite_test': '.jpg',
}
"""The path to the output csv file."""
PATH_TO_CSV = {
'horse2zebra_train': './input/horse2zebra/horse2zebra_train.csv',
'horse2zebra_test': './input/horse2zebra/horse2zebra_test.csv',
'apple2orange_train': './input/apple2orange/apple2orange_train.csv',
'apple2orange_test': './input/apple2orange/apple2orange_test.csv',
'lion2tiger_train': './input/lion2tiger/lion2tiger_train.csv',
'lion2tiger_test': './input/lion2tiger/lion2tiger_test.csv',
'summer2winter_yosemite_train': './input/summer2winter_yosemite/summer2winter_yosemite_train.csv',
'summer2winter_yosemite_test': './input/summer2winter_yosemite/summer2winter_yosemite_test.csv'
}
================================================
FILE: data_loader.py
================================================
import tensorflow as tf
import cyclegan_datasets
import model
def _load_samples(csv_name, image_type):
filename_queue = tf.train.string_input_producer(
[csv_name])
reader = tf.TextLineReader()
_, csv_filename = reader.read(filename_queue)
record_defaults = [tf.constant([], dtype=tf.string),
tf.constant([], dtype=tf.string)]
filename_i, filename_j = tf.decode_csv(
csv_filename, record_defaults=record_defaults)
file_contents_i = tf.read_file(filename_i)
file_contents_j = tf.read_file(filename_j)
if image_type == '.jpg':
image_decoded_A = tf.image.decode_jpeg(
file_contents_i, channels=model.IMG_CHANNELS)
image_decoded_B = tf.image.decode_jpeg(
file_contents_j, channels=model.IMG_CHANNELS)
elif image_type == '.png':
image_decoded_A = tf.image.decode_png(
file_contents_i, channels=model.IMG_CHANNELS, dtype=tf.uint8)
image_decoded_B = tf.image.decode_png(
file_contents_j, channels=model.IMG_CHANNELS, dtype=tf.uint8)
return image_decoded_A, image_decoded_B
def load_data(dataset_name, image_size_before_crop,
do_shuffle=True, do_flipping=False):
"""
:param dataset_name: The name of the dataset.
:param image_size_before_crop: Resize to this size before random cropping.
:param do_shuffle: Shuffle switch.
:param do_flipping: Flip switch.
:return:
"""
if dataset_name not in cyclegan_datasets.DATASET_TO_SIZES:
raise ValueError('split name %s was not recognized.'
% dataset_name)
csv_name = cyclegan_datasets.PATH_TO_CSV[dataset_name]
image_i, image_j = _load_samples(
csv_name, cyclegan_datasets.DATASET_TO_IMAGETYPE[dataset_name])
inputs = {
'image_i': image_i,
'image_j': image_j
}
# Preprocessing:
inputs['image_i'] = tf.image.resize_images(
inputs['image_i'], [image_size_before_crop, image_size_before_crop])
inputs['image_j'] = tf.image.resize_images(
inputs['image_j'], [image_size_before_crop, image_size_before_crop])
if do_flipping is True:
inputs['image_i'] = tf.image.random_flip_left_right(inputs['image_i'], seed=1)
inputs['image_j'] = tf.image.random_flip_left_right(inputs['image_j'], seed=1)
inputs['image_i'] = tf.random_crop(
inputs['image_i'], [model.IMG_HEIGHT, model.IMG_WIDTH, 3], seed=1)
inputs['image_j'] = tf.random_crop(
inputs['image_j'], [model.IMG_HEIGHT, model.IMG_WIDTH, 3], seed=1)
inputs['image_i'] = tf.subtract(tf.div(inputs['image_i'], 127.5), 1)
inputs['image_j'] = tf.subtract(tf.div(inputs['image_j'], 127.5), 1)
# Batch
if do_shuffle is True:
inputs['images_i'], inputs['images_j'] = tf.train.shuffle_batch(
[inputs['image_i'], inputs['image_j']], 1, 5000, 100, seed=1)
else:
inputs['images_i'], inputs['images_j'] = tf.train.batch(
[inputs['image_i'], inputs['image_j']], 1)
return inputs
================================================
FILE: download_datasets.sh
================================================
FILE=$1
if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
exit 1
fi
mkdir ./input
URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
ZIP_FILE=./input/$FILE.zip
TARGET_DIR=./input/$FILE/
wget -N $URL -O $ZIP_FILE
mkdir $TARGET_DIR
unzip $ZIP_FILE -d ./input/
rm $ZIP_FILE
================================================
FILE: layers.py
================================================
import tensorflow as tf
def lrelu(x, leak=0.2, name="lrelu", alt_relu_impl=False):
with tf.variable_scope(name):
if alt_relu_impl:
f1 = 0.5 * (1 + leak)
f2 = 0.5 * (1 - leak)
return f1 * x + f2 * abs(x)
else:
return tf.maximum(x, leak * x)
def instance_norm(x):
with tf.variable_scope("instance_norm"):
epsilon = 1e-5
mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
scale = tf.get_variable('scale', [x.get_shape()[-1]],
initializer=tf.truncated_normal_initializer(
mean=1.0, stddev=0.02
))
offset = tf.get_variable(
'offset', [x.get_shape()[-1]],
initializer=tf.constant_initializer(0.0)
)
out = scale * tf.div(x - mean, tf.sqrt(var + epsilon)) + offset
return out
def instance_norm_bis(x,mask):
with tf.variable_scope("instance_norm"):
epsilon = 1e-5
for i in range(x.shape[-1]):
slice = tf.gather(x, i, axis=3)
slice_mask = tf.gather(mask, i, axis=3)
tmp = tf.boolean_mask(slice,slice_mask)
mean, var = tf.nn.moments_bis(x, [1, 2], keep_dims=False)
mean, var = tf.nn.moments_bis(x, [1, 2], keep_dims=True)
scale = tf.get_variable('scale', [x.get_shape()[-1]],
initializer=tf.truncated_normal_initializer(
mean=1.0, stddev=0.02
))
offset = tf.get_variable(
'offset', [x.get_shape()[-1]],
initializer=tf.constant_initializer(0.0)
)
out = scale * tf.div(x - mean, tf.sqrt(var + epsilon)) + offset
return out
def general_conv2d_(inputconv, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02,
padding="VALID", name="conv2d", do_norm=True, do_relu=True,
relufactor=0):
with tf.variable_scope(name):
conv = tf.contrib.layers.conv2d(
inputconv, o_d, f_w, s_w, padding,
activation_fn=None,
weights_initializer=tf.truncated_normal_initializer(
stddev=stddev
),
biases_initializer=tf.constant_initializer(0.0)
)
if do_norm:
conv = instance_norm(conv)
if do_relu:
if(relufactor == 0):
conv = tf.nn.relu(conv, "relu")
else:
conv = lrelu(conv, relufactor, "lrelu")
return conv
def general_conv2d(inputconv, do_norm, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02,
padding="VALID", name="conv2d", do_relu=True,
relufactor=0):
with tf.variable_scope(name):
conv = tf.contrib.layers.conv2d(
inputconv, o_d, f_w, s_w, padding,
activation_fn=None,
weights_initializer=tf.truncated_normal_initializer(
stddev=stddev
),
biases_initializer=tf.constant_initializer(0.0)
)
conv = tf.cond(do_norm, lambda: instance_norm(conv), lambda: conv)
if do_relu:
if(relufactor == 0):
conv = tf.nn.relu(conv, "relu")
else:
conv = lrelu(conv, relufactor, "lrelu")
return conv
def general_deconv2d(inputconv, outshape, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1,
stddev=0.02, padding="VALID", name="deconv2d",
do_norm=True, do_relu=True, relufactor=0):
with tf.variable_scope(name):
conv = tf.contrib.layers.conv2d_transpose(
inputconv, o_d, [f_h, f_w],
[s_h, s_w], padding,
activation_fn=None,
weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
biases_initializer=tf.constant_initializer(0.0)
)
if do_norm:
conv = instance_norm(conv)
if do_relu:
if(relufactor == 0):
conv = tf.nn.relu(conv, "relu")
else:
conv = lrelu(conv, relufactor, "lrelu")
return conv
def upsamplingDeconv(inputconv, size, is_scale, method,align_corners, name):
if len(inputconv.get_shape()) == 3:
if is_scale:
size_h = size[0] * int(inputconv.get_shape()[0])
size_w = size[1] * int(inputconv.get_shape()[1])
size = [int(size_h), int(size_w)]
elif len(inputconv.get_shape()) == 4:
if is_scale:
size_h = size[0] * int(inputconv.get_shape()[1])
size_w = size[1] * int(inputconv.get_shape()[2])
size = [int(size_h), int(size_w)]
else:
raise Exception("Donot support shape %s" % inputconv.get_shape())
print(" [TL] UpSampling2dLayer %s: is_scale:%s size:%s method:%d align_corners:%s" %
(name, is_scale, size, method, align_corners))
with tf.variable_scope(name) as vs:
try:
out = tf.image.resize_images(inputconv, size=size, method=method, align_corners=align_corners)
except: # for TF 0.10
out = tf.image.resize_images(inputconv, new_height=size[0], new_width=size[1], method=method,
align_corners=align_corners)
return out
def general_fc_layers(inpfc, outshape, name):
with tf.variable_scope(name):
fcw = tf.Variable(tf.truncated_normal(outshape,
dtype=tf.float32,
stddev=1e-1), name='weights')
fcb = tf.Variable(tf.constant(1.0, shape=[outshape[-1]], dtype=tf.float32),
trainable=True, name='biases')
fcl = tf.nn.bias_add(tf.matmul(inpfc, fcw), fcb)
fc_out = tf.nn.relu(fcl)
return fc_out
================================================
FILE: losses.py
================================================
"""Contains losses used for performing image-to-image domain adaptation."""
import tensorflow as tf
def cycle_consistency_loss(real_images, generated_images):
"""Compute the cycle consistency loss.
The cycle consistency loss is defined as the sum of the L1 distances
between the real images from each domain and their generated (fake)
counterparts.
This definition is derived from Equation 2 in:
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
Networks.
Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros.
Args:
real_images: A batch of images from domain X, a `Tensor` of shape
[batch_size, height, width, channels].
generated_images: A batch of generated images made to look like they
came from domain X, a `Tensor` of shape
[batch_size, height, width, channels].
Returns:
The cycle consistency loss.
"""
return tf.reduce_mean(tf.abs(real_images - generated_images))
def mask_loss(gen_image, mask):
return tf.reduce_mean(tf.abs(tf.multiply(gen_image,1-mask)))
def lsgan_loss_generator(prob_fake_is_real):
"""Computes the LS-GAN loss as minimized by the generator.
Rather than compute the negative loglikelihood, a least-squares loss is
used to optimize the discriminators as per Equation 2 in:
Least Squares Generative Adversarial Networks
Xudong Mao, Qing Li, Haoran Xie, Raymond Y.K. Lau, Zhen Wang, and
Stephen Paul Smolley.
https://arxiv.org/pdf/1611.04076.pdf
Args:
prob_fake_is_real: The discriminator's estimate that generated images
made to look like real images are real.
Returns:
The total LS-GAN loss.
"""
return tf.reduce_mean(tf.squared_difference(prob_fake_is_real, 1))
def lsgan_loss_discriminator(prob_real_is_real, prob_fake_is_real):
"""Computes the LS-GAN loss as minimized by the discriminator.
Rather than compute the negative loglikelihood, a least-squares loss is
used to optimize the discriminators as per Equation 2 in:
Least Squares Generative Adversarial Networks
Xudong Mao, Qing Li, Haoran Xie, Raymond Y.K. Lau, Zhen Wang, and
Stephen Paul Smolley.
https://arxiv.org/pdf/1611.04076.pdf
Args:
prob_real_is_real: The discriminator's estimate that images actually
drawn from the real domain are in fact real.
prob_fake_is_real: The discriminator's estimate that generated images
made to look like real images are real.
Returns:
The total LS-GAN loss.
"""
return (tf.reduce_mean(tf.squared_difference(prob_real_is_real, 1)) +
tf.reduce_mean(tf.squared_difference(prob_fake_is_real, 0))) * 0.5
================================================
FILE: main.py
================================================
"""Code for training CycleGAN."""
from datetime import datetime
import json
import numpy as np
import os
import random
from scipy.misc import imsave
import argparse
import tensorflow as tf
import cyclegan_datasets
import data_loader, losses, model
tf.set_random_seed(1)
np.random.seed(0)
slim = tf.contrib.slim
class CycleGAN:
"""The CycleGAN module."""
def __init__(self, pool_size, lambda_a,
lambda_b, output_root_dir, to_restore,
base_lr, max_step, network_version,
dataset_name, checkpoint_dir, do_flipping, skip, switch, threshold_fg):
current_time = datetime.now().strftime("%Y%m%d-%H%M%S")
self._pool_size = pool_size
self._size_before_crop = 286
self._switch = switch
self._threshold_fg = threshold_fg
self._lambda_a = lambda_a
self._lambda_b = lambda_b
self._output_dir = os.path.join(output_root_dir, current_time +
'_switch'+str(switch)+'_thres_'+str(threshold_fg))
self._images_dir = os.path.join(self._output_dir, 'imgs')
self._num_imgs_to_save = 20
self._to_restore = to_restore
self._base_lr = base_lr
self._max_step = max_step
self._network_version = network_version
self._dataset_name = dataset_name
self._checkpoint_dir = checkpoint_dir
self._do_flipping = do_flipping
self._skip = skip
self.fake_images_A = []
self.fake_images_B = []
def model_setup(self):
"""
This function sets up the model to train.
self.input_A/self.input_B -> Set of training images.
self.fake_A/self.fake_B -> Generated images by corresponding generator
of input_A and input_B
self.lr -> Learning rate variable
self.cyc_A/ self.cyc_B -> Images generated after feeding
self.fake_A/self.fake_B to corresponding generator.
This is use to calculate cyclic loss
"""
self.input_a = tf.placeholder(
tf.float32, [
1,
model.IMG_WIDTH,
model.IMG_HEIGHT,
model.IMG_CHANNELS
], name="input_A")
self.input_b = tf.placeholder(
tf.float32, [
1,
model.IMG_WIDTH,
model.IMG_HEIGHT,
model.IMG_CHANNELS
], name="input_B")
self.fake_pool_A = tf.placeholder(
tf.float32, [
None,
model.IMG_WIDTH,
model.IMG_HEIGHT,
model.IMG_CHANNELS
], name="fake_pool_A")
self.fake_pool_B = tf.placeholder(
tf.float32, [
None,
model.IMG_WIDTH,
model.IMG_HEIGHT,
model.IMG_CHANNELS
], name="fake_pool_B")
self.fake_pool_A_mask = tf.placeholder(
tf.float32, [
None,
model.IMG_WIDTH,
model.IMG_HEIGHT,
model.IMG_CHANNELS
], name="fake_pool_A_mask")
self.fake_pool_B_mask = tf.placeholder(
tf.float32, [
None,
model.IMG_WIDTH,
model.IMG_HEIGHT,
model.IMG_CHANNELS
], name="fake_pool_B_mask")
self.global_step = slim.get_or_create_global_step()
self.num_fake_inputs = 0
self.learning_rate = tf.placeholder(tf.float32, shape=[], name="lr")
self.transition_rate = tf.placeholder(tf.float32, shape=[], name="tr")
self.donorm = tf.placeholder(tf.bool, shape=[], name="donorm")
inputs = {
'images_a': self.input_a,
'images_b': self.input_b,
'fake_pool_a': self.fake_pool_A,
'fake_pool_b': self.fake_pool_B,
'fake_pool_a_mask': self.fake_pool_A_mask,
'fake_pool_b_mask': self.fake_pool_B_mask,
'transition_rate': self.transition_rate,
'donorm': self.donorm,
}
outputs = model.get_outputs(
inputs, skip=self._skip)
self.prob_real_a_is_real = outputs['prob_real_a_is_real']
self.prob_real_b_is_real = outputs['prob_real_b_is_real']
self.fake_images_a = outputs['fake_images_a']
self.fake_images_b = outputs['fake_images_b']
self.prob_fake_a_is_real = outputs['prob_fake_a_is_real']
self.prob_fake_b_is_real = outputs['prob_fake_b_is_real']
self.cycle_images_a = outputs['cycle_images_a']
self.cycle_images_b = outputs['cycle_images_b']
self.prob_fake_pool_a_is_real = outputs['prob_fake_pool_a_is_real']
self.prob_fake_pool_b_is_real = outputs['prob_fake_pool_b_is_real']
self.masks = outputs['masks']
self.masked_gen_ims = outputs['masked_gen_ims']
self.masked_ims = outputs['masked_ims']
self.masks_ = outputs['mask_tmp']
def compute_losses(self):
"""
In this function we are defining the variables for loss calculations
and training model.
d_loss_A/d_loss_B -> loss for discriminator A/B
g_loss_A/g_loss_B -> loss for generator A/B
*_trainer -> Various trainer for above loss functions
*_summ -> Summary variables for above loss functions
"""
cycle_consistency_loss_a = \
self._lambda_a * losses.cycle_consistency_loss(
real_images=self.input_a, generated_images=self.cycle_images_a,
)
cycle_consistency_loss_b = \
self._lambda_b * losses.cycle_consistency_loss(
real_images=self.input_b, generated_images=self.cycle_images_b,
)
lsgan_loss_a = losses.lsgan_loss_generator(self.prob_fake_a_is_real)
lsgan_loss_b = losses.lsgan_loss_generator(self.prob_fake_b_is_real)
g_loss_A = \
cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b
g_loss_B = \
cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a
d_loss_A = losses.lsgan_loss_discriminator(
prob_real_is_real=self.prob_real_a_is_real,
prob_fake_is_real=self.prob_fake_pool_a_is_real,
)
d_loss_B = losses.lsgan_loss_discriminator(
prob_real_is_real=self.prob_real_b_is_real,
prob_fake_is_real=self.prob_fake_pool_b_is_real,
)
optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5)
self.model_vars = tf.trainable_variables()
d_A_vars = [var for var in self.model_vars if 'd_A' in var.name]
g_A_vars = [var for var in self.model_vars if 'g_A/' in var.name]
d_B_vars = [var for var in self.model_vars if 'd_B' in var.name]
g_B_vars = [var for var in self.model_vars if 'g_B/' in var.name]
g_Ae_vars = [var for var in self.model_vars if 'g_A_ae' in var.name]
g_Be_vars = [var for var in self.model_vars if 'g_B_ae' in var.name]
self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars+g_Ae_vars)
self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars+g_Be_vars)
self.g_A_trainer_bis = optimizer.minimize(g_loss_A, var_list=g_A_vars)
self.g_B_trainer_bis = optimizer.minimize(g_loss_B, var_list=g_B_vars)
self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)
self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)
self.params_ae_c1 = g_A_vars[0]
self.params_ae_c1_B = g_B_vars[0]
for var in self.model_vars:
print(var.name)
# Summary variables for tensorboard
self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A)
self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B)
self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A)
self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)
def save_images(self, sess, epoch, curr_tr):
"""
Saves input and output images.
:param sess: The session.
:param epoch: Currnt epoch.
"""
if not os.path.exists(self._images_dir):
os.makedirs(self._images_dir)
if curr_tr >0:
donorm = False
else:
donorm = True
names = ['inputA_', 'inputB_', 'fakeA_',
'fakeB_', 'cycA_', 'cycB_',
'mask_a', 'mask_b']
with open(os.path.join(
self._output_dir, 'epoch_' + str(epoch) + '.html'
), 'w') as v_html:
for i in range(0, self._num_imgs_to_save):
print("Saving image {}/{}".format(i, self._num_imgs_to_save))
inputs = sess.run(self.inputs)
fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp, masks = sess.run([
self.fake_images_a,
self.fake_images_b,
self.cycle_images_a,
self.cycle_images_b,
self.masks,
], feed_dict={
self.input_a: inputs['images_i'],
self.input_b: inputs['images_j'],
self.transition_rate: curr_tr,
self.donorm: donorm,
})
tensors = [inputs['images_i'], inputs['images_j'],
fake_B_temp, fake_A_temp, cyc_A_temp, cyc_B_temp, masks[0], masks[1]]
for name, tensor in zip(names, tensors):
image_name = name + str(epoch) + "_" + str(i) + ".jpg"
if 'mask_' in name:
imsave(os.path.join(self._images_dir, image_name),
(np.squeeze(tensor[0]))
)
else:
imsave(os.path.join(self._images_dir, image_name),
((np.squeeze(tensor[0]) + 1) * 127.5).astype(np.uint8)
)
v_html.write(
"
"
)
v_html.write("
")
def save_images_bis(self, sess, epoch):
"""
Saves input and output images.
:param sess: The session.
:param epoch: Currnt epoch.
"""
if not os.path.exists(self._images_dir):
os.makedirs(self._images_dir)
names = ['input_A_', 'mask_A_', 'masked_inputA_', 'fakeB_',
'input_B_', 'mask_B_', 'masked_inputB_', 'fakeA_']
space = '                        ' \
'                        ' \
'         '
with open(os.path.join(self._output_dir, 'results_' + str(epoch) + '.html'), 'w') as v_html:
v_html.write("INPUT" + space + "MASK" + space + "MASKED_IMAGE" + space + "GENERATED_IMAGE")
v_html.write("
")
for i in range(0, self._num_imgs_to_save):
print("Saving image {}/{}".format(i, self._num_imgs_to_save))
inputs = sess.run(self.inputs)
fake_A_temp, fake_B_temp, masks, masked_ims = sess.run([
self.fake_images_a,
self.fake_images_b,
self.masks,
self.masked_ims
], feed_dict={
self.input_a: inputs['images_i'],
self.input_b: inputs['images_j'],
self.transition_rate: 0.1
})
tensors = [inputs['images_i'], masks[0], masked_ims[0], fake_B_temp,
inputs['images_j'], masks[1], masked_ims[1], fake_A_temp]
for name, tensor in zip(names, tensors):
image_name = name + str(i) + ".jpg"
if 'mask_' in name:
imsave(os.path.join(self._images_dir, image_name),
(np.squeeze(tensor[0]))
)
else:
imsave(os.path.join(self._images_dir, image_name),
((np.squeeze(tensor[0]) + 1) * 127.5).astype(np.uint8)
)
v_html.write(
"
"
)
if 'fakeB_' in name:
v_html.write("
")
v_html.write("
")
def fake_image_pool(self, num_fakes, fake, mask, fake_pool):
"""
This function saves the generated image to corresponding
pool of images.
It keeps on feeling the pool till it is full and then randomly
selects an already stored image and replace it with new one.
"""
tmp = {}
tmp['im'] = fake
tmp['mask'] = mask
if num_fakes < self._pool_size:
fake_pool.append(tmp)
return tmp
else:
p = random.random()
if p > 0.5:
random_id = random.randint(0, self._pool_size - 1)
temp = fake_pool[random_id]
fake_pool[random_id] = tmp
return temp
else:
return tmp
def train(self):
"""Training Function."""
# Load Dataset from the dataset folder
self.inputs = data_loader.load_data(
self._dataset_name, self._size_before_crop,
False, self._do_flipping)
# Build the network
self.model_setup()
# Loss function calculations
self.compute_losses()
# Initializing the global variables
init = (tf.global_variables_initializer(),
tf.local_variables_initializer())
saver = tf.train.Saver(max_to_keep=None)
max_images = cyclegan_datasets.DATASET_TO_SIZES[self._dataset_name]
half_training = int(self._max_step / 2)
with tf.Session() as sess:
sess.run(init)
# Restore the model to run the model from last checkpoint
if self._to_restore:
chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)
saver.restore(sess, chkpt_fname)
writer = tf.summary.FileWriter(self._output_dir)
if not os.path.exists(self._output_dir):
os.makedirs(self._output_dir)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# Training Loop
for epoch in range(sess.run(self.global_step), self._max_step):
print("In the epoch ", epoch)
saver.save(sess, os.path.join(
self._output_dir, "AGGAN"), global_step=epoch)
# Dealing with the learning rate as per the epoch number
if epoch < half_training:
curr_lr = self._base_lr
else:
curr_lr = self._base_lr - \
self._base_lr * (epoch - half_training) / half_training
if epoch < self._switch:
curr_tr = 0.
donorm = True
to_train_A = self.g_A_trainer
to_train_B = self.g_B_trainer
else:
curr_tr = self._threshold_fg
donorm = False
to_train_A = self.g_A_trainer_bis
to_train_B = self.g_B_trainer_bis
self.save_images(sess, epoch, curr_tr)
for i in range(0, max_images):
print("Processing batch {}/{}".format(i, max_images))
inputs = sess.run(self.inputs)
# Optimizing the G_A network
_, fake_B_temp, smask_a,summary_str = sess.run(
[to_train_A,
self.fake_images_b,
self.masks[0],
self.g_A_loss_summ],
feed_dict={
self.input_a:
inputs['images_i'],
self.input_b:
inputs['images_j'],
self.learning_rate: curr_lr,
self.transition_rate: curr_tr,
self.donorm: donorm,
}
)
writer.add_summary(summary_str, epoch * max_images + i)
fake_B_temp1 = self.fake_image_pool(
self.num_fake_inputs, fake_B_temp, smask_a, self.fake_images_B)
# Optimizing the D_B network
_,summary_str = sess.run(
[self.d_B_trainer, self.d_B_loss_summ],
feed_dict={
self.input_a:
inputs['images_i'],
self.input_b:
inputs['images_j'],
self.learning_rate: curr_lr,
self.fake_pool_B: fake_B_temp1['im'],
self.fake_pool_B_mask: fake_B_temp1['mask'],
self.transition_rate: curr_tr,
self.donorm: donorm,
}
)
writer.add_summary(summary_str, epoch * max_images + i)
# Optimizing the G_B network
_, fake_A_temp, smask_b, summary_str = sess.run(
[to_train_B,
self.fake_images_a,
self.masks[1],
self.g_B_loss_summ],
feed_dict={
self.input_a:
inputs['images_i'],
self.input_b:
inputs['images_j'],
self.learning_rate: curr_lr,
self.transition_rate: curr_tr,
self.donorm: donorm,
}
)
writer.add_summary(summary_str, epoch * max_images + i)
fake_A_temp1 = self.fake_image_pool(
self.num_fake_inputs, fake_A_temp, smask_b ,self.fake_images_A)
# Optimizing the D_A network
_, mask_tmp__,summary_str = sess.run(
[self.d_A_trainer,self.masks_, self.d_A_loss_summ],
feed_dict={
self.input_a:
inputs['images_i'],
self.input_b:
inputs['images_j'],
self.learning_rate: curr_lr,
self.fake_pool_A: fake_A_temp1['im'],
self.fake_pool_A_mask: fake_A_temp1['mask'],
self.transition_rate: curr_tr,
self.donorm: donorm,
}
)
writer.add_summary(summary_str, epoch * max_images + i)
writer.flush()
self.num_fake_inputs += 1
sess.run(tf.assign(self.global_step, epoch + 1))
coord.request_stop()
coord.join(threads)
writer.add_graph(sess.graph)
def test(self):
"""Test Function."""
print("Testing the results")
self.inputs = data_loader.load_data(
self._dataset_name, self._size_before_crop,
False, self._do_flipping)
self.model_setup()
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)
saver.restore(sess, chkpt_fname)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
self._num_imgs_to_save = cyclegan_datasets.DATASET_TO_SIZES[
self._dataset_name]
self.save_images_bis(sess, sess.run(self.global_step))
coord.request_stop()
coord.join(threads)
def parse_args():
desc = "Tensorflow implementation of cycleGan using attention"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--to_train', type=int, default=True, help='Whether it is train or false.')
parser.add_argument('--log_dir',
type=str,
default=None,
help='Where the data is logged to.')
parser.add_argument('--config_filename', type=str, default='train', help='The name of the configuration file.')
parser.add_argument('--checkpoint_dir', type=str, default='', help='The name of the train/test split.')
parser.add_argument('--skip', type=bool, default=False,
help='Whether to add skip connection between input and output.')
parser.add_argument('--switch', type=int, default=30,
help='In what epoch the FG starts to be fed to the discriminator')
parser.add_argument('--threshold', type=float, default=0.1,
help='The threshold value to select the FG')
return parser.parse_args()
def main():
"""
:param to_train: Specify whether it is training or testing. 1: training; 2:
resuming from latest checkpoint; 0: testing.
:param log_dir: The root dir to save checkpoints and imgs. The actual dir
is the root dir appended by the folder with the name timestamp.
:param config_filename: The configuration file.
:param checkpoint_dir: The directory that saves the latest checkpoint. It
only takes effect when to_train == 2.
:param skip: A boolean indicating whether to add skip connection between
input and output.
"""
args = parse_args()
if args is None:
exit()
to_train = args.to_train
log_dir = args.log_dir
config_filename = args.config_filename
checkpoint_dir = args.checkpoint_dir
skip = args.skip
switch = args.switch
threshold_fg = args.threshold
if not os.path.isdir(log_dir):
os.makedirs(log_dir)
with open(config_filename) as config_file:
config = json.load(config_file)
lambda_a = float(config['_LAMBDA_A']) if '_LAMBDA_A' in config else 10.0
lambda_b = float(config['_LAMBDA_B']) if '_LAMBDA_B' in config else 10.0
pool_size = int(config['pool_size']) if 'pool_size' in config else 50
to_restore = (to_train == 2)
base_lr = float(config['base_lr']) if 'base_lr' in config else 0.0002
max_step = int(config['max_step']) if 'max_step' in config else 200
network_version = str(config['network_version'])
dataset_name = str(config['dataset_name'])
do_flipping = bool(config['do_flipping'])
cyclegan_model = CycleGAN(pool_size, lambda_a, lambda_b, log_dir,
to_restore, base_lr, max_step, network_version,
dataset_name, checkpoint_dir, do_flipping, skip,
switch, threshold_fg)
if to_train > 0:
cyclegan_model.train()
else:
cyclegan_model.test()
if __name__ == '__main__':
main()
================================================
FILE: model.py
================================================
"""Code for constructing the model and get the outputs from the model."""
import tensorflow as tf
import numpy as np
import layers
# The number of samples per batch.
BATCH_SIZE = 1
# The height of each image.
IMG_HEIGHT = 256
# The width of each image.
IMG_WIDTH = 256
# The number of color channels per image.
IMG_CHANNELS = 3
POOL_SIZE = 50
ngf = 32
ndf = 64
def get_outputs(inputs, skip=False):
images_a = inputs['images_a']
images_b = inputs['images_b']
fake_pool_a = inputs['fake_pool_a']
fake_pool_b = inputs['fake_pool_b']
fake_pool_a_mask = inputs['fake_pool_a_mask']
fake_pool_b_mask = inputs['fake_pool_b_mask']
transition_rate = inputs['transition_rate']
donorm = inputs['donorm']
with tf.variable_scope("Model") as scope:
current_autoenc = autoenc_upsample
current_discriminator = discriminator
current_generator = build_generator_resnet_9blocks
mask_a = current_autoenc(images_a, "g_A_ae")
mask_b = current_autoenc(images_b, "g_B_ae")
mask_a = tf.concat([mask_a] * 3, axis=3)
mask_b = tf.concat([mask_b] * 3, axis=3)
mask_a_on_a = tf.multiply(images_a, mask_a)
mask_b_on_b = tf.multiply(images_b, mask_b)
prob_real_a_is_real = current_discriminator(images_a, mask_a, transition_rate, donorm, "d_A")
prob_real_b_is_real = current_discriminator(images_b, mask_b, transition_rate, donorm, "d_B")
fake_images_b_from_g = current_generator(images_a, name="g_A", skip=skip)
fake_images_b = tf.multiply(fake_images_b_from_g, mask_a) + tf.multiply(images_a, 1-mask_a)
fake_images_a_from_g = current_generator(images_b, name="g_B", skip=skip)
fake_images_a = tf.multiply(fake_images_a_from_g, mask_b) + tf.multiply(images_b, 1-mask_b)
scope.reuse_variables()
prob_fake_a_is_real = current_discriminator(fake_images_a, mask_b, transition_rate, donorm, "d_A")
prob_fake_b_is_real = current_discriminator(fake_images_b, mask_a, transition_rate, donorm, "d_B")
mask_acycle = current_autoenc(fake_images_a, "g_A_ae")
mask_bcycle = current_autoenc(fake_images_b, "g_B_ae")
mask_bcycle = tf.concat([mask_bcycle] * 3, axis=3)
mask_acycle = tf.concat([mask_acycle] * 3, axis=3)
mask_acycle_on_fakeA = tf.multiply(fake_images_a, mask_acycle)
mask_bcycle_on_fakeB = tf.multiply(fake_images_b, mask_bcycle)
cycle_images_a_from_g = current_generator(fake_images_b, name="g_B", skip=skip)
cycle_images_b_from_g = current_generator(fake_images_a, name="g_A", skip=skip)
cycle_images_a = tf.multiply(cycle_images_a_from_g,
mask_bcycle) + tf.multiply(fake_images_b, 1 - mask_bcycle)
cycle_images_b = tf.multiply(cycle_images_b_from_g,
mask_acycle) + tf.multiply(fake_images_a, 1 - mask_acycle)
scope.reuse_variables()
prob_fake_pool_a_is_real = current_discriminator(fake_pool_a, fake_pool_a_mask, transition_rate, donorm, "d_A")
prob_fake_pool_b_is_real = current_discriminator(fake_pool_b, fake_pool_b_mask, transition_rate, donorm, "d_B")
return {
'prob_real_a_is_real': prob_real_a_is_real,
'prob_real_b_is_real': prob_real_b_is_real,
'prob_fake_a_is_real': prob_fake_a_is_real,
'prob_fake_b_is_real': prob_fake_b_is_real,
'prob_fake_pool_a_is_real': prob_fake_pool_a_is_real,
'prob_fake_pool_b_is_real': prob_fake_pool_b_is_real,
'cycle_images_a': cycle_images_a,
'cycle_images_b': cycle_images_b,
'fake_images_a': fake_images_a,
'fake_images_b': fake_images_b,
'masked_ims': [mask_a_on_a, mask_b_on_b, mask_acycle_on_fakeA, mask_bcycle_on_fakeB],
'masks': [mask_a, mask_b, mask_acycle, mask_bcycle],
'masked_gen_ims' : [fake_images_b_from_g, fake_images_a_from_g , cycle_images_a_from_g, cycle_images_b_from_g],
'mask_tmp' : mask_a,
}
def autoenc_upsample(inputae, name):
with tf.variable_scope(name):
f = 7
ks = 3
padding = "REFLECT"
pad_input = tf.pad(inputae, [[0, 0], [ks, ks], [
ks, ks], [0, 0]], padding)
o_c1 = layers.general_conv2d(
pad_input, tf.constant(True, dtype=bool), ngf, f, f, 2, 2, 0.02, name="c1")
o_c2 = layers.general_conv2d(
o_c1, tf.constant(True, dtype=bool), ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c2")
o_r1 = build_resnet_block_Att(o_c2, ngf * 2, "r1", padding)
size_d1 = o_r1.get_shape().as_list()
o_c4 = layers.upsamplingDeconv(o_r1, size=[size_d1[1] * 2, size_d1[2] * 2], is_scale=False, method=1,
align_corners=False,name= 'up1')
# o_c4_pad = tf.pad(o_c4, [[0, 0], [1, 1], [1, 1], [0, 0]], "REFLECT", name='padup1')
o_c4_end = layers.general_conv2d(o_c4, tf.constant(True, dtype=bool), ngf * 2, (3, 3), (1, 1), padding='VALID', name='c4')
size_d2 = o_c4_end.get_shape().as_list()
o_c5 = layers.upsamplingDeconv(o_c4_end, size=[size_d2[1] * 2, size_d2[2] * 2], is_scale=False, method=1,
align_corners=False, name='up2')
# o_c5_pad = tf.pad(o_c5, [[0, 0], [1, 1], [1, 1], [0, 0]], "REFLECT", name='padup2')
oc5_end = layers.general_conv2d(o_c5, tf.constant(True, dtype=bool), ngf , (3, 3), (1, 1), padding='VALID', name='c5')
# o_c6 = tf.pad(oc5_end, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT", name='padup3')
o_c6_end = layers.general_conv2d(oc5_end, tf.constant(False, dtype=bool),
1 , (f, f), (1, 1), padding='VALID', name='c6', do_relu=False)
return tf.nn.sigmoid(o_c6_end,'sigmoid')
def build_resnet_block(inputres, dim, name="resnet", padding="REFLECT"):
"""build a single block of resnet.
:param inputres: inputres
:param dim: dim
:param name: name
:param padding: for tensorflow version use REFLECT; for pytorch version use
CONSTANT
:return: a single block of resnet.
"""
with tf.variable_scope(name):
out_res = tf.pad(inputres, [[0, 0], [1, 1], [
1, 1], [0, 0]], padding)
out_res = layers.general_conv2d(
out_res, tf.constant(True, dtype=bool), dim, 3, 3, 1, 1, 0.02, "VALID", "c1")
out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
out_res = layers.general_conv2d(
out_res, tf.constant(True, dtype=bool), dim, 3, 3, 1, 1, 0.02, "VALID", "c2", do_relu=False)
return tf.nn.relu(out_res + inputres)
def build_resnet_block_Att(inputres, dim, name="resnet", padding="REFLECT"):
"""build a single block of resnet.
:param inputres: inputres
:param dim: dim
:param name: name
:param padding: for tensorflow version use REFLECT; for pytorch version use
CONSTANT
:return: a single block of resnet.
"""
with tf.variable_scope(name):
out_res = tf.pad(inputres, [[0, 0], [1, 1], [
1, 1], [0, 0]], padding)
out_res = layers.general_conv2d(
out_res, tf.constant(True, dtype=bool), dim, 3, 3, 1, 1, 0.02, "VALID", "c1")
out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
out_res = layers.general_conv2d(
out_res, tf.constant(True, dtype=bool), dim, 3, 3, 1, 1, 0.02, "VALID", "c2", do_relu=False)
return tf.nn.relu(out_res + inputres)
def build_generator_resnet_9blocks(inputgen, name="generator", skip=False):
with tf.variable_scope(name):
f = 7
ks = 3
padding = "CONSTANT"
inputgen = tf.pad(inputgen, [[0, 0], [ks, ks], [
ks, ks], [0, 0]], padding)
o_c1 = layers.general_conv2d(
inputgen, tf.constant(True, dtype=bool), ngf, f, f, 1, 1, 0.02, name="c1")
o_c2 = layers.general_conv2d(
o_c1, tf.constant(True, dtype=bool),ngf * 2, ks, ks, 2, 2, 0.02, padding='same', name="c2")
o_c3 = layers.general_conv2d(
o_c2, tf.constant(True, dtype=bool), ngf * 4, ks, ks, 2, 2, 0.02, padding='same', name="c3")
o_r1 = build_resnet_block(o_c3, ngf * 4, "r1", padding)
o_r2 = build_resnet_block(o_r1, ngf * 4, "r2", padding)
o_r3 = build_resnet_block(o_r2, ngf * 4, "r3", padding)
o_r4 = build_resnet_block(o_r3, ngf * 4, "r4", padding)
o_r5 = build_resnet_block(o_r4, ngf * 4, "r5", padding)
o_r6 = build_resnet_block(o_r5, ngf * 4, "r6", padding)
o_r7 = build_resnet_block(o_r6, ngf * 4, "r7", padding)
o_r8 = build_resnet_block(o_r7, ngf * 4, "r8", padding)
o_r9 = build_resnet_block(o_r8, ngf * 4, "r9", padding)
o_c4 = layers.general_deconv2d(
o_r9, [BATCH_SIZE, 128, 128, ngf * 2], ngf * 2, ks, ks, 2, 2, 0.02,
"SAME", "c4")
o_c5 = layers.general_deconv2d(
o_c4, [BATCH_SIZE, 256, 256, ngf], ngf, ks, ks, 2, 2, 0.02,
"SAME", "c5")
o_c6 = layers.general_conv2d(o_c5, tf.constant(False, dtype=bool), IMG_CHANNELS, f, f, 1, 1,
0.02, "SAME", "c6", do_relu=False)
if skip is True:
out_gen = tf.nn.tanh(inputgen + o_c6, "t1")
else:
out_gen = tf.nn.tanh(o_c6, "t1")
return out_gen
def discriminator(inputdisc, mask, transition_rate, donorm, name="discriminator"):
with tf.variable_scope(name):
mask = tf.cast(tf.greater_equal(mask, transition_rate), tf.float32)
inputdisc = tf.multiply(inputdisc, mask)
f = 4
padw = 2
pad_input = tf.pad(inputdisc, [[0, 0], [padw, padw], [
padw, padw], [0, 0]], "CONSTANT")
o_c1 = layers.general_conv2d(pad_input, donorm, ndf, f, f, 2, 2,
0.02, "VALID", "c1",
relufactor=0.2)
pad_o_c1 = tf.pad(o_c1, [[0, 0], [padw, padw], [
padw, padw], [0, 0]], "CONSTANT")
o_c2 = layers.general_conv2d(pad_o_c1, donorm, ndf * 2, f, f, 2, 2,
0.02, "VALID", "c2", relufactor=0.2)
pad_o_c2 = tf.pad(o_c2, [[0, 0], [padw, padw], [
padw, padw], [0, 0]], "CONSTANT")
o_c3 = layers.general_conv2d(pad_o_c2, donorm, ndf * 4, f, f, 2, 2,
0.02, "VALID", "c3", relufactor=0.2)
pad_o_c3 = tf.pad(o_c3, [[0, 0], [padw, padw], [
padw, padw], [0, 0]], "CONSTANT")
o_c4 = layers.general_conv2d(pad_o_c3, donorm, ndf * 8, f, f, 1, 1,
0.02, "VALID", "c4", relufactor=0.2)
# o_c4 = tf.multiply(o_c4, mask_4)
pad_o_c4 = tf.pad(o_c4, [[0, 0], [padw, padw], [
padw, padw], [0, 0]], "CONSTANT")
o_c5 = layers.general_conv2d(
pad_o_c4, tf.constant(False, dtype=bool), 1, f, f, 1, 1, 0.02, "VALID", "c5", do_relu=False)
return o_c5
================================================
FILE: test/__init__.py
================================================
================================================
FILE: test/evaluate_losses.py
================================================
import numpy as np
import tensorflow as tf
from .. import losses
def test_evaluate_g_losses(sess):
_LAMBDA_A = 10
_LAMBDA_B = 10
input_a = tf.random_uniform((5, 7), maxval=1)
cycle_images_a = input_a + 1
input_b = tf.random_uniform((5, 7), maxval=1)
cycle_images_b = input_b - 2
cycle_consistency_loss_a = _LAMBDA_A * losses.cycle_consistency_loss(
real_images=input_a, generated_images=cycle_images_a,
)
cycle_consistency_loss_b = _LAMBDA_B * losses.cycle_consistency_loss(
real_images=input_b, generated_images=cycle_images_b,
)
prob_fake_a_is_real = tf.constant([0, 1.0, 0])
prob_fake_b_is_real = tf.constant([1.0, 1.0, 0])
lsgan_loss_a = losses.lsgan_loss_generator(prob_fake_a_is_real)
lsgan_loss_b = losses.lsgan_loss_generator(prob_fake_b_is_real)
assert np.isclose(sess.run(lsgan_loss_a), 0.66666669) and \
np.isclose(sess.run(lsgan_loss_b), 0.3333333) and \
np.isclose(sess.run(cycle_consistency_loss_a), 10) and \
np.isclose(sess.run(cycle_consistency_loss_b), 20)
def test_evaluate_d_losses(sess):
prob_real_a_is_real = tf.constant([1.0, 1.0, 0])
prob_fake_pool_a_is_real = tf.constant([1.0, 0, 0])
d_loss_A = losses.lsgan_loss_discriminator(
prob_real_is_real=prob_real_a_is_real,
prob_fake_is_real=prob_fake_pool_a_is_real)
assert np.isclose(sess.run(d_loss_A), 0.3333333)
================================================
FILE: test/evaluate_networks.py
================================================
import numpy as np
import tensorflow as tf
from .. import model
def test_evaluate_g(sess):
x_val = np.ones_like(np.random.randn(1, 16, 16, 3)).astype(np.float32)
for i in range(16):
for j in range(16):
for k in range(3):
x_val[0][i][j][k] = ((i + j + k) % 2) / 2.0
inputs = {
'images_a': tf.stack(x_val),
'images_b': tf.stack(x_val),
'fake_pool_a': tf.zeros([1, 16, 16, 3]),
'fake_pool_b': tf.zeros([1, 16, 16, 3]),
}
outputs = model.get_outputs(inputs)
sess.run(tf.global_variables_initializer())
assert sess.run(outputs['fake_images_a'][0][5][7][0]) == 5
def test_evaluate_d(sess):
x_val = np.ones_like(np.random.randn(1, 16, 16, 3)).astype(np.float32)
for i in range(16):
for j in range(16):
for k in range(3):
x_val[0][i][j][k] = ((i + j + k) % 2) / 2.0
inputs = {
'images_a': tf.stack(x_val),
'images_b': tf.stack(x_val),
'fake_pool_a': tf.zeros([1, 16, 16, 3]),
'fake_pool_b': tf.zeros([1, 16, 16, 3]),
}
outputs = model.get_outputs(inputs)
sess.run(tf.global_variables_initializer())
assert sess.run(outputs['prob_real_a_is_real'][0][3][3][0]) == 5
================================================
FILE: test/test_losses.py
================================================
import numpy as np
import tensorflow as tf
from .. import losses
def test_cycle_consistency_loss_is_none_with_perfect_fakes(sess):
batch_size, height, width, channels = [16, 2, 3, 1]
tf.set_random_seed(0)
images = tf.random_uniform((batch_size, height, width, channels), maxval=1)
loss = losses.cycle_consistency_loss(
real_images=images,
generated_images=images,
)
assert sess.run(loss) == 0
def test_cycle_consistency_loss_is_positive_with_imperfect_fake_x(sess):
batch_size, height, width, channels = [16, 2, 3, 1]
tf.set_random_seed(0)
real_images = tf.random_uniform(
(batch_size, height, width, channels), maxval=1,
)
generated_images = real_images + 1
loss = losses.cycle_consistency_loss(
real_images=real_images,
generated_images=generated_images,
)
assert sess.run(loss) == 1
def test_lsgan_loss_discrim_is_none_with_perfect_discrimination(sess):
batch_size = 100
prob_real_is_real = tf.ones((batch_size))
prob_fake_is_real = tf.zeros((batch_size))
loss = losses.lsgan_loss_discriminator(
prob_real_is_real, prob_fake_is_real,
)
assert sess.run(loss) == 0
def test_lsgan_loss_discrim_is_positive_with_imperfect_discrimination(sess):
batch_size = 100
prob_real_is_real = tf.ones((batch_size)) * 0.4
prob_fake_is_real = tf.ones((batch_size)) * 0.7
loss = losses.lsgan_loss_discriminator(
prob_real_is_real, prob_fake_is_real,
)
loss = sess.run(loss)
np.testing.assert_almost_equal(loss, (0.6 * 0.6 + 0.7 * 0.7) / 2)
================================================
FILE: test/test_model.py
================================================
import tensorflow as tf
from dl_research.testing import slow
from .. import model
# -----------------------------------------------------------------------------
@slow
def test_output_sizes(sess):
images_size = [
model.BATCH_SIZE,
model.IMG_HEIGHT,
model.IMG_WIDTH,
model.IMG_CHANNELS,
]
pool_size = [
model.POOL_SIZE,
model.IMG_HEIGHT,
model.IMG_WIDTH,
model.IMG_CHANNELS,
]
inputs = {
'images_a': tf.ones(images_size),
'images_b': tf.ones(images_size),
'fake_pool_a': tf.ones(pool_size),
'fake_pool_b': tf.ones(pool_size),
}
outputs = model.get_outputs(inputs)
assert outputs['prob_real_a_is_real'].get_shape().as_list() == [
model.BATCH_SIZE, 32, 32, 1,
]
assert outputs['prob_real_b_is_real'].get_shape().as_list() == [
model.BATCH_SIZE, 32, 32, 1,
]
assert outputs['prob_fake_a_is_real'].get_shape().as_list() == [
model.BATCH_SIZE, 32, 32, 1,
]
assert outputs['prob_fake_b_is_real'].get_shape().as_list() == [
model.BATCH_SIZE, 32, 32, 1,
]
assert outputs['prob_fake_pool_a_is_real'].get_shape().as_list() == [
model.POOL_SIZE, 32, 32, 1,
]
assert outputs['prob_fake_pool_b_is_real'].get_shape().as_list() == [
model.POOL_SIZE, 32, 32, 1,
]
assert outputs['cycle_images_a'].get_shape().as_list() == images_size
assert outputs['cycle_images_b'].get_shape().as_list() == images_size