Repository: AlamiMejjati/Unsupervised-Attention-guided-Image-to-Image-Translation Branch: master Commit: a691e3a94b7a Files: 25 Total size: 63.5 KB Directory structure: gitextract__oabw19d/ ├── LICENSE ├── README.md ├── Trained_models/ │ └── README.md ├── __init__.py ├── configs/ │ ├── exp_01.json │ ├── exp_01_test.json │ ├── exp_02.json │ ├── exp_02_test.json │ ├── exp_04.json │ ├── exp_04_test.json │ ├── exp_05.json │ └── exp_05_test.json ├── create_cyclegan_dataset.py ├── cyclegan_datasets.py ├── data_loader.py ├── download_datasets.sh ├── layers.py ├── losses.py ├── main.py ├── model.py └── test/ ├── __init__.py ├── evaluate_losses.py ├── evaluate_networks.py ├── test_losses.py └── test_model.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2017 Harry Yang Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Unsupervised Attention-guided Image-to-Image Translation This repository contains the TensorFlow code for our NeurIPS 2018 paper [“Unsupervised Attention-guided Image-to-Image Translation”](https://arxiv.org/pdf/1806.02311.pdf). This code is based on the TensorFlow implementation of CycleGAN provided by [Harry Yang](https://github.com/leehomyc/cyclegan-1). You may need to train several times as the quality of the results are sensitive to the initialization. By leveraging attention, our architecture (shown in the figure bellow) only maps relevant areas of the image, and by doing so, further enhances the quality of image to image translation. Our model architecture is defined as depicted below, please refer to the paper for more details: ## Mapping results ### Our learned attention maps The figure bellow displays automatically learned attention maps on various translation datasets: ### Horse-to-Zebra image translation results: #### Horse-to-Zebra: Top row in the figure below are input images and bottom row are the mappings produced by our algorithm. #### Zebra-to-Horse: Top row in the figure below are input images and bottom row are the mappings produced by our algorithm. ### Apple-to-Orange image translation results: #### Apple-to-Orange: Top row in the figure below are input images and bottom row are the mappings produced by our algorithm. #### Orange-to-Apple: Top row in the figure below are input images and bottom row are the mappings produced by our algorithm. ### Getting Started with the code ### Prepare dataset * You can either download one of the defaults CycleGAN datasets or use your own dataset. * Download a CycleGAN dataset (e.g. horse2zebra, apple2orange): ```bash bash ./download_datasets.sh horse2zebra ``` * Use your own dataset: put images from each domain at folder_a and folder_b respectively. * Create the csv file as input to the data loader. * Edit the [```cyclegan_datasets.py```](cyclegan_datasets.py) file. For example, if you have a horse2zebra_train dataset which contains 1067 horse images and 1334 zebra images (both in JPG format), you can just edit the [```cyclegan_datasets.py```](cyclegan_datasets.py) as following: ```python DATASET_TO_SIZES = { 'horse2zebra_train': 1334 } PATH_TO_CSV = { 'horse2zebra_train': './AGGAN/input/horse2zebra/horse2zebra_train.csv' } DATASET_TO_IMAGETYPE = { 'horse2zebra_train': '.jpg' } ``` * Run create_cyclegan_dataset.py: ```bash python -m create_cyclegan_dataset --image_path_a='./input/horse2zebra/trainB' --image_path_b='./input/horse2zebra/trainA' --dataset_name="horse2zebra_train" --do_shuffle=0 ``` ### Training * Create the configuration file. The configuration file contains basic information for training/testing. An example of the configuration file could be found at [```configs/exp_01.json```](configs/exp_01.json). * Start training: ```bash python main.py --to_train=1 --log_dir=./output/AGGAN/exp_01 --config_filename=./configs/exp_01.json ``` * Check the intermediate results: * Tensorboard ```bash tensorboard --port=6006 --logdir=./output/AGGAN/exp_01/#timestamp# ``` * Check the html visualization at ./output/AGGAN/exp_01/#timestamp#/epoch_#id#.html. ### Restoring from the previous checkpoint ```bash python main.py --to_train=2 --log_dir=./output/AGGAN/exp_01 --config_filename=./configs/exp_01.json --checkpoint_dir=./output/AGGAN/exp_01/#timestamp# ``` ### Testing * Create the testing dataset: * Edit the cyclegan_datasets.py file the same way as training. * Create the csv file as the input to the data loader: ```bash python -m create_cyclegan_dataset --image_path_a='./input/horse2zebra/testB' --image_path_b='./input/horse2zebra/testA' --dataset_name="horse2zebra_test" --do_shuffle=0 ``` * Run testing: ```bash python main.py --to_train=0 --log_dir=./output/AGGAN/exp_01 --config_filename=./configs/exp_01_test.json --checkpoint_dir=./output/AGGAN/exp_01/#old_timestamp# ``` * Trained models: Our trained models can be downloaded from https://drive.google.com/open?id=1YEQMJK41KQj_-HfKFneSI12nWpTajgzT ================================================ FILE: Trained_models/README.md ================================================ #### Trained models. Our trained models can be downloaded from https://drive.google.com/open?id=1YEQMJK41KQj_-HfKFneSI12nWpTajgzT When using the trained parameters of horse to zebra image translation. Note that in our case the Source (input_a) is zebra and the target (input_b) is zebra and not the opposite. ================================================ FILE: __init__.py ================================================ ================================================ FILE: configs/exp_01.json ================================================ { "description": "The official PyTorch version of CycleGAN.", "pool_size": 50, "base_lr":0.0001, "max_step": 100, "network_version": "pytorch", "dataset_name": "horse2zebra_train", "do_flipping": 1, "_LAMBDA_A": 10, "_LAMBDA_B": 10 } ================================================ FILE: configs/exp_01_test.json ================================================ { "description": "Testing with trained model.", "network_version": "pytorch", "dataset_name": "horse2zebra_test", "do_flipping": 0 } ================================================ FILE: configs/exp_02.json ================================================ { "description": "The official PyTorch version of CycleGAN.", "pool_size": 50, "base_lr":0.0001, "max_step": 100, "network_version": "pytorch", "dataset_name": "apple2orange_train", "do_flipping": 1, "_LAMBDA_A": 10, "_LAMBDA_B": 10 } ================================================ FILE: configs/exp_02_test.json ================================================ { "description": "Testing with trained model.", "network_version": "pytorch", "dataset_name": "apple2orange_test", "do_flipping": 0 } ================================================ FILE: configs/exp_04.json ================================================ { "description": "The official PyTorch version of CycleGAN.", "pool_size": 50, "base_lr":0.0001, "max_step": 100, "network_version": "pytorch", "dataset_name": "lion2tiger_train", "do_flipping": 1, "_LAMBDA_A": 10, "_LAMBDA_B": 10 } ================================================ FILE: configs/exp_04_test.json ================================================ { "description": "Testing with trained model.", "network_version": "pytorch", "dataset_name": "lion2tiger_test", "do_flipping": 0 } ================================================ FILE: configs/exp_05.json ================================================ { "description": "The official PyTorch version of CycleGAN.", "pool_size": 50, "base_lr":0.0002, "max_step": 200, "network_version": "pytorch", "dataset_name": "summer2winter_yosemite_train", "do_flipping": 1, "_LAMBDA_A": 10, "_LAMBDA_B": 10 } ================================================ FILE: configs/exp_05_test.json ================================================ { "description": "Testing with trained model.", "network_version": "pytorch", "dataset_name": "summer2winter_yosemite_test", "do_flipping": 0 } ================================================ FILE: create_cyclegan_dataset.py ================================================ """Create datasets for training and testing.""" import csv import os import random import click import cyclegan_datasets def create_list(foldername, fulldir=True, suffix=".jpg"): """ :param foldername: The full path of the folder. :param fulldir: Whether to return the full path or not. :param suffix: Filter by suffix. :return: The list of filenames in the folder with given suffix. """ file_list_tmp = os.listdir(foldername) file_list = [] if fulldir: for item in file_list_tmp: if item.endswith(suffix): file_list.append(os.path.join(foldername, item)) else: for item in file_list_tmp: if item.endswith(suffix): file_list.append(item) return file_list @click.command() @click.option('--image_path_a', type=click.STRING, default='./input/horse2zebra/trainA', help='The path to the images from domain_a.') @click.option('--image_path_b', type=click.STRING, default='./input/horse2zebra/trainB', help='The path to the images from domain_b.') @click.option('--dataset_name', type=click.STRING, default='horse2zebra_train', help='The name of the dataset in cyclegan_dataset.') @click.option('--do_shuffle', type=click.BOOL, default=False, help='Whether to shuffle images when creating the dataset.') def create_dataset(image_path_a, image_path_b, dataset_name, do_shuffle): list_a = create_list(image_path_a, True, cyclegan_datasets.DATASET_TO_IMAGETYPE[dataset_name]) list_b = create_list(image_path_b, True, cyclegan_datasets.DATASET_TO_IMAGETYPE[dataset_name]) output_path = cyclegan_datasets.PATH_TO_CSV[dataset_name] num_rows = cyclegan_datasets.DATASET_TO_SIZES[dataset_name] all_data_tuples = [] for i in range(num_rows): all_data_tuples.append(( list_a[i % len(list_a)], list_b[i % len(list_b)] )) if do_shuffle is True: random.shuffle(all_data_tuples) with open(output_path, 'w') as csv_file: csv_writer = csv.writer(csv_file) for data_tuple in enumerate(all_data_tuples): csv_writer.writerow(list(data_tuple[1])) if __name__ == '__main__': create_dataset() ================================================ FILE: cyclegan_datasets.py ================================================ """Contains the standard train/test splits for the cyclegan data.""" """The size of each dataset. Usually it is the maximum number of images from each domain.""" DATASET_TO_SIZES = { 'horse2zebra_train': 1334, 'horse2zebra_test': 140, 'apple2orange_train': 1019, 'apple2orange_test': 266, 'lion2tiger_train': 916, 'lion2tiger_test': 103, 'summer2winter_yosemite_train': 1231, 'summer2winter_yosemite_test': 309, } """The image types of each dataset. Currently only supports .jpg or .png""" DATASET_TO_IMAGETYPE = { 'horse2zebra_train': '.jpg', 'horse2zebra_test': '.jpg', 'apple2orange_train': '.jpg', 'apple2orange_test': '.jpg', 'lion2tiger_train': '.jpg', 'lion2tiger_test': '.jpg', 'summer2winter_yosemite_train': '.jpg', 'summer2winter_yosemite_test': '.jpg', } """The path to the output csv file.""" PATH_TO_CSV = { 'horse2zebra_train': './input/horse2zebra/horse2zebra_train.csv', 'horse2zebra_test': './input/horse2zebra/horse2zebra_test.csv', 'apple2orange_train': './input/apple2orange/apple2orange_train.csv', 'apple2orange_test': './input/apple2orange/apple2orange_test.csv', 'lion2tiger_train': './input/lion2tiger/lion2tiger_train.csv', 'lion2tiger_test': './input/lion2tiger/lion2tiger_test.csv', 'summer2winter_yosemite_train': './input/summer2winter_yosemite/summer2winter_yosemite_train.csv', 'summer2winter_yosemite_test': './input/summer2winter_yosemite/summer2winter_yosemite_test.csv' } ================================================ FILE: data_loader.py ================================================ import tensorflow as tf import cyclegan_datasets import model def _load_samples(csv_name, image_type): filename_queue = tf.train.string_input_producer( [csv_name]) reader = tf.TextLineReader() _, csv_filename = reader.read(filename_queue) record_defaults = [tf.constant([], dtype=tf.string), tf.constant([], dtype=tf.string)] filename_i, filename_j = tf.decode_csv( csv_filename, record_defaults=record_defaults) file_contents_i = tf.read_file(filename_i) file_contents_j = tf.read_file(filename_j) if image_type == '.jpg': image_decoded_A = tf.image.decode_jpeg( file_contents_i, channels=model.IMG_CHANNELS) image_decoded_B = tf.image.decode_jpeg( file_contents_j, channels=model.IMG_CHANNELS) elif image_type == '.png': image_decoded_A = tf.image.decode_png( file_contents_i, channels=model.IMG_CHANNELS, dtype=tf.uint8) image_decoded_B = tf.image.decode_png( file_contents_j, channels=model.IMG_CHANNELS, dtype=tf.uint8) return image_decoded_A, image_decoded_B def load_data(dataset_name, image_size_before_crop, do_shuffle=True, do_flipping=False): """ :param dataset_name: The name of the dataset. :param image_size_before_crop: Resize to this size before random cropping. :param do_shuffle: Shuffle switch. :param do_flipping: Flip switch. :return: """ if dataset_name not in cyclegan_datasets.DATASET_TO_SIZES: raise ValueError('split name %s was not recognized.' % dataset_name) csv_name = cyclegan_datasets.PATH_TO_CSV[dataset_name] image_i, image_j = _load_samples( csv_name, cyclegan_datasets.DATASET_TO_IMAGETYPE[dataset_name]) inputs = { 'image_i': image_i, 'image_j': image_j } # Preprocessing: inputs['image_i'] = tf.image.resize_images( inputs['image_i'], [image_size_before_crop, image_size_before_crop]) inputs['image_j'] = tf.image.resize_images( inputs['image_j'], [image_size_before_crop, image_size_before_crop]) if do_flipping is True: inputs['image_i'] = tf.image.random_flip_left_right(inputs['image_i'], seed=1) inputs['image_j'] = tf.image.random_flip_left_right(inputs['image_j'], seed=1) inputs['image_i'] = tf.random_crop( inputs['image_i'], [model.IMG_HEIGHT, model.IMG_WIDTH, 3], seed=1) inputs['image_j'] = tf.random_crop( inputs['image_j'], [model.IMG_HEIGHT, model.IMG_WIDTH, 3], seed=1) inputs['image_i'] = tf.subtract(tf.div(inputs['image_i'], 127.5), 1) inputs['image_j'] = tf.subtract(tf.div(inputs['image_j'], 127.5), 1) # Batch if do_shuffle is True: inputs['images_i'], inputs['images_j'] = tf.train.shuffle_batch( [inputs['image_i'], inputs['image_j']], 1, 5000, 100, seed=1) else: inputs['images_i'], inputs['images_j'] = tf.train.batch( [inputs['image_i'], inputs['image_j']], 1) return inputs ================================================ FILE: download_datasets.sh ================================================ FILE=$1 if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos" exit 1 fi mkdir ./input URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip ZIP_FILE=./input/$FILE.zip TARGET_DIR=./input/$FILE/ wget -N $URL -O $ZIP_FILE mkdir $TARGET_DIR unzip $ZIP_FILE -d ./input/ rm $ZIP_FILE ================================================ FILE: layers.py ================================================ import tensorflow as tf def lrelu(x, leak=0.2, name="lrelu", alt_relu_impl=False): with tf.variable_scope(name): if alt_relu_impl: f1 = 0.5 * (1 + leak) f2 = 0.5 * (1 - leak) return f1 * x + f2 * abs(x) else: return tf.maximum(x, leak * x) def instance_norm(x): with tf.variable_scope("instance_norm"): epsilon = 1e-5 mean, var = tf.nn.moments(x, [1, 2], keep_dims=True) scale = tf.get_variable('scale', [x.get_shape()[-1]], initializer=tf.truncated_normal_initializer( mean=1.0, stddev=0.02 )) offset = tf.get_variable( 'offset', [x.get_shape()[-1]], initializer=tf.constant_initializer(0.0) ) out = scale * tf.div(x - mean, tf.sqrt(var + epsilon)) + offset return out def instance_norm_bis(x,mask): with tf.variable_scope("instance_norm"): epsilon = 1e-5 for i in range(x.shape[-1]): slice = tf.gather(x, i, axis=3) slice_mask = tf.gather(mask, i, axis=3) tmp = tf.boolean_mask(slice,slice_mask) mean, var = tf.nn.moments_bis(x, [1, 2], keep_dims=False) mean, var = tf.nn.moments_bis(x, [1, 2], keep_dims=True) scale = tf.get_variable('scale', [x.get_shape()[-1]], initializer=tf.truncated_normal_initializer( mean=1.0, stddev=0.02 )) offset = tf.get_variable( 'offset', [x.get_shape()[-1]], initializer=tf.constant_initializer(0.0) ) out = scale * tf.div(x - mean, tf.sqrt(var + epsilon)) + offset return out def general_conv2d_(inputconv, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02, padding="VALID", name="conv2d", do_norm=True, do_relu=True, relufactor=0): with tf.variable_scope(name): conv = tf.contrib.layers.conv2d( inputconv, o_d, f_w, s_w, padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer( stddev=stddev ), biases_initializer=tf.constant_initializer(0.0) ) if do_norm: conv = instance_norm(conv) if do_relu: if(relufactor == 0): conv = tf.nn.relu(conv, "relu") else: conv = lrelu(conv, relufactor, "lrelu") return conv def general_conv2d(inputconv, do_norm, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02, padding="VALID", name="conv2d", do_relu=True, relufactor=0): with tf.variable_scope(name): conv = tf.contrib.layers.conv2d( inputconv, o_d, f_w, s_w, padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer( stddev=stddev ), biases_initializer=tf.constant_initializer(0.0) ) conv = tf.cond(do_norm, lambda: instance_norm(conv), lambda: conv) if do_relu: if(relufactor == 0): conv = tf.nn.relu(conv, "relu") else: conv = lrelu(conv, relufactor, "lrelu") return conv def general_deconv2d(inputconv, outshape, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02, padding="VALID", name="deconv2d", do_norm=True, do_relu=True, relufactor=0): with tf.variable_scope(name): conv = tf.contrib.layers.conv2d_transpose( inputconv, o_d, [f_h, f_w], [s_h, s_w], padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev), biases_initializer=tf.constant_initializer(0.0) ) if do_norm: conv = instance_norm(conv) if do_relu: if(relufactor == 0): conv = tf.nn.relu(conv, "relu") else: conv = lrelu(conv, relufactor, "lrelu") return conv def upsamplingDeconv(inputconv, size, is_scale, method,align_corners, name): if len(inputconv.get_shape()) == 3: if is_scale: size_h = size[0] * int(inputconv.get_shape()[0]) size_w = size[1] * int(inputconv.get_shape()[1]) size = [int(size_h), int(size_w)] elif len(inputconv.get_shape()) == 4: if is_scale: size_h = size[0] * int(inputconv.get_shape()[1]) size_w = size[1] * int(inputconv.get_shape()[2]) size = [int(size_h), int(size_w)] else: raise Exception("Donot support shape %s" % inputconv.get_shape()) print(" [TL] UpSampling2dLayer %s: is_scale:%s size:%s method:%d align_corners:%s" % (name, is_scale, size, method, align_corners)) with tf.variable_scope(name) as vs: try: out = tf.image.resize_images(inputconv, size=size, method=method, align_corners=align_corners) except: # for TF 0.10 out = tf.image.resize_images(inputconv, new_height=size[0], new_width=size[1], method=method, align_corners=align_corners) return out def general_fc_layers(inpfc, outshape, name): with tf.variable_scope(name): fcw = tf.Variable(tf.truncated_normal(outshape, dtype=tf.float32, stddev=1e-1), name='weights') fcb = tf.Variable(tf.constant(1.0, shape=[outshape[-1]], dtype=tf.float32), trainable=True, name='biases') fcl = tf.nn.bias_add(tf.matmul(inpfc, fcw), fcb) fc_out = tf.nn.relu(fcl) return fc_out ================================================ FILE: losses.py ================================================ """Contains losses used for performing image-to-image domain adaptation.""" import tensorflow as tf def cycle_consistency_loss(real_images, generated_images): """Compute the cycle consistency loss. The cycle consistency loss is defined as the sum of the L1 distances between the real images from each domain and their generated (fake) counterparts. This definition is derived from Equation 2 in: Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks. Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros. Args: real_images: A batch of images from domain X, a `Tensor` of shape [batch_size, height, width, channels]. generated_images: A batch of generated images made to look like they came from domain X, a `Tensor` of shape [batch_size, height, width, channels]. Returns: The cycle consistency loss. """ return tf.reduce_mean(tf.abs(real_images - generated_images)) def mask_loss(gen_image, mask): return tf.reduce_mean(tf.abs(tf.multiply(gen_image,1-mask))) def lsgan_loss_generator(prob_fake_is_real): """Computes the LS-GAN loss as minimized by the generator. Rather than compute the negative loglikelihood, a least-squares loss is used to optimize the discriminators as per Equation 2 in: Least Squares Generative Adversarial Networks Xudong Mao, Qing Li, Haoran Xie, Raymond Y.K. Lau, Zhen Wang, and Stephen Paul Smolley. https://arxiv.org/pdf/1611.04076.pdf Args: prob_fake_is_real: The discriminator's estimate that generated images made to look like real images are real. Returns: The total LS-GAN loss. """ return tf.reduce_mean(tf.squared_difference(prob_fake_is_real, 1)) def lsgan_loss_discriminator(prob_real_is_real, prob_fake_is_real): """Computes the LS-GAN loss as minimized by the discriminator. Rather than compute the negative loglikelihood, a least-squares loss is used to optimize the discriminators as per Equation 2 in: Least Squares Generative Adversarial Networks Xudong Mao, Qing Li, Haoran Xie, Raymond Y.K. Lau, Zhen Wang, and Stephen Paul Smolley. https://arxiv.org/pdf/1611.04076.pdf Args: prob_real_is_real: The discriminator's estimate that images actually drawn from the real domain are in fact real. prob_fake_is_real: The discriminator's estimate that generated images made to look like real images are real. Returns: The total LS-GAN loss. """ return (tf.reduce_mean(tf.squared_difference(prob_real_is_real, 1)) + tf.reduce_mean(tf.squared_difference(prob_fake_is_real, 0))) * 0.5 ================================================ FILE: main.py ================================================ """Code for training CycleGAN.""" from datetime import datetime import json import numpy as np import os import random from scipy.misc import imsave import argparse import tensorflow as tf import cyclegan_datasets import data_loader, losses, model tf.set_random_seed(1) np.random.seed(0) slim = tf.contrib.slim class CycleGAN: """The CycleGAN module.""" def __init__(self, pool_size, lambda_a, lambda_b, output_root_dir, to_restore, base_lr, max_step, network_version, dataset_name, checkpoint_dir, do_flipping, skip, switch, threshold_fg): current_time = datetime.now().strftime("%Y%m%d-%H%M%S") self._pool_size = pool_size self._size_before_crop = 286 self._switch = switch self._threshold_fg = threshold_fg self._lambda_a = lambda_a self._lambda_b = lambda_b self._output_dir = os.path.join(output_root_dir, current_time + '_switch'+str(switch)+'_thres_'+str(threshold_fg)) self._images_dir = os.path.join(self._output_dir, 'imgs') self._num_imgs_to_save = 20 self._to_restore = to_restore self._base_lr = base_lr self._max_step = max_step self._network_version = network_version self._dataset_name = dataset_name self._checkpoint_dir = checkpoint_dir self._do_flipping = do_flipping self._skip = skip self.fake_images_A = [] self.fake_images_B = [] def model_setup(self): """ This function sets up the model to train. self.input_A/self.input_B -> Set of training images. self.fake_A/self.fake_B -> Generated images by corresponding generator of input_A and input_B self.lr -> Learning rate variable self.cyc_A/ self.cyc_B -> Images generated after feeding self.fake_A/self.fake_B to corresponding generator. This is use to calculate cyclic loss """ self.input_a = tf.placeholder( tf.float32, [ 1, model.IMG_WIDTH, model.IMG_HEIGHT, model.IMG_CHANNELS ], name="input_A") self.input_b = tf.placeholder( tf.float32, [ 1, model.IMG_WIDTH, model.IMG_HEIGHT, model.IMG_CHANNELS ], name="input_B") self.fake_pool_A = tf.placeholder( tf.float32, [ None, model.IMG_WIDTH, model.IMG_HEIGHT, model.IMG_CHANNELS ], name="fake_pool_A") self.fake_pool_B = tf.placeholder( tf.float32, [ None, model.IMG_WIDTH, model.IMG_HEIGHT, model.IMG_CHANNELS ], name="fake_pool_B") self.fake_pool_A_mask = tf.placeholder( tf.float32, [ None, model.IMG_WIDTH, model.IMG_HEIGHT, model.IMG_CHANNELS ], name="fake_pool_A_mask") self.fake_pool_B_mask = tf.placeholder( tf.float32, [ None, model.IMG_WIDTH, model.IMG_HEIGHT, model.IMG_CHANNELS ], name="fake_pool_B_mask") self.global_step = slim.get_or_create_global_step() self.num_fake_inputs = 0 self.learning_rate = tf.placeholder(tf.float32, shape=[], name="lr") self.transition_rate = tf.placeholder(tf.float32, shape=[], name="tr") self.donorm = tf.placeholder(tf.bool, shape=[], name="donorm") inputs = { 'images_a': self.input_a, 'images_b': self.input_b, 'fake_pool_a': self.fake_pool_A, 'fake_pool_b': self.fake_pool_B, 'fake_pool_a_mask': self.fake_pool_A_mask, 'fake_pool_b_mask': self.fake_pool_B_mask, 'transition_rate': self.transition_rate, 'donorm': self.donorm, } outputs = model.get_outputs( inputs, skip=self._skip) self.prob_real_a_is_real = outputs['prob_real_a_is_real'] self.prob_real_b_is_real = outputs['prob_real_b_is_real'] self.fake_images_a = outputs['fake_images_a'] self.fake_images_b = outputs['fake_images_b'] self.prob_fake_a_is_real = outputs['prob_fake_a_is_real'] self.prob_fake_b_is_real = outputs['prob_fake_b_is_real'] self.cycle_images_a = outputs['cycle_images_a'] self.cycle_images_b = outputs['cycle_images_b'] self.prob_fake_pool_a_is_real = outputs['prob_fake_pool_a_is_real'] self.prob_fake_pool_b_is_real = outputs['prob_fake_pool_b_is_real'] self.masks = outputs['masks'] self.masked_gen_ims = outputs['masked_gen_ims'] self.masked_ims = outputs['masked_ims'] self.masks_ = outputs['mask_tmp'] def compute_losses(self): """ In this function we are defining the variables for loss calculations and training model. d_loss_A/d_loss_B -> loss for discriminator A/B g_loss_A/g_loss_B -> loss for generator A/B *_trainer -> Various trainer for above loss functions *_summ -> Summary variables for above loss functions """ cycle_consistency_loss_a = \ self._lambda_a * losses.cycle_consistency_loss( real_images=self.input_a, generated_images=self.cycle_images_a, ) cycle_consistency_loss_b = \ self._lambda_b * losses.cycle_consistency_loss( real_images=self.input_b, generated_images=self.cycle_images_b, ) lsgan_loss_a = losses.lsgan_loss_generator(self.prob_fake_a_is_real) lsgan_loss_b = losses.lsgan_loss_generator(self.prob_fake_b_is_real) g_loss_A = \ cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b g_loss_B = \ cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a d_loss_A = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_real_a_is_real, prob_fake_is_real=self.prob_fake_pool_a_is_real, ) d_loss_B = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_real_b_is_real, prob_fake_is_real=self.prob_fake_pool_b_is_real, ) optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) self.model_vars = tf.trainable_variables() d_A_vars = [var for var in self.model_vars if 'd_A' in var.name] g_A_vars = [var for var in self.model_vars if 'g_A/' in var.name] d_B_vars = [var for var in self.model_vars if 'd_B' in var.name] g_B_vars = [var for var in self.model_vars if 'g_B/' in var.name] g_Ae_vars = [var for var in self.model_vars if 'g_A_ae' in var.name] g_Be_vars = [var for var in self.model_vars if 'g_B_ae' in var.name] self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars+g_Ae_vars) self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars+g_Be_vars) self.g_A_trainer_bis = optimizer.minimize(g_loss_A, var_list=g_A_vars) self.g_B_trainer_bis = optimizer.minimize(g_loss_B, var_list=g_B_vars) self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars) self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars) self.params_ae_c1 = g_A_vars[0] self.params_ae_c1_B = g_B_vars[0] for var in self.model_vars: print(var.name) # Summary variables for tensorboard self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A) self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B) self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A) self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B) def save_images(self, sess, epoch, curr_tr): """ Saves input and output images. :param sess: The session. :param epoch: Currnt epoch. """ if not os.path.exists(self._images_dir): os.makedirs(self._images_dir) if curr_tr >0: donorm = False else: donorm = True names = ['inputA_', 'inputB_', 'fakeA_', 'fakeB_', 'cycA_', 'cycB_', 'mask_a', 'mask_b'] with open(os.path.join( self._output_dir, 'epoch_' + str(epoch) + '.html' ), 'w') as v_html: for i in range(0, self._num_imgs_to_save): print("Saving image {}/{}".format(i, self._num_imgs_to_save)) inputs = sess.run(self.inputs) fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp, masks = sess.run([ self.fake_images_a, self.fake_images_b, self.cycle_images_a, self.cycle_images_b, self.masks, ], feed_dict={ self.input_a: inputs['images_i'], self.input_b: inputs['images_j'], self.transition_rate: curr_tr, self.donorm: donorm, }) tensors = [inputs['images_i'], inputs['images_j'], fake_B_temp, fake_A_temp, cyc_A_temp, cyc_B_temp, masks[0], masks[1]] for name, tensor in zip(names, tensors): image_name = name + str(epoch) + "_" + str(i) + ".jpg" if 'mask_' in name: imsave(os.path.join(self._images_dir, image_name), (np.squeeze(tensor[0])) ) else: imsave(os.path.join(self._images_dir, image_name), ((np.squeeze(tensor[0]) + 1) * 127.5).astype(np.uint8) ) v_html.write( "" ) v_html.write("
") def save_images_bis(self, sess, epoch): """ Saves input and output images. :param sess: The session. :param epoch: Currnt epoch. """ if not os.path.exists(self._images_dir): os.makedirs(self._images_dir) names = ['input_A_', 'mask_A_', 'masked_inputA_', 'fakeB_', 'input_B_', 'mask_B_', 'masked_inputB_', 'fakeA_'] space = '                        ' \ '                        ' \ '         ' with open(os.path.join(self._output_dir, 'results_' + str(epoch) + '.html'), 'w') as v_html: v_html.write("INPUT" + space + "MASK" + space + "MASKED_IMAGE" + space + "GENERATED_IMAGE") v_html.write("
") for i in range(0, self._num_imgs_to_save): print("Saving image {}/{}".format(i, self._num_imgs_to_save)) inputs = sess.run(self.inputs) fake_A_temp, fake_B_temp, masks, masked_ims = sess.run([ self.fake_images_a, self.fake_images_b, self.masks, self.masked_ims ], feed_dict={ self.input_a: inputs['images_i'], self.input_b: inputs['images_j'], self.transition_rate: 0.1 }) tensors = [inputs['images_i'], masks[0], masked_ims[0], fake_B_temp, inputs['images_j'], masks[1], masked_ims[1], fake_A_temp] for name, tensor in zip(names, tensors): image_name = name + str(i) + ".jpg" if 'mask_' in name: imsave(os.path.join(self._images_dir, image_name), (np.squeeze(tensor[0])) ) else: imsave(os.path.join(self._images_dir, image_name), ((np.squeeze(tensor[0]) + 1) * 127.5).astype(np.uint8) ) v_html.write( "" ) if 'fakeB_' in name: v_html.write("
") v_html.write("
") def fake_image_pool(self, num_fakes, fake, mask, fake_pool): """ This function saves the generated image to corresponding pool of images. It keeps on feeling the pool till it is full and then randomly selects an already stored image and replace it with new one. """ tmp = {} tmp['im'] = fake tmp['mask'] = mask if num_fakes < self._pool_size: fake_pool.append(tmp) return tmp else: p = random.random() if p > 0.5: random_id = random.randint(0, self._pool_size - 1) temp = fake_pool[random_id] fake_pool[random_id] = tmp return temp else: return tmp def train(self): """Training Function.""" # Load Dataset from the dataset folder self.inputs = data_loader.load_data( self._dataset_name, self._size_before_crop, False, self._do_flipping) # Build the network self.model_setup() # Loss function calculations self.compute_losses() # Initializing the global variables init = (tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver(max_to_keep=None) max_images = cyclegan_datasets.DATASET_TO_SIZES[self._dataset_name] half_training = int(self._max_step / 2) with tf.Session() as sess: sess.run(init) # Restore the model to run the model from last checkpoint if self._to_restore: chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir) saver.restore(sess, chkpt_fname) writer = tf.summary.FileWriter(self._output_dir) if not os.path.exists(self._output_dir): os.makedirs(self._output_dir) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) # Training Loop for epoch in range(sess.run(self.global_step), self._max_step): print("In the epoch ", epoch) saver.save(sess, os.path.join( self._output_dir, "AGGAN"), global_step=epoch) # Dealing with the learning rate as per the epoch number if epoch < half_training: curr_lr = self._base_lr else: curr_lr = self._base_lr - \ self._base_lr * (epoch - half_training) / half_training if epoch < self._switch: curr_tr = 0. donorm = True to_train_A = self.g_A_trainer to_train_B = self.g_B_trainer else: curr_tr = self._threshold_fg donorm = False to_train_A = self.g_A_trainer_bis to_train_B = self.g_B_trainer_bis self.save_images(sess, epoch, curr_tr) for i in range(0, max_images): print("Processing batch {}/{}".format(i, max_images)) inputs = sess.run(self.inputs) # Optimizing the G_A network _, fake_B_temp, smask_a,summary_str = sess.run( [to_train_A, self.fake_images_b, self.masks[0], self.g_A_loss_summ], feed_dict={ self.input_a: inputs['images_i'], self.input_b: inputs['images_j'], self.learning_rate: curr_lr, self.transition_rate: curr_tr, self.donorm: donorm, } ) writer.add_summary(summary_str, epoch * max_images + i) fake_B_temp1 = self.fake_image_pool( self.num_fake_inputs, fake_B_temp, smask_a, self.fake_images_B) # Optimizing the D_B network _,summary_str = sess.run( [self.d_B_trainer, self.d_B_loss_summ], feed_dict={ self.input_a: inputs['images_i'], self.input_b: inputs['images_j'], self.learning_rate: curr_lr, self.fake_pool_B: fake_B_temp1['im'], self.fake_pool_B_mask: fake_B_temp1['mask'], self.transition_rate: curr_tr, self.donorm: donorm, } ) writer.add_summary(summary_str, epoch * max_images + i) # Optimizing the G_B network _, fake_A_temp, smask_b, summary_str = sess.run( [to_train_B, self.fake_images_a, self.masks[1], self.g_B_loss_summ], feed_dict={ self.input_a: inputs['images_i'], self.input_b: inputs['images_j'], self.learning_rate: curr_lr, self.transition_rate: curr_tr, self.donorm: donorm, } ) writer.add_summary(summary_str, epoch * max_images + i) fake_A_temp1 = self.fake_image_pool( self.num_fake_inputs, fake_A_temp, smask_b ,self.fake_images_A) # Optimizing the D_A network _, mask_tmp__,summary_str = sess.run( [self.d_A_trainer,self.masks_, self.d_A_loss_summ], feed_dict={ self.input_a: inputs['images_i'], self.input_b: inputs['images_j'], self.learning_rate: curr_lr, self.fake_pool_A: fake_A_temp1['im'], self.fake_pool_A_mask: fake_A_temp1['mask'], self.transition_rate: curr_tr, self.donorm: donorm, } ) writer.add_summary(summary_str, epoch * max_images + i) writer.flush() self.num_fake_inputs += 1 sess.run(tf.assign(self.global_step, epoch + 1)) coord.request_stop() coord.join(threads) writer.add_graph(sess.graph) def test(self): """Test Function.""" print("Testing the results") self.inputs = data_loader.load_data( self._dataset_name, self._size_before_crop, False, self._do_flipping) self.model_setup() saver = tf.train.Saver() init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir) saver.restore(sess, chkpt_fname) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) self._num_imgs_to_save = cyclegan_datasets.DATASET_TO_SIZES[ self._dataset_name] self.save_images_bis(sess, sess.run(self.global_step)) coord.request_stop() coord.join(threads) def parse_args(): desc = "Tensorflow implementation of cycleGan using attention" parser = argparse.ArgumentParser(description=desc) parser.add_argument('--to_train', type=int, default=True, help='Whether it is train or false.') parser.add_argument('--log_dir', type=str, default=None, help='Where the data is logged to.') parser.add_argument('--config_filename', type=str, default='train', help='The name of the configuration file.') parser.add_argument('--checkpoint_dir', type=str, default='', help='The name of the train/test split.') parser.add_argument('--skip', type=bool, default=False, help='Whether to add skip connection between input and output.') parser.add_argument('--switch', type=int, default=30, help='In what epoch the FG starts to be fed to the discriminator') parser.add_argument('--threshold', type=float, default=0.1, help='The threshold value to select the FG') return parser.parse_args() def main(): """ :param to_train: Specify whether it is training or testing. 1: training; 2: resuming from latest checkpoint; 0: testing. :param log_dir: The root dir to save checkpoints and imgs. The actual dir is the root dir appended by the folder with the name timestamp. :param config_filename: The configuration file. :param checkpoint_dir: The directory that saves the latest checkpoint. It only takes effect when to_train == 2. :param skip: A boolean indicating whether to add skip connection between input and output. """ args = parse_args() if args is None: exit() to_train = args.to_train log_dir = args.log_dir config_filename = args.config_filename checkpoint_dir = args.checkpoint_dir skip = args.skip switch = args.switch threshold_fg = args.threshold if not os.path.isdir(log_dir): os.makedirs(log_dir) with open(config_filename) as config_file: config = json.load(config_file) lambda_a = float(config['_LAMBDA_A']) if '_LAMBDA_A' in config else 10.0 lambda_b = float(config['_LAMBDA_B']) if '_LAMBDA_B' in config else 10.0 pool_size = int(config['pool_size']) if 'pool_size' in config else 50 to_restore = (to_train == 2) base_lr = float(config['base_lr']) if 'base_lr' in config else 0.0002 max_step = int(config['max_step']) if 'max_step' in config else 200 network_version = str(config['network_version']) dataset_name = str(config['dataset_name']) do_flipping = bool(config['do_flipping']) cyclegan_model = CycleGAN(pool_size, lambda_a, lambda_b, log_dir, to_restore, base_lr, max_step, network_version, dataset_name, checkpoint_dir, do_flipping, skip, switch, threshold_fg) if to_train > 0: cyclegan_model.train() else: cyclegan_model.test() if __name__ == '__main__': main() ================================================ FILE: model.py ================================================ """Code for constructing the model and get the outputs from the model.""" import tensorflow as tf import numpy as np import layers # The number of samples per batch. BATCH_SIZE = 1 # The height of each image. IMG_HEIGHT = 256 # The width of each image. IMG_WIDTH = 256 # The number of color channels per image. IMG_CHANNELS = 3 POOL_SIZE = 50 ngf = 32 ndf = 64 def get_outputs(inputs, skip=False): images_a = inputs['images_a'] images_b = inputs['images_b'] fake_pool_a = inputs['fake_pool_a'] fake_pool_b = inputs['fake_pool_b'] fake_pool_a_mask = inputs['fake_pool_a_mask'] fake_pool_b_mask = inputs['fake_pool_b_mask'] transition_rate = inputs['transition_rate'] donorm = inputs['donorm'] with tf.variable_scope("Model") as scope: current_autoenc = autoenc_upsample current_discriminator = discriminator current_generator = build_generator_resnet_9blocks mask_a = current_autoenc(images_a, "g_A_ae") mask_b = current_autoenc(images_b, "g_B_ae") mask_a = tf.concat([mask_a] * 3, axis=3) mask_b = tf.concat([mask_b] * 3, axis=3) mask_a_on_a = tf.multiply(images_a, mask_a) mask_b_on_b = tf.multiply(images_b, mask_b) prob_real_a_is_real = current_discriminator(images_a, mask_a, transition_rate, donorm, "d_A") prob_real_b_is_real = current_discriminator(images_b, mask_b, transition_rate, donorm, "d_B") fake_images_b_from_g = current_generator(images_a, name="g_A", skip=skip) fake_images_b = tf.multiply(fake_images_b_from_g, mask_a) + tf.multiply(images_a, 1-mask_a) fake_images_a_from_g = current_generator(images_b, name="g_B", skip=skip) fake_images_a = tf.multiply(fake_images_a_from_g, mask_b) + tf.multiply(images_b, 1-mask_b) scope.reuse_variables() prob_fake_a_is_real = current_discriminator(fake_images_a, mask_b, transition_rate, donorm, "d_A") prob_fake_b_is_real = current_discriminator(fake_images_b, mask_a, transition_rate, donorm, "d_B") mask_acycle = current_autoenc(fake_images_a, "g_A_ae") mask_bcycle = current_autoenc(fake_images_b, "g_B_ae") mask_bcycle = tf.concat([mask_bcycle] * 3, axis=3) mask_acycle = tf.concat([mask_acycle] * 3, axis=3) mask_acycle_on_fakeA = tf.multiply(fake_images_a, mask_acycle) mask_bcycle_on_fakeB = tf.multiply(fake_images_b, mask_bcycle) cycle_images_a_from_g = current_generator(fake_images_b, name="g_B", skip=skip) cycle_images_b_from_g = current_generator(fake_images_a, name="g_A", skip=skip) cycle_images_a = tf.multiply(cycle_images_a_from_g, mask_bcycle) + tf.multiply(fake_images_b, 1 - mask_bcycle) cycle_images_b = tf.multiply(cycle_images_b_from_g, mask_acycle) + tf.multiply(fake_images_a, 1 - mask_acycle) scope.reuse_variables() prob_fake_pool_a_is_real = current_discriminator(fake_pool_a, fake_pool_a_mask, transition_rate, donorm, "d_A") prob_fake_pool_b_is_real = current_discriminator(fake_pool_b, fake_pool_b_mask, transition_rate, donorm, "d_B") return { 'prob_real_a_is_real': prob_real_a_is_real, 'prob_real_b_is_real': prob_real_b_is_real, 'prob_fake_a_is_real': prob_fake_a_is_real, 'prob_fake_b_is_real': prob_fake_b_is_real, 'prob_fake_pool_a_is_real': prob_fake_pool_a_is_real, 'prob_fake_pool_b_is_real': prob_fake_pool_b_is_real, 'cycle_images_a': cycle_images_a, 'cycle_images_b': cycle_images_b, 'fake_images_a': fake_images_a, 'fake_images_b': fake_images_b, 'masked_ims': [mask_a_on_a, mask_b_on_b, mask_acycle_on_fakeA, mask_bcycle_on_fakeB], 'masks': [mask_a, mask_b, mask_acycle, mask_bcycle], 'masked_gen_ims' : [fake_images_b_from_g, fake_images_a_from_g , cycle_images_a_from_g, cycle_images_b_from_g], 'mask_tmp' : mask_a, } def autoenc_upsample(inputae, name): with tf.variable_scope(name): f = 7 ks = 3 padding = "REFLECT" pad_input = tf.pad(inputae, [[0, 0], [ks, ks], [ ks, ks], [0, 0]], padding) o_c1 = layers.general_conv2d( pad_input, tf.constant(True, dtype=bool), ngf, f, f, 2, 2, 0.02, name="c1") o_c2 = layers.general_conv2d( o_c1, tf.constant(True, dtype=bool), ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c2") o_r1 = build_resnet_block_Att(o_c2, ngf * 2, "r1", padding) size_d1 = o_r1.get_shape().as_list() o_c4 = layers.upsamplingDeconv(o_r1, size=[size_d1[1] * 2, size_d1[2] * 2], is_scale=False, method=1, align_corners=False,name= 'up1') # o_c4_pad = tf.pad(o_c4, [[0, 0], [1, 1], [1, 1], [0, 0]], "REFLECT", name='padup1') o_c4_end = layers.general_conv2d(o_c4, tf.constant(True, dtype=bool), ngf * 2, (3, 3), (1, 1), padding='VALID', name='c4') size_d2 = o_c4_end.get_shape().as_list() o_c5 = layers.upsamplingDeconv(o_c4_end, size=[size_d2[1] * 2, size_d2[2] * 2], is_scale=False, method=1, align_corners=False, name='up2') # o_c5_pad = tf.pad(o_c5, [[0, 0], [1, 1], [1, 1], [0, 0]], "REFLECT", name='padup2') oc5_end = layers.general_conv2d(o_c5, tf.constant(True, dtype=bool), ngf , (3, 3), (1, 1), padding='VALID', name='c5') # o_c6 = tf.pad(oc5_end, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT", name='padup3') o_c6_end = layers.general_conv2d(oc5_end, tf.constant(False, dtype=bool), 1 , (f, f), (1, 1), padding='VALID', name='c6', do_relu=False) return tf.nn.sigmoid(o_c6_end,'sigmoid') def build_resnet_block(inputres, dim, name="resnet", padding="REFLECT"): """build a single block of resnet. :param inputres: inputres :param dim: dim :param name: name :param padding: for tensorflow version use REFLECT; for pytorch version use CONSTANT :return: a single block of resnet. """ with tf.variable_scope(name): out_res = tf.pad(inputres, [[0, 0], [1, 1], [ 1, 1], [0, 0]], padding) out_res = layers.general_conv2d( out_res, tf.constant(True, dtype=bool), dim, 3, 3, 1, 1, 0.02, "VALID", "c1") out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], padding) out_res = layers.general_conv2d( out_res, tf.constant(True, dtype=bool), dim, 3, 3, 1, 1, 0.02, "VALID", "c2", do_relu=False) return tf.nn.relu(out_res + inputres) def build_resnet_block_Att(inputres, dim, name="resnet", padding="REFLECT"): """build a single block of resnet. :param inputres: inputres :param dim: dim :param name: name :param padding: for tensorflow version use REFLECT; for pytorch version use CONSTANT :return: a single block of resnet. """ with tf.variable_scope(name): out_res = tf.pad(inputres, [[0, 0], [1, 1], [ 1, 1], [0, 0]], padding) out_res = layers.general_conv2d( out_res, tf.constant(True, dtype=bool), dim, 3, 3, 1, 1, 0.02, "VALID", "c1") out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], padding) out_res = layers.general_conv2d( out_res, tf.constant(True, dtype=bool), dim, 3, 3, 1, 1, 0.02, "VALID", "c2", do_relu=False) return tf.nn.relu(out_res + inputres) def build_generator_resnet_9blocks(inputgen, name="generator", skip=False): with tf.variable_scope(name): f = 7 ks = 3 padding = "CONSTANT" inputgen = tf.pad(inputgen, [[0, 0], [ks, ks], [ ks, ks], [0, 0]], padding) o_c1 = layers.general_conv2d( inputgen, tf.constant(True, dtype=bool), ngf, f, f, 1, 1, 0.02, name="c1") o_c2 = layers.general_conv2d( o_c1, tf.constant(True, dtype=bool),ngf * 2, ks, ks, 2, 2, 0.02, padding='same', name="c2") o_c3 = layers.general_conv2d( o_c2, tf.constant(True, dtype=bool), ngf * 4, ks, ks, 2, 2, 0.02, padding='same', name="c3") o_r1 = build_resnet_block(o_c3, ngf * 4, "r1", padding) o_r2 = build_resnet_block(o_r1, ngf * 4, "r2", padding) o_r3 = build_resnet_block(o_r2, ngf * 4, "r3", padding) o_r4 = build_resnet_block(o_r3, ngf * 4, "r4", padding) o_r5 = build_resnet_block(o_r4, ngf * 4, "r5", padding) o_r6 = build_resnet_block(o_r5, ngf * 4, "r6", padding) o_r7 = build_resnet_block(o_r6, ngf * 4, "r7", padding) o_r8 = build_resnet_block(o_r7, ngf * 4, "r8", padding) o_r9 = build_resnet_block(o_r8, ngf * 4, "r9", padding) o_c4 = layers.general_deconv2d( o_r9, [BATCH_SIZE, 128, 128, ngf * 2], ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c4") o_c5 = layers.general_deconv2d( o_c4, [BATCH_SIZE, 256, 256, ngf], ngf, ks, ks, 2, 2, 0.02, "SAME", "c5") o_c6 = layers.general_conv2d(o_c5, tf.constant(False, dtype=bool), IMG_CHANNELS, f, f, 1, 1, 0.02, "SAME", "c6", do_relu=False) if skip is True: out_gen = tf.nn.tanh(inputgen + o_c6, "t1") else: out_gen = tf.nn.tanh(o_c6, "t1") return out_gen def discriminator(inputdisc, mask, transition_rate, donorm, name="discriminator"): with tf.variable_scope(name): mask = tf.cast(tf.greater_equal(mask, transition_rate), tf.float32) inputdisc = tf.multiply(inputdisc, mask) f = 4 padw = 2 pad_input = tf.pad(inputdisc, [[0, 0], [padw, padw], [ padw, padw], [0, 0]], "CONSTANT") o_c1 = layers.general_conv2d(pad_input, donorm, ndf, f, f, 2, 2, 0.02, "VALID", "c1", relufactor=0.2) pad_o_c1 = tf.pad(o_c1, [[0, 0], [padw, padw], [ padw, padw], [0, 0]], "CONSTANT") o_c2 = layers.general_conv2d(pad_o_c1, donorm, ndf * 2, f, f, 2, 2, 0.02, "VALID", "c2", relufactor=0.2) pad_o_c2 = tf.pad(o_c2, [[0, 0], [padw, padw], [ padw, padw], [0, 0]], "CONSTANT") o_c3 = layers.general_conv2d(pad_o_c2, donorm, ndf * 4, f, f, 2, 2, 0.02, "VALID", "c3", relufactor=0.2) pad_o_c3 = tf.pad(o_c3, [[0, 0], [padw, padw], [ padw, padw], [0, 0]], "CONSTANT") o_c4 = layers.general_conv2d(pad_o_c3, donorm, ndf * 8, f, f, 1, 1, 0.02, "VALID", "c4", relufactor=0.2) # o_c4 = tf.multiply(o_c4, mask_4) pad_o_c4 = tf.pad(o_c4, [[0, 0], [padw, padw], [ padw, padw], [0, 0]], "CONSTANT") o_c5 = layers.general_conv2d( pad_o_c4, tf.constant(False, dtype=bool), 1, f, f, 1, 1, 0.02, "VALID", "c5", do_relu=False) return o_c5 ================================================ FILE: test/__init__.py ================================================ ================================================ FILE: test/evaluate_losses.py ================================================ import numpy as np import tensorflow as tf from .. import losses def test_evaluate_g_losses(sess): _LAMBDA_A = 10 _LAMBDA_B = 10 input_a = tf.random_uniform((5, 7), maxval=1) cycle_images_a = input_a + 1 input_b = tf.random_uniform((5, 7), maxval=1) cycle_images_b = input_b - 2 cycle_consistency_loss_a = _LAMBDA_A * losses.cycle_consistency_loss( real_images=input_a, generated_images=cycle_images_a, ) cycle_consistency_loss_b = _LAMBDA_B * losses.cycle_consistency_loss( real_images=input_b, generated_images=cycle_images_b, ) prob_fake_a_is_real = tf.constant([0, 1.0, 0]) prob_fake_b_is_real = tf.constant([1.0, 1.0, 0]) lsgan_loss_a = losses.lsgan_loss_generator(prob_fake_a_is_real) lsgan_loss_b = losses.lsgan_loss_generator(prob_fake_b_is_real) assert np.isclose(sess.run(lsgan_loss_a), 0.66666669) and \ np.isclose(sess.run(lsgan_loss_b), 0.3333333) and \ np.isclose(sess.run(cycle_consistency_loss_a), 10) and \ np.isclose(sess.run(cycle_consistency_loss_b), 20) def test_evaluate_d_losses(sess): prob_real_a_is_real = tf.constant([1.0, 1.0, 0]) prob_fake_pool_a_is_real = tf.constant([1.0, 0, 0]) d_loss_A = losses.lsgan_loss_discriminator( prob_real_is_real=prob_real_a_is_real, prob_fake_is_real=prob_fake_pool_a_is_real) assert np.isclose(sess.run(d_loss_A), 0.3333333) ================================================ FILE: test/evaluate_networks.py ================================================ import numpy as np import tensorflow as tf from .. import model def test_evaluate_g(sess): x_val = np.ones_like(np.random.randn(1, 16, 16, 3)).astype(np.float32) for i in range(16): for j in range(16): for k in range(3): x_val[0][i][j][k] = ((i + j + k) % 2) / 2.0 inputs = { 'images_a': tf.stack(x_val), 'images_b': tf.stack(x_val), 'fake_pool_a': tf.zeros([1, 16, 16, 3]), 'fake_pool_b': tf.zeros([1, 16, 16, 3]), } outputs = model.get_outputs(inputs) sess.run(tf.global_variables_initializer()) assert sess.run(outputs['fake_images_a'][0][5][7][0]) == 5 def test_evaluate_d(sess): x_val = np.ones_like(np.random.randn(1, 16, 16, 3)).astype(np.float32) for i in range(16): for j in range(16): for k in range(3): x_val[0][i][j][k] = ((i + j + k) % 2) / 2.0 inputs = { 'images_a': tf.stack(x_val), 'images_b': tf.stack(x_val), 'fake_pool_a': tf.zeros([1, 16, 16, 3]), 'fake_pool_b': tf.zeros([1, 16, 16, 3]), } outputs = model.get_outputs(inputs) sess.run(tf.global_variables_initializer()) assert sess.run(outputs['prob_real_a_is_real'][0][3][3][0]) == 5 ================================================ FILE: test/test_losses.py ================================================ import numpy as np import tensorflow as tf from .. import losses def test_cycle_consistency_loss_is_none_with_perfect_fakes(sess): batch_size, height, width, channels = [16, 2, 3, 1] tf.set_random_seed(0) images = tf.random_uniform((batch_size, height, width, channels), maxval=1) loss = losses.cycle_consistency_loss( real_images=images, generated_images=images, ) assert sess.run(loss) == 0 def test_cycle_consistency_loss_is_positive_with_imperfect_fake_x(sess): batch_size, height, width, channels = [16, 2, 3, 1] tf.set_random_seed(0) real_images = tf.random_uniform( (batch_size, height, width, channels), maxval=1, ) generated_images = real_images + 1 loss = losses.cycle_consistency_loss( real_images=real_images, generated_images=generated_images, ) assert sess.run(loss) == 1 def test_lsgan_loss_discrim_is_none_with_perfect_discrimination(sess): batch_size = 100 prob_real_is_real = tf.ones((batch_size)) prob_fake_is_real = tf.zeros((batch_size)) loss = losses.lsgan_loss_discriminator( prob_real_is_real, prob_fake_is_real, ) assert sess.run(loss) == 0 def test_lsgan_loss_discrim_is_positive_with_imperfect_discrimination(sess): batch_size = 100 prob_real_is_real = tf.ones((batch_size)) * 0.4 prob_fake_is_real = tf.ones((batch_size)) * 0.7 loss = losses.lsgan_loss_discriminator( prob_real_is_real, prob_fake_is_real, ) loss = sess.run(loss) np.testing.assert_almost_equal(loss, (0.6 * 0.6 + 0.7 * 0.7) / 2) ================================================ FILE: test/test_model.py ================================================ import tensorflow as tf from dl_research.testing import slow from .. import model # ----------------------------------------------------------------------------- @slow def test_output_sizes(sess): images_size = [ model.BATCH_SIZE, model.IMG_HEIGHT, model.IMG_WIDTH, model.IMG_CHANNELS, ] pool_size = [ model.POOL_SIZE, model.IMG_HEIGHT, model.IMG_WIDTH, model.IMG_CHANNELS, ] inputs = { 'images_a': tf.ones(images_size), 'images_b': tf.ones(images_size), 'fake_pool_a': tf.ones(pool_size), 'fake_pool_b': tf.ones(pool_size), } outputs = model.get_outputs(inputs) assert outputs['prob_real_a_is_real'].get_shape().as_list() == [ model.BATCH_SIZE, 32, 32, 1, ] assert outputs['prob_real_b_is_real'].get_shape().as_list() == [ model.BATCH_SIZE, 32, 32, 1, ] assert outputs['prob_fake_a_is_real'].get_shape().as_list() == [ model.BATCH_SIZE, 32, 32, 1, ] assert outputs['prob_fake_b_is_real'].get_shape().as_list() == [ model.BATCH_SIZE, 32, 32, 1, ] assert outputs['prob_fake_pool_a_is_real'].get_shape().as_list() == [ model.POOL_SIZE, 32, 32, 1, ] assert outputs['prob_fake_pool_b_is_real'].get_shape().as_list() == [ model.POOL_SIZE, 32, 32, 1, ] assert outputs['cycle_images_a'].get_shape().as_list() == images_size assert outputs['cycle_images_b'].get_shape().as_list() == images_size