[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2017 Harry Yang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Unsupervised Attention-guided Image-to-Image Translation\n\nThis repository contains the TensorFlow code for our NeurIPS 2018 paper [“Unsupervised Attention-guided Image-to-Image Translation”](https://arxiv.org/pdf/1806.02311.pdf). This code is based on the TensorFlow implementation of CycleGAN provided by [Harry Yang](https://github.com/leehomyc/cyclegan-1). You may need to train several times as the quality of the results are sensitive to the initialization.\n\nBy leveraging attention, our architecture (shown in the figure bellow) only maps relevant areas of the image, and by doing so, further enhances the quality of image to image translation.\n\nOur model architecture is defined as depicted below, please refer to the paper for more details: \n<img src='imgs/AGGANDiagram.jpg' width=\"900px\"/>\n\n## Mapping results\n### Our learned attention maps\n\nThe figure bellow displays automatically learned attention maps on various translation datasets:  \n<img src='imgs/attentionMaps.jpg' width=\"900px\"/>\n\n### Horse-to-Zebra image translation results: \n#### Horse-to-Zebra:\nTop row in the figure below are input images and bottom row are the mappings produced by our algorithm.\n<img src='imgs/HtZ.jpg' width=\"900px\"/>\n#### Zebra-to-Horse:\nTop row in the figure below are input images and bottom row are the mappings produced by our algorithm.\n<img src='imgs/ZtH.jpg' width=\"900px\"/>\n\n### Apple-to-Orange image translation results: \n#### Apple-to-Orange:\nTop row in the figure below are input images and bottom row are the mappings produced by our algorithm.\n<img src='imgs/AtO.jpg' width=\"900px\"/>\n#### Orange-to-Apple:\nTop row in the figure below are input images and bottom row are the mappings produced by our algorithm.\n<img src='imgs/OtA.jpg' width=\"900px\"/>\n\n### Getting Started with the code\n### Prepare dataset\n* You can either download one of the defaults CycleGAN datasets or use your own dataset. \n\t* Download a CycleGAN dataset (e.g. horse2zebra, apple2orange):\n\t\t```bash\n\t\tbash ./download_datasets.sh horse2zebra\n\t\t```\n\t* Use your own dataset: put images from each domain at folder_a and folder_b respectively. \n\n* Create the csv file as input to the data loader. \n\t* Edit the [```cyclegan_datasets.py```](cyclegan_datasets.py) file. For example, if you have a horse2zebra_train dataset which contains 1067 horse images and 1334 zebra images (both in JPG format), you can just edit the [```cyclegan_datasets.py```](cyclegan_datasets.py) as following:\n\t\t```python\n\t\tDATASET_TO_SIZES = {\n\t\t  'horse2zebra_train': 1334\n\t\t}\n\n\t\tPATH_TO_CSV = {\n\t\t  'horse2zebra_train': './AGGAN/input/horse2zebra/horse2zebra_train.csv'\n\t\t}\n\n\t\tDATASET_TO_IMAGETYPE = {\n\t\t  'horse2zebra_train': '.jpg'\n\t\t}\n\t\t``` \n\t* Run create_cyclegan_dataset.py:\n\t\t```bash\n\t\tpython -m create_cyclegan_dataset --image_path_a='./input/horse2zebra/trainB' --image_path_b='./input/horse2zebra/trainA'  --dataset_name=\"horse2zebra_train\" --do_shuffle=0\n\t\t```\n### Training\n* Create the configuration file. The configuration file contains basic information for training/testing. An example of the configuration file could be found at [```configs/exp_01.json```](configs/exp_01.json).\n\n* Start training:\n\t```bash\n\tpython main.py  --to_train=1 --log_dir=./output/AGGAN/exp_01 --config_filename=./configs/exp_01.json\n\t```\n* Check the intermediate results:\n\t* Tensorboard\n\t\t```bash\n\t\ttensorboard --port=6006 --logdir=./output/AGGAN/exp_01/#timestamp# \n\t\t```\n\t* Check the html visualization at ./output/AGGAN/exp_01/#timestamp#/epoch_#id#.html.  \n\n### Restoring from the previous checkpoint\n```bash\npython main.py --to_train=2 --log_dir=./output/AGGAN/exp_01 --config_filename=./configs/exp_01.json --checkpoint_dir=./output/AGGAN/exp_01/#timestamp#\n```\n\n### Testing\n* Create the testing dataset:\n\t* Edit the cyclegan_datasets.py file the same way as training.\n\t* Create the csv file as the input to the data loader:\n\t\t```bash\n\t\tpython -m create_cyclegan_dataset --image_path_a='./input/horse2zebra/testB' --image_path_b='./input/horse2zebra/testA' --dataset_name=\"horse2zebra_test\" --do_shuffle=0\n\t\t```\n* Run testing:\n\t```bash\n\tpython main.py --to_train=0 --log_dir=./output/AGGAN/exp_01 --config_filename=./configs/exp_01_test.json --checkpoint_dir=./output/AGGAN/exp_01/#old_timestamp# \n\t```\n* Trained models:\nOur trained models can be downloaded from https://drive.google.com/open?id=1YEQMJK41KQj_-HfKFneSI12nWpTajgzT\n"
  },
  {
    "path": "Trained_models/README.md",
    "content": "#### Trained models.\nOur trained models can be downloaded from https://drive.google.com/open?id=1YEQMJK41KQj_-HfKFneSI12nWpTajgzT\n\nWhen using the trained parameters of horse to zebra image translation. Note that in our case the Source (input_a) is zebra and the target (input_b) is zebra and not the opposite. \n"
  },
  {
    "path": "__init__.py",
    "content": ""
  },
  {
    "path": "configs/exp_01.json",
    "content": "{\n  \"description\": \"The official PyTorch version of CycleGAN.\", \n  \"pool_size\": 50,\n  \"base_lr\":0.0001,\n  \"max_step\": 100,\n  \"network_version\": \"pytorch\",\n  \"dataset_name\": \"horse2zebra_train\",\n  \"do_flipping\": 1,\n  \"_LAMBDA_A\": 10,\n  \"_LAMBDA_B\": 10\n}"
  },
  {
    "path": "configs/exp_01_test.json",
    "content": "{\n  \"description\": \"Testing with trained model.\",\n  \"network_version\": \"pytorch\",\n  \"dataset_name\": \"horse2zebra_test\",\n  \"do_flipping\": 0\n}\n"
  },
  {
    "path": "configs/exp_02.json",
    "content": "{\n  \"description\": \"The official PyTorch version of CycleGAN.\", \n  \"pool_size\": 50,\n  \"base_lr\":0.0001,\n  \"max_step\": 100,\n  \"network_version\": \"pytorch\",\n  \"dataset_name\": \"apple2orange_train\",\n  \"do_flipping\": 1,\n  \"_LAMBDA_A\": 10,\n  \"_LAMBDA_B\": 10\n}\n"
  },
  {
    "path": "configs/exp_02_test.json",
    "content": "{\n  \"description\": \"Testing with trained model.\",\n  \"network_version\": \"pytorch\",\n  \"dataset_name\": \"apple2orange_test\",\n  \"do_flipping\": 0\n}"
  },
  {
    "path": "configs/exp_04.json",
    "content": "{\n  \"description\": \"The official PyTorch version of CycleGAN.\", \n  \"pool_size\": 50,\n  \"base_lr\":0.0001,\n  \"max_step\": 100,\n  \"network_version\": \"pytorch\",\n  \"dataset_name\": \"lion2tiger_train\",\n  \"do_flipping\": 1,\n  \"_LAMBDA_A\": 10,\n  \"_LAMBDA_B\": 10\n}\n"
  },
  {
    "path": "configs/exp_04_test.json",
    "content": "{\n  \"description\": \"Testing with trained model.\",\n  \"network_version\": \"pytorch\",\n  \"dataset_name\": \"lion2tiger_test\",\n  \"do_flipping\": 0\n}\n"
  },
  {
    "path": "configs/exp_05.json",
    "content": "{\n  \"description\": \"The official PyTorch version of CycleGAN.\", \n  \"pool_size\": 50,\n  \"base_lr\":0.0002,\n  \"max_step\": 200,\n  \"network_version\": \"pytorch\",\n  \"dataset_name\": \"summer2winter_yosemite_train\",\n  \"do_flipping\": 1,\n  \"_LAMBDA_A\": 10,\n  \"_LAMBDA_B\": 10\n}\n"
  },
  {
    "path": "configs/exp_05_test.json",
    "content": "{\n  \"description\": \"Testing with trained model.\",\n  \"network_version\": \"pytorch\",\n  \"dataset_name\": \"summer2winter_yosemite_test\",\n  \"do_flipping\": 0\n}"
  },
  {
    "path": "create_cyclegan_dataset.py",
    "content": "\"\"\"Create datasets for training and testing.\"\"\"\nimport csv\nimport os\nimport random\n\nimport click\n\nimport cyclegan_datasets\n\n\ndef create_list(foldername, fulldir=True, suffix=\".jpg\"):\n    \"\"\"\n\n    :param foldername: The full path of the folder.\n    :param fulldir: Whether to return the full path or not.\n    :param suffix: Filter by suffix.\n\n    :return: The list of filenames in the folder with given suffix.\n\n    \"\"\"\n    file_list_tmp = os.listdir(foldername)\n    file_list = []\n    if fulldir:\n        for item in file_list_tmp:\n            if item.endswith(suffix):\n                file_list.append(os.path.join(foldername, item))\n    else:\n        for item in file_list_tmp:\n            if item.endswith(suffix):\n                file_list.append(item)\n    return file_list\n\n\n@click.command()\n@click.option('--image_path_a',\n              type=click.STRING,\n              default='./input/horse2zebra/trainA',\n              help='The path to the images from domain_a.')\n@click.option('--image_path_b',\n              type=click.STRING,\n              default='./input/horse2zebra/trainB',\n              help='The path to the images from domain_b.')\n@click.option('--dataset_name',\n              type=click.STRING,\n              default='horse2zebra_train',\n              help='The name of the dataset in cyclegan_dataset.')\n@click.option('--do_shuffle',\n              type=click.BOOL,\n              default=False,\n              help='Whether to shuffle images when creating the dataset.')\ndef create_dataset(image_path_a, image_path_b,\n                   dataset_name, do_shuffle):\n    list_a = create_list(image_path_a, True,\n                         cyclegan_datasets.DATASET_TO_IMAGETYPE[dataset_name])\n    list_b = create_list(image_path_b, True,\n                         cyclegan_datasets.DATASET_TO_IMAGETYPE[dataset_name])\n\n    output_path = cyclegan_datasets.PATH_TO_CSV[dataset_name]\n\n    num_rows = cyclegan_datasets.DATASET_TO_SIZES[dataset_name]\n    all_data_tuples = []\n    for i in range(num_rows):\n        all_data_tuples.append((\n            list_a[i % len(list_a)],\n            list_b[i % len(list_b)]\n        ))\n    if do_shuffle is True:\n        random.shuffle(all_data_tuples)\n    with open(output_path, 'w') as csv_file:\n        csv_writer = csv.writer(csv_file)\n        for data_tuple in enumerate(all_data_tuples):\n            csv_writer.writerow(list(data_tuple[1]))\n\n\nif __name__ == '__main__':\n    create_dataset()\n"
  },
  {
    "path": "cyclegan_datasets.py",
    "content": "\"\"\"Contains the standard train/test splits for the cyclegan data.\"\"\"\n\n\"\"\"The size of each dataset. Usually it is the maximum number of images from\neach domain.\"\"\"\nDATASET_TO_SIZES = {\n    'horse2zebra_train': 1334,\n    'horse2zebra_test': 140,\n    'apple2orange_train': 1019,\n    'apple2orange_test': 266,\n    'lion2tiger_train': 916,\n    'lion2tiger_test': 103,\n    'summer2winter_yosemite_train': 1231,\n    'summer2winter_yosemite_test': 309,\n}\n\n\"\"\"The image types of each dataset. Currently only supports .jpg or .png\"\"\"\nDATASET_TO_IMAGETYPE = {\n    'horse2zebra_train': '.jpg',\n    'horse2zebra_test': '.jpg',\n    'apple2orange_train': '.jpg',\n    'apple2orange_test': '.jpg',\n    'lion2tiger_train': '.jpg',\n    'lion2tiger_test': '.jpg',\n    'summer2winter_yosemite_train': '.jpg',\n    'summer2winter_yosemite_test': '.jpg',\n}\n\n\"\"\"The path to the output csv file.\"\"\"\nPATH_TO_CSV = {\n    'horse2zebra_train': './input/horse2zebra/horse2zebra_train.csv',\n    'horse2zebra_test': './input/horse2zebra/horse2zebra_test.csv',\n    'apple2orange_train': './input/apple2orange/apple2orange_train.csv',\n    'apple2orange_test': './input/apple2orange/apple2orange_test.csv',\n    'lion2tiger_train': './input/lion2tiger/lion2tiger_train.csv',\n    'lion2tiger_test': './input/lion2tiger/lion2tiger_test.csv',\n    'summer2winter_yosemite_train': './input/summer2winter_yosemite/summer2winter_yosemite_train.csv',\n    'summer2winter_yosemite_test': './input/summer2winter_yosemite/summer2winter_yosemite_test.csv'\n}\n"
  },
  {
    "path": "data_loader.py",
    "content": "import tensorflow as tf\n\nimport cyclegan_datasets\nimport model\n\n\ndef _load_samples(csv_name, image_type):\n    filename_queue = tf.train.string_input_producer(\n        [csv_name])\n\n    reader = tf.TextLineReader()\n    _, csv_filename = reader.read(filename_queue)\n\n    record_defaults = [tf.constant([], dtype=tf.string),\n                       tf.constant([], dtype=tf.string)]\n\n    filename_i, filename_j = tf.decode_csv(\n        csv_filename, record_defaults=record_defaults)\n\n    file_contents_i = tf.read_file(filename_i)\n    file_contents_j = tf.read_file(filename_j)\n    if image_type == '.jpg':\n        image_decoded_A = tf.image.decode_jpeg(\n            file_contents_i, channels=model.IMG_CHANNELS)\n        image_decoded_B = tf.image.decode_jpeg(\n            file_contents_j, channels=model.IMG_CHANNELS)\n    elif image_type == '.png':\n        image_decoded_A = tf.image.decode_png(\n            file_contents_i, channels=model.IMG_CHANNELS, dtype=tf.uint8)\n        image_decoded_B = tf.image.decode_png(\n            file_contents_j, channels=model.IMG_CHANNELS, dtype=tf.uint8)\n\n    return image_decoded_A, image_decoded_B\n\n\ndef load_data(dataset_name, image_size_before_crop,\n              do_shuffle=True, do_flipping=False):\n    \"\"\"\n\n    :param dataset_name: The name of the dataset.\n    :param image_size_before_crop: Resize to this size before random cropping.\n    :param do_shuffle: Shuffle switch.\n    :param do_flipping: Flip switch.\n    :return:\n    \"\"\"\n    if dataset_name not in cyclegan_datasets.DATASET_TO_SIZES:\n        raise ValueError('split name %s was not recognized.'\n                         % dataset_name)\n\n    csv_name = cyclegan_datasets.PATH_TO_CSV[dataset_name]\n\n    image_i, image_j = _load_samples(\n        csv_name, cyclegan_datasets.DATASET_TO_IMAGETYPE[dataset_name])\n    inputs = {\n        'image_i': image_i,\n        'image_j': image_j\n    }\n\n    # Preprocessing:\n    inputs['image_i'] = tf.image.resize_images(\n        inputs['image_i'], [image_size_before_crop, image_size_before_crop])\n    inputs['image_j'] = tf.image.resize_images(\n        inputs['image_j'], [image_size_before_crop, image_size_before_crop])\n\n    if do_flipping is True:\n        inputs['image_i'] = tf.image.random_flip_left_right(inputs['image_i'], seed=1)\n        inputs['image_j'] = tf.image.random_flip_left_right(inputs['image_j'], seed=1)\n\n    inputs['image_i'] = tf.random_crop(\n        inputs['image_i'], [model.IMG_HEIGHT, model.IMG_WIDTH, 3], seed=1)\n    inputs['image_j'] = tf.random_crop(\n        inputs['image_j'], [model.IMG_HEIGHT, model.IMG_WIDTH, 3], seed=1)\n\n    inputs['image_i'] = tf.subtract(tf.div(inputs['image_i'], 127.5), 1)\n    inputs['image_j'] = tf.subtract(tf.div(inputs['image_j'], 127.5), 1)\n\n    # Batch\n    if do_shuffle is True:\n        inputs['images_i'], inputs['images_j'] = tf.train.shuffle_batch(\n            [inputs['image_i'], inputs['image_j']], 1, 5000, 100, seed=1)\n    else:\n        inputs['images_i'], inputs['images_j'] = tf.train.batch(\n            [inputs['image_i'], inputs['image_j']], 1)\n\n    return inputs\n"
  },
  {
    "path": "download_datasets.sh",
    "content": "FILE=$1\n\nif [[ $FILE != \"ae_photos\" && $FILE != \"apple2orange\" && $FILE != \"summer2winter_yosemite\" &&  $FILE != \"horse2zebra\" && $FILE != \"monet2photo\" && $FILE != \"cezanne2photo\" && $FILE != \"ukiyoe2photo\" && $FILE != \"vangogh2photo\" && $FILE != \"maps\" && $FILE != \"cityscapes\" && $FILE != \"facades\" && $FILE != \"iphone2dslr_flower\" && $FILE != \"ae_photos\" ]]; then\n    echo \"Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos\"\n    exit 1\nfi\n\nmkdir ./input\nURL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip\nZIP_FILE=./input/$FILE.zip\nTARGET_DIR=./input/$FILE/\nwget -N $URL -O $ZIP_FILE\nmkdir $TARGET_DIR\nunzip $ZIP_FILE -d ./input/\nrm $ZIP_FILE\n"
  },
  {
    "path": "layers.py",
    "content": "import tensorflow as tf\n\n\ndef lrelu(x, leak=0.2, name=\"lrelu\", alt_relu_impl=False):\n\n    with tf.variable_scope(name):\n        if alt_relu_impl:\n            f1 = 0.5 * (1 + leak)\n            f2 = 0.5 * (1 - leak)\n            return f1 * x + f2 * abs(x)\n        else:\n            return tf.maximum(x, leak * x)\n\n\ndef instance_norm(x):\n\n    with tf.variable_scope(\"instance_norm\"):\n        epsilon = 1e-5\n        mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)\n        scale = tf.get_variable('scale', [x.get_shape()[-1]],\n                                initializer=tf.truncated_normal_initializer(\n                                    mean=1.0, stddev=0.02\n        ))\n        offset = tf.get_variable(\n            'offset', [x.get_shape()[-1]],\n            initializer=tf.constant_initializer(0.0)\n        )\n        out = scale * tf.div(x - mean, tf.sqrt(var + epsilon)) + offset\n\n        return out\n\ndef instance_norm_bis(x,mask):\n\n    with tf.variable_scope(\"instance_norm\"):\n        epsilon = 1e-5\n        for i in range(x.shape[-1]):\n            slice = tf.gather(x, i, axis=3)\n            slice_mask = tf.gather(mask, i, axis=3)\n            tmp = tf.boolean_mask(slice,slice_mask)\n            mean, var = tf.nn.moments_bis(x, [1, 2], keep_dims=False)\n\n        mean, var = tf.nn.moments_bis(x, [1, 2], keep_dims=True)\n        scale = tf.get_variable('scale', [x.get_shape()[-1]],\n                                initializer=tf.truncated_normal_initializer(\n                                    mean=1.0, stddev=0.02\n        ))\n        offset = tf.get_variable(\n            'offset', [x.get_shape()[-1]],\n            initializer=tf.constant_initializer(0.0)\n        )\n        out = scale * tf.div(x - mean, tf.sqrt(var + epsilon)) + offset\n\n        return out\n\n\ndef general_conv2d_(inputconv, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02,\n                   padding=\"VALID\", name=\"conv2d\", do_norm=True, do_relu=True,\n                   relufactor=0):\n    with tf.variable_scope(name):\n\n        conv = tf.contrib.layers.conv2d(\n            inputconv, o_d, f_w, s_w, padding,\n            activation_fn=None,\n            weights_initializer=tf.truncated_normal_initializer(\n                stddev=stddev\n            ),\n            biases_initializer=tf.constant_initializer(0.0)\n        )\n        if do_norm:\n            conv = instance_norm(conv)\n\n        if do_relu:\n            if(relufactor == 0):\n                conv = tf.nn.relu(conv, \"relu\")\n            else:\n                conv = lrelu(conv, relufactor, \"lrelu\")\n\n        return conv\n\ndef general_conv2d(inputconv, do_norm, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02,\n                   padding=\"VALID\", name=\"conv2d\", do_relu=True,\n                   relufactor=0):\n    with tf.variable_scope(name):\n        conv = tf.contrib.layers.conv2d(\n            inputconv, o_d, f_w, s_w, padding,\n            activation_fn=None,\n            weights_initializer=tf.truncated_normal_initializer(\n                stddev=stddev\n            ),\n            biases_initializer=tf.constant_initializer(0.0)\n        )\n\n        conv = tf.cond(do_norm, lambda: instance_norm(conv), lambda: conv)\n\n\n        if do_relu:\n            if(relufactor == 0):\n                conv = tf.nn.relu(conv, \"relu\")\n            else:\n                conv = lrelu(conv, relufactor, \"lrelu\")\n\n\n        return conv\n\ndef general_deconv2d(inputconv, outshape, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1,\n                     stddev=0.02, padding=\"VALID\", name=\"deconv2d\",\n                     do_norm=True, do_relu=True, relufactor=0):\n    with tf.variable_scope(name):\n\n        conv = tf.contrib.layers.conv2d_transpose(\n            inputconv, o_d, [f_h, f_w],\n            [s_h, s_w], padding,\n            activation_fn=None,\n            weights_initializer=tf.truncated_normal_initializer(stddev=stddev),\n            biases_initializer=tf.constant_initializer(0.0)\n        )\n\n        if do_norm:\n            conv = instance_norm(conv)\n\n        if do_relu:\n            if(relufactor == 0):\n                conv = tf.nn.relu(conv, \"relu\")\n            else:\n                conv = lrelu(conv, relufactor, \"lrelu\")\n\n        return conv\n\ndef upsamplingDeconv(inputconv, size, is_scale, method,align_corners, name):\n    if len(inputconv.get_shape()) == 3:\n        if is_scale:\n            size_h = size[0] * int(inputconv.get_shape()[0])\n            size_w = size[1] * int(inputconv.get_shape()[1])\n            size = [int(size_h), int(size_w)]\n    elif len(inputconv.get_shape()) == 4:\n        if is_scale:\n            size_h = size[0] * int(inputconv.get_shape()[1])\n            size_w = size[1] * int(inputconv.get_shape()[2])\n            size = [int(size_h), int(size_w)]\n    else:\n        raise Exception(\"Donot support shape %s\" % inputconv.get_shape())\n    print(\"  [TL] UpSampling2dLayer %s: is_scale:%s size:%s method:%d align_corners:%s\" %\n          (name, is_scale, size, method, align_corners))\n    with tf.variable_scope(name) as vs:\n        try:\n            out = tf.image.resize_images(inputconv, size=size, method=method, align_corners=align_corners)\n        except:  # for TF 0.10\n            out = tf.image.resize_images(inputconv, new_height=size[0], new_width=size[1], method=method,\n                                                  align_corners=align_corners)\n    return out\n\ndef general_fc_layers(inpfc, outshape, name):\n    with tf.variable_scope(name):\n\n        fcw = tf.Variable(tf.truncated_normal(outshape,\n                                               dtype=tf.float32,\n                                               stddev=1e-1), name='weights')\n        fcb = tf.Variable(tf.constant(1.0, shape=[outshape[-1]], dtype=tf.float32),\n                           trainable=True, name='biases')\n\n        fcl = tf.nn.bias_add(tf.matmul(inpfc, fcw), fcb)\n        fc_out = tf.nn.relu(fcl)\n\n        return fc_out\n"
  },
  {
    "path": "losses.py",
    "content": "\"\"\"Contains losses used for performing image-to-image domain adaptation.\"\"\"\nimport tensorflow as tf\n\n\ndef cycle_consistency_loss(real_images, generated_images):\n    \"\"\"Compute the cycle consistency loss.\n\n    The cycle consistency loss is defined as the sum of the L1 distances\n    between the real images from each domain and their generated (fake)\n    counterparts.\n\n    This definition is derived from Equation 2 in:\n        Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial\n        Networks.\n        Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros.\n\n\n    Args:\n        real_images: A batch of images from domain X, a `Tensor` of shape\n            [batch_size, height, width, channels].\n        generated_images: A batch of generated images made to look like they\n            came from domain X, a `Tensor` of shape\n            [batch_size, height, width, channels].\n\n    Returns:\n        The cycle consistency loss.\n    \"\"\"\n    return tf.reduce_mean(tf.abs(real_images - generated_images))\n\n\ndef mask_loss(gen_image, mask):\n\n    return tf.reduce_mean(tf.abs(tf.multiply(gen_image,1-mask)))\n\ndef lsgan_loss_generator(prob_fake_is_real):\n    \"\"\"Computes the LS-GAN loss as minimized by the generator.\n\n    Rather than compute the negative loglikelihood, a least-squares loss is\n    used to optimize the discriminators as per Equation 2 in:\n        Least Squares Generative Adversarial Networks\n        Xudong Mao, Qing Li, Haoran Xie, Raymond Y.K. Lau, Zhen Wang, and\n        Stephen Paul Smolley.\n        https://arxiv.org/pdf/1611.04076.pdf\n\n    Args:\n        prob_fake_is_real: The discriminator's estimate that generated images\n            made to look like real images are real.\n\n    Returns:\n        The total LS-GAN loss.\n    \"\"\"\n    return tf.reduce_mean(tf.squared_difference(prob_fake_is_real, 1))\n\n\ndef lsgan_loss_discriminator(prob_real_is_real, prob_fake_is_real):\n    \"\"\"Computes the LS-GAN loss as minimized by the discriminator.\n\n    Rather than compute the negative loglikelihood, a least-squares loss is\n    used to optimize the discriminators as per Equation 2 in:\n        Least Squares Generative Adversarial Networks\n        Xudong Mao, Qing Li, Haoran Xie, Raymond Y.K. Lau, Zhen Wang, and\n        Stephen Paul Smolley.\n        https://arxiv.org/pdf/1611.04076.pdf\n\n    Args:\n        prob_real_is_real: The discriminator's estimate that images actually\n            drawn from the real domain are in fact real.\n        prob_fake_is_real: The discriminator's estimate that generated images\n            made to look like real images are real.\n\n    Returns:\n        The total LS-GAN loss.\n    \"\"\"\n    return (tf.reduce_mean(tf.squared_difference(prob_real_is_real, 1)) +\n            tf.reduce_mean(tf.squared_difference(prob_fake_is_real, 0))) * 0.5\n"
  },
  {
    "path": "main.py",
    "content": "\"\"\"Code for training CycleGAN.\"\"\"\r\nfrom datetime import datetime\r\nimport json\r\nimport numpy as np\r\nimport os\r\nimport random\r\nfrom scipy.misc import imsave\r\n\r\nimport argparse\r\nimport tensorflow as tf\r\n\r\nimport cyclegan_datasets\r\nimport data_loader, losses, model\r\n\r\ntf.set_random_seed(1)\r\nnp.random.seed(0)\r\nslim = tf.contrib.slim\r\n\r\n\r\nclass CycleGAN:\r\n    \"\"\"The CycleGAN module.\"\"\"\r\n\r\n    def __init__(self, pool_size, lambda_a,\r\n                 lambda_b, output_root_dir, to_restore,\r\n                 base_lr, max_step, network_version,\r\n                 dataset_name, checkpoint_dir, do_flipping, skip, switch, threshold_fg):\r\n        current_time = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\r\n\r\n        self._pool_size = pool_size\r\n        self._size_before_crop = 286\r\n        self._switch = switch\r\n        self._threshold_fg = threshold_fg\r\n        self._lambda_a = lambda_a\r\n        self._lambda_b = lambda_b\r\n        self._output_dir = os.path.join(output_root_dir, current_time +\r\n                                        '_switch'+str(switch)+'_thres_'+str(threshold_fg))\r\n        self._images_dir = os.path.join(self._output_dir, 'imgs')\r\n        self._num_imgs_to_save = 20\r\n        self._to_restore = to_restore\r\n        self._base_lr = base_lr\r\n        self._max_step = max_step\r\n        self._network_version = network_version\r\n        self._dataset_name = dataset_name\r\n        self._checkpoint_dir = checkpoint_dir\r\n        self._do_flipping = do_flipping\r\n        self._skip = skip\r\n\r\n        self.fake_images_A = []\r\n        self.fake_images_B = []\r\n\r\n    def model_setup(self):\r\n        \"\"\"\r\n        This function sets up the model to train.\r\n\r\n        self.input_A/self.input_B -> Set of training images.\r\n        self.fake_A/self.fake_B -> Generated images by corresponding generator\r\n        of input_A and input_B\r\n        self.lr -> Learning rate variable\r\n        self.cyc_A/ self.cyc_B -> Images generated after feeding\r\n        self.fake_A/self.fake_B to corresponding generator.\r\n        This is use to calculate cyclic loss\r\n        \"\"\"\r\n        self.input_a = tf.placeholder(\r\n            tf.float32, [\r\n                1,\r\n                model.IMG_WIDTH,\r\n                model.IMG_HEIGHT,\r\n                model.IMG_CHANNELS\r\n            ], name=\"input_A\")\r\n        self.input_b = tf.placeholder(\r\n            tf.float32, [\r\n                1,\r\n                model.IMG_WIDTH,\r\n                model.IMG_HEIGHT,\r\n                model.IMG_CHANNELS\r\n            ], name=\"input_B\")\r\n\r\n        self.fake_pool_A = tf.placeholder(\r\n            tf.float32, [\r\n                None,\r\n                model.IMG_WIDTH,\r\n                model.IMG_HEIGHT,\r\n                model.IMG_CHANNELS\r\n            ], name=\"fake_pool_A\")\r\n        self.fake_pool_B = tf.placeholder(\r\n            tf.float32, [\r\n                None,\r\n                model.IMG_WIDTH,\r\n                model.IMG_HEIGHT,\r\n                model.IMG_CHANNELS\r\n            ], name=\"fake_pool_B\")\r\n        self.fake_pool_A_mask = tf.placeholder(\r\n            tf.float32, [\r\n                None,\r\n                model.IMG_WIDTH,\r\n                model.IMG_HEIGHT,\r\n                model.IMG_CHANNELS\r\n            ], name=\"fake_pool_A_mask\")\r\n        self.fake_pool_B_mask = tf.placeholder(\r\n            tf.float32, [\r\n                None,\r\n                model.IMG_WIDTH,\r\n                model.IMG_HEIGHT,\r\n                model.IMG_CHANNELS\r\n            ], name=\"fake_pool_B_mask\")\r\n\r\n        self.global_step = slim.get_or_create_global_step()\r\n\r\n        self.num_fake_inputs = 0\r\n\r\n        self.learning_rate = tf.placeholder(tf.float32, shape=[], name=\"lr\")\r\n        self.transition_rate = tf.placeholder(tf.float32, shape=[], name=\"tr\")\r\n        self.donorm = tf.placeholder(tf.bool, shape=[], name=\"donorm\")\r\n\r\n        inputs = {\r\n            'images_a': self.input_a,\r\n            'images_b': self.input_b,\r\n            'fake_pool_a': self.fake_pool_A,\r\n            'fake_pool_b': self.fake_pool_B,\r\n            'fake_pool_a_mask': self.fake_pool_A_mask,\r\n            'fake_pool_b_mask': self.fake_pool_B_mask,\r\n            'transition_rate': self.transition_rate,\r\n            'donorm': self.donorm,\r\n        }\r\n\r\n        outputs = model.get_outputs(\r\n            inputs, skip=self._skip)\r\n\r\n        self.prob_real_a_is_real = outputs['prob_real_a_is_real']\r\n        self.prob_real_b_is_real = outputs['prob_real_b_is_real']\r\n        self.fake_images_a = outputs['fake_images_a']\r\n        self.fake_images_b = outputs['fake_images_b']\r\n        self.prob_fake_a_is_real = outputs['prob_fake_a_is_real']\r\n        self.prob_fake_b_is_real = outputs['prob_fake_b_is_real']\r\n\r\n        self.cycle_images_a = outputs['cycle_images_a']\r\n        self.cycle_images_b = outputs['cycle_images_b']\r\n\r\n        self.prob_fake_pool_a_is_real = outputs['prob_fake_pool_a_is_real']\r\n        self.prob_fake_pool_b_is_real = outputs['prob_fake_pool_b_is_real']\r\n        self.masks = outputs['masks']\r\n        self.masked_gen_ims = outputs['masked_gen_ims']\r\n        self.masked_ims = outputs['masked_ims']\r\n        self.masks_ = outputs['mask_tmp']\r\n\r\n    def compute_losses(self):\r\n        \"\"\"\r\n        In this function we are defining the variables for loss calculations\r\n        and training model.\r\n\r\n        d_loss_A/d_loss_B -> loss for discriminator A/B\r\n        g_loss_A/g_loss_B -> loss for generator A/B\r\n        *_trainer -> Various trainer for above loss functions\r\n        *_summ -> Summary variables for above loss functions\r\n        \"\"\"\r\n\r\n\r\n        cycle_consistency_loss_a = \\\r\n            self._lambda_a * losses.cycle_consistency_loss(\r\n                real_images=self.input_a, generated_images=self.cycle_images_a,\r\n            )\r\n        cycle_consistency_loss_b = \\\r\n            self._lambda_b * losses.cycle_consistency_loss(\r\n                real_images=self.input_b, generated_images=self.cycle_images_b,\r\n            )\r\n\r\n        lsgan_loss_a = losses.lsgan_loss_generator(self.prob_fake_a_is_real)\r\n        lsgan_loss_b = losses.lsgan_loss_generator(self.prob_fake_b_is_real)\r\n\r\n        g_loss_A = \\\r\n            cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b\r\n        g_loss_B = \\\r\n            cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a\r\n\r\n        d_loss_A = losses.lsgan_loss_discriminator(\r\n            prob_real_is_real=self.prob_real_a_is_real,\r\n            prob_fake_is_real=self.prob_fake_pool_a_is_real,\r\n        )\r\n        d_loss_B = losses.lsgan_loss_discriminator(\r\n            prob_real_is_real=self.prob_real_b_is_real,\r\n            prob_fake_is_real=self.prob_fake_pool_b_is_real,\r\n        )\r\n\r\n        optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5)\r\n        self.model_vars = tf.trainable_variables()\r\n\r\n        d_A_vars = [var for var in self.model_vars if 'd_A' in var.name]\r\n        g_A_vars = [var for var in self.model_vars if 'g_A/' in var.name]\r\n        d_B_vars = [var for var in self.model_vars if 'd_B' in var.name]\r\n        g_B_vars = [var for var in self.model_vars if 'g_B/' in var.name]\r\n        g_Ae_vars = [var for var in self.model_vars if 'g_A_ae' in var.name]\r\n        g_Be_vars = [var for var in self.model_vars if 'g_B_ae' in var.name]\r\n\r\n\r\n        self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars+g_Ae_vars)\r\n        self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars+g_Be_vars)\r\n        self.g_A_trainer_bis = optimizer.minimize(g_loss_A, var_list=g_A_vars)\r\n        self.g_B_trainer_bis = optimizer.minimize(g_loss_B, var_list=g_B_vars)\r\n        self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)\r\n        self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)\r\n\r\n        self.params_ae_c1 = g_A_vars[0]\r\n        self.params_ae_c1_B = g_B_vars[0]\r\n        for var in self.model_vars:\r\n            print(var.name)\r\n\r\n        # Summary variables for tensorboard\r\n        self.g_A_loss_summ = tf.summary.scalar(\"g_A_loss\", g_loss_A)\r\n        self.g_B_loss_summ = tf.summary.scalar(\"g_B_loss\", g_loss_B)\r\n        self.d_A_loss_summ = tf.summary.scalar(\"d_A_loss\", d_loss_A)\r\n        self.d_B_loss_summ = tf.summary.scalar(\"d_B_loss\", d_loss_B)\r\n\r\n    def save_images(self, sess, epoch, curr_tr):\r\n        \"\"\"\r\n        Saves input and output images.\r\n\r\n        :param sess: The session.\r\n        :param epoch: Currnt epoch.\r\n        \"\"\"\r\n        if not os.path.exists(self._images_dir):\r\n            os.makedirs(self._images_dir)\r\n\r\n        if curr_tr >0:\r\n            donorm = False\r\n        else:\r\n            donorm = True\r\n\r\n        names = ['inputA_', 'inputB_', 'fakeA_',\r\n                 'fakeB_', 'cycA_', 'cycB_',\r\n                 'mask_a', 'mask_b']\r\n\r\n        with open(os.path.join(\r\n                self._output_dir, 'epoch_' + str(epoch) + '.html'\r\n        ), 'w') as v_html:\r\n            for i in range(0, self._num_imgs_to_save):\r\n                print(\"Saving image {}/{}\".format(i, self._num_imgs_to_save))\r\n                inputs = sess.run(self.inputs)\r\n                fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp, masks = sess.run([\r\n                    self.fake_images_a,\r\n                    self.fake_images_b,\r\n                    self.cycle_images_a,\r\n                    self.cycle_images_b,\r\n                    self.masks,\r\n                ], feed_dict={\r\n                    self.input_a: inputs['images_i'],\r\n                    self.input_b: inputs['images_j'],\r\n                    self.transition_rate: curr_tr,\r\n                    self.donorm: donorm,\r\n                })\r\n\r\n                tensors = [inputs['images_i'], inputs['images_j'],\r\n                           fake_B_temp, fake_A_temp, cyc_A_temp, cyc_B_temp, masks[0], masks[1]]\r\n\r\n                for name, tensor in zip(names, tensors):\r\n                    image_name = name + str(epoch) + \"_\" + str(i) + \".jpg\"\r\n                    if 'mask_' in name:\r\n                        imsave(os.path.join(self._images_dir, image_name),\r\n                               (np.squeeze(tensor[0]))\r\n                               )\r\n                    else:\r\n\r\n                        imsave(os.path.join(self._images_dir, image_name),\r\n                               ((np.squeeze(tensor[0]) + 1) * 127.5).astype(np.uint8)\r\n                               )\r\n                    v_html.write(\r\n                        \"<img src=\\\"\" +\r\n                        os.path.join('imgs', image_name) + \"\\\">\"\r\n                    )\r\n                v_html.write(\"<br>\")\r\n\r\n    def save_images_bis(self, sess, epoch):\r\n        \"\"\"\r\n        Saves input and output images.\r\n\r\n        :param sess: The session.\r\n        :param epoch: Currnt epoch.\r\n        \"\"\"\r\n        if not os.path.exists(self._images_dir):\r\n            os.makedirs(self._images_dir)\r\n\r\n        names = ['input_A_', 'mask_A_', 'masked_inputA_', 'fakeB_',\r\n                 'input_B_', 'mask_B_', 'masked_inputB_', 'fakeA_']\r\n\r\n        space = '&nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp ' \\\r\n                '&nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp ' \\\r\n                '&nbsp &nbsp &nbsp &nbsp &nbsp'\r\n        with open(os.path.join(self._output_dir, 'results_' + str(epoch) + '.html'), 'w') as v_html:\r\n            v_html.write(\"<b>INPUT\" + space + \"MASK\" + space + \"MASKED_IMAGE\" + space + \"GENERATED_IMAGE</b>\")\r\n            v_html.write(\"<br>\")\r\n            for i in range(0, self._num_imgs_to_save):\r\n                print(\"Saving image {}/{}\".format(i, self._num_imgs_to_save))\r\n                inputs = sess.run(self.inputs)\r\n                fake_A_temp, fake_B_temp, masks, masked_ims = sess.run([\r\n                    self.fake_images_a,\r\n                    self.fake_images_b,\r\n                    self.masks,\r\n                    self.masked_ims\r\n                ], feed_dict={\r\n                    self.input_a: inputs['images_i'],\r\n                    self.input_b: inputs['images_j'],\r\n                    self.transition_rate: 0.1\r\n                })\r\n                tensors = [inputs['images_i'], masks[0], masked_ims[0], fake_B_temp,\r\n                           inputs['images_j'], masks[1], masked_ims[1], fake_A_temp]\r\n\r\n                for name, tensor in zip(names, tensors):\r\n                    image_name = name + str(i) + \".jpg\"\r\n\r\n                    if 'mask_' in name:\r\n                        imsave(os.path.join(self._images_dir, image_name),\r\n                               (np.squeeze(tensor[0]))\r\n                               )\r\n                    else:\r\n\r\n                        imsave(os.path.join(self._images_dir, image_name),\r\n                               ((np.squeeze(tensor[0]) + 1) * 127.5).astype(np.uint8)\r\n                               )\r\n\r\n                    v_html.write(\r\n                        \"<img src=\\\"\" +\r\n                        os.path.join('imgs', image_name) + \"\\\">\"\r\n                    )\r\n\r\n                    if 'fakeB_' in name:\r\n                        v_html.write(\"<br>\")\r\n                v_html.write(\"<br>\")\r\n\r\n    def fake_image_pool(self, num_fakes, fake, mask, fake_pool):\r\n        \"\"\"\r\n        This function saves the generated image to corresponding\r\n        pool of images.\r\n\r\n        It keeps on feeling the pool till it is full and then randomly\r\n        selects an already stored image and replace it with new one.\r\n        \"\"\"\r\n        tmp = {}\r\n        tmp['im'] = fake\r\n        tmp['mask'] = mask\r\n        if num_fakes < self._pool_size:\r\n            fake_pool.append(tmp)\r\n            return tmp\r\n        else:\r\n            p = random.random()\r\n            if p > 0.5:\r\n                random_id = random.randint(0, self._pool_size - 1)\r\n                temp = fake_pool[random_id]\r\n                fake_pool[random_id] = tmp\r\n                return temp\r\n            else:\r\n                return tmp\r\n\r\n    def train(self):\r\n        \"\"\"Training Function.\"\"\"\r\n        # Load Dataset from the dataset folder\r\n        self.inputs = data_loader.load_data(\r\n            self._dataset_name, self._size_before_crop,\r\n            False, self._do_flipping)\r\n\r\n        # Build the network\r\n        self.model_setup()\r\n\r\n        # Loss function calculations\r\n        self.compute_losses()\r\n\r\n        # Initializing the global variables\r\n        init = (tf.global_variables_initializer(),\r\n                tf.local_variables_initializer())\r\n\r\n        saver = tf.train.Saver(max_to_keep=None)\r\n\r\n        max_images = cyclegan_datasets.DATASET_TO_SIZES[self._dataset_name]\r\n        half_training = int(self._max_step / 2)\r\n        with tf.Session() as sess:\r\n            sess.run(init)\r\n            # Restore the model to run the model from last checkpoint\r\n            if self._to_restore:\r\n                chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)\r\n                saver.restore(sess, chkpt_fname)\r\n\r\n            writer = tf.summary.FileWriter(self._output_dir)\r\n\r\n            if not os.path.exists(self._output_dir):\r\n                os.makedirs(self._output_dir)\r\n\r\n            coord = tf.train.Coordinator()\r\n            threads = tf.train.start_queue_runners(coord=coord)\r\n\r\n            # Training Loop\r\n            for epoch in range(sess.run(self.global_step), self._max_step):\r\n                print(\"In the epoch \", epoch)\r\n                saver.save(sess, os.path.join(\r\n                    self._output_dir, \"AGGAN\"), global_step=epoch)\r\n\r\n                # Dealing with the learning rate as per the epoch number\r\n                if epoch < half_training:\r\n                    curr_lr = self._base_lr\r\n                else:\r\n                    curr_lr = self._base_lr - \\\r\n                        self._base_lr * (epoch - half_training) / half_training\r\n\r\n                if epoch < self._switch:\r\n                    curr_tr = 0.\r\n                    donorm = True\r\n                    to_train_A = self.g_A_trainer\r\n                    to_train_B = self.g_B_trainer\r\n                else:\r\n                    curr_tr = self._threshold_fg\r\n                    donorm = False\r\n                    to_train_A = self.g_A_trainer_bis\r\n                    to_train_B = self.g_B_trainer_bis\r\n\r\n\r\n                self.save_images(sess, epoch, curr_tr)\r\n\r\n                for i in range(0, max_images):\r\n                    print(\"Processing batch {}/{}\".format(i, max_images))\r\n\r\n                    inputs = sess.run(self.inputs)\r\n                    # Optimizing the G_A network\r\n                    _, fake_B_temp, smask_a,summary_str = sess.run(\r\n                        [to_train_A,\r\n                         self.fake_images_b,\r\n                         self.masks[0],\r\n                         self.g_A_loss_summ],\r\n                        feed_dict={\r\n                            self.input_a:\r\n                                inputs['images_i'],\r\n                            self.input_b:\r\n                                inputs['images_j'],\r\n                            self.learning_rate: curr_lr,\r\n                            self.transition_rate: curr_tr,\r\n                            self.donorm: donorm,\r\n                        }\r\n                    )\r\n                    writer.add_summary(summary_str, epoch * max_images + i)\r\n\r\n                    fake_B_temp1 = self.fake_image_pool(\r\n                        self.num_fake_inputs, fake_B_temp, smask_a, self.fake_images_B)\r\n\r\n                    # Optimizing the D_B network\r\n                    _,summary_str = sess.run(\r\n                        [self.d_B_trainer, self.d_B_loss_summ],\r\n                        feed_dict={\r\n                            self.input_a:\r\n                                inputs['images_i'],\r\n                            self.input_b:\r\n                                inputs['images_j'],\r\n                            self.learning_rate: curr_lr,\r\n                            self.fake_pool_B: fake_B_temp1['im'],\r\n                            self.fake_pool_B_mask: fake_B_temp1['mask'],\r\n                            self.transition_rate: curr_tr,\r\n                            self.donorm: donorm,\r\n                        }\r\n                    )\r\n                    writer.add_summary(summary_str, epoch * max_images + i)\r\n\r\n\r\n                    # Optimizing the G_B network\r\n                    _, fake_A_temp, smask_b, summary_str = sess.run(\r\n                        [to_train_B,\r\n                         self.fake_images_a,\r\n                         self.masks[1],\r\n                         self.g_B_loss_summ],\r\n                        feed_dict={\r\n                            self.input_a:\r\n                                inputs['images_i'],\r\n                            self.input_b:\r\n                                inputs['images_j'],\r\n                            self.learning_rate: curr_lr,\r\n                            self.transition_rate: curr_tr,\r\n                            self.donorm: donorm,\r\n                        }\r\n                    )\r\n                    writer.add_summary(summary_str, epoch * max_images + i)\r\n\r\n                    fake_A_temp1 = self.fake_image_pool(\r\n                        self.num_fake_inputs, fake_A_temp, smask_b ,self.fake_images_A)\r\n\r\n                    # Optimizing the D_A network\r\n                    _, mask_tmp__,summary_str = sess.run(\r\n                        [self.d_A_trainer,self.masks_, self.d_A_loss_summ],\r\n                        feed_dict={\r\n                            self.input_a:\r\n                                inputs['images_i'],\r\n                            self.input_b:\r\n                                inputs['images_j'],\r\n                            self.learning_rate: curr_lr,\r\n                            self.fake_pool_A: fake_A_temp1['im'],\r\n                            self.fake_pool_A_mask: fake_A_temp1['mask'],\r\n                            self.transition_rate: curr_tr,\r\n                            self.donorm: donorm,\r\n                        }\r\n                    )\r\n                    writer.add_summary(summary_str, epoch * max_images + i)\r\n\r\n                    writer.flush()\r\n                    self.num_fake_inputs += 1\r\n\r\n                sess.run(tf.assign(self.global_step, epoch + 1))\r\n\r\n            coord.request_stop()\r\n            coord.join(threads)\r\n            writer.add_graph(sess.graph)\r\n\r\n    def test(self):\r\n        \"\"\"Test Function.\"\"\"\r\n        print(\"Testing the results\")\r\n\r\n        self.inputs = data_loader.load_data(\r\n            self._dataset_name, self._size_before_crop,\r\n            False, self._do_flipping)\r\n\r\n        self.model_setup()\r\n        saver = tf.train.Saver()\r\n        init = tf.global_variables_initializer()\r\n\r\n        with tf.Session() as sess:\r\n            sess.run(init)\r\n\r\n            chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)\r\n            saver.restore(sess, chkpt_fname)\r\n\r\n            coord = tf.train.Coordinator()\r\n            threads = tf.train.start_queue_runners(coord=coord)\r\n\r\n            self._num_imgs_to_save = cyclegan_datasets.DATASET_TO_SIZES[\r\n                self._dataset_name]\r\n            self.save_images_bis(sess, sess.run(self.global_step))\r\n\r\n            coord.request_stop()\r\n            coord.join(threads)\r\n\r\n\r\ndef parse_args():\r\n    desc = \"Tensorflow implementation of cycleGan using attention\"\r\n    parser = argparse.ArgumentParser(description=desc)\r\n\r\n    parser.add_argument('--to_train', type=int, default=True, help='Whether it is train or false.')\r\n    parser.add_argument('--log_dir',\r\n              type=str,\r\n              default=None,\r\n              help='Where the data is logged to.')\r\n\r\n    parser.add_argument('--config_filename', type=str, default='train', help='The name of the configuration file.')\r\n\r\n    parser.add_argument('--checkpoint_dir', type=str, default='', help='The name of the train/test split.')\r\n    parser.add_argument('--skip', type=bool, default=False,\r\n                        help='Whether to add skip connection between input and output.')\r\n    parser.add_argument('--switch', type=int, default=30,\r\n                        help='In what epoch the FG starts to be fed to the discriminator')\r\n    parser.add_argument('--threshold', type=float, default=0.1,\r\n                        help='The threshold value to select the FG')\r\n\r\n\r\n    return parser.parse_args()\r\n\r\ndef main():\r\n    \"\"\"\r\n\r\n    :param to_train: Specify whether it is training or testing. 1: training; 2:\r\n     resuming from latest checkpoint; 0: testing.\r\n    :param log_dir: The root dir to save checkpoints and imgs. The actual dir\r\n    is the root dir appended by the folder with the name timestamp.\r\n    :param config_filename: The configuration file.\r\n    :param checkpoint_dir: The directory that saves the latest checkpoint. It\r\n    only takes effect when to_train == 2.\r\n    :param skip: A boolean indicating whether to add skip connection between\r\n    input and output.\r\n    \"\"\"\r\n    args = parse_args()\r\n    if args is None:\r\n        exit()\r\n\r\n    to_train = args.to_train\r\n    log_dir = args.log_dir\r\n    config_filename = args.config_filename\r\n    checkpoint_dir = args.checkpoint_dir\r\n    skip = args.skip\r\n    switch = args.switch\r\n    threshold_fg = args.threshold\r\n\r\n    if not os.path.isdir(log_dir):\r\n        os.makedirs(log_dir)\r\n\r\n    with open(config_filename) as config_file:\r\n        config = json.load(config_file)\r\n\r\n\r\n\r\n    lambda_a = float(config['_LAMBDA_A']) if '_LAMBDA_A' in config else 10.0\r\n    lambda_b = float(config['_LAMBDA_B']) if '_LAMBDA_B' in config else 10.0\r\n    pool_size = int(config['pool_size']) if 'pool_size' in config else 50\r\n\r\n    to_restore = (to_train == 2)\r\n    base_lr = float(config['base_lr']) if 'base_lr' in config else 0.0002\r\n    max_step = int(config['max_step']) if 'max_step' in config else 200\r\n    network_version = str(config['network_version'])\r\n    dataset_name = str(config['dataset_name'])\r\n    do_flipping = bool(config['do_flipping'])\r\n\r\n    cyclegan_model = CycleGAN(pool_size, lambda_a, lambda_b, log_dir,\r\n                              to_restore, base_lr, max_step, network_version,\r\n                              dataset_name, checkpoint_dir, do_flipping, skip,\r\n                              switch, threshold_fg)\r\n\r\n    if to_train > 0:\r\n        cyclegan_model.train()\r\n    else:\r\n        cyclegan_model.test()\r\n\r\n\r\nif __name__ == '__main__':\r\n    main()\r\n"
  },
  {
    "path": "model.py",
    "content": "\"\"\"Code for constructing the model and get the outputs from the model.\"\"\"\n\nimport tensorflow as tf\nimport numpy as np\nimport layers\n\n# The number of samples per batch.\nBATCH_SIZE = 1\n\n# The height of each image.\nIMG_HEIGHT = 256\n\n# The width of each image.\nIMG_WIDTH = 256\n\n# The number of color channels per image.\nIMG_CHANNELS = 3\n\nPOOL_SIZE = 50\nngf = 32\nndf = 64\n\n\ndef get_outputs(inputs, skip=False):\n\n    images_a = inputs['images_a']\n    images_b = inputs['images_b']\n    fake_pool_a = inputs['fake_pool_a']\n    fake_pool_b = inputs['fake_pool_b']\n    fake_pool_a_mask = inputs['fake_pool_a_mask']\n    fake_pool_b_mask = inputs['fake_pool_b_mask']\n    transition_rate = inputs['transition_rate']\n    donorm = inputs['donorm']\n    with tf.variable_scope(\"Model\") as scope:\n\n        current_autoenc = autoenc_upsample\n        current_discriminator = discriminator\n        current_generator = build_generator_resnet_9blocks\n\n        mask_a = current_autoenc(images_a, \"g_A_ae\")\n        mask_b = current_autoenc(images_b, \"g_B_ae\")\n        mask_a = tf.concat([mask_a] * 3, axis=3)\n        mask_b = tf.concat([mask_b] * 3, axis=3)\n\n        mask_a_on_a = tf.multiply(images_a, mask_a)\n        mask_b_on_b = tf.multiply(images_b, mask_b)\n\n        prob_real_a_is_real = current_discriminator(images_a, mask_a, transition_rate, donorm, \"d_A\")\n        prob_real_b_is_real = current_discriminator(images_b, mask_b, transition_rate, donorm, \"d_B\")\n\n        fake_images_b_from_g = current_generator(images_a, name=\"g_A\", skip=skip)\n        fake_images_b = tf.multiply(fake_images_b_from_g, mask_a) + tf.multiply(images_a, 1-mask_a)\n\n        fake_images_a_from_g = current_generator(images_b, name=\"g_B\", skip=skip)\n        fake_images_a = tf.multiply(fake_images_a_from_g, mask_b) + tf.multiply(images_b, 1-mask_b)\n        scope.reuse_variables()\n\n        prob_fake_a_is_real = current_discriminator(fake_images_a, mask_b, transition_rate, donorm, \"d_A\")\n        prob_fake_b_is_real = current_discriminator(fake_images_b, mask_a, transition_rate, donorm, \"d_B\")\n\n        mask_acycle = current_autoenc(fake_images_a, \"g_A_ae\")\n        mask_bcycle = current_autoenc(fake_images_b, \"g_B_ae\")\n        mask_bcycle = tf.concat([mask_bcycle] * 3, axis=3)\n        mask_acycle = tf.concat([mask_acycle] * 3, axis=3)\n\n        mask_acycle_on_fakeA = tf.multiply(fake_images_a, mask_acycle)\n        mask_bcycle_on_fakeB = tf.multiply(fake_images_b, mask_bcycle)\n\n        cycle_images_a_from_g = current_generator(fake_images_b, name=\"g_B\", skip=skip)\n        cycle_images_b_from_g = current_generator(fake_images_a, name=\"g_A\", skip=skip)\n\n        cycle_images_a = tf.multiply(cycle_images_a_from_g,\n                                     mask_bcycle) + tf.multiply(fake_images_b, 1 - mask_bcycle)\n\n        cycle_images_b = tf.multiply(cycle_images_b_from_g,\n                                     mask_acycle) + tf.multiply(fake_images_a, 1 - mask_acycle)\n\n        scope.reuse_variables()\n\n        prob_fake_pool_a_is_real = current_discriminator(fake_pool_a, fake_pool_a_mask, transition_rate, donorm, \"d_A\")\n        prob_fake_pool_b_is_real = current_discriminator(fake_pool_b, fake_pool_b_mask, transition_rate, donorm, \"d_B\")\n\n    return {\n        'prob_real_a_is_real': prob_real_a_is_real,\n        'prob_real_b_is_real': prob_real_b_is_real,\n        'prob_fake_a_is_real': prob_fake_a_is_real,\n        'prob_fake_b_is_real': prob_fake_b_is_real,\n        'prob_fake_pool_a_is_real': prob_fake_pool_a_is_real,\n        'prob_fake_pool_b_is_real': prob_fake_pool_b_is_real,\n        'cycle_images_a': cycle_images_a,\n        'cycle_images_b': cycle_images_b,\n        'fake_images_a': fake_images_a,\n        'fake_images_b': fake_images_b,\n        'masked_ims': [mask_a_on_a, mask_b_on_b, mask_acycle_on_fakeA, mask_bcycle_on_fakeB],\n        'masks': [mask_a, mask_b, mask_acycle, mask_bcycle],\n        'masked_gen_ims' : [fake_images_b_from_g, fake_images_a_from_g , cycle_images_a_from_g, cycle_images_b_from_g],\n        'mask_tmp' : mask_a,\n    }\n\ndef autoenc_upsample(inputae, name):\n\n    with tf.variable_scope(name):\n        f = 7\n        ks = 3\n        padding = \"REFLECT\"\n\n        pad_input = tf.pad(inputae, [[0, 0], [ks, ks], [\n            ks, ks], [0, 0]], padding)\n        o_c1 = layers.general_conv2d(\n            pad_input, tf.constant(True, dtype=bool), ngf, f, f, 2, 2, 0.02, name=\"c1\")\n        o_c2 = layers.general_conv2d(\n            o_c1, tf.constant(True, dtype=bool), ngf * 2, ks, ks, 2, 2, 0.02, \"SAME\", \"c2\")\n\n        o_r1 = build_resnet_block_Att(o_c2, ngf * 2, \"r1\", padding)\n\n        size_d1 = o_r1.get_shape().as_list()\n        o_c4 = layers.upsamplingDeconv(o_r1, size=[size_d1[1] * 2, size_d1[2] * 2], is_scale=False, method=1,\n                                   align_corners=False,name= 'up1')\n        # o_c4_pad = tf.pad(o_c4, [[0, 0], [1, 1], [1, 1], [0, 0]], \"REFLECT\", name='padup1')\n        o_c4_end = layers.general_conv2d(o_c4, tf.constant(True, dtype=bool), ngf * 2, (3, 3), (1, 1), padding='VALID', name='c4')\n\n        size_d2 = o_c4_end.get_shape().as_list()\n        o_c5 = layers.upsamplingDeconv(o_c4_end, size=[size_d2[1] * 2, size_d2[2] * 2], is_scale=False, method=1,\n                                       align_corners=False, name='up2')\n        # o_c5_pad = tf.pad(o_c5, [[0, 0], [1, 1], [1, 1], [0, 0]], \"REFLECT\", name='padup2')\n        oc5_end = layers.general_conv2d(o_c5, tf.constant(True, dtype=bool), ngf , (3, 3), (1, 1), padding='VALID', name='c5')\n\n        # o_c6 = tf.pad(oc5_end, [[0, 0], [3, 3], [3, 3], [0, 0]], \"REFLECT\", name='padup3')\n        o_c6_end = layers.general_conv2d(oc5_end, tf.constant(False, dtype=bool),\n                                         1 , (f, f), (1, 1), padding='VALID', name='c6', do_relu=False)\n\n        return tf.nn.sigmoid(o_c6_end,'sigmoid')\n\ndef build_resnet_block(inputres, dim, name=\"resnet\", padding=\"REFLECT\"):\n    \"\"\"build a single block of resnet.\n\n    :param inputres: inputres\n    :param dim: dim\n    :param name: name\n    :param padding: for tensorflow version use REFLECT; for pytorch version use\n     CONSTANT\n    :return: a single block of resnet.\n    \"\"\"\n    with tf.variable_scope(name):\n        out_res = tf.pad(inputres, [[0, 0], [1, 1], [\n            1, 1], [0, 0]], padding)\n        out_res = layers.general_conv2d(\n            out_res, tf.constant(True, dtype=bool), dim, 3, 3, 1, 1, 0.02, \"VALID\", \"c1\")\n        out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)\n        out_res = layers.general_conv2d(\n            out_res, tf.constant(True, dtype=bool), dim, 3, 3, 1, 1, 0.02, \"VALID\", \"c2\", do_relu=False)\n\n        return tf.nn.relu(out_res + inputres)\n\ndef build_resnet_block_Att(inputres, dim, name=\"resnet\", padding=\"REFLECT\"):\n    \"\"\"build a single block of resnet.\n\n    :param inputres: inputres\n    :param dim: dim\n    :param name: name\n    :param padding: for tensorflow version use REFLECT; for pytorch version use\n     CONSTANT\n    :return: a single block of resnet.\n    \"\"\"\n    with tf.variable_scope(name):\n        out_res = tf.pad(inputres, [[0, 0], [1, 1], [\n            1, 1], [0, 0]], padding)\n        out_res = layers.general_conv2d(\n            out_res, tf.constant(True, dtype=bool), dim, 3, 3, 1, 1, 0.02, \"VALID\", \"c1\")\n        out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)\n        out_res = layers.general_conv2d(\n            out_res, tf.constant(True, dtype=bool), dim, 3, 3, 1, 1, 0.02, \"VALID\", \"c2\", do_relu=False)\n\n        return tf.nn.relu(out_res + inputres)\n\ndef build_generator_resnet_9blocks(inputgen, name=\"generator\", skip=False):\n\n    with tf.variable_scope(name):\n        f = 7\n        ks = 3\n        padding = \"CONSTANT\"\n        inputgen = tf.pad(inputgen, [[0, 0], [ks, ks], [\n            ks, ks], [0, 0]], padding)\n\n        o_c1 = layers.general_conv2d(\n            inputgen, tf.constant(True, dtype=bool), ngf, f, f, 1, 1, 0.02, name=\"c1\")\n\n        o_c2 = layers.general_conv2d(\n            o_c1, tf.constant(True, dtype=bool),ngf * 2, ks, ks, 2, 2, 0.02, padding='same', name=\"c2\")\n\n        o_c3 = layers.general_conv2d(\n            o_c2, tf.constant(True, dtype=bool), ngf * 4, ks, ks, 2, 2, 0.02, padding='same', name=\"c3\")\n\n\n        o_r1 = build_resnet_block(o_c3, ngf * 4, \"r1\", padding)\n        o_r2 = build_resnet_block(o_r1, ngf * 4, \"r2\", padding)\n        o_r3 = build_resnet_block(o_r2, ngf * 4, \"r3\", padding)\n        o_r4 = build_resnet_block(o_r3, ngf * 4, \"r4\", padding)\n        o_r5 = build_resnet_block(o_r4, ngf * 4, \"r5\", padding)\n        o_r6 = build_resnet_block(o_r5, ngf * 4, \"r6\", padding)\n        o_r7 = build_resnet_block(o_r6, ngf * 4, \"r7\", padding)\n        o_r8 = build_resnet_block(o_r7, ngf * 4, \"r8\", padding)\n        o_r9 = build_resnet_block(o_r8, ngf * 4, \"r9\", padding)\n\n        o_c4 = layers.general_deconv2d(\n            o_r9, [BATCH_SIZE, 128, 128, ngf * 2], ngf * 2, ks, ks, 2, 2, 0.02,\n            \"SAME\", \"c4\")\n\n        o_c5 = layers.general_deconv2d(\n            o_c4, [BATCH_SIZE, 256, 256, ngf], ngf, ks, ks, 2, 2, 0.02,\n            \"SAME\", \"c5\")\n\n        o_c6 = layers.general_conv2d(o_c5, tf.constant(False, dtype=bool), IMG_CHANNELS, f, f, 1, 1,\n                                     0.02, \"SAME\", \"c6\", do_relu=False)\n\n        if skip is True:\n            out_gen = tf.nn.tanh(inputgen + o_c6, \"t1\")\n        else:\n            out_gen = tf.nn.tanh(o_c6, \"t1\")\n\n        return out_gen\n\ndef discriminator(inputdisc,  mask, transition_rate, donorm,  name=\"discriminator\"):\n\n    with tf.variable_scope(name):\n        mask = tf.cast(tf.greater_equal(mask, transition_rate), tf.float32)\n        inputdisc = tf.multiply(inputdisc, mask)\n        f = 4\n        padw = 2\n        pad_input = tf.pad(inputdisc, [[0, 0], [padw, padw], [\n            padw, padw], [0, 0]], \"CONSTANT\")\n\n        o_c1 = layers.general_conv2d(pad_input, donorm, ndf, f, f, 2, 2,\n                                     0.02, \"VALID\", \"c1\",\n                                     relufactor=0.2)\n\n        pad_o_c1 = tf.pad(o_c1, [[0, 0], [padw, padw], [\n            padw, padw], [0, 0]], \"CONSTANT\")\n\n        o_c2 = layers.general_conv2d(pad_o_c1, donorm, ndf * 2, f, f, 2, 2,\n                                     0.02, \"VALID\", \"c2\",  relufactor=0.2)\n\n        pad_o_c2 = tf.pad(o_c2, [[0, 0], [padw, padw], [\n            padw, padw], [0, 0]], \"CONSTANT\")\n\n        o_c3 = layers.general_conv2d(pad_o_c2, donorm, ndf * 4, f, f, 2, 2,\n                                     0.02, \"VALID\", \"c3\", relufactor=0.2)\n\n        pad_o_c3 = tf.pad(o_c3, [[0, 0], [padw, padw], [\n            padw, padw], [0, 0]], \"CONSTANT\")\n\n        o_c4 = layers.general_conv2d(pad_o_c3, donorm, ndf * 8, f, f, 1, 1,\n                                     0.02, \"VALID\", \"c4\", relufactor=0.2)\n        # o_c4 = tf.multiply(o_c4, mask_4)\n        pad_o_c4 = tf.pad(o_c4, [[0, 0], [padw, padw], [\n            padw, padw], [0, 0]], \"CONSTANT\")\n\n        o_c5 = layers.general_conv2d(\n            pad_o_c4, tf.constant(False, dtype=bool), 1, f, f, 1, 1, 0.02, \"VALID\", \"c5\", do_relu=False)\n\n\n        return o_c5\n"
  },
  {
    "path": "test/__init__.py",
    "content": ""
  },
  {
    "path": "test/evaluate_losses.py",
    "content": "import numpy as np\nimport tensorflow as tf\n\nfrom .. import losses\n\n\ndef test_evaluate_g_losses(sess):\n\n    _LAMBDA_A = 10\n    _LAMBDA_B = 10\n\n    input_a = tf.random_uniform((5, 7), maxval=1)\n    cycle_images_a = input_a + 1\n    input_b = tf.random_uniform((5, 7), maxval=1)\n    cycle_images_b = input_b - 2\n\n    cycle_consistency_loss_a = _LAMBDA_A * losses.cycle_consistency_loss(\n        real_images=input_a, generated_images=cycle_images_a,\n    )\n    cycle_consistency_loss_b = _LAMBDA_B * losses.cycle_consistency_loss(\n        real_images=input_b, generated_images=cycle_images_b,\n    )\n\n    prob_fake_a_is_real = tf.constant([0, 1.0, 0])\n    prob_fake_b_is_real = tf.constant([1.0, 1.0, 0])\n\n    lsgan_loss_a = losses.lsgan_loss_generator(prob_fake_a_is_real)\n    lsgan_loss_b = losses.lsgan_loss_generator(prob_fake_b_is_real)\n\n    assert np.isclose(sess.run(lsgan_loss_a), 0.66666669) and \\\n        np.isclose(sess.run(lsgan_loss_b), 0.3333333) and \\\n        np.isclose(sess.run(cycle_consistency_loss_a), 10) and \\\n        np.isclose(sess.run(cycle_consistency_loss_b), 20)\n\n\ndef test_evaluate_d_losses(sess):\n\n    prob_real_a_is_real = tf.constant([1.0, 1.0, 0])\n    prob_fake_pool_a_is_real = tf.constant([1.0, 0, 0])\n    d_loss_A = losses.lsgan_loss_discriminator(\n        prob_real_is_real=prob_real_a_is_real,\n        prob_fake_is_real=prob_fake_pool_a_is_real)\n    assert np.isclose(sess.run(d_loss_A), 0.3333333)\n"
  },
  {
    "path": "test/evaluate_networks.py",
    "content": "import numpy as np\nimport tensorflow as tf\n\nfrom .. import model\n\n\ndef test_evaluate_g(sess):\n    x_val = np.ones_like(np.random.randn(1, 16, 16, 3)).astype(np.float32)\n    for i in range(16):\n        for j in range(16):\n            for k in range(3):\n                x_val[0][i][j][k] = ((i + j + k) % 2) / 2.0\n    inputs = {\n        'images_a': tf.stack(x_val),\n        'images_b': tf.stack(x_val),\n        'fake_pool_a': tf.zeros([1, 16, 16, 3]),\n        'fake_pool_b': tf.zeros([1, 16, 16, 3]),\n    }\n\n    outputs = model.get_outputs(inputs)\n\n    sess.run(tf.global_variables_initializer())\n    assert sess.run(outputs['fake_images_a'][0][5][7][0]) == 5\n\n\ndef test_evaluate_d(sess):\n    x_val = np.ones_like(np.random.randn(1, 16, 16, 3)).astype(np.float32)\n    for i in range(16):\n        for j in range(16):\n            for k in range(3):\n                x_val[0][i][j][k] = ((i + j + k) % 2) / 2.0\n    inputs = {\n        'images_a': tf.stack(x_val),\n        'images_b': tf.stack(x_val),\n        'fake_pool_a': tf.zeros([1, 16, 16, 3]),\n        'fake_pool_b': tf.zeros([1, 16, 16, 3]),\n    }\n\n    outputs = model.get_outputs(inputs)\n\n    sess.run(tf.global_variables_initializer())\n    assert sess.run(outputs['prob_real_a_is_real'][0][3][3][0]) == 5\n"
  },
  {
    "path": "test/test_losses.py",
    "content": "import numpy as np\nimport tensorflow as tf\n\nfrom .. import losses\n\n\ndef test_cycle_consistency_loss_is_none_with_perfect_fakes(sess):\n    batch_size, height, width, channels = [16, 2, 3, 1]\n\n    tf.set_random_seed(0)\n\n    images = tf.random_uniform((batch_size, height, width, channels), maxval=1)\n\n    loss = losses.cycle_consistency_loss(\n        real_images=images,\n        generated_images=images,\n    )\n\n    assert sess.run(loss) == 0\n\n\ndef test_cycle_consistency_loss_is_positive_with_imperfect_fake_x(sess):\n    batch_size, height, width, channels = [16, 2, 3, 1]\n\n    tf.set_random_seed(0)\n\n    real_images = tf.random_uniform(\n        (batch_size, height, width, channels), maxval=1,\n    )\n    generated_images = real_images + 1\n\n    loss = losses.cycle_consistency_loss(\n        real_images=real_images,\n        generated_images=generated_images,\n    )\n\n    assert sess.run(loss) == 1\n\n\ndef test_lsgan_loss_discrim_is_none_with_perfect_discrimination(sess):\n    batch_size = 100\n    prob_real_is_real = tf.ones((batch_size))\n    prob_fake_is_real = tf.zeros((batch_size))\n    loss = losses.lsgan_loss_discriminator(\n        prob_real_is_real, prob_fake_is_real,\n    )\n    assert sess.run(loss) == 0\n\n\ndef test_lsgan_loss_discrim_is_positive_with_imperfect_discrimination(sess):\n    batch_size = 100\n    prob_real_is_real = tf.ones((batch_size)) * 0.4\n    prob_fake_is_real = tf.ones((batch_size)) * 0.7\n    loss = losses.lsgan_loss_discriminator(\n        prob_real_is_real, prob_fake_is_real,\n    )\n    loss = sess.run(loss)\n\n    np.testing.assert_almost_equal(loss, (0.6 * 0.6 + 0.7 * 0.7) / 2)\n"
  },
  {
    "path": "test/test_model.py",
    "content": "import tensorflow as tf\n\nfrom dl_research.testing import slow\n\nfrom .. import model\n\n# -----------------------------------------------------------------------------\n\n\n@slow\ndef test_output_sizes(sess):\n    images_size = [\n        model.BATCH_SIZE,\n        model.IMG_HEIGHT,\n        model.IMG_WIDTH,\n        model.IMG_CHANNELS,\n    ]\n\n    pool_size = [\n        model.POOL_SIZE,\n        model.IMG_HEIGHT,\n        model.IMG_WIDTH,\n        model.IMG_CHANNELS,\n    ]\n\n    inputs = {\n        'images_a': tf.ones(images_size),\n        'images_b': tf.ones(images_size),\n        'fake_pool_a': tf.ones(pool_size),\n        'fake_pool_b': tf.ones(pool_size),\n    }\n\n    outputs = model.get_outputs(inputs)\n\n    assert outputs['prob_real_a_is_real'].get_shape().as_list() == [\n        model.BATCH_SIZE, 32, 32, 1,\n    ]\n    assert outputs['prob_real_b_is_real'].get_shape().as_list() == [\n        model.BATCH_SIZE, 32, 32, 1,\n    ]\n    assert outputs['prob_fake_a_is_real'].get_shape().as_list() == [\n        model.BATCH_SIZE, 32, 32, 1,\n    ]\n    assert outputs['prob_fake_b_is_real'].get_shape().as_list() == [\n        model.BATCH_SIZE, 32, 32, 1,\n    ]\n    assert outputs['prob_fake_pool_a_is_real'].get_shape().as_list() == [\n        model.POOL_SIZE, 32, 32, 1,\n    ]\n    assert outputs['prob_fake_pool_b_is_real'].get_shape().as_list() == [\n        model.POOL_SIZE, 32, 32, 1,\n    ]\n    assert outputs['cycle_images_a'].get_shape().as_list() == images_size\n    assert outputs['cycle_images_b'].get_shape().as_list() == images_size\n"
  }
]