[
  {
    "path": "CONTRIBUTING.md",
    "content": "# How to Contribute\n\nWe'd love to accept your patches and contributions to this project. There are\njust a few small guidelines you need to follow.\n\n## Contributor License Agreement\n\nContributions to this project must be accompanied by a Contributor License\nAgreement. You (or your employer) retain the copyright to your contribution;\nthis simply gives us permission to use and redistribute your contributions as\npart of the project. Head over to <https://cla.developers.google.com/> to see\nyour current agreements on file or to sign a new one.\n\nYou generally only need to submit a CLA once, so if you've already submitted one\n(even if it was for a different project), you probably don't need to do it\nagain.\n\n## Code reviews\n\nAll submissions, including submissions by project members, require review. We\nuse GitHub pull requests for this purpose. Consult\n[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more\ninformation on using pull requests.\n\n## Community Guidelines\n\nThis project follows\n[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).\n"
  },
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License."
  },
  {
    "path": "README.md",
    "content": "# Neural Rerendering in the Wild\nMoustafa Meshry<sup>1</sup>,\n[Dan B Goldman](http://www.danbgoldman.com/)<sup>2</sup>,\n[Sameh Khamis](http://www.samehkhamis.com/)<sup>2</sup>,\n[Hugues Hoppe](http://hhoppe.com/)<sup>2</sup>,\nRohit Pandey<sup>2</sup>,\n[Noah Snavely](http://www.cs.cornell.edu/~snavely/)<sup>2</sup>,\n[Ricardo Martin-Brualla](http://www.ricardomartinbrualla.com/)<sup>2</sup>.\n\n<sup>1</sup>University of Maryland, College Park &nbsp;&nbsp;&nbsp;&nbsp; <sup>2</sup>Google Inc.\n\nTo appear at CVPR 2019 (Oral). <br><br>\n\n\n<figure class=\"image\">\n  <img align=\"center\" src=\"imgs/teaser_with_caption.jpg\" width=\"500px\">\n</figure>\n\n<!--- ![Teaser figure](https://github.com/MoustafaMeshry/neural_rerendering_in_the_wild/blob/master/imgs/teaser_with_caption.jpg?raw=true | width=450) --->\n\nWe will provide Tensorflow implementation and pretrained models for our paper soon.\n\n[**Paper**](https://arxiv.org/abs/1904.04290) | [**Video**](https://www.youtube.com/watch?v=E1crWQn_kmY) | [**Code**](https://github.com/MoustafaMeshry/neural_rerendering_in_the_wild) | [**Project page**](https://moustafameshry.github.io/neural_rerendering_in_the_wild/)\n\n### Abstract\n\nWe explore total scene capture — recording, modeling, and rerendering a scene under varying appearance such as season and time of day.\nStarting from internet photos of a tourist landmark, we apply traditional 3D reconstruction to register the photos and approximate the scene as a point cloud.\nFor each photo, we render the scene points into a deep framebuffer,\nand train a neural network to learn the mapping of these initial renderings to the actual photos.\nThis rerendering network also takes as input a latent appearance vector and a semantic mask indicating the location of transient objects like pedestrians.\nThe model is evaluated on several datasets of publicly available images spanning a broad range of illumination conditions.\nWe create short videos demonstrating realistic manipulation of the image viewpoint, appearance, and semantic labeling.\nWe also compare results with prior work on scene reconstruction from internet photos.\n\n### Video\n[![Supplementary material video](https://img.youtube.com/vi/E1crWQn_kmY/0.jpg)](https://www.youtube.com/watch?v=E1crWQn_kmY)\n\n\n### Appearance variation\n\nWe capture the appearance of the original images in the left column, and rerender several viewpoints under them. The last column is a detail of the previous one. The top row shows the renderings part of the input to the rerenderer, that exhibit artifacts like incomplete features in the statue, and an inconsistent mix of day and night appearances. Note the hallucinated twilight scene in the sky using the last appearance. Image credits: Flickr users William Warby, Neil Rickards, Rafael Jimenez, acme401 (Creative Commons).\n\n<figure class=\"image\">\n  <img src=\"imgs/app_variation.jpg\" width=\"900px\">\n</figure>\n\n### Appearance interpolation\nFrames from a synthesized camera path that smoothly transitions from the photo on the left to the photo on the right by smoothly interpolating both viewpoint and the latent appearance vectors. Please see the supplementary video. Photo Credits: Allie Caulfield, Tahbepet, Till Westermayer, Elliott Brown (Creative Commons).\n<figure class=\"image\">\n  <img src=\"imgs/app_interpolation.jpg\" width=\"900px\">\n</figure>\n\n### Acknowledgements\nWe thank Gregory Blascovich for his help in conducting the user study, and Johannes Schönberger and True Price for their help generating datasets.\n\n### Run and train instructions\n\nStaged-training consists of three stages:\n\n-   Pretraining the appearance network.\n-   Training the rendering network while fixing the weights for the appearance\n    network.\n-   Finetuning both the appearance and the rendering networks.\n\n### Aligned dataset preprocessing\n\n#### Manual preparation\n\n*   Set a path to a base_dir that contains the source code:\n\n```\nbase_dir=//to/neural_rendering\nmkdir $base_dir\ncd $base_dir\n```\n\n*   We assume the following format for an aligned dataset:\n    * Each training image contains 3 file with the following nameing format:\n        * real image: %04d_reference.png\n        * render color: %04d_color.png\n        * render depth: %04d_depth.png\n*   Set dataset name: e.g.\n```\ndataset_name='trevi3k'  # set to any name\n```\n*   Split the dataset into train and validation sets in two subdirectories:\n    *   $base_dir/datasets/$dataset_name/train\n    *   $base_dir/datasets/$dataset_name/val\n*   Download the DeepLab semantic segmentation model trained on the ADE20K\n    dataset from this link:\n    http://download.tensorflow.org/models/deeplabv3_xception_ade20k_train_2018_05_29.tar.gz\n*   Unzip the downloaded file to: $base_dir/deeplabv3_xception_ade20k_train\n*   Download this [file](https://github.com/MoustafaMeshry/vgg_loss/blob/master/vgg16.py) for an implementation of a vgg-based perceptual loss.\n*   Download trained weights for the vgg network as instructed in this link: https://github.com/machrisaa/tensorflow-vgg\n*   Save the vgg weights to $base_dir/vgg16_weights/vgg16.npy\n\n\n#### Data preprocessing\n\n*   Run the preprocessing pipeline which consists of:\n    *   Filtering out sparse renders.\n    *   Semantic segmentation of ground truth images.\n    *   Exporting the dataset to tfrecord format.\n\n```\n# Run locally\npython tools/dataset_utils.py \\\n--dataset_name=$dataset_name \\\n--dataset_parent_dir=$base_dir/datasets/$dataset_name \\\n--output_dir=$base_dir/datasets/$dataset_name \\\n--xception_frozen_graph_path=$base_dir/deeplabv3_xception_ade20k_train/frozen_inference_graph.pb \\\n--alsologtostderr\n```\n\n### Pretraining the appearance encoder network\n\n```\n# Run locally\npython pretrain_appearance.py \\\n  --dataset_name=$dataset_name \\\n  --train_dir=$base_dir/train_models/$dataset_name-app_pretrain \\\n  --imageset_dir=$base_dir/datasets/$dataset_name/train \\\n  --train_resolution=512 \\\n  --metadata_output_dir=$base_dir/datasets/$dataset_name\n```\n\n### Training the rerendering network with a fixed appearance encoder\n\nSet the dataset_parent_dir variable below to point to the directory containing\nthe generated TFRecords.\n\n```\n# Run locally:\ndataset_parent_dir=$base_dir/datasets/$dataset_name\ntrain_dir=$base_dir/train_models/$dataset_name-staged-fixed_appearance\nload_pretrained_app_encoder=true\nappearance_pretrain_dir=$base_dir/train_models/$dataset_name-app_pretrain\nload_from_another_ckpt=false\nfixed_appearance_train_dir=''\ntrain_app_encoder=false\n\npython neural_rerendering.py \\\n--dataset_name=$dataset_name \\\n--dataset_parent_dir=$dataset_parent_dir \\\n--train_dir=$train_dir \\\n--load_pretrained_app_encoder=$load_pretrained_app_encoder \\\n--appearance_pretrain_dir=$appearance_pretrain_dir \\\n--train_app_encoder=$train_app_encoder \\\n--load_from_another_ckpt=$load_from_another_ckpt \\\n--fixed_appearance_train_dir=$fixed_appearance_train_dir \\\n--total_kimg=4000\n```\n\n### Finetuning the rerendering network and the appearance encoder\n\nSet the fixed_appearance_train_dir to the train directory from the previous\nstep.\n\n```\n# Run locally:\ndataset_parent_dir=$base_dir/datasets/$dataset_name\ntrain_dir=$base_dir/train_models/$dataset_name-staged-finetune_appearance\nload_pretrained_app_encoder=false\nappearance_pretrain_dir=''\nload_from_another_ckpt=true\nfixed_appearance_train_dir=$base_dir/train_models/$dataset_name-staged-fixed_appearance\ntrain_app_encoder=true\n\npython neural_rerendering.py \\\n--dataset_name=$dataset_name \\\n--dataset_parent_dir=$dataset_parent_dir \\\n--train_dir=$train_dir \\\n--load_pretrained_app_encoder=$load_pretrained_app_encoder \\\n--appearance_pretrain_dir=$appearance_pretrain_dir \\\n--train_app_encoder=$train_app_encoder \\\n--load_from_another_ckpt=$load_from_another_ckpt \\\n--fixed_appearance_train_dir=$fixed_appearance_train_dir \\\n--total_kimg=4000\n```\n\n\n### Evaluate model on validation set\n\n```\nexperiment_title=$dataset_name-staged-finetune_appearance\nlocal_train_dir=$base_dir/train_models/$experiment_title\ndataset_parent_dir=$base_dir/datasets/$dataset_name\nval_set_out_dir=$local_train_dir/val_set_output\n\n# Run the model on validation set\necho \"Evaluating the validation set\"\npython neural_rerendering.py \\\n      --train_dir=$local_train_dir \\\n      --dataset_name=$dataset_name \\\n      --dataset_parent_dir=$dataset_parent_dir \\\n      --run_mode='eval_subset' \\\n      --virtual_seq_name='val' \\\n      --output_validation_dir=$val_set_out_dir \\\n      --logtostderr\n# Evaluate quantitative metrics\npython evaluate_quantitative_metrics.py \\\n      --val_set_out_dir=$val_set_out_dir \\\n      --experiment_title=$experiment_title \\\n      --logtostderr\n```\n"
  },
  {
    "path": "data.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom options import FLAGS as opts\nimport functools\nimport glob\nimport numpy as np\nimport os.path as osp\nimport random\nimport tensorflow as tf\n\n\ndef provide_data(dataset_name='', parent_dir='', batch_size=8, subset=None,\n                 max_examples=None, crop_flag=False, crop_size=256, seeds=None,\n                 use_appearance=True, shuffle=128):\n  # Parsing function for each tfrecord example.\n  record_parse_fn = functools.partial(\n      _parser_rendered_dataset, crop_flag=crop_flag, crop_size=crop_size,\n      use_alpha=opts.use_alpha, use_depth=opts.use_depth,\n      use_semantics=opts.use_semantic, seeds=seeds,\n      use_appearance=use_appearance)\n\n  input_dict_var = multi_input_fn_record(\n      record_parse_fn, parent_dir, dataset_name, batch_size,\n      subset=subset, max_examples=max_examples, shuffle=shuffle)\n  return input_dict_var\n\n\ndef _parser_rendered_dataset(\n    serialized_example, crop_flag, crop_size, seeds, use_alpha, use_depth,\n    use_semantics, use_appearance):\n  \"\"\"\n  Parses a single tf.Example into a features dictionary with input tensors.\n  \"\"\"\n  # Structure of features_dict need to match the dictionary structure that was\n  # serialized to a tf.Example\n  features_dict = {'height': tf.FixedLenFeature([], tf.int64),\n                   'width': tf.FixedLenFeature([], tf.int64),\n                   'rendered': tf.FixedLenFeature([], tf.string),\n                   'depth': tf.FixedLenFeature([], tf.string),\n                   'real': tf.FixedLenFeature([], tf.string),\n                   'seg': tf.FixedLenFeature([], tf.string)}\n  features = tf.parse_single_example(serialized_example, features=features_dict)\n  height = tf.cast(features['height'], tf.int32)\n  width = tf.cast(features['width'], tf.int32)\n\n  # Parse the rendered image.\n  rendered = tf.decode_raw(features['rendered'], tf.uint8)\n  rendered = tf.cast(rendered, tf.float32) * (2.0 / 255) - 1.0\n  rendered = tf.reshape(rendered, [height, width, 4])\n  if not use_alpha:\n    rendered = tf.slice(rendered, [0, 0, 0], [height, width, 3])\n  conditional_input = rendered\n\n  # Parse the depth image.\n  if use_depth:\n    depth = tf.decode_raw(features['depth'], tf.uint16)\n    depth = tf.reshape(depth, [height, width, 1])\n    depth = tf.cast(depth, tf.float32) * (2.0 / 255) - 1.0\n    conditional_input = tf.concat([conditional_input, depth], axis=-1)\n\n  # Parse the semantic map.\n  if use_semantics:\n    seg_img = tf.decode_raw(features['seg'], tf.uint8)\n    seg_img = tf.reshape(seg_img, [height, width, 3])\n    seg_img = tf.cast(seg_img, tf.float32) * (2.0 / 255) - 1\n    conditional_input = tf.concat([conditional_input, seg_img], axis=-1)\n\n  # Verify that the parsed input has the correct number of channels.\n  assert conditional_input.shape[-1] == opts.deep_buffer_nc, ('num channels '\n      'in the parsed input doesn\\'t match num input channels specified in '\n      'opts.deep_buffer_nc!')\n\n  # Parse the ground truth image.\n  real = tf.decode_raw(features['real'], tf.uint8)\n  real = tf.cast(real, tf.float32) * (2.0 / 255) - 1.0\n  real = tf.reshape(real, [height, width, 3])\n\n  # Parse the appearance image (if any).\n  appearance_input = []\n  if use_appearance:\n    # Concatenate the deep buffer to the real image.\n    appearance_input = tf.concat([real, conditional_input], axis=-1)\n    # Verify that the parsed input has the correct number of channels.\n    assert appearance_input.shape[-1] == opts.appearance_nc, ('num channels '\n        'in the parsed appearance input doesn\\'t match num input channels '\n        'specified in opts.appearance_nc!')\n\n  # Crop conditional_input and real images, but keep the appearance input\n  # uncropped (learn a one-to-many mapping from appearance to output)\n  if crop_flag:\n    assert crop_size is not None, 'crop_size is not provided!'\n    if isinstance(crop_size, int):\n      crop_size = [crop_size, crop_size]\n    assert len(crop_size) == 2, 'crop_size is either an int or a 2-tuple!'\n\n    # Central crop\n    if seeds is not None and len(seeds) <= 1:\n      conditional_input = tf.image.resize_image_with_crop_or_pad(\n          conditional_input, crop_size[0], crop_size[1])\n      real = tf.image.resize_image_with_crop_or_pad(real, crop_size[0],\n                                                    crop_size[1])\n    else:\n      if not seeds:  # random crops\n        seed = random.randint(0, (1 << 31) - 1)\n      else:  # fixed crops\n        seed_idx = random.randint(0, len(seeds) - 1)\n        seed = seeds[seed_idx]\n      conditional_input = tf.random_crop(\n          conditional_input, crop_size + [opts.deep_buffer_nc], seed=seed)\n      real = tf.random_crop(real, crop_size + [3], seed=seed)\n\n  features = {'conditional_input': conditional_input,\n              'expected_output': real,\n              'peek_input': appearance_input}\n  return features\n\n\ndef multi_input_fn_record(\n    record_parse_fn, parent_dir, tfrecord_basename, batch_size, subset=None,\n    max_examples=None, shuffle=128):\n  \"\"\"Creates a Dataset pipeline for tfrecord files.\n\n  Returns:\n    Dataset iterator.\n  \"\"\"\n  subset_suffix = '*_%s.tfrecord' % subset if subset else '*.tfrecord'\n  input_pattern = osp.join(parent_dir, tfrecord_basename + subset_suffix)\n  filenames = sorted(glob.glob(input_pattern))\n  assert len(filenames) > 0, ('Error! input pattern \"%s\" didn\\'t match any '\n                              'files' % input_pattern)\n  dataset = tf.data.TFRecordDataset(filenames)\n  if shuffle == 0:  # keep input deterministic\n    # use one thread to get deterministic results\n    dataset = dataset.map(record_parse_fn, num_parallel_calls=None)\n  else:\n    dataset = dataset.repeat()  # Repeat indefinitely.\n    dataset = dataset.map(record_parse_fn,\n                          num_parallel_calls=max(4, batch_size // 4))\n    if opts.training_pipeline == 'drit':\n      dataset1 = dataset.shuffle(shuffle)\n      dataset2 = dataset.shuffle(shuffle)\n      paired_dataset = tf.data.Dataset.zip((dataset1, dataset2))\n\n      def _join_paired_dataset(features_a, features_b):\n        features_a['conditional_input_2'] = features_b['conditional_input']\n        features_a['expected_output_2'] = features_b['expected_output']\n        return features_a\n\n      joined_dataset = paired_dataset.map(_join_paired_dataset)\n      dataset = joined_dataset\n    else:\n      dataset = dataset.shuffle(shuffle)\n  if max_examples is not None:\n    dataset = dataset.take(max_examples)\n  dataset = dataset.batch(batch_size)\n  if shuffle > 0:  # input is not deterministic\n    dataset = dataset.prefetch(4)  # Prefetch a few batches.\n  return dataset.make_one_shot_iterator().get_next()\n"
  },
  {
    "path": "dataset_utils.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom PIL import Image\nfrom absl import app\nfrom absl import flags\nfrom options import FLAGS as opts\nimport cv2\nimport data\nimport functools\nimport glob\nimport numpy as np\nimport os\nimport os.path as osp\nimport shutil\nimport six\nimport tensorflow as tf\nimport segment_dataset as segment_utils\nimport utils\n\nFLAGS = flags.FLAGS\nflags.DEFINE_string('output_dir', None, 'Directory to save exported tfrecords.')\nflags.DEFINE_string('xception_frozen_graph_path', None,\n                    'Path to the deeplab xception model frozen graph')\n\n\nclass AlignedRenderedDataset(object):\n  def __init__(self, rendered_filepattern, use_semantic_map=True):\n    \"\"\"\n    Args:\n      rendered_filepattern: string, path filepattern to 3D rendered images (\n        assumes filenames are '/path/to/dataset/%d_color.png')\n      use_semantic_map: bool, include semantic maps. in the TFRecord\n    \"\"\"\n    self.filenames = sorted(glob.glob(rendered_filepattern))\n    assert len(self.filenames) > 0, ('input %s didn\\'t match any files!' %\n                                     rendered_filepattern)\n    self.iter_idx = 0\n    self.use_semantic_map = use_semantic_map\n\n  def __iter__(self):\n    return self\n\n  def __next__(self):\n    return self.next()\n\n  def next(self):\n    if self.iter_idx < len(self.filenames):\n      rendered_img_name = self.filenames[self.iter_idx]\n      basename = rendered_img_name[:-9]  # remove the 'color.png' suffix\n      ref_img_name = basename + 'reference.png'\n      depth_img_name = basename + 'depth.png'\n      # Read the 3D rendered image\n      img_rendered = cv2.imread(rendered_img_name, cv2.IMREAD_UNCHANGED)\n      # Change BGR (default cv2 format) to RGB\n      img_rendered = img_rendered[:, :, [2,1,0,3]]  # it has a 4th alpha channel\n      # Read the depth image\n      img_depth = cv2.imread(depth_img_name, cv2.IMREAD_UNCHANGED)\n      # Workaround as some depth images are read with a different data type!\n      img_depth = img_depth.astype(np.uint16)\n      # Read reference image if exists, otherwise replace with a zero image.\n      if osp.exists(ref_img_name):\n        img_ref = cv2.imread(ref_img_name)\n        img_ref = img_ref[:, :, ::-1]  # Change BGR to RGB format.\n      else:  # use a dummy 3-channel zero image as a placeholder\n        print('Warning: no reference image found! Using a dummy placeholder!')\n        img_height, img_width = img_depth.shape\n        img_ref = np.zeros((img_height, img_width, 3), dtype=np.uint8)\n\n      if self.use_semantic_map:\n        semantic_seg_img_name = basename + 'seg_rgb.png'\n        img_seg = cv2.imread(semantic_seg_img_name)\n        img_seg = img_seg[:, :, ::-1]  # Change from BGR to RGB\n        if img_seg.shape[0] == 512 and img_seg.shape[1] == 512:\n          img_ref = utils.get_central_crop(img_ref)\n          img_rendered = utils.get_central_crop(img_rendered)\n          img_depth = utils.get_central_crop(img_depth)\n\n      img_shape = img_depth.shape\n      assert img_seg.shape == (img_shape + (3,)), 'error in seg image %s %s' % (\n        basename, str(img_seg.shape))\n      assert img_ref.shape == (img_shape + (3,)), 'error in ref image %s %s' % (\n        basename, str(img_ref.shape))\n      assert img_rendered.shape == (img_shape + (4,)), ('error in rendered '\n        'image %s %s' % (basename, str(img_rendered.shape)))\n      assert len(img_depth.shape) == 2, 'error in depth image %s %s' % (\n        basename, str(img_depth.shape))\n\n      raw_example = dict()\n      raw_example['height'] = img_ref.shape[0]\n      raw_example['width'] = img_ref.shape[1]\n      raw_example['rendered'] = img_rendered.tostring()\n      raw_example['depth'] = img_depth.tostring()\n      raw_example['real'] = img_ref.tostring()\n      if self.use_semantic_map:\n        raw_example['seg'] = img_seg.tostring()\n      self.iter_idx += 1\n      return raw_example\n    else:\n      raise StopIteration()\n\n\ndef filter_out_sparse_renders(dataset_dir, splits, ratio_threshold=0.15):\n  print('Filtering %s' % dataset_dir)\n  if splits is None:\n    imgs_dirs = [dataset_dir]\n  else:\n    imgs_dirs = [osp.join(dataset_dir, split) for split in splits]\n  \n  filtered_images = []\n  total_images = 0\n  sum_density = 0\n  for cur_dir in imgs_dirs:\n    filtered_dir = osp.join(cur_dir, 'sparse_renders')\n    if not osp.exists(filtered_dir):\n      os.makedirs(filtered_dir)\n    imgs_file_pattern = osp.join(cur_dir, '*_color.png')\n    images_path = sorted(glob.glob(imgs_file_pattern))\n    print('Processing %d files' % len(images_path))\n    total_images += len(images_path)\n    for ii, img_path in enumerate(images_path):\n      img = np.array(Image.open(img_path))\n      aggregate = np.squeeze(np.sum(img, axis=2))\n      height, width = aggregate.shape\n      mask = aggregate > 0\n      density = np.sum(mask) * 1. / (height * width)\n      sum_density += density\n      if density <= ratio_threshold:\n        parent, basename = osp.split(img_path)\n        basename = basename[:-10]  # remove the '_color.png' suffix\n        srcs = sorted(glob.glob(osp.join(parent, basename + '_*')))\n        dest = unicode(filtered_dir + '/.')\n        for src in srcs:\n          shutil.move(src, dest)\n        filtered_images.append(basename)\n        print('filtered fie %d: %s with a desnity of %.3f' % (ii, basename,\n                                                              density))\n    print('Filtered %d/%d images' % (len(filtered_images), total_images))\n    print('Mean desnity = %.4f' % (sum_density / total_images))\n\n\ndef _to_example(dictionary):\n  \"\"\"Helper: build tf.Example from (string -> int/float/str list) dictionary.\"\"\"\n  features = {}\n  for (k, v) in six.iteritems(dictionary):\n    if isinstance(v, six.integer_types):\n      features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=[v]))\n    elif isinstance(v, float):\n      features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=[v]))\n    elif isinstance(v, six.string_types):\n      features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[v]))\n    elif isinstance(v, bytes):\n      features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[v]))\n    else:\n      raise ValueError(\"Value for %s is not a recognized type; v: %s type: %s\" %\n                       (k, str(v[0]), str(type(v[0]))))\n\n  return tf.train.Example(features=tf.train.Features(feature=features))\n\n\ndef _generate_tfrecord_dataset(generator,\n                              output_name,\n                              output_dir):\n  \"\"\"Convert a dataset into TFRecord format.\"\"\"\n  output_filename = os.path.join(output_dir, output_name)\n  output_file = os.path.join(output_dir, output_filename)\n  tf.logging.info(\"Writing TFRecords to file %s\", output_file)\n  writer = tf.python_io.TFRecordWriter(output_file)\n\n  counter = 0\n  for case in generator:\n    if counter % 100 == 0:\n      print('Generating case %d for %s.' % (counter, output_name))\n    counter += 1\n    example = _to_example(case)\n    writer.write(example.SerializeToString())\n\n  writer.close()\n  return output_file\n\n\ndef export_aligned_dataset_to_tfrecord(\n    dataset_dir, output_dir, output_basename, splits,\n    xception_frozen_graph_path):\n\n  # Step 1: filter out sparse renders\n  filter_out_sparse_renders(dataset_dir, splits, 0.15)\n\n  # Step 2: generate semantic segmentation masks\n  segment_utils.segment_and_color_dataset(\n      dataset_dir, xception_frozen_graph_path, splits)\n\n  # Step 3: export dataset to TFRecord\n  if splits is None:\n    input_filepattern = osp.join(dataset_dir, '*_color.png')\n    dataset_iter = AlignedRenderedDataset(input_filepattern)\n    output_name = output_basename + '.tfrecord'\n    _generate_tfrecord_dataset(dataset_iter, output_name, output_dir)\n  else:\n    for split in splits:\n      input_filepattern = osp.join(dataset_dir, split, '*_color.png')\n      dataset_iter = AlignedRenderedDataset(input_filepattern)\n      output_name = '%s_%s.tfrecord' % (output_basename, split)\n      _generate_tfrecord_dataset(dataset_iter, output_name, output_dir)\n\n\ndef main(argv):\n  # Read input flags\n  dataset_name = opts.dataset_name\n  dataset_parent_dir = opts.dataset_parent_dir\n  output_dir = FLAGS.output_dir\n  xception_frozen_graph_path = FLAGS.xception_frozen_graph_path\n  splits = ['train', 'val']\n  # Run the preprocessing pipeline\n  export_aligned_dataset_to_tfrecord(\n    dataset_parent_dir, output_dir, dataset_name, splits,\n    xception_frozen_graph_path)\n\n\nif __name__ == '__main__':\n  app.run(main)\n"
  },
  {
    "path": "evaluate_quantitative_metrics.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom PIL import Image\nfrom absl import app\nfrom absl import flags\nimport functools\nimport glob\nimport numpy as np\nimport os\nimport os.path as osp\nimport skimage.measure\nimport tensorflow as tf\nimport utils\n\nFLAGS = flags.FLAGS\nflags.DEFINE_string('val_set_out_dir', None,\n                    'Output directory with concatenated fake and real images.')\nflags.DEFINE_string('experiment_title', 'experiment',\n                    'Name for the experiment to evaluate')\n\n\ndef _extract_real_and_fake_from_concatenated_output(val_set_out_dir):\n      out_dir = osp.join(val_set_out_dir, 'fake')\n      gt_dir = osp.join(val_set_out_dir, 'real')\n      if not osp.exists(out_dir):\n        os.makedirs(out_dir)\n      if not osp.exists(gt_dir):\n        os.makedirs(gt_dir)\n      imgs_pattern = osp.join(val_set_out_dir, '*.png')\n      imgs_paths = sorted(glob.glob(imgs_pattern))\n      print('Separating %d images in %s.' % (len(imgs_paths), val_set_out_dir))\n      for img_path in imgs_paths:\n        basename = osp.basename(img_path)[:-4]  # remove the '.png' suffix\n        img = np.array(Image.open(img_path))\n        img_res = 512\n        fake = img[:, -2*img_res:-img_res, :]\n        real = img[:, -img_res:, :]\n        fake_path = osp.join(out_dir, '%s_fake.png' % basename)\n        real_path = osp.join(gt_dir, '%s_real.png' % basename)\n        Image.fromarray(fake).save(fake_path)\n        Image.fromarray(real).save(real_path)\n\n\ndef compute_l1_loss_metric(image_set1_paths, image_set2_paths):\n  assert len(image_set1_paths) == len(image_set2_paths)\n  assert len(image_set1_paths) > 0\n  print('Evaluating L1 loss for %d pairs' % len(image_set1_paths))\n\n  total_loss = 0.\n  for ii, (img1_path, img2_path) in enumerate(zip(image_set1_paths,\n                                                  image_set2_paths)):\n    img1_in_ar = np.array(Image.open(img1_path), dtype=np.float32)\n    img1_in_ar = utils.crop_to_multiple(img1_in_ar)\n\n    img2_in_ar = np.array(Image.open(img2_path), dtype=np.float32)\n    img2_in_ar = utils.crop_to_multiple(img2_in_ar)\n\n    loss_l1 = np.mean(np.abs(img1_in_ar - img2_in_ar))\n    total_loss += loss_l1\n\n  return total_loss / len(image_set1_paths)\n\n\ndef compute_psnr_loss_metric(image_set1_paths, image_set2_paths):\n  assert len(image_set1_paths) == len(image_set2_paths)\n  assert len(image_set1_paths) > 0\n  print('Evaluating PSNR loss for %d pairs' % len(image_set1_paths))\n\n  total_loss = 0.\n  for ii, (img1_path, img2_path) in enumerate(zip(image_set1_paths,\n                                                  image_set2_paths)):\n    img1_in_ar = np.array(Image.open(img1_path))\n    img1_in_ar = utils.crop_to_multiple(img1_in_ar)\n\n    img2_in_ar = np.array(Image.open(img2_path))\n    img2_in_ar = utils.crop_to_multiple(img2_in_ar)\n\n    loss_psnr = skimage.measure.compare_psnr(img1_in_ar, img2_in_ar)\n    total_loss += loss_psnr\n\n  return total_loss / len(image_set1_paths)\n\n\ndef evaluate_experiment(val_set_out_dir, title='experiment',\n                        metrics=['psnr', 'l1']):\n\n  out_dir = osp.join(val_set_out_dir, 'fake')\n  gt_dir = osp.join(val_set_out_dir, 'real')\n  _extract_real_and_fake_from_concatenated_output(val_set_out_dir)\n  input_pattern1 = osp.join(gt_dir, '*.png')\n  input_pattern2 = osp.join(out_dir, '*.png')\n  set1 = sorted(glob.glob(input_pattern1))\n  set2 = sorted(glob.glob(input_pattern2))\n  for metric in metrics:\n    if metric == 'l1':\n      mean_loss = compute_l1_loss_metric(set1, set2)\n    elif metric == 'psnr':\n      mean_loss = compute_psnr_loss_metric(set1, set2)\n    print('*** mean %s loss for %s = %f' % (metric, title, mean_loss))\n\n\ndef main(argv):\n  evaluate_experiment(FLAGS.val_set_out_dir, title=FLAGS.experiment_title,\n                      metrics=['psnr', 'l1'])\n\n\nif __name__ == '__main__':\n  app.run(main)\n"
  },
  {
    "path": "layers.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools\nimport numpy as np\nimport tensorflow as tf\n\n\nclass LayerInstanceNorm(object):\n\n  def __init__(self, scope_suffix='instance_norm'):\n    curr_scope = tf.get_variable_scope().name\n    self._scope = curr_scope + '/' + scope_suffix\n\n  def __call__(self, x):\n    with tf.variable_scope(self._scope, reuse=tf.AUTO_REUSE):\n      return tf.contrib.layers.instance_norm(\n        x, epsilon=1e-05, center=True, scale=True)\n\n\ndef layer_norm(x, scope='layer_norm'):\n  return tf.contrib.layers.layer_norm(x, center=True, scale=True)\n\n\ndef pixel_norm(x):\n  \"\"\"Pixel normalization.\n\n  Args:\n    x: 4D image tensor in B01C format.\n\n  Returns:\n    4D tensor with pixel normalized channels.\n  \"\"\"\n  return x * tf.rsqrt(tf.reduce_mean(tf.square(x), [-1], keepdims=True) + 1e-8)\n\n\ndef global_avg_pooling(x):\n  return tf.reduce_mean(x, axis=[1, 2], keepdims=True)\n\n\nclass FullyConnected(object):\n\n  def __init__(self, n_out_units, scope_suffix='FC'):\n    weight_init = tf.random_normal_initializer(mean=0., stddev=0.02)\n    weight_regularizer = tf.contrib.layers.l2_regularizer(scale=0.0001)\n\n    curr_scope = tf.get_variable_scope().name\n    self._scope = curr_scope + '/' + scope_suffix\n    self.fc_layer = functools.partial(\n      tf.layers.dense, units=n_out_units, kernel_initializer=weight_init,\n      kernel_regularizer=weight_regularizer, use_bias=True)\n\n  def __call__(self, x):\n    with tf.variable_scope(self._scope, reuse=tf.AUTO_REUSE):\n      return self.fc_layer(x)\n\n\ndef init_he_scale(shape, slope=1.0):\n  \"\"\"He neural network random normal scaling for initialization.\n\n  Args:\n    shape: list of the dimensions of the tensor.\n    slope: float, slope of the ReLu following the layer.\n\n  Returns:\n    a float, He's standard deviation.\n  \"\"\"\n  fan_in = np.prod(shape[:-1])\n  return np.sqrt(2. / ((1. + slope**2) * fan_in))\n\n\nclass LayerConv(object):\n  \"\"\"Convolution layer with support for equalized learning.\"\"\"\n\n  def __init__(self,\n               name,\n               w,\n               n,\n               stride,\n               padding='SAME',\n               use_scaling=False,\n               relu_slope=1.):\n    \"\"\"Layer constructor.\n\n    Args:\n      name: string, layer name.\n      w: int or 2-tuple, width of the convolution kernel.\n      n: 2-tuple of ints, input and output channel depths.\n      stride: int or 2-tuple, stride for the convolution kernel.\n      padding: string, the padding method. {SAME, VALID, REFLECT}.\n      use_scaling: bool, whether to use weight norm and scaling.\n      relu_slope: float, the slope of the ReLu following the layer.\n    \"\"\"\n    assert padding in ['SAME', 'VALID', 'REFLECT'], 'Error: unsupported padding'\n    self._padding = padding\n    with tf.variable_scope(name):\n      if isinstance(stride, int):\n        stride = [1, stride, stride, 1]\n      else:\n        assert len(stride) == 0, \"stride is either an int or a 2-tuple\"\n        stride = [1, stride[0], stride[1], 1]\n      if isinstance(w, int):\n        w = [w, w]\n      self.w = w\n      shape = [w[0], w[1], n[0], n[1]]\n      init_scale, pre_scale = init_he_scale(shape, relu_slope), 1.\n      if use_scaling:\n        init_scale, pre_scale = pre_scale, init_scale\n      self._stride = stride\n      self._pre_scale = pre_scale\n      self._weight = tf.get_variable(\n          'weight',\n          shape=shape,\n          initializer=tf.random_normal_initializer(stddev=init_scale))\n      self._bias = tf.get_variable(\n          'bias', shape=[n[1]], initializer=tf.zeros_initializer)\n\n  def __call__(self, x):\n    \"\"\"Apply layer to tensor x.\"\"\"\n    if self._padding != 'REFLECT':\n      padding = self._padding\n    else:\n      padding = 'VALID'\n      pad_top = self.w[0] // 2\n      pad_left = self.w[1] // 2\n      if (self.w[0] - self._stride[1]) % 2 == 0:\n        pad_bottom = pad_top\n      else:\n        pad_bottom = self.w[0] - self._stride[1] - pad_top\n      if (self.w[1] - self._stride[2]) % 2 == 0:\n        pad_right = pad_left\n      else:\n        pad_right = self.w[1] - self._stride[2] - pad_left\n      x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right],\n                     [0, 0]], mode='REFLECT')\n    y = tf.nn.conv2d(x, self._weight, strides=self._stride, padding=padding)\n    return self._pre_scale * y + self._bias\n\n\nclass LayerTransposedConv(object):\n  \"\"\"Convolution layer with support for equalized learning.\"\"\"\n\n  def __init__(self,\n               name,\n               w,\n               n,\n               stride,\n               padding='SAME',\n               use_scaling=False,\n               relu_slope=1.):\n    \"\"\"Layer constructor.\n\n    Args:\n      name: string, layer name.\n      w: int or 2-tuple, width of the convolution kernel.\n      n: 2-tuple int, [n_in_channels, n_out_channels]\n      stride: int or 2-tuple, stride for the convolution kernel.\n      padding: string, the padding method {SAME, VALID, REFLECT}.\n      use_scaling: bool, whether to use weight norm and scaling.\n      relu_slope: float, the slope of the ReLu following the layer.\n    \"\"\"\n    assert padding in ['SAME'], 'Error: unsupported padding for transposed conv'\n    if isinstance(stride, int):\n      stride = [1, stride, stride, 1]\n    else:\n      assert len(stride) == 2, \"stride is either an int or a 2-tuple\"\n      stride = [1, stride[0], stride[1], 1]\n    if isinstance(w, int):\n      w = [w, w]\n    self.padding = padding\n    self.nc_in, self.nc_out = n\n    self.stride = stride\n    with tf.variable_scope(name):\n      kernel_shape = [w[0], w[1], self.nc_out, self.nc_in]\n      init_scale, pre_scale = init_he_scale(kernel_shape, relu_slope), 1.\n      if use_scaling:\n        init_scale, pre_scale = pre_scale, init_scale\n      self._pre_scale = pre_scale\n      self._weight = tf.get_variable(\n          'weight',\n          shape=kernel_shape,\n          initializer=tf.random_normal_initializer(stddev=init_scale))\n      self._bias = tf.get_variable(\n          'bias', shape=[self.nc_out], initializer=tf.zeros_initializer)\n\n  def __call__(self, x):\n    \"\"\"Apply layer to tensor x.\"\"\"\n    x_shape = x.get_shape().as_list()\n    batch_size = tf.shape(x)[0]\n    stride_x, stride_y = self.stride[1], self.stride[2]\n    output_shape = tf.stack([\n      batch_size, x_shape[1] * stride_x, x_shape[2] * stride_y, self.nc_out])\n    y = tf.nn.conv2d_transpose(\n      x, filter=self._weight, output_shape=output_shape, strides=self.stride,\n      padding=self.padding)\n    return self._pre_scale * y + self._bias\n\n\nclass ResBlock(object):\n  def __init__(self,\n               name,\n               nc,\n               norm_layer_constructor,\n               activation,\n               padding='SAME',\n               use_scaling=False,\n               relu_slope=1.):\n    \"\"\"Layer constructor.\"\"\"\n    self.name = name\n    conv2d = functools.partial(\n        LayerConv, w=3, n=[nc, nc], stride=1, padding=padding,\n        use_scaling=use_scaling, relu_slope=relu_slope)\n    self.blocks = []\n    with tf.variable_scope(self.name):\n      with tf.variable_scope('res0'):\n        self.blocks.append(\n          LayerPipe([\n            conv2d('res0_conv'),\n            norm_layer_constructor('res0_norm'),\n            activation\n          ])\n        )\n      with tf.variable_scope('res1'):\n        self.blocks.append(\n          LayerPipe([\n            conv2d('res1_conv'),\n            norm_layer_constructor('res1_norm')\n          ])\n        )\n\n  def __call__(self, x_init):\n    \"\"\"Apply layer to tensor x.\"\"\"\n    x = x_init\n    for f in self.blocks:\n      x = f(x)\n    return x + x_init\n\n\nclass BasicBlock(object):\n  def __init__(self,\n               name,\n               n,\n               activation=functools.partial(tf.nn.leaky_relu, alpha=0.2),\n               padding='SAME',\n               use_scaling=True,\n               relu_slope=1.):\n    \"\"\"Layer constructor.\"\"\"\n    self.name = name\n    conv2d = functools.partial(\n        LayerConv, stride=1, padding=padding,\n        use_scaling=use_scaling, relu_slope=relu_slope)\n    avg_pool = functools.partial(downscale, n=2)\n    nc_in, nc_out = n  # n is a 2-tuple\n    with tf.variable_scope(self.name):\n      self.path1_blocks = []\n      with tf.variable_scope('bb_path1'):\n        self.path1_blocks.append(\n          LayerPipe([\n            activation,\n            conv2d('bb_conv0', w=3, n=[nc_in, nc_out]),\n            activation,\n            conv2d('bb_conv1', w=3, n=[nc_out, nc_out]),\n            downscale\n          ])\n        )\n\n      self.path2_blocks = []\n      with tf.variable_scope('bb_path2'):\n        self.path2_blocks.append(\n          LayerPipe([\n            downscale,\n            conv2d('path2_conv', w=1, n=[nc_in, nc_out])\n          ])\n        )\n\n  def __call__(self, x_init):\n    \"\"\"Apply layer to tensor x.\"\"\"\n    x1 = x_init\n    x2 = x_init\n    for f in self.path1_blocks:\n      x1 = f(x1)\n    for f in self.path2_blocks:\n      x2 = f(x2)\n    return x1 + x2\n\n\nclass LayerDense(object):\n  \"\"\"Dense layer with a non-linearity.\"\"\"\n\n  def __init__(self, name, n, use_scaling=False, relu_slope=1.):\n    \"\"\"Layer constructor.\n\n    Args:\n      name: string, layer name.\n      n: 2-tuple of ints, input and output widths.\n      use_scaling: bool, whether to use weight norm and scaling.\n      relu_slope: float, the slope of the ReLu following the layer.\n    \"\"\"\n    with tf.variable_scope(name):\n      init_scale, pre_scale = init_he_scale(n, relu_slope), 1.\n      if use_scaling:\n        init_scale, pre_scale = pre_scale, init_scale\n      self._pre_scale = pre_scale\n      self._weight = tf.get_variable(\n          'weight',\n          shape=n,\n          initializer=tf.random_normal_initializer(stddev=init_scale))\n      self._bias = tf.get_variable(\n          'bias', shape=[n[1]], initializer=tf.zeros_initializer)\n\n  def __call__(self, x):\n    \"\"\"Apply layer to tensor x.\"\"\"\n    return self._pre_scale * tf.matmul(x, self._weight) + self._bias\n\n\nclass LayerPipe(object):\n  \"\"\"Pipe a sequence of functions.\"\"\"\n\n  def __init__(self, functions):\n    \"\"\"Layer constructor.\n\n    Args:\n      functions: list, functions to pipe.\n    \"\"\"\n    self._functions = tuple(functions)\n\n  def __call__(self, x, **kwargs):\n    \"\"\"Apply pipe to tensor x and return result.\"\"\"\n    del kwargs\n    for f in self._functions:\n      x = f(x)\n    return x\n\n\ndef downscale(x, n=2):\n  \"\"\"Box downscaling.\n\n  Args:\n    x: 4D image tensor.\n    n: integer scale (must be a power of 2).\n\n  Returns:\n    4D tensor of images down scaled by a factor n.\n  \"\"\"\n  if n == 1:\n    return x\n  return tf.nn.avg_pool(x, [1, n, n, 1], [1, n, n, 1], 'VALID')\n\n\ndef upscale(x, n):\n  \"\"\"Box upscaling (also called nearest neighbors).\n\n  Args:\n    x: 4D image tensor in B01C format.\n    n: integer scale (must be a power of 2).\n\n  Returns:\n    4D tensor of images up scaled by a factor n.\n  \"\"\"\n  if n == 1:\n    return x\n  x_shape = tf.shape(x)\n  height, width = x_shape[1], x_shape[2]\n  return tf.image.resize_nearest_neighbor(x, [n * height, n * width])\n\n\ndef tile_and_concatenate(x, z, n_z):\n  z = tf.reshape(z, shape=[-1, 1, 1, n_z])\n  z = tf.tile(z, [1, tf.shape(x)[1], tf.shape(x)[2], 1])\n  x = tf.concat([x, z], axis=-1)\n  return x\n\n\ndef minibatch_mean_variance(x):\n  \"\"\"Computes the variance average.\n\n  This is used by the discriminator as a form of batch discrimination.\n\n  Args:\n    x: nD tensor for which to compute variance average.\n\n  Returns:\n    a scalar, the mean variance of variable x.\n  \"\"\"\n  mean = tf.reduce_mean(x, 0, keepdims=True)\n  vals = tf.sqrt(tf.reduce_mean(tf.squared_difference(x, mean), 0) + 1e-8)\n  vals = tf.reduce_mean(vals)\n  return vals\n\n\ndef scalar_concat(x, scalar):\n  \"\"\"Concatenate a scalar to a 4D tensor as an extra channel.\n\n  Args:\n    x: 4D image tensor in B01C format.\n    scalar: a scalar to concatenate to the tensor.\n\n  Returns:\n    a 4D tensor with one extra channel containing the value scalar at\n     every position.\n  \"\"\"\n  s = tf.shape(x)\n  return tf.concat([x, tf.ones([s[0], s[1], s[2], 1]) * scalar], axis=3)\n"
  },
  {
    "path": "losses.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom options import FLAGS as opts\nimport layers\nimport os.path as osp\nimport tensorflow as tf\nimport vgg16\n\n\ndef gradient_penalty_loss(y_xy, xy, iwass_target=1, iwass_lambda=10):\n  grad = tf.gradients(tf.reduce_sum(y_xy), [xy])[0]\n  grad_norm = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1, 2, 3]) + 1e-8)\n  loss_gp = tf.reduce_mean(\n      tf.square(grad_norm - iwass_target)) * iwass_lambda / iwass_target**2\n  return loss_gp\n\n\ndef KL_loss(mean, logvar):\n  loss = 0.5 * tf.reduce_sum(tf.square(mean) + tf.exp(logvar) - 1. - logvar,\n                             axis=-1)\n  return tf.reduce_sum(loss)  # just to match DRIT implementation\n\n\ndef l2_regularize(x):\n  return tf.reduce_mean(tf.square(x))\n\n\ndef L1_loss(x, y):\n  return tf.reduce_mean(tf.abs(x - y))\n\n\nclass PerceptualLoss:\n  def __init__(self, x, y, image_shape, layers, w_layers, w_act=0.1):\n    \"\"\"\n    Builds vgg16 network and computes the perceptual loss.\n    \"\"\"\n    assert len(image_shape) == 3 and image_shape[-1] == 3\n    assert osp.exists(opts.vgg16_path), 'Cannot find %s' % opts.vgg16_path\n\n    self.w_act = w_act\n    self.vgg_layers = layers\n    self.w_layers = w_layers\n    batch_shape = [None] + image_shape  # [None, H, W, 3]\n\n    vgg_net = vgg16.Vgg16(opts.vgg16_path)\n    self.x_acts = vgg_net.get_vgg_activations(x, layers)\n    self.y_acts = vgg_net.get_vgg_activations(y, layers)\n    loss = 0\n    for w, act1, act2 in zip(self.w_layers, self.x_acts, self.y_acts):\n      loss += w * tf.reduce_mean(tf.square(self.w_act * (act1 - act2)))\n    self.loss = loss\n\n  def __call__(self):\n    return self.loss\n\n\ndef lsgan_appearance_E_loss(disc_response):\n  disc_response = tf.squeeze(disc_response)\n  gt_label = 0.5\n  loss = tf.reduce_mean(tf.square(disc_response - gt_label))\n  return loss\n\n\ndef lsgan_loss(disc_response, is_real):\n  gt_label = 1 if is_real else 0\n  disc_response = tf.squeeze(disc_response)\n  # The following works for both regular and patchGAN discriminators\n  loss = tf.reduce_mean(tf.square(disc_response - gt_label))\n  return loss\n\n\ndef multiscale_discriminator_loss(Ds_responses, is_real):\n  num_D = len(Ds_responses)\n  loss = 0\n  for i in range(num_D):\n    curr_response = Ds_responses[i][-1][-1]\n    loss += lsgan_loss(curr_response, is_real)\n  return loss\n"
  },
  {
    "path": "networks.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom options import FLAGS as opts\nimport functools\nimport layers\nimport tensorflow as tf\n\n\nclass RenderingModel(object):\n\n  def __init__(self, model_name, use_appearance=True):\n\n    if model_name == 'pggan':\n      self._model = ModelPGGAN(use_appearance)\n    else:\n      raise ValueError('Model %s not implemented!' % model_name)\n\n  def __call__(self, x_in, z_app=None):\n    return self._model(x_in, z_app)\n\n  def get_appearance_encoder(self):\n    return self._model._appearance_encoder\n\n  def get_generator(self):\n    return self._model._generator\n\n  def get_content_encoder(self):\n    return self._model._content_encoder\n\n\n# \"Progressive Growing of GANs (PGGAN)\"-inspired architecture. Implementation is\n# based on the implementation details in their paper, but code is not taken from\n# the authors' released code.\n# Main changes are:\n#  - conditional GAN setup by introducting an encoder + skip connections.\n#  - no progressive growing during training.\nclass ModelPGGAN(RenderingModel):\n\n  def __init__(self, use_appearance=True):\n    self._use_appearance = use_appearance\n    self._content_encoder = None\n    self._generator = GeneratorPGGAN(appearance_vec_size=opts.app_vector_size)\n    if use_appearance:\n      self._appearance_encoder = DRITAppearanceEncoderConcat(\n          'appearance_net', opts.appearance_nc, opts.normalize_drit_Ez)\n    else:\n      self._appearance_encoder = None\n\n  def __call__(self, x_in, z_app=None):\n    y = self._generator(x_in, z_app)\n    return y\n\n  def get_appearance_encoder(self):\n    return self._appearance_encoder\n\n  def get_generator(self):\n    return self._generator\n\n  def get_content_encoder(self):\n    return self._content_encoder\n\n\nclass PatchGANDiscriminator(object):\n\n  def __init__(self, name_scope, input_nc, nf=64, n_layers=3, get_fmaps=False):\n    \"\"\"Constructor for a patchGAN discriminators.\n\n    Args:\n      name_scope: str - tf name scope.\n      input_nc: int - number of input channels.\n      nf: int - starting number of discriminator filters.\n      n_layers: int - number of layers in the discriminator.\n      get_fmaps: bool - return intermediate feature maps for FeatLoss.\n    \"\"\"\n    self.get_fmaps = get_fmaps\n    self.n_layers = n_layers\n    kw = 4  # kernel width for convolution\n\n    activation = functools.partial(tf.nn.leaky_relu, alpha=0.2)\n    norm_layer = functools.partial(layers.LayerInstanceNorm)\n    conv2d = functools.partial(layers.LayerConv, use_scaling=opts.use_scaling,\n                               relu_slope=0.2)\n\n    def minibatch_stats(x):\n      return layers.scalar_concat(x, layers.minibatch_mean_variance(x))\n\n    # Create layers.\n    self.blocks = []\n    with tf.variable_scope(name_scope, tf.AUTO_REUSE):\n      with tf.variable_scope('block_0'):\n        self.blocks.append([\n            conv2d('conv0', w=kw, n=[input_nc, nf], stride=2),\n            activation\n        ])\n      for ii_block in range(1, n_layers):\n        nf_prev = nf\n        nf = min(nf * 2, 512)\n        with tf.variable_scope('block_%d' % ii_block):\n          self.blocks.append([\n              conv2d('conv%d' % ii_block, w=kw, n=[nf_prev, nf], stride=2),\n              norm_layer(),\n              activation\n          ])\n      # Add minibatch_stats (from PGGAN) and do a stride1 convolution.\n      nf_prev = nf\n      nf = min(nf * 2, 512)\n      with tf.variable_scope('block_%d' % (n_layers + 1)):\n        self.blocks.append([\n            minibatch_stats,  # this is improvised by @meshry\n            conv2d('conv%d' % (n_layers + 1), w=kw, n=[nf_prev + 1, nf],\n                   stride=1),\n            norm_layer(),\n            activation\n        ])\n      # Get 1-channel patchGAN logits\n      with tf.variable_scope('patchGAN_logits'):\n        self.blocks.append([\n            conv2d('conv%d' % (n_layers + 2), w=kw, n=[nf, 1], stride=1)\n        ])\n\n  def __call__(self, x, x_cond=None):\n    # Concatenate extra conditioning input, if any.\n    if x_cond is not None:\n      x = tf.concat([x, x_cond], axis=3)\n\n    if self.get_fmaps:\n      # Dummy addition of x to D_fmaps, which will be removed before returing\n      D_fmaps = [[x]]\n      for i_block in range(len(self.blocks)):\n        # Apply layer #0 in the current block\n        block_fmaps = [self.blocks[i_block][0](D_fmaps[-1][-1])]\n        # Apply the remaining layers of this block\n        for i_layer in range(1, len(self.blocks[i_block])):\n          block_fmaps.append(self.blocks[i_block][i_layer](block_fmaps[-1]))\n        # Append the feature maps of this block to D_fmaps\n        D_fmaps.append(block_fmaps)\n      return D_fmaps[1:]  # exclude the input x which we added initially\n    else:\n      y = x\n      for i_block in range(len(self.blocks)):\n        for i_layer in range(len(self.blocks[i_block])):\n          y = self.blocks[i_block][i_layer](y)\n      return [[y]]\n\n\nclass MultiScaleDiscriminator(object):\n\n  def __init__(self, name_scope, input_nc, num_scales=3, nf=64, n_layers=3,\n               get_fmaps=False):\n    self.get_fmaps = get_fmaps\n    discs = []\n    with tf.variable_scope(name_scope):\n      for i in range(num_scales):\n        discs.append(PatchGANDiscriminator(\n            'D_scale%d' % i, input_nc, nf=nf, n_layers=n_layers,\n            get_fmaps=get_fmaps))\n    self.discriminators = discs\n\n  def __call__(self, x, x_cond=None, params=None):\n    del params\n    if x_cond is not None:\n      x = tf.concat([x, x_cond], axis=3)\n\n    responses = []\n    for ii, D in enumerate(self.discriminators):\n      responses.append(D(x, x_cond=None))  # x_cond is already concatenated\n      if ii != len(self.discriminators) - 1:\n        x = layers.downscale(x, n=2)\n    return responses\n\n\nclass GeneratorPGGAN(object):\n  def __init__(self, appearance_vec_size=8, use_scaling=True,\n               num_blocks=5, input_nc=7,\n               fmap_base=8192, fmap_decay=1.0, fmap_max=512):\n    \"\"\"Generator model.\n  \n    Args:\n      appearance_vec_size: int, size of the latent appearance vector.\n      use_scaling: bool, whether to use weight scaling.\n      resolution: int, width of the images (assumed to be square).\n      input_nc: int, number of input channles.\n      fmap_base: int, base number of channels.\n      fmap_decay: float, decay rate of channels with respect to depth.\n      fmap_max: int, max number of channels (supersedes fmap_base).\n  \n    Returns:\n      function of the model.\n    \"\"\"\n    def _num_filters(fmap_base, fmap_decay, fmap_max, stage):\n      if opts.g_nf == 32:\n        return min(int(2**(10 - stage)), fmap_max)  # nf32\n      elif opts.g_nf == 64:\n        return min(int(2**(11 - stage)), fmap_max)  # nf64\n      else:\n        raise ValueError('Currently unsupported num filters')\n\n    nf = functools.partial(_num_filters, fmap_base, fmap_decay, fmap_max)\n    self.num_blocks = num_blocks\n    activation = functools.partial(tf.nn.leaky_relu, alpha=0.2)\n    conv2d_stride1 = functools.partial(\n        layers.LayerConv, stride=1, use_scaling=use_scaling, relu_slope=0.2)\n    conv2d_rgb = functools.partial(layers.LayerConv, w=1, stride=1,\n                                   use_scaling=use_scaling)\n  \n    # Create encoder layers.\n    with tf.variable_scope('g_model_enc', tf.AUTO_REUSE):\n      self.enc_stage = []\n      self.from_rgb = []\n\n      if opts.use_appearance and opts.inject_z == 'to_encoder':\n        input_nc += appearance_vec_size\n  \n      for i in range(num_blocks, -1, -1):\n        with tf.variable_scope('res_%d' % i):\n          self.from_rgb.append(\n              layers.LayerPipe([\n                  conv2d_rgb('from_rgb', n=[input_nc, nf(i + 1)]),\n                  activation,\n              ])\n          )\n          self.enc_stage.append(\n              layers.LayerPipe([\n                  functools.partial(layers.downscale, n=2),\n                  conv2d_stride1('conv0', w=3, n=[nf(i + 1), nf(i)]),\n                  activation,\n                  layers.pixel_norm,\n                  conv2d_stride1('conv1', w=3, n=[nf(i), nf(i)]),\n                  activation,\n                  layers.pixel_norm\n              ])\n          )\n  \n    # Create decoder layers.\n    with tf.variable_scope('g_model_dec', tf.AUTO_REUSE):\n      self.dec_stage = []\n      self.to_rgb = []\n  \n      nf_bottleneck = nf(0)  # num input filters at the bottleneck\n      if opts.use_appearance and opts.inject_z == 'to_bottleneck':\n        nf_bottleneck += appearance_vec_size\n\n      with tf.variable_scope('res_0'):\n        self.dec_stage.append(\n          layers.LayerPipe([\n            functools.partial(layers.upscale, n=2),\n            conv2d_stride1('conv0', w=3, n=[nf_bottleneck, nf(1)]),\n            activation,\n            layers.pixel_norm,\n            conv2d_stride1('conv1', w=3, n=[nf(1), nf(1)]),\n            activation,\n            layers.pixel_norm\n          ])\n        )\n        self.to_rgb.append(conv2d_rgb('to_rgb', n=[nf(1), opts.output_nc]))\n  \n      multiply_factor = 2 if opts.concatenate_skip_layers else 1\n      for i in range(1, num_blocks + 1):\n        with tf.variable_scope('res_%d' % i):\n          self.dec_stage.append(\n              layers.LayerPipe([\n                  functools.partial(layers.upscale, n=2),\n                  conv2d_stride1('conv0', w=3,\n                                 n=[multiply_factor * nf(i), nf(i + 1)]),\n                  activation,\n                  layers.pixel_norm,\n                  conv2d_stride1('conv1', w=3, n=[nf(i + 1), nf(i + 1)]),\n                  activation,\n                  layers.pixel_norm\n              ])\n          )\n          self.to_rgb.append(conv2d_rgb('to_rgb',\n                                        n=[nf(i + 1), opts.output_nc]))\n\n  def __call__(self, x, appearance_embedding=None, encoder_fmaps=None):\n    \"\"\"Generator function.\n\n    Args:\n      x: 2D tensor (batch, latents), the conditioning input batch of images.\n      appearance_embedding: float tensor: latent appearance vector.\n    Returns:\n      4D tensor of images (NHWC), the generated images.\n    \"\"\"\n    del encoder_fmaps\n    enc_st_idx = 0\n    if opts.use_appearance and opts.inject_z == 'to_encoder':\n      x = layers.tile_and_concatenate(x, appearance_embedding,\n                                      opts.app_vector_size)\n    y = self.from_rgb[enc_st_idx](x)\n\n    enc_responses = []\n    for i in range(enc_st_idx, len(self.enc_stage)):\n      y = self.enc_stage[i](y)\n      enc_responses.insert(0, y)\n\n    # Concatenate appearance vector to y\n    if opts.use_appearance and opts.inject_z == 'to_bottleneck':\n      appearance_tensor = tf.tile(appearance_embedding,\n                                  [1, tf.shape(y)[1], tf.shape(y)[2], 1])\n      y = tf.concat([y, appearance_tensor], axis=3)\n\n    y_list = []\n    for i in range(self.num_blocks + 1):\n      if i > 0:\n        y_skip = enc_responses[i]  # skip layer\n        if opts.concatenate_skip_layers:\n          y = tf.concat([y, y_skip], axis=3)\n        else:\n          y = y + y_skip\n      y = self.dec_stage[i](y)\n      y_list.append(y)\n\n    return self.to_rgb[self.num_blocks](y_list[-1])\n\n\nclass DRITAppearanceEncoderConcat(object):\n\n  def __init__(self, name_scope, input_nc, normalize_encoder):\n    self.blocks = []\n    activation = functools.partial(tf.nn.leaky_relu, alpha=0.2)\n    conv2d = functools.partial(layers.LayerConv, use_scaling=opts.use_scaling,\n                               relu_slope=0.2, padding='SAME')\n    with tf.variable_scope(name_scope, tf.AUTO_REUSE):\n      if normalize_encoder:\n        self.blocks.append(layers.LayerPipe([\n            conv2d('conv0', w=4, n=[input_nc, 64], stride=2),\n            layers.BasicBlock('BB0', n=[64, 128], use_scaling=opts.use_scaling),\n            layers.pixel_norm,\n            layers.BasicBlock('BB1', n=[128, 192], use_scaling=opts.use_scaling),\n            layers.pixel_norm,\n            layers.BasicBlock('BB2', n=[192, 256], use_scaling=opts.use_scaling),\n            layers.pixel_norm,\n            activation,\n            layers.global_avg_pooling\n        ]))\n      else:\n        self.blocks.append(layers.LayerPipe([\n            conv2d('conv0', w=4, n=[input_nc, 64], stride=2),\n            layers.BasicBlock('BB0', n=[64, 128], use_scaling=opts.use_scaling),\n            layers.BasicBlock('BB1', n=[128, 192], use_scaling=opts.use_scaling),\n            layers.BasicBlock('BB2', n=[192, 256], use_scaling=opts.use_scaling),\n            activation,\n            layers.global_avg_pooling\n        ]))\n      # FC layers to get the mean and logvar\n      self.fc_mean = layers.FullyConnected(opts.app_vector_size, 'FC_mean')\n      self.fc_logvar = layers.FullyConnected(opts.app_vector_size, 'FC_logvar')\n\n  def __call__(self, x):\n    for f in self.blocks:\n      x = f(x)\n\n    mean = self.fc_mean(x)\n    logvar = self.fc_logvar(x)\n    # The following is an arbitrarily chosen *deterministic* latent vector\n    # computation. Another option is to let z = mean, but gradients from logvar\n    # will be None and will need to be removed.\n    z = mean + tf.exp(0.5 * logvar)\n    return z, mean, logvar\n"
  },
  {
    "path": "neural_rerendering.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom PIL import Image\nfrom absl import app\nfrom options import FLAGS as opts\nimport data\nimport datetime\nimport functools\nimport glob\nimport losses\nimport networks\nimport numpy as np\nimport options\nimport os.path as osp\nimport random\nimport skimage.measure\nimport staged_model\nimport tensorflow as tf\nimport time\nimport utils\n\n\ndef build_model_fn(use_exponential_moving_average=True):\n  \"\"\"Builds and returns the model function for an estimator.\n\n  Args:\n    use_exponential_moving_average: bool. If true, the exponential moving\n    average will be used.\n\n  Returns:\n    function, the model_fn function typically required by an estimator.\n  \"\"\"\n  arch_type = opts.arch_type\n  use_appearance = opts.use_appearance\n  def model_fn(features, labels, mode, params):\n    \"\"\"An estimator build_fn.\"\"\"\n    del labels, params\n    if mode == tf.estimator.ModeKeys.TRAIN:\n      step = tf.train.get_global_step()\n\n      x_in = features['conditional_input']\n      x_gt = features['expected_output']  # ground truth output\n      x_app = features['peek_input']\n\n      if opts.training_pipeline == 'staged':\n        ops = staged_model.create_computation_graph(x_in, x_gt, x_app=x_app,\n                                                    arch_type=opts.arch_type)\n        op_increment_step = tf.assign_add(step, 1)\n        train_disc_op = ops['train_disc_op']\n        train_renderer_op = ops['train_renderer_op']\n        train_op = tf.group(train_disc_op, train_renderer_op, op_increment_step)\n\n        utils.HookReport.log_tensor(ops['total_loss_d'], 'total_loss_d')\n        utils.HookReport.log_tensor(ops['loss_d_real'], 'loss_d_real')\n        utils.HookReport.log_tensor(ops['loss_d_fake'], 'loss_d_fake')\n        utils.HookReport.log_tensor(ops['total_loss_g'], 'total_loss_g')\n        utils.HookReport.log_tensor(ops['loss_g_gan'], 'loss_g_gan')\n        utils.HookReport.log_tensor(ops['loss_g_recon'], 'loss_g_recon')\n        utils.HookReport.log_tensor(step, 'global_step')\n\n        return tf.estimator.EstimatorSpec(\n            mode=mode, loss=ops['total_loss_d'] + ops['total_loss_g'],\n            train_op=train_op)\n      else:\n        raise NotImplementedError('%s training is not implemented.' %\n                                  opts.training_pipeline)\n    elif mode == tf.estimator.ModeKeys.EVAL:\n      raise NotImplementedError('Eval is not implemented.')\n    else:  # all below modes are for difference inference tasks.\n      # Build network and initialize inference variables.\n      g_func = networks.RenderingModel(arch_type, use_appearance)\n      if use_appearance:\n        app_func = g_func.get_appearance_encoder()\n      if use_exponential_moving_average:\n        ema = tf.train.ExponentialMovingAverage(decay=0.999)\n        var_dict = ema.variables_to_restore()\n        tf.train.init_from_checkpoint(osp.join(opts.train_dir), var_dict)\n\n      if mode == tf.estimator.ModeKeys.PREDICT:\n        x_in = features['conditional_input']\n        if use_appearance:\n          x_app = features['peek_input']\n          x_app_embedding, _, _ = app_func(x_app)\n        else:\n          x_app_embedding = None\n        y = g_func(x_in, x_app_embedding)\n        tf.logging.info('DBG: shape of y during prediction %s.' % str(y.shape))\n        return tf.estimator.EstimatorSpec(mode=mode, predictions=y)\n\n      # 'eval_subset' mode is same as PREDICT but it concatenates the output to\n      # the input render, semantic map and ground truth for easy comparison.\n      elif mode == 'eval_subset':\n        x_in = features['conditional_input']\n        x_gt = features['expected_output']\n        if use_appearance:\n          x_app = features['peek_input']\n          x_app_embedding, _, _ = app_func(x_app)\n        else:\n          x_app_embedding = None\n        y = g_func(x_in, x_app_embedding)\n        tf.logging.info('DBG: shape of y during prediction %s.' % str(y.shape))\n        x_in_rgb = tf.slice(x_in, [0, 0, 0, 0], [-1, -1, -1, 3])\n        if opts.use_semantic:\n          x_in_semantic = tf.slice(x_in, [0, 0, 0, 4], [-1, -1, -1, 3])\n          output_tuple = tf.concat([x_in_rgb, x_in_semantic, y, x_gt], axis=2)\n        else:\n          output_tuple = tf.concat([x_in_rgb, y, x_gt], axis=2)\n        return tf.estimator.EstimatorSpec(mode=mode, predictions=output_tuple)\n\n      # 'compute_appearance' mode computes and returns the latent z vector.\n      elif mode == 'compute_appearance':\n        assert use_appearance, 'use_appearance is set to False!'\n        x_app_in = features['peek_input']\n        # NOTE the following line is a temporary hack (which is\n        # specially bad for inputs smaller than 512x512).\n        x_app_in = tf.image.resize_image_with_crop_or_pad(x_app_in, 512, 512)\n        app_embedding, _, _ = app_func(x_app_in)\n        return tf.estimator.EstimatorSpec(mode=mode, predictions=app_embedding)\n\n      # 'interpolate_appearance' mode expects an already computed latent z\n      # vector as input passed a value to the dict key 'appearance_embedding'.\n      elif mode == 'interpolate_appearance':\n        assert use_appearance, 'use_appearance is set to False!'\n        x_in = features['conditional_input']\n        x_app_embedding = features['appearance_embedding']\n        y = g_func(x_in, x_app_embedding)\n        tf.logging.info('DBG: shape of y during prediction %s.' % str(y.shape))\n        return tf.estimator.EstimatorSpec(mode=mode, predictions=y)\n      else:\n        raise ValueError('Unsupported mode: ' + mode)\n\n  return model_fn\n\n\ndef make_sample_grid_and_save(est, dataset_name, dataset_parent_dir, grid_dims,\n                              output_dir, cur_nimg):\n  \"\"\"Evaluate a fixed set of validation images and save output.\n\n  Args:\n    est: tf,estimator.Estimator, TF estimator to run the predictions.\n    dataset_name: basename for the validation tfrecord from which to load\n      validation images.\n    dataset_parent_dir: path to a directory containing the validation tfrecord.\n    grid_dims: 2-tuple int for the grid size (1 unit = 1 image).\n    output_dir: string, where to save image samples.\n    cur_nimg: int, current number of images seen by training.\n\n  Returns:\n    None.\n  \"\"\"\n  num_examples = grid_dims[0] * grid_dims[1]\n  def input_val_fn():\n    dict_inp = data.provide_data(\n        dataset_name=dataset_name, parent_dir=dataset_parent_dir, subset='val',\n        batch_size=1, crop_flag=True, crop_size=opts.train_resolution,\n        seeds=[0], max_examples=num_examples,\n        use_appearance=opts.use_appearance, shuffle=0)\n    x_in = dict_inp['conditional_input']\n    x_gt = dict_inp['expected_output']  # ground truth output\n    x_app = dict_inp['peek_input']\n    return x_in, x_gt, x_app\n\n  def est_input_val_fn():\n    x_in, _, x_app = input_val_fn()\n    features = {'conditional_input': x_in, 'peek_input': x_app}\n    return features\n\n  images = [x for x in est.predict(est_input_val_fn)]\n  images = np.array(images, 'f')\n  images = images.reshape(grid_dims + images.shape[1:])\n  utils.save_images(utils.to_png(utils.images_to_grid(images)), output_dir,\n                    cur_nimg)\n\n\ndef visualize_image_sequence(est, dataset_name, dataset_parent_dir,\n                             input_sequence_name, app_base_path, output_dir):\n  \"\"\"Generates an image sequence as a video and stores it to disk.\"\"\"\n  batch_sz = opts.batch_size\n  def input_seq_fn():\n    dict_inp = data.provide_data(\n        dataset_name=dataset_name, parent_dir=dataset_parent_dir,\n        subset=input_sequence_name, batch_size=batch_sz, crop_flag=False,\n        seeds=None, use_appearance=False, shuffle=0)\n    x_in = dict_inp['conditional_input']\n    return x_in\n\n  # Compute appearance embedding only once and use it for all input frames.\n  app_rgb_path = app_base_path + '_reference.png'\n  app_rendered_path = app_base_path + '_color.png'\n  app_depth_path = app_base_path + '_depth.png'\n  app_sem_path = app_base_path + '_seg_rgb.png'\n  x_app = _load_and_concatenate_image_channels(\n      app_rgb_path, app_rendered_path, app_depth_path, app_sem_path)\n  def seq_with_single_appearance_inp_fn():\n    \"\"\"input frames with a fixed latent appearance vector.\"\"\"\n    x_in_op = input_seq_fn()\n    x_app_op = tf.convert_to_tensor(x_app)\n    x_app_tiled_op = tf.tile(x_app_op, [tf.shape(x_in_op)[0], 1, 1, 1])\n    return {'conditional_input': x_in_op,\n            'peek_input': x_app_tiled_op}\n\n  images = [x for x in est.predict(seq_with_single_appearance_inp_fn)]\n  for i, gen_img in enumerate(images):\n    output_file_path = osp.join(output_dir, 'out_%04d.png' % i)\n    print('Saving frame #%d to %s' % (i, output_file_path))\n    with tf.gfile.Open(output_file_path, 'wb') as f:\n      f.write(utils.to_png(gen_img))\n\n\ndef train(dataset_name, dataset_parent_dir, load_pretrained_app_encoder,\n          load_trained_fixed_app, save_samples_kimg=50):\n  \"\"\"Main training procedure.\n\n  The trained model is saved in opts.train_dir, the function itself does not\n   return anything.\n\n  Args:\n    save_samples_kimg: int, period (in KiB) to save sample images.\n\n  Returns:\n    None.\n  \"\"\"\n  image_dir = osp.join(opts.train_dir, 'images')  # to save validation images.\n  tf.gfile.MakeDirs(image_dir)\n  config = tf.estimator.RunConfig(\n      save_summary_steps=(1 << 10) // opts.batch_size,\n      save_checkpoints_steps=(save_samples_kimg << 10) // opts.batch_size,\n      keep_checkpoint_max=5,\n      log_step_count_steps=1 << 30)\n  model_dir = opts.train_dir\n  if (opts.use_appearance and load_trained_fixed_app and\n      not tf.train.latest_checkpoint(model_dir)):\n    tf.logging.warning('***** Loading resume_step from %s!' %\n                       opts.fixed_appearance_train_dir)\n    resume_step = utils.load_global_step_from_checkpoint_dir(\n        opts.fixed_appearance_train_dir)\n  else:\n    tf.logging.warning('***** Loading resume_step (if any) from %s!' %\n                       model_dir)\n    resume_step = utils.load_global_step_from_checkpoint_dir(model_dir)\n  if resume_step != 0:\n    tf.logging.warning('****** Resuming training at %d!' % resume_step)\n\n  model_fn = build_model_fn()  # model function for TFEstimator.\n\n  hooks = [utils.HookReport(1 << 12, opts.batch_size)]\n\n  if opts.use_appearance and load_pretrained_app_encoder:\n    tf.logging.warning('***** will warm-start from %s!' %\n                       opts.appearance_pretrain_dir)\n    ws = tf.estimator.WarmStartSettings(\n        ckpt_to_initialize_from=opts.appearance_pretrain_dir,\n        vars_to_warm_start='appearance_net/.*')\n  elif opts.use_appearance and load_trained_fixed_app:\n    tf.logging.warning('****** finetuning will warm-start from %s!' %\n                       opts.fixed_appearance_train_dir)\n    ws = tf.estimator.WarmStartSettings(\n        ckpt_to_initialize_from=opts.fixed_appearance_train_dir,\n        vars_to_warm_start='.*')\n  else:\n    ws = None\n    tf.logging.warning('****** No warm-starting; using random initialization!')\n\n  est = tf.estimator.Estimator(model_fn, model_dir, config, params={},\n                               warm_start_from=ws)\n\n  for next_kimg in range(opts.save_samples_kimg, opts.total_kimg + 1,\n                         opts.save_samples_kimg):\n    next_step = (next_kimg << 10) // opts.batch_size\n    if opts.num_crops == -1:  # use random crops\n      crop_seeds = None\n    else:\n      crop_seeds = list(100 * np.arange(opts.num_crops))\n    input_train_fn = functools.partial(\n        data.provide_data, dataset_name=dataset_name,\n        parent_dir=dataset_parent_dir, subset='train',\n        batch_size=opts.batch_size, crop_flag=True,\n        crop_size=opts.train_resolution, seeds=crop_seeds,\n        use_appearance=opts.use_appearance)\n    est.train(input_train_fn, max_steps=next_step, hooks=hooks)\n    tf.logging.info('DBG: kimg=%d, cur_step=%d' % (next_kimg, next_step))\n    tf.logging.info('DBG: Saving a validation grid image %06d to %s' % (\n        next_kimg, image_dir))\n    make_sample_grid_and_save(est, dataset_name, dataset_parent_dir, (3, 3),\n                              image_dir, next_kimg << 10)\n\n\ndef _build_inference_estimator(model_dir):\n  model_fn = build_model_fn()\n  est = tf.estimator.Estimator(model_fn, model_dir)\n  return est\n\n\ndef evaluate_sequence(dataset_name, dataset_parent_dir, virtual_seq_name,\n                      app_base_path):\n  output_dir = osp.join(opts.train_dir, 'seq_output_%s' % virtual_seq_name)\n  tf.gfile.MakeDirs(output_dir)\n  est = _build_inference_estimator(opts.train_dir)\n  visualize_image_sequence(est, dataset_name, dataset_parent_dir,\n                           virtual_seq_name, app_base_path, output_dir)\n\n\ndef evaluate_image_set(dataset_name, dataset_parent_dir, subset_suffix,\n                       output_dir=None, batch_size=6):\n  if output_dir is None:\n    output_dir = osp.join(opts.train_dir, 'validation_output_%s' % subset_suffix)\n  tf.gfile.MakeDirs(output_dir)\n  model_fn_old = build_model_fn()\n  def model_fn_wrapper(features, labels, mode, params):\n    del mode\n    return model_fn_old(features, labels, 'eval_subset', params)\n  model_dir = opts.train_dir\n  est = tf.estimator.Estimator(model_fn_wrapper, model_dir)\n  est_inp_fn = functools.partial(\n      data.provide_data, dataset_name=dataset_name,\n      parent_dir=dataset_parent_dir, subset=subset_suffix,\n      batch_size=batch_size, use_appearance=opts.use_appearance, shuffle=0)\n\n  print('Evaluating images for subset %s' % subset_suffix)\n  images = [x for x in est.predict(est_inp_fn)]\n  print('Evaluated %d images' % len(images))\n  for i, img in enumerate(images):\n    output_file_path = osp.join(output_dir, 'out_%04d.png' % i)\n    print('Saving file #%d: %s' % (i, output_file_path))\n    with tf.gfile.Open(output_file_path, 'wb') as f:\n      f.write(utils.to_png(img))\n\n\ndef _load_and_concatenate_image_channels(rgb_path=None, rendered_path=None,\n                                         depth_path=None, seg_path=None,\n                                         size_multiple=64):\n  \"\"\"Prepares a single input for the network.\"\"\"\n  if (rgb_path is None and rendered_path is None and depth_path is None and\n      seg_path is None):\n    raise ValueError('At least one of the inputs has to be not None')\n\n  channels = ()\n  if rgb_path is not None:\n    rgb_img = np.array(Image.open(rgb_path)).astype(np.float32)\n    rgb_img = utils.crop_to_multiple(rgb_img, size_multiple)\n    channels = channels + (rgb_img,)\n  if rendered_path is not None:\n    rendered_img = np.array(Image.open(rendered_path)).astype(np.float32)\n    if not opts.use_alpha:\n      rendered_img = rendered_img[:, :, :3]  # drop the alpha channel\n    rendered_img = utils.crop_to_multiple(rendered_img, size_multiple)\n    channels = channels + (rendered_img,)\n  if depth_path is not None:\n    depth_img = np.array(Image.open(depth_path))\n    depth_img = depth_img.astype(np.float32)\n    depth_img = utils.crop_to_multiple(depth_img[:, :, np.newaxis],\n                                       size_multiple)\n    channels = channels + (depth_img,)\n    # depth_img = depth_img * (2.0 / 255) - 1.0\n  if seg_path is not None:\n    seg_img = np.array(Image.open(seg_path)).astype(np.float32)\n    seg_img = utils.crop_to_multiple(seg_img, size_multiple)\n    channels = channels + (seg_img,)\n  # Concatenate and normalize channels\n  img = np.dstack(channels)\n  img = np.expand_dims(img, axis=0)\n  img = img * (2.0 / 255) - 1.0\n  return img\n\n\ndef infer_dir(model_dir, input_dir, output_dir):\n  tf.gfile.MakeDirs(output_dir)\n  est = _build_inference_estimator(opts.train_dir)\n\n  def read_image(base_path, is_appearance=False):\n    if is_appearance:\n      ref_img_path = base_path + '_reference.png'\n    else:\n      ref_img_path = None\n    rendered_img_path = base_path + '_color.png'\n    depth_img_path = base_path + '_depth.png'\n    seg_img_path = base_path + '_seg_rgb.png'\n    img = _load_and_concatenate_image_channels(\n        rgb_path=ref_img_path, rendered_path=rendered_img_path,\n        depth_path=depth_img_path, seg_path=seg_img_path)\n    return img\n\n  def get_inference_input_fn(base_path, app_base_path):\n    x_in = read_image(base_path, False)\n    x_app_in = read_image(app_base_path, True)\n    def infer_input_fn():\n      return {'conditional_input': x_in, 'peek_input': x_app_in}\n    return infer_input_fn\n\n  file_paths = sorted(glob.glob(osp.join(input_dir, '*_depth.png')))\n  base_paths = [x[:-10] for x in file_paths]  # remove the '_depth.png' suffix\n  for inp_base_path in base_paths:\n    est_inp_fn = get_inference_input_fn(inp_base_path, inp_base_path)\n    img = next(est.predict(est_inp_fn))\n    basename = osp.basename(inp_base_path)\n    output_img_path = osp.join(output_dir, basename + '_out.png')\n    print('Saving generated image to %s' % output_img_path)\n    with tf.gfile.Open(output_img_path, 'wb') as f:\n      f.write(utils.to_png(img))\n\n\ndef joint_interpolation(model_dir, app_input_dir, st_app_basename,\n                        end_app_basename, camera_path_dir):\n  \"\"\"\n  Interpolates both viewpoint and appearance between two input images.\n  \"\"\"\n  # Create output direcotry\n  output_dir = osp.join(model_dir, 'joint_interpolation_out')\n  tf.gfile.MakeDirs(output_dir)\n\n  # Build estimator\n  model_fn_old = build_model_fn()\n  def model_fn_wrapper(features, labels, mode, params):\n    del mode\n    return model_fn_old(features, labels, 'interpolate_appearance', params)\n  def appearance_model_fn(features, labels, mode, params):\n    del mode\n    return model_fn_old(features, labels, 'compute_appearance', params)\n  config = tf.estimator.RunConfig(\n      save_summary_steps=1000, save_checkpoints_steps=50000,\n      keep_checkpoint_max=50, log_step_count_steps=1 << 30)\n  model_dir = model_dir\n  est = tf.estimator.Estimator(model_fn_wrapper, model_dir, config, params={})\n  est_app = tf.estimator.Estimator(appearance_model_fn, model_dir, config,\n                                   params={})\n\n  # Compute appearance embeddings for the two input appearance images.\n  app_inputs = []\n  for app_basename in [st_app_basename, end_app_basename]:\n    app_rgb_path = osp.join(app_input_dir, app_basename + '_reference.png')\n    app_rendered_path = osp.join(app_input_dir, app_basename + '_color.png')\n    app_depth_path = osp.join(app_input_dir, app_basename + '_depth.png')\n    app_seg_path = osp.join(app_input_dir, app_basename + '_seg_rgb.png')\n    app_in = _load_and_concatenate_image_channels(\n        rgb_path=app_rgb_path, rendered_path=app_rendered_path,\n        depth_path=app_depth_path, seg_path=app_seg_path)\n    # app_inputs.append(tf.convert_to_tensor(app_in))\n    app_inputs.append(app_in)\n\n  embedding1 = next(est_app.predict(\n      lambda: {'peek_input': app_inputs[0]}))\n  embedding1 = np.expand_dims(embedding1, axis=0)\n  embedding2 = next(est_app.predict(\n      lambda: {'peek_input': app_inputs[1]}))\n  embedding2 = np.expand_dims(embedding2, axis=0)\n\n  file_paths = sorted(glob.glob(osp.join(camera_path_dir, '*_depth.png')))\n  base_paths = [x[:-10] for x in file_paths]  # remove the '_depth.png' suffix\n\n  # Compute interpolated appearance embeddings\n  num_interpolations = len(base_paths)\n  interpolated_embeddings = []\n  delta_vec = (embedding2 - embedding1) / (num_interpolations - 1)\n  for delta_iter in range(num_interpolations):\n    x_app_embedding = embedding1 + delta_iter * delta_vec\n    interpolated_embeddings.append(x_app_embedding)\n\n  # Generate and save interpolated images\n  for frame_idx, embedding in enumerate(interpolated_embeddings):\n    # Read in input frame\n    frame_render_path = osp.join(base_paths[frame_idx] + '_color.png')\n    frame_depth_path = osp.join(base_paths[frame_idx] + '_depth.png')\n    frame_seg_path = osp.join(base_paths[frame_idx] + '_seg_rgb.png')\n    x_in = _load_and_concatenate_image_channels(\n        rgb_path=None, rendered_path=frame_render_path,\n        depth_path=frame_depth_path, seg_path=frame_seg_path)\n\n    img = next(est.predict(\n        lambda: {'conditional_input': tf.convert_to_tensor(x_in),\n                 'appearance_embedding': tf.convert_to_tensor(embedding)}))\n    output_img_name = '%s_%s_%03d.png' % (st_app_basename, end_app_basename,\n                                          frame_idx)\n    output_img_path = osp.join(output_dir, output_img_name)\n    print('Saving interpolated image to %s' % output_img_path)\n    with tf.gfile.Open(output_img_path, 'wb') as f:\n      f.write(utils.to_png(img))\n\n\ndef interpolate_appearance(model_dir, input_dir, target_img_basename,\n                           appearance_img1_basename, appearance_img2_basename):\n  # Create output direcotry\n  output_dir = osp.join(model_dir, 'interpolate_appearance_out')\n  tf.gfile.MakeDirs(output_dir)\n\n  # Build estimator\n  model_fn_old = build_model_fn()\n  def model_fn_wrapper(features, labels, mode, params):\n    del mode\n    return model_fn_old(features, labels, 'interpolate_appearance', params)\n  def appearance_model_fn(features, labels, mode, params):\n    del mode\n    return model_fn_old(features, labels, 'compute_appearance', params)\n  config = tf.estimator.RunConfig(\n      save_summary_steps=1000, save_checkpoints_steps=50000,\n      keep_checkpoint_max=50, log_step_count_steps=1 << 30)\n  model_dir = model_dir\n  est = tf.estimator.Estimator(model_fn_wrapper, model_dir, config, params={})\n  est_app = tf.estimator.Estimator(appearance_model_fn, model_dir, config,\n                                   params={})\n\n  # Compute appearance embeddings for the two input appearance images.\n  app_inputs = []\n  for app_basename in [appearance_img1_basename, appearance_img2_basename]:\n    app_rgb_path = osp.join(input_dir, app_basename + '_reference.png')\n    app_rendered_path = osp.join(input_dir, app_basename + '_color.png')\n    app_depth_path = osp.join(input_dir, app_basename + '_depth.png')\n    app_seg_path = osp.join(input_dir, app_basename + '_seg_rgb.png')\n    app_in = _load_and_concatenate_image_channels(\n        rgb_path=app_rgb_path, rendered_path=app_rendered_path,\n        depth_path=app_depth_path, seg_path=app_seg_path)\n    # app_inputs.append(tf.convert_to_tensor(app_in))\n    app_inputs.append(app_in)\n\n  embedding1 = next(est_app.predict(\n      lambda: {'peek_input': app_inputs[0]}))\n  embedding2 = next(est_app.predict(\n      lambda: {'peek_input': app_inputs[1]}))\n  embedding1 = np.expand_dims(embedding1, axis=0)\n  embedding2 = np.expand_dims(embedding2, axis=0)\n\n  # Compute interpolated appearance embeddings\n  num_interpolations = 10\n  interpolated_embeddings = []\n  delta_vec = (embedding2 - embedding1) / num_interpolations\n  for delta_iter in range(num_interpolations + 1):\n    x_app_embedding = embedding1 + delta_iter * delta_vec\n    interpolated_embeddings.append(x_app_embedding)\n\n  # Read in the generator input for the target image to render\n  rendered_img_path = osp.join(input_dir, target_img_basename + '_color.png')\n  depth_img_path = osp.join(input_dir, target_img_basename + '_depth.png')\n  seg_img_path = osp.join(input_dir, target_img_basename + '_seg_rgb.png')\n  x_in = _load_and_concatenate_image_channels(\n      rgb_path=None, rendered_path=rendered_img_path,\n      depth_path=depth_img_path, seg_path=seg_img_path)\n\n  # Generate and save interpolated images\n  for interpolate_iter, embedding in enumerate(interpolated_embeddings):\n    img = next(est.predict(\n        lambda: {'conditional_input': tf.convert_to_tensor(x_in),\n                 'appearance_embedding': tf.convert_to_tensor(embedding)}))\n    output_img_name = 'interpolate_%s_%s_%s_%03d.png' % (\n        target_img_basename, appearance_img1_basename, appearance_img2_basename,\n        interpolate_iter)\n    output_img_path = osp.join(output_dir, output_img_name)\n    print('Saving interpolated image to %s' % output_img_path)\n    with tf.gfile.Open(output_img_path, 'wb') as f:\n      f.write(utils.to_png(img))\n\n\ndef main(argv):\n  del argv\n  configs_str = options.list_options()\n  tf.gfile.MakeDirs(opts.train_dir)\n  with tf.gfile.Open(osp.join(opts.train_dir, 'configs.txt'), 'wb') as f:\n    f.write(configs_str)\n  tf.logging.info('Local configs\\n%s' % configs_str)\n\n  if opts.run_mode == 'train':\n    dataset_name = opts.dataset_name\n    dataset_parent_dir = opts.dataset_parent_dir\n    load_pretrained_app_encoder = opts.load_pretrained_app_encoder\n    load_trained_fixed_app = opts.load_from_another_ckpt\n    batch_size = opts.batch_size\n    train(dataset_name, dataset_parent_dir, load_pretrained_app_encoder,\n          load_trained_fixed_app)\n  elif opts.run_mode == 'eval':  # generate a camera path output sequence from TFRecord inputs.\n    dataset_name = opts.dataset_name\n    dataset_parent_dir = opts.dataset_parent_dir\n    virtual_seq_name = opts.virtual_seq_name\n    inp_app_img_base_path = opts.inp_app_img_base_path\n    evaluate_sequence(dataset_name, dataset_parent_dir, virtual_seq_name,\n                      inp_app_img_base_path)\n  elif opts.run_mode == 'eval_subset':  # generate output for validation set (encoded as TFRecords)\n    dataset_name = opts.dataset_name\n    dataset_parent_dir = opts.dataset_parent_dir\n    virtual_seq_name = opts.virtual_seq_name\n    evaluate_image_set(dataset_name, dataset_parent_dir, virtual_seq_name,\n                       opts.output_validation_dir, opts.batch_size)\n  elif opts.run_mode == 'eval_dir':  # evaluate output for a directory with input images\n    input_dir = opts.inference_input_path\n    output_dir = opts.inference_output_dir\n    model_dir = opts.train_dir\n    infer_dir(model_dir, input_dir, output_dir)\n  elif opts.run_mode == 'interpolate_appearance':  # interpolate appearance only between two images.\n    model_dir = opts.train_dir\n    input_dir = opts.inference_input_path\n    target_img_basename = opts.target_img_basename\n    app_img1_basename = opts.appearance_img1_basename\n    app_img2_basename = opts.appearance_img2_basename\n    interpolate_appearance(model_dir, input_dir, target_img_basename,\n                           app_img1_basename, app_img2_basename)\n  elif opts.run_mode == 'joint_interpolation':  # interpolate viewpoint and appearance between two images\n    model_dir = opts.train_dir\n    app_input_dir = opts.inference_input_path\n    st_app_basename = opts.appearance_img1_basename\n    end_app_basename = opts.appearance_img2_basename\n    frames_dir = opts.frames_dir\n    joint_interpolation(model_dir, app_input_dir, st_app_basename,\n                        end_app_basename, frames_dir)\n  else:\n    raise ValueError('Unsupported --run_mode %s' % opts.run_mode)\n\n\nif __name__ == '__main__':\n  app.run(main)\n"
  },
  {
    "path": "options.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom absl import flags\nimport numpy as np\n\nFLAGS = flags.FLAGS\n\n# ------------------------------------------------------------------------------\n# Train flags\n# ------------------------------------------------------------------------------\n\n# Dataset, model directory and run mode\nflags.DEFINE_string('train_dir', '/tmp/nerual_rendering',\n                    'Directory for model training.')\nflags.DEFINE_string('dataset_name', 'sanmarco9k', 'name ID for a dataset.')\nflags.DEFINE_string(\n    'dataset_parent_dir', '',\n    'Directory containing generated tfrecord dataset.')\nflags.DEFINE_string('run_mode', 'train', \"{'train', 'eval', 'infer'}\")\nflags.DEFINE_string('imageset_dir', None, 'Directory containing trainset '\n                    'images for appearance pretraining.')\nflags.DEFINE_string('metadata_output_dir', None, 'Directory to save pickled '\n                    'pairwise distance matrix for appearance pretraining.')\nflags.DEFINE_integer('save_samples_kimg', 50, 'kimg cycle to save sample'\n                     'validation ouptut during training.')\n\n# Network inputs/outputs\nflags.DEFINE_boolean('use_depth', True, 'Add depth image to the deep buffer.')\nflags.DEFINE_boolean('use_alpha', False,\n                     'Add alpha channel to the deep buffer.')\nflags.DEFINE_boolean('use_semantic', True,\n                     'Add semantic map to the deep buffer.')\nflags.DEFINE_boolean('use_appearance', True,\n                     'Capture appearance from an input real image.')\nflags.DEFINE_integer('deep_buffer_nc', 7,\n                     'Number of input channels in the deep buffer.')\nflags.DEFINE_integer('appearance_nc', 10,\n                     'Number of input channels to the appearance encoder.')\nflags.DEFINE_integer('output_nc', 3,\n                     'Number of channels for the generated image.')\n\n# Staged training flags\nflags.DEFINE_string(\n    'vgg16_path', './vgg16_weights/vgg16.npy',\n    'path to a *.npy file with vgg16 pretrained weights')\nflags.DEFINE_boolean('load_pretrained_app_encoder', False,\n                     'Warmstart appearance encoder with pretrained weights.')\nflags.DEFINE_string('appearance_pretrain_dir', '',\n                    'Model dir for the pretrained appearance encoder.')\nflags.DEFINE_boolean('train_app_encoder', False, 'Whether to make the weights '\n                     'for the appearance encoder trainable or not.')\nflags.DEFINE_boolean(\n    'load_from_another_ckpt', False, 'Load weights from another trained model, '\n                     'e.g load model trained with a fixed appearance encoder.')\nflags.DEFINE_string('fixed_appearance_train_dir', '',\n                    'Model dir for training G with a fixed appearance net.')\n\n# -----------------------------------------------------------------------------\n\n# More hparams\nflags.DEFINE_integer('train_resolution', 256,\n                     'Crop train images to this resolution.')\nflags.DEFINE_float('d_lr', 0.001, 'Learning rate for the discriminator.')\nflags.DEFINE_float('g_lr', 0.001, 'Learning rate for the generator.')\nflags.DEFINE_float('ez_lr', 0.0001, 'Learning rate for appearance encoder.')\nflags.DEFINE_integer('batch_size', 8, 'Batch size for training.')\nflags.DEFINE_boolean('use_scaling', True, \"use He's scaling.\")\nflags.DEFINE_integer('num_crops', 30, 'num crops from train images'\n                     '(use -1 for random crops).')\nflags.DEFINE_integer('app_vector_size', 8, 'Size of latent appearance vector.')\nflags.DEFINE_integer('total_kimg', 20000,\n                     'Max number (in kilo) of training images for training.')\nflags.DEFINE_float('adam_beta1', 0.0, 'beta1 for adam optimizer.')\nflags.DEFINE_float('adam_beta2', 0.99, 'beta2 for adam optimizer.')\n\n# Loss weights\nflags.DEFINE_float('w_loss_vgg', 0.3, 'VGG loss weight.')\nflags.DEFINE_float('w_loss_feat', 10., 'Feature loss weight (from pix2pixHD).')\nflags.DEFINE_float('w_loss_l1', 50., 'L1 loss weight.')\nflags.DEFINE_float('w_loss_z_recon', 10., 'Z reconstruction loss weight.')\nflags.DEFINE_float('w_loss_gan', 1., 'Adversarial loss weight.')\nflags.DEFINE_float('w_loss_z_gan', 1., 'Z adversarial loss weight.')\nflags.DEFINE_float('w_loss_kl', 0.01, 'KL divergence weight.')\nflags.DEFINE_float('w_loss_l2_reg', 0.01, 'Weight for L2 regression on Z.')\n\n# -----------------------------------------------------------------------------\n\n# Architecture and training setup\nflags.DEFINE_string('arch_type', 'pggan',\n                    'Architecture type: {pggan, pix2pixhd}.')\nflags.DEFINE_string('training_pipeline', 'staged',\n                    'Training type type: {staged, bicycle_gan, drit}.')\nflags.DEFINE_integer('g_nf', 64,\n                     'num filters in the first/last layers of U-net.')\nflags.DEFINE_boolean('concatenate_skip_layers', True,\n                     'Use concatenation for skip connections.')\n\n## if arch_type == 'pggan':\nflags.DEFINE_integer('pggan_n_blocks', 5,\n                     'Num blocks for the pggan architecture.')\n## if arch_type == 'pix2pixhd':\nflags.DEFINE_integer('p2p_n_downsamples', 3,\n                     'Num downsamples for the pix2pixHD architecture.')\nflags.DEFINE_integer('p2p_n_resblocks', 4, 'Num residual blocks at the '\n                     'end/start of the pix2pixHD encoder/decoder.')\n## if use_drit_pipeline:\nflags.DEFINE_boolean('use_concat', True, '\"concat\" mode from DRIT.')\nflags.DEFINE_boolean('normalize_drit_Ez', True, 'Add pixelnorm layers to the '\n                     'appearance encoder.')\nflags.DEFINE_boolean('concat_z_in_all_layers', True, 'Inject z at each '\n                     'upsampling layer in the decoder (only for DRIT baseline)')\nflags.DEFINE_string('inject_z', 'to_bottleneck', 'Method for injecting z; '\n                     'one of {to_encoder, to_bottleneck}.')\nflags.DEFINE_boolean('use_vgg_loss', True, 'vgg v L1 reconstruction loss.')\n\n# ------------------------------------------------------------------------------\n# Inference flags\n# ------------------------------------------------------------------------------\n\nflags.DEFINE_string('inference_input_path', '',\n                    'Parent directory for input images at inference time.')\nflags.DEFINE_string('inference_output_dir', '', 'Output path for inference')\nflags.DEFINE_string('target_img_basename', '',\n                    'basename of target image to render for interpolation')\nflags.DEFINE_string('virtual_seq_name', 'full_camera_path',\n                    'name for the virtual camera path suffix for the TFRecord.')\nflags.DEFINE_string('inp_app_img_base_path', '',\n                    'base path for the input appearance image for camera paths')\n\nflags.DEFINE_string('appearance_img1_basename', '',\n                    'basename of the first appearance image for interpolation')\nflags.DEFINE_string('appearance_img2_basename', '',\n                    'basename of the first appearance image for interpolation')\nflags.DEFINE_list('input_basenames', [], 'input basenames for inference')\nflags.DEFINE_list('input_app_basenames', [], 'input appearance basenames for '\n                  'inference')\nflags.DEFINE_string('frames_dir', '',\n                    'Folder with input frames to a camera path')\nflags.DEFINE_string('output_validation_dir', '',\n                    'dataset_name for storing results in a structured folder')\nflags.DEFINE_string('input_rendered', '',\n                    'input rendered image name for inference')\nflags.DEFINE_string('input_depth', '', 'input depth image name for inference')\nflags.DEFINE_string('input_seg', '',\n                    'input segmentation mask image name for inference')\nflags.DEFINE_string('input_app_rgb', '',\n                    'input appearance rgb image name for inference')\nflags.DEFINE_string('input_app_rendered', '',\n                    'input appearance rendered image name for inference')\nflags.DEFINE_string('input_app_depth', '',\n                    'input appearance depth image name for inference')\nflags.DEFINE_string('input_app_seg', '',\n                    'input appearance segmentation mask image name for'\n                    'inference')\nflags.DEFINE_string('output_img_name', '',\n                    '[OPTIONAL] output image name for inference')\n\n# -----------------------------------------------------------------------------\n# Some validation and assertions\n# -----------------------------------------------------------------------------\n\ndef validate_options():\n  if FLAGS.use_drit_training:\n    assert FLAGS.use_appearance, 'DRIT pipeline requires --use_appearance'\n  assert not (\n    FLAGS.load_pretrained_appearance_encoder and FLAGS.load_from_another_ckpt), (\n      'You cannot load weights for the appearance encoder from two different '\n      'checkpoints!')\n  if not FLAGS.use_appearance:\n    print('**Warning: setting --app_vector_size to 0 since '\n          '--use_appearance=False!')\n    FLAGS.set_default('app_vector_size', 0)\n  \n# -----------------------------------------------------------------------------\n# Print all options\n# -----------------------------------------------------------------------------\n\ndef list_options():\n  configs = ('# Run flags/options from options.py:\\n'\n             '# ----------------------------------\\n')\n  configs += ('## Train flags:\\n'\n              '## ------------\\n')\n  configs += 'train_dir = %s\\n' % FLAGS.train_dir\n  configs += 'dataset_name = %s\\n' % FLAGS.dataset_name\n  configs += 'dataset_parent_dir = %s\\n' % FLAGS.dataset_parent_dir\n  configs += 'run_mode = %s\\n' % FLAGS.run_mode\n  configs += 'save_samples_kimg = %d\\n' % FLAGS.save_samples_kimg\n  configs += '\\n# --------------------------------------------------------\\n\\n'\n\n  configs += ('## Network inputs and outputs:\\n'\n              '## ---------------------------\\n')\n  configs += 'use_depth = %s\\n' % str(FLAGS.use_depth)\n  configs += 'use_alpha = %s\\n' % str(FLAGS.use_alpha)\n  configs += 'use_semantic = %s\\n' % str(FLAGS.use_semantic)\n  configs += 'use_appearance = %s\\n' % str(FLAGS.use_appearance)\n  configs += 'deep_buffer_nc = %d\\n' % FLAGS.deep_buffer_nc\n  configs += 'appearance_nc = %d\\n' % FLAGS.appearance_nc\n  configs += 'output_nc = %d\\n' % FLAGS.output_nc\n  configs += 'train_resolution = %d\\n' % FLAGS.train_resolution\n  configs += '\\n# --------------------------------------------------------\\n\\n'\n\n  configs += ('## Staged training flags:\\n'\n              '## ----------------------\\n')\n  configs += 'load_pretrained_app_encoder = %s\\n' % str(\n                                            FLAGS.load_pretrained_app_encoder)\n  configs += 'appearance_pretrain_dir = %s\\n' % FLAGS.appearance_pretrain_dir\n  configs += 'train_app_encoder = %s\\n' % str(FLAGS.train_app_encoder)\n  configs += 'load_from_another_ckpt = %s\\n' % str(FLAGS.load_from_another_ckpt)\n  configs += 'fixed_appearance_train_dir = %s\\n' % str(\n                                            FLAGS.fixed_appearance_train_dir)\n  configs += '\\n# --------------------------------------------------------\\n\\n'\n\n  configs += ('## More hyper-parameters:\\n'\n              '## ----------------------\\n')\n  configs += 'd_lr = %f\\n' % FLAGS.d_lr\n  configs += 'g_lr = %f\\n' % FLAGS.g_lr\n  configs += 'ez_lr = %f\\n' % FLAGS.ez_lr\n  configs += 'batch_size = %d\\n' % FLAGS.batch_size\n  configs += 'use_scaling = %s\\n' % str(FLAGS.use_scaling)\n  configs += 'num_crops = %d\\n' % FLAGS.num_crops\n  configs += 'app_vector_size = %d\\n' % FLAGS.app_vector_size\n  configs += 'total_kimg = %d\\n' % FLAGS.total_kimg\n  configs += 'adam_beta1 = %f\\n' % FLAGS.adam_beta1\n  configs += 'adam_beta2 = %f\\n' % FLAGS.adam_beta2\n  configs += '\\n# --------------------------------------------------------\\n\\n'\n\n  configs += ('## Loss weights:\\n'\n              '## -------------\\n')\n  configs += 'w_loss_vgg = %f\\n' % FLAGS.w_loss_vgg\n  configs += 'w_loss_feat = %f\\n' % FLAGS.w_loss_feat\n  configs += 'w_loss_l1 = %f\\n' % FLAGS.w_loss_l1\n  configs += 'w_loss_z_recon = %f\\n' % FLAGS.w_loss_z_recon\n  configs += 'w_loss_gan = %f\\n' % FLAGS.w_loss_gan\n  configs += 'w_loss_z_gan = %f\\n' % FLAGS.w_loss_z_gan\n  configs += 'w_loss_kl = %f\\n' % FLAGS.w_loss_kl\n  configs += 'w_loss_l2_reg = %f\\n' % FLAGS.w_loss_l2_reg\n  configs += '\\n# --------------------------------------------------------\\n\\n'\n\n  configs += ('## Architecture and training setup:\\n'\n              '## --------------------------------\\n')\n  configs += 'arch_type = %s\\n' % FLAGS.arch_type\n  configs += 'training_pipeline = %s\\n' % FLAGS.training_pipeline\n  configs += 'g_nf = %d\\n' % FLAGS.g_nf\n  configs += 'concatenate_skip_layers = %s\\n' % str(\n                                                FLAGS.concatenate_skip_layers)\n  configs += 'p2p_n_downsamples = %d\\n' % FLAGS.p2p_n_downsamples\n  configs += 'p2p_n_resblocks = %d\\n' % FLAGS.p2p_n_resblocks\n  configs += 'use_concat = %s\\n' % str(FLAGS.use_concat)\n  configs += 'normalize_drit_Ez = %s\\n' % str(FLAGS.normalize_drit_Ez)\n  configs += 'inject_z = %s\\n' % FLAGS.inject_z\n  configs += 'concat_z_in_all_layers = %s\\n' % str(FLAGS.concat_z_in_all_layers)\n  configs += 'use_vgg_loss = %s\\n' % str(FLAGS.use_vgg_loss)\n  configs += '\\n# --------------------------------------------------------\\n\\n'\n\n  return configs\n"
  },
  {
    "path": "pretrain_appearance.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom PIL import Image\nfrom absl import app\nfrom absl import flags\nfrom options import FLAGS as opts\nimport glob\nimport networks\nimport numpy as np\nimport os\nimport os.path as osp\nimport pickle\nimport style_loss\nimport tensorflow as tf\nimport utils\n\n\ndef _load_and_concatenate_image_channels(\n    rgb_path=None, rendered_path=None, depth_path=None, seg_path=None,\n    crop_size=512):\n  if (rgb_path is None and rendered_path is None and depth_path is None and\n      seg_path is None):\n    raise ValueError('At least one of the inputs has to be not None')\n\n  channels = ()\n  if rgb_path is not None:\n    rgb_img = np.array(Image.open(rgb_path)).astype(np.float32)\n    rgb_img = utils.get_central_crop(rgb_img, crop_size, crop_size)\n    channels = channels + (rgb_img,)\n  if rendered_path is not None:\n    rendered_img = np.array(Image.open(rendered_path)).astype(np.float32)\n    rendered_img = utils.get_central_crop(rendered_img, crop_size, crop_size)\n    if not opts.use_alpha:\n      rendered_img = rendered_img[:,:, :3]  # drop the alpha channel\n    channels = channels + (rendered_img,)\n  if depth_path is not None:\n    depth_img = np.array(Image.open(depth_path))\n    depth_img = depth_img.astype(np.float32)\n    depth_img = utils.get_central_crop(depth_img, crop_size, crop_size)\n    channels = channels + (depth_img,)\n  if seg_path is not None:\n    seg_img = np.array(Image.open(seg_path)).astype(np.float32)\n    channels = channels + (seg_img,)\n  # Concatenate and normalize channels\n  img = np.dstack(channels)\n  img = img * (2.0 / 255) - 1.0\n  return img\n\n\ndef read_single_appearance_input(rgb_img_path):\n  base_path = rgb_img_path[:-14]  # remove the '_reference.png' suffix\n  rendered_img_path = base_path + '_color.png'\n  depth_img_path = base_path + '_depth.png'\n  semantic_img_path = base_path + '_seg_rgb.png'\n  network_input_img = _load_and_concatenate_image_channels(\n      rgb_img_path, rendered_img_path, depth_img_path, semantic_img_path,\n      crop_size=opts.train_resolution)\n  return network_input_img\n\n\ndef get_triplet_input_fn(dataset_path, dist_file_path=None, k_max_nearest=5,\n                         k_max_farthest=13):\n  input_images_pattern = osp.join(dataset_path, '*_reference.png')\n  filenames = sorted(glob.glob(input_images_pattern))\n  print('DBG: obtained %d input filenames for triplet inputs' % len(filenames))\n  print('DBG: Computing pairwise style distances:')\n  if dist_file_path is not None and osp.exists(dist_file_path):\n    print('*** Loading distance matrix from %s' % dist_file_path)\n    with open(dist_file_path, 'rb') as f:\n      dist_matrix = pickle.load(f)['dist_matrix']\n      print('loaded a dist_matrix of shape: %s' % str(dist_matrix.shape))\n  else:\n    dist_matrix = style_loss.compute_pairwise_style_loss_v2(filenames)\n    dist_dict = {'dist_matrix': dist_matrix}\n    print('Saving distance matrix to %s' % dist_file_path)\n    with open(dist_file_path, 'wb') as f:\n      pickle.dump(dist_dict, f)\n\n  # Sort neighbors for each anchor image\n  num_imgs = len(dist_matrix)\n  sorted_neighbors = [np.argsort(dist_matrix[ii, :]) for ii in range(num_imgs)]\n\n  def triplet_input_fn(anchor_idx):\n    # start from 1 to avoid getting the same image as its own neighbor\n    positive_neighbor_idx = np.random.randint(1, k_max_nearest + 1)\n    negative_neighbor_idx = num_imgs - 1 - np.random.randint(0, k_max_farthest)\n    positive_img_idx = sorted_neighbors[anchor_idx][positive_neighbor_idx]\n    negative_img_idx = sorted_neighbors[anchor_idx][negative_neighbor_idx]\n    # Read anchor image\n    anchor_rgb_path = osp.join(dataset_path, filenames[anchor_idx])\n    anchor_input = read_single_appearance_input(anchor_rgb_path)\n    # Read positive image\n    positive_rgb_path = osp.join(dataset_path, filenames[positive_img_idx])\n    positive_input = read_single_appearance_input(positive_rgb_path)\n    # Read negative image\n    negative_rgb_path = osp.join(dataset_path, filenames[negative_img_idx])\n    negative_input = read_single_appearance_input(negative_rgb_path)\n    # Return triplet\n    return anchor_input, positive_input, negative_input\n\n  return triplet_input_fn\n\n\ndef get_tf_triplet_dataset_iter(\n    dataset_path, trainset_size, dist_file_path, batch_size=4,\n    deterministic_flag=False, shuffle_buf_size=128, repeat_flag=True):\n  # Create a dataset of anchor image indices.\n  idx_dataset = tf.data.Dataset.range(trainset_size)\n  # Create a mapper function from anchor idx to triplet images.\n  triplet_mapper = lambda idx: tuple(tf.py_func(\n      get_triplet_input_fn(dataset_path, dist_file_path), [idx],\n      [tf.float32, tf.float32, tf.float32]))\n  # Convert triplet to a dictionary for the estimator input format.\n  triplet_to_dict_mapper = lambda anchor, pos, neg: {\n      'anchor_img': anchor, 'positive_img': pos, 'negative_img': neg}\n  if repeat_flag:\n    idx_dataset = idx_dataset.repeat()  # Repeat indefinitely.\n  if not deterministic_flag:\n    idx_dataset = idx_dataset.shuffle(shuffle_buf_size)\n    triplet_dataset = idx_dataset.map(\n        triplet_mapper, num_parallel_calls=max(4, batch_size // 4))\n    triplet_dataset = triplet_dataset.map(\n        triplet_to_dict_mapper, num_parallel_calls=max(4, batch_size // 4))\n  else:\n    triplet_dataset = idx_dataset.map(triplet_mapper, num_parallel_calls=None)\n    triplet_dataset = triplet_dataset.map(triplet_to_dict_mapper,\n                                          num_parallel_calls=None)\n  triplet_dataset = triplet_dataset.batch(batch_size)\n  if not deterministic_flag:\n    triplet_dataset = triplet_dataset.prefetch(4)  # Prefetch a few batches.\n  return triplet_dataset.make_one_shot_iterator()\n\n\ndef build_model_fn(batch_size, lr_app_pretrain=0.0001, adam_beta1=0.0,\n                   adam_beta2=0.99):\n  def model_fn(features, labels, mode, params):\n    del labels, params\n\n    step = tf.train.get_global_step()\n    app_func = networks.DRITAppearanceEncoderConcat(\n      'appearance_net', opts.appearance_nc, opts.normalize_drit_Ez)\n\n    if mode == tf.estimator.ModeKeys.TRAIN:\n      op_increment_step = tf.assign_add(step, 1)\n      with tf.name_scope('Appearance_Loss'):\n        anchor_img = features['anchor_img']\n        positive_img = features['positive_img']\n        negative_img = features['negative_img']\n        # Compute embeddings (each of shape [batch_sz, 1, 1, app_vector_sz])\n        z_anchor, _, _ = app_func(anchor_img)\n        z_pos, _, _ = app_func(positive_img)\n        z_neg, _, _ = app_func(negative_img)\n        # Squeeze into shape of [batch_sz x vec_sz]\n        anchor_embedding = tf.squeeze(z_anchor, axis=[1, 2], name='z_anchor')\n        positive_embedding = tf.squeeze(z_pos, axis=[1, 2])\n        negative_embedding = tf.squeeze(z_neg, axis=[1, 2])\n        # Compute triplet loss\n        margin = 0.1\n        anchor_positive_dist = tf.reduce_sum(\n            tf.square(anchor_embedding - positive_embedding), axis=1)\n        anchor_negative_dist = tf.reduce_sum(\n            tf.square(anchor_embedding - negative_embedding), axis=1)\n        triplet_loss = anchor_positive_dist - anchor_negative_dist + margin\n        triplet_loss = tf.maximum(triplet_loss, 0.)\n        triplet_loss = tf.reduce_sum(triplet_loss) / batch_size\n        tf.summary.scalar('appearance_triplet_loss', triplet_loss)\n\n        # Image summaries\n        anchor_rgb = tf.slice(anchor_img, [0, 0, 0, 0], [-1, -1, -1, 3])\n        positive_rgb = tf.slice(positive_img, [0, 0, 0, 0], [-1, -1, -1, 3])\n        negative_rgb = tf.slice(negative_img, [0, 0, 0, 0], [-1, -1, -1, 3])\n        tb_vis = tf.concat([anchor_rgb, positive_rgb, negative_rgb], axis=2)\n        with tf.name_scope('triplet_vis'):\n          tf.summary.image('anchor-pos-neg', tb_vis)\n\n      optimizer = tf.train.AdamOptimizer(lr_app_pretrain, adam_beta1,\n                                         adam_beta2)\n      optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)\n      app_vars = utils.model_vars('appearance_net')[0]\n      print('\\n\\n***************************************************')\n      print('DBG: len(app_vars) = %d' % len(app_vars))\n      for ii, v in enumerate(app_vars):\n        print('%03d) %s' % (ii, str(v)))\n      print('***************************************************\\n\\n')\n      app_train_op = optimizer.minimize(triplet_loss, var_list=app_vars)\n      return tf.estimator.EstimatorSpec(\n          mode=mode, loss=triplet_loss,\n          train_op=tf.group(app_train_op, op_increment_step))\n    elif mode == tf.estimator.ModeKeys.PREDICT:\n      imgs = features['anchor_img']\n      embeddings = tf.squeeze(app_func(imgs), axis=[1, 2])\n      app_vars = utils.model_vars('appearance_net')[0]\n      tf.train.init_from_checkpoint(osp.join(opts.train_dir),\n                                    {'appearance_net/': 'appearance_net/'})\n      return tf.estimator.EstimatorSpec(mode=mode, predictions=embeddings)\n    else:\n      raise ValueError('Unsupported mode for the appearance model: ' + mode)\n\n  return model_fn\n\n\ndef compute_dist_matrix(imageset_dir, dist_file_path, recompute_dist=False):\n  if not recompute_dist and osp.exists(dist_file_path):\n   print('*** Loading distance matrix from %s' % dist_file_path)\n   with open(dist_file_path, 'rb') as f:\n     dist_matrix = pickle.load(f)['dist_matrix']\n     print('loaded a dist_matrix of shape: %s' % str(dist_matrix.shape))\n     return dist_matrix\n  else:\n    images_paths = sorted(glob.glob(osp.join(imageset_dir, '*_reference.png')))\n    dist_matrix = style_loss.compute_pairwise_style_loss_v2(images_paths)\n    dist_dict = {'dist_matrix': dist_matrix}\n    print('Saving distance matrix to %s' % dist_file_path)\n    with open(dist_file_path, 'wb') as f:\n      pickle.dump(dist_dict, f)\n    return dist_matrix\n\n\ndef train_appearance(train_dir, imageset_dir, dist_file_path):\n  batch_size = 8\n  lr_app_pretrain = 0.001\n\n  trainset_size = len(glob.glob(osp.join(imageset_dir, '*_reference.png')))\n  resume_step = utils.load_global_step_from_checkpoint_dir(train_dir)\n  if resume_step != 0:\n    tf.logging.warning('DBG: resuming apperance pretraining at %d!' %\n                       resume_step)\n  model_fn = build_model_fn(batch_size, lr_app_pretrain)\n  config = tf.estimator.RunConfig(\n      save_summary_steps=50,\n      save_checkpoints_steps=500,\n      keep_checkpoint_max=5,\n      log_step_count_steps=100)\n  est = tf.estimator.Estimator(\n      tf.contrib.estimator.replicate_model_fn(model_fn), train_dir,\n      config, params={})\n  # Get input function\n  input_train_fn = lambda: get_tf_triplet_dataset_iter(\n      imageset_dir, trainset_size, dist_file_path,\n      batch_size=batch_size).get_next()\n  print('Starting pretraining steps...')\n  est.train(input_train_fn, steps=None, hooks=None)  # train indefinitely\n\n\ndef main(argv):\n  if len(argv) > 1:\n    raise app.UsageError('Too many command-line arguments.')\n\n  train_dir = opts.train_dir\n  dataset_name = opts.dataset_name\n  imageset_dir = opts.imageset_dir\n  output_dir = opts.metadata_output_dir\n  if not osp.exists(output_dir):\n    os.makedirs(output_dir)\n  dist_file_path = osp.join(output_dir, 'dist_%s.pckl' % dataset_name)\n  compute_dist_matrix(imageset_dir, dist_file_path)\n  train_appearance(train_dir, imageset_dir, dist_file_path)\n\nif __name__ == '__main__':\n  app.run(main)\n"
  },
  {
    "path": "segment_dataset.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Generate semantic segmentations\nThis module uses Xception model trained on ADE20K dataset to generate semantic\nsegmentation mask to any set of images.\n\"\"\"\n\nfrom absl import app\nfrom absl import flags\nfrom PIL import Image\nimport glob\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport os\nimport os.path as osp\nimport shutil\nimport tensorflow as tf\nimport utils\n\n\ndef get_semantic_color_coding():\n  \"\"\"\n  assigns the 30 (actually 29) semantic colors from cityscapes semantic mapping\n  to selected classes from the ADE20K150 semantic classes.\n  \"\"\"\n  # Below are the 30 cityscape colors (one is duplicate. so total is 29 not 30)\n  colors = [\n    (111, 74,  0),\n    ( 81,  0, 81),\n    (128, 64,128),\n    (244, 35,232),\n    (250,170,160),\n    (230,150,140),\n    ( 70, 70, 70),\n    (102,102,156),\n    (190,153,153),\n    (180,165,180),\n    (150,100,100),\n    (150,120, 90),\n    (153,153,153),\n    # (153,153,153),\n    (250,170, 30),\n    (220,220,  0),\n    (107,142, 35),\n    (152,251,152),\n    ( 70,130,180),\n    (220, 20, 60),\n    (255,  0,  0),\n    (  0,  0,142),\n    (  0,  0, 70),\n    (  0, 60,100),\n    (  0,  0, 90),\n    (  0,  0,110),\n    (  0, 80,100),\n    (  0,  0,230),\n    (119, 11, 32),\n    (  0,  0,142)]\n  k_num_ade20k_classes = 150\n  # initially all 150 classes are mapped to a single color (last color idx: -1)\n  # Some classes are to be assigned independent colors\n  # semantic classes are 1-based (1 thru 150)\n  semantic_to_color_idx = -1 * np.ones(k_num_ade20k_classes + 1, dtype=int)\n  semantic_to_color_idx [1] = 0    # wall\n  semantic_to_color_idx [2] = 1    # building;edifice\n  semantic_to_color_idx [3] = 2    # sky\n  semantic_to_color_idx [105] = 3  # fountain\n  semantic_to_color_idx [27] = 4   # sea\n  semantic_to_color_idx [60] = 5   # stairway;staircase \n  semantic_to_color_idx [5] = 6    # tree\n  semantic_to_color_idx [12] = 7   # sidewalk;pavement \n  semantic_to_color_idx [4]  = 7   # floor;flooring\n  semantic_to_color_idx [7]  = 7   # road;route\n  semantic_to_color_idx [13] = 8   # people\n  semantic_to_color_idx [18] = 9   # plant;flora;plant;life\n  semantic_to_color_idx [17] = 10  # mountain;mount\n  semantic_to_color_idx [20] = 11  # chair\n  semantic_to_color_idx [6] = 12   # ceiling\n  semantic_to_color_idx [22] = 13  # water\n  semantic_to_color_idx [35] = 14  # rock;stone\n  semantic_to_color_idx [14] = 15  # earth;ground\n  semantic_to_color_idx [10] = 16  # grass\n  semantic_to_color_idx [70] = 17  # bench\n  semantic_to_color_idx [54] = 18  # stairs;steps\n  semantic_to_color_idx [101] = 19 # poster\n  semantic_to_color_idx [77] = 20  # boat\n  semantic_to_color_idx [85] = 21  # tower\n  semantic_to_color_idx [23] = 22  # painting;picture\n  semantic_to_color_idx [88] = 23  # streetlight;stree;lamp\n  semantic_to_color_idx [43] = 24  # column;pillar\n  semantic_to_color_idx [9] = 25   # window;windowpane\n  semantic_to_color_idx [15] = 26  # door;\n  semantic_to_color_idx [133] = 27 # sculpture\n\n  semantic_to_rgb = np.array(\n    [colors[col_idx][:] for col_idx in semantic_to_color_idx])\n  return semantic_to_rgb\n\n\ndef _apply_colors(seg_images_path, save_dir, idx_to_color):\n  for i, img_path in enumerate(seg_images_path):\n    print('processing img #%05d / %05d: %s' % (i, len(seg_images_path),\n                                               osp.split(img_path)[1]))\n    seg = np.array(Image.open(img_path))\n    seg_rgb = np.zeros(seg.shape + (3,), dtype=np.uint8)\n    for col_idx in range(len(idx_to_color)):\n      if idx_to_color[col_idx][0] != -1:\n        mask = seg == col_idx\n        seg_rgb[mask, :] = idx_to_color[col_idx][:]\n\n    parent_dir, filename = osp.split(img_path)\n    basename, ext = osp.splitext(filename)\n    out_filename = basename + \"_rgb.png\"\n    out_filepath = osp.join(save_dir, out_filename)\n    # Save rescaled segmentation image\n    Image.fromarray(seg_rgb).save(out_filepath)\n\n\n# The frozen xception model only segments 512x512 images. But it would be better\n# to segment the full image instead!\ndef segment_images(images_path, xception_frozen_graph_path, save_dir,\n                   crop_height=512, crop_width=512):\n  if not osp.exists(xception_frozen_graph_path):\n    raise OSError('Xception frozen graph not found at %s' %\n                            xception_frozen_graph_path)\n  with tf.gfile.GFile(xception_frozen_graph_path, \"rb\") as f:\n    graph_def = tf.GraphDef()\n    graph_def.ParseFromString(f.read())\n\n  with tf.Graph().as_default() as graph:\n    new_input = tf.placeholder(tf.uint8, [1, crop_height, crop_width, 3],\n                               name=\"new_input\")\n    tf.import_graph_def(\n      graph_def,\n      input_map={\"ImageTensor:0\": new_input},\n      return_elements=None,\n      name=\"sem_seg\",\n      op_dict=None,\n      producer_op_list=None\n    )\n\n  corrupted_dir = osp.join(save_dir, 'corrupted')\n  if not osp.exists(corrupted_dir):\n    os.makedirs(corrupted_dir)\n  with tf.Session(graph=graph) as sess:\n    for i, img_path in enumerate(images_path):\n      print('Segmenting image %05d / %05d: %s' % (i + 1, len(images_path),\n                                                  img_path))\n      img = np.array(Image.open(img_path))\n      if len(img.shape) == 2 or img.shape[2] != 3:\n        print('Warning! corrupted image %s' % img_path)\n        img_base_path = img_path[:-14]  # remove the '_reference.png' suffix\n        srcs = sorted(glob.glob(img_base_path + '_*'))\n        dest = unicode(corrupted_dir + '/.')\n        for src in srcs:\n          shutil.move(src, dest)\n        continue\n      img = utils.get_central_crop(img, crop_height=crop_height,\n                             crop_width=crop_width)\n      img = np.expand_dims(img, 0)  # convert to NHWC format\n      seg = sess.run(\"sem_seg/SemanticPredictions:0\", feed_dict={\n          new_input: img})\n      assert np.max(seg[:]) <= 255, 'segmentation image is not of type uint8!'\n      seg = np.squeeze(np.uint8(seg))  # convert to uint8 and squeeze to WxH.\n      parent_dir, filename = osp.split(img_path)\n      basename, ext = osp.splitext(filename)\n      basename = basename[:-10]  # remove the '_reference' suffix\n      seg_filename = basename + \"_seg.png\"\n      seg_filepath = osp.join(save_dir, seg_filename)\n      # Save segmentation image\n      Image.fromarray(seg).save(seg_filepath)\n\ndef segment_and_color_dataset(dataset_dir, xception_frozen_graph_path,\n                              splits=None, resegment_images=True):\n  if splits is None:\n    imgs_dirs = [dataset_dir]\n  else:\n    imgs_dirs = [osp.join(dataset_dir, split) for split in splits]\n  \n  for cur_dir in imgs_dirs:\n    imgs_file_pattern = osp.join(cur_dir, '*_reference.png')\n    images_path = sorted(glob.glob(imgs_file_pattern))\n    if resegment_images:\n      segment_images(images_path, xception_frozen_graph_path, cur_dir,\n                     crop_height=512, crop_width=512)\n\n  idx_to_col = get_semantic_color_coding()\n\n  for cur_dir in imgs_dirs:\n    save_dir = cur_dir\n    seg_file_pattern = osp.join(cur_dir, '*_seg.png')\n    seg_imgs_paths = sorted(glob.glob(seg_file_pattern))\n    _apply_colors(seg_imgs_paths, save_dir, idx_to_col)\n"
  },
  {
    "path": "staged_model.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Neural re-rerendering in the wild.\n\nImplementation of the staged training pipeline.\n\"\"\"\n\nfrom options import FLAGS as opts\nimport losses\nimport networks\nimport tensorflow as tf\nimport utils\n\n\ndef create_computation_graph(x_in, x_gt, x_app=None, arch_type='pggan',\n                             use_appearance=True):\n  \"\"\"Create the models and the losses.\n\n  Args:\n    x_in: 4D tensor, batch of conditional input images in NHWC format.\n    x_gt: 2D tensor, batch ground-truth images in NHWC format.\n    x_app: 4D tensor, batch of input appearance images.\n\n  Returns:\n    Dictionary of placeholders and TF graph functions.\n  \"\"\"\n  # ---------------------------------------------------------------------------\n  # Build models/networks\n  # ---------------------------------------------------------------------------\n\n  rerenderer = networks.RenderingModel(arch_type, use_appearance)\n  app_enc = rerenderer.get_appearance_encoder()\n  discriminator = networks.MultiScaleDiscriminator(\n      'd_model', opts.appearance_nc, num_scales=3, nf=64, n_layers=3,\n      get_fmaps=False)\n\n  # ---------------------------------------------------------------------------\n  # Forward pass\n  # ---------------------------------------------------------------------------\n\n  if opts.use_appearance:\n    z_app, _, _ = app_enc(x_app)\n  else:\n    z_app = None\n\n  y = rerenderer(x_in, z_app)\n\n  # ---------------------------------------------------------------------------\n  # Losses\n  # ---------------------------------------------------------------------------\n\n  w_loss_gan = opts.w_loss_gan\n  w_loss_recon = opts.w_loss_vgg if opts.use_vgg_loss else opts.w_loss_l1\n\n  # compute discriminator logits\n  disc_real_featmaps = discriminator(x_gt, x_in)\n  disc_fake_featmaps = discriminator(y, x_in)\n\n  # discriminator loss\n  loss_d_real = losses.multiscale_discriminator_loss(disc_real_featmaps, True)\n  loss_d_fake = losses.multiscale_discriminator_loss(disc_fake_featmaps, False)\n  loss_d = loss_d_real + loss_d_fake\n\n  # generator loss\n  loss_g_gan = losses.multiscale_discriminator_loss(disc_fake_featmaps, True)\n  if opts.use_vgg_loss:\n    vgg_layers = ['conv%d_2' % i for i in range(1, 6)]  # conv1 through conv5\n    vgg_layer_weights = [1./32, 1./16, 1./8, 1./4, 1.]\n    vgg_loss = losses.PerceptualLoss(y, x_gt, [256, 256, 3], vgg_layers,\n                                     vgg_layer_weights)  # NOTE: shouldn't hardcode image size!\n    loss_g_recon = vgg_loss()\n  else:\n    loss_g_recon = losses.L1_loss(y, x_gt)\n  loss_g = w_loss_gan * loss_g_gan + w_loss_recon * loss_g_recon\n\n  # ---------------------------------------------------------------------------\n  # Tensorboard visualizations\n  # ---------------------------------------------------------------------------\n\n  x_in_render = tf.slice(x_in, [0, 0, 0, 0], [-1, -1, -1, 3])\n  if opts.use_semantic:\n    x_in_semantic = tf.slice(x_in, [0, 0, 0, 4], [-1, -1, -1, 3])\n    tb_visualization = tf.concat([x_in_render, x_in_semantic, y, x_gt], axis=2)\n  else:\n    tb_visualization = tf.concat([x_in_render, y, x_gt], axis=2)\n  tf.summary.image('rendered-semantic-generated-gt tuple', tb_visualization)\n\n  # Show input appearance images\n  if opts.use_appearance:\n    x_app_rgb = tf.slice(x_app, [0, 0, 0, 0], [-1, -1, -1, 3])\n    x_app_sem = tf.slice(x_app, [0, 0, 0, 7], [-1, -1, -1, -1])\n    tb_app_visualization = tf.concat([x_app_rgb, x_app_sem], axis=2)\n    tf.summary.image('input appearance image', tb_app_visualization)\n\n  # Loss summaries\n  with tf.name_scope('Discriminator_Loss'):\n    tf.summary.scalar('D_real_loss', loss_d_real)\n    tf.summary.scalar('D_fake_loss', loss_d_fake)\n    tf.summary.scalar('D_total_loss', loss_d)\n  with tf.name_scope('Generator_Loss'):\n    tf.summary.scalar('G_GAN_loss', w_loss_gan * loss_g_gan)\n    tf.summary.scalar('G_reconstruction_loss', w_loss_recon * loss_g_recon)\n    tf.summary.scalar('G_total_loss', loss_g)\n\n  # ---------------------------------------------------------------------------\n  # Optimizers\n  # ---------------------------------------------------------------------------\n\n  def get_optimizer(lr, loss, var_list):\n    optimizer = tf.train.AdamOptimizer(lr, opts.adam_beta1, opts.adam_beta2)\n    # optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)\n    return optimizer.minimize(loss, var_list=var_list)\n\n  # Training ops.\n  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n  with tf.control_dependencies(update_ops):\n    with tf.variable_scope('optimizers'):\n      d_vars = utils.model_vars('d_model')[0]\n      g_vars_all = utils.model_vars('g_model')[0]\n      train_d = [get_optimizer(opts.d_lr, loss_d, d_vars)]\n      train_g = [get_optimizer(opts.g_lr, loss_g, g_vars_all)]\n\n      train_app_encoder = []\n      if opts.train_app_encoder:\n        lr_app = opts.ez_lr\n        app_enc_vars = utils.model_vars('appearance_net')[0]\n        train_app_encoder.append(get_optimizer(lr_app, loss_g, app_enc_vars))\n\n  ema = tf.train.ExponentialMovingAverage(decay=0.999)\n  with tf.control_dependencies(train_g + train_app_encoder):\n    inference_vars_all = g_vars_all\n    if opts.use_appearance:\n      app_enc_vars = utils.model_vars('appearance_net')[0]\n      inference_vars_all += app_enc_vars\n    ema_op = ema.apply(inference_vars_all)\n\n  print('***************************************************')\n  print('len(g_vars_all) = %d' % len(g_vars_all))\n  for ii, v in enumerate(g_vars_all):\n    print('%03d) %s' % (ii, str(v)))\n  print('-------------------------------------------------------')\n  print('len(d_vars) = %d' % len(d_vars))\n  for ii, v in enumerate(d_vars):\n    print('%03d) %s' % (ii, str(v)))\n  if opts.train_app_encoder:\n    print('-------------------------------------------------------')\n    print('len(app_enc_vars) = %d' % len(app_enc_vars))\n    for ii, v in enumerate(app_enc_vars):\n      print('%03d) %s' % (ii, str(v)))\n  print('***************************************************\\n\\n')\n\n  return {\n      'train_disc_op': tf.group(train_d),\n      'train_renderer_op': ema_op,\n      'total_loss_d': loss_d,\n      'loss_d_real': loss_d_real,\n      'loss_d_fake': loss_d_fake,\n      'loss_g_gan': w_loss_gan * loss_g_gan,\n      'loss_g_recon': w_loss_recon * loss_g_recon,\n      'total_loss_g': loss_g}\n"
  },
  {
    "path": "style_loss.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom PIL import Image\nfrom options import FLAGS as opts\nimport data\nimport layers\nimport numpy as np\nimport tensorflow as tf\nimport utils\nimport vgg16\n\n\ndef gram_matrix(layer):\n  \"\"\"Computes the gram_matrix for a batch of single vgg layer\n  Input:\n    layer: a batch of vgg activations for a single conv layer\n  Returns:\n    gram: [batch_sz x num_channels x num_channels]: a batch of gram matrices\n  \"\"\"\n  batch_size, height, width, num_channels = layer.get_shape().as_list()\n  features = tf.reshape(layer, [batch_size, height * width, num_channels])\n  num_elements = tf.constant(num_channels * height * width, tf.float32)\n  gram = tf.matmul(features, features, adjoint_a=True) / num_elements\n  return gram\n\n\ndef compute_gram_matrices(\n    images, vgg_layers=['conv1_2', 'conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']):\n  \"\"\"Computes the gram matrix representation of a batch of images\"\"\"\n  vgg_net = vgg16.Vgg16(opts.vgg16_path)\n  vgg_acts = vgg_net.get_vgg_activations(images, vgg_layers)\n  grams = [gram_matrix(layer) for layer in vgg_acts]\n  return grams\n\n\ndef compute_pairwise_style_loss_v2(image_paths_list):\n  grams_all = [None] * len(image_paths_list)\n  crop_height, crop_width = opts.train_resolution, opts.train_resolution\n  img_var = tf.placeholder(tf.float32, shape=[1, crop_height, crop_width, 3])\n  vgg_layers = ['conv%d_2' % i for i in range(1, 6)]  # conv1 through conv5\n  grams_ops = compute_gram_matrices(img_var, vgg_layers)\n  with tf.Session() as sess:\n    for ii, img_path in enumerate(image_paths_list):\n      print('Computing gram matrices for image #%d' % (ii + 1))\n      img = np.array(Image.open(img_path), dtype=np.float32)\n      img = img * 2. / 255. - 1  # normalize image\n      img = utils.get_central_crop(img, crop_height, crop_width)\n      img = np.expand_dims(img, axis=0)\n      grams_all[ii] = sess.run(grams_ops, feed_dict={img_var: img})\n  print('Number of images = %d' % len(grams_all))\n  print('Gram matrices per image:')\n  for i in range(len(grams_all[0])):\n    print('gram_matrix[%d].shape = %s' % (i, grams_all[0][i].shape))\n  n_imgs = len(grams_all)\n  dist_matrix = np.zeros((n_imgs, n_imgs))\n  for i in range(n_imgs):\n    print('Computing distances for image #%d' % i)\n    for j in range(i + 1, n_imgs):\n      loss_style = 0\n      # Compute loss using all gram matrices from all layers\n      for gram_i, gram_j in zip(grams_all[i], grams_all[j]):\n        loss_style += np.mean((gram_i - gram_j) ** 2, axis=(1, 2))\n      dist_matrix[i][j] = dist_matrix[j][i] = loss_style\n\n  return dist_matrix\n"
  },
  {
    "path": "utils.py",
    "content": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Utilities for GANs.\n\nBasic functions such as generating sample grid, exporting to PNG, etc...\n\"\"\"\n\nimport functools\nimport numpy as np\nimport os.path\nimport tensorflow as tf\nimport time\n\n\ndef crop_to_multiple(img, size_multiple=64):\n  \"\"\" Crops the image so that its dimensions are multiples of size_multiple.\"\"\"\n  new_width = (img.shape[1] // size_multiple) * size_multiple\n  new_height = (img.shape[0] // size_multiple) * size_multiple\n  offset_x = (img.shape[1] - new_width) // 2\n  offset_y = (img.shape[0] - new_height) // 2\n  return img[offset_y:offset_y + new_height, offset_x:offset_x + new_width, :]\n\n\ndef get_central_crop(img, crop_height=512, crop_width=512):\n  if len(img.shape) == 2:\n    img = np.expand_dims(img, axis=2)\n  assert len(img.shape) == 3, ('input image should be either a 2D or 3D matrix,'\n                               ' but input was of shape %s' % str(img.shape))\n  height, width, _ = img.shape\n  assert height >= crop_height and width >= crop_width, ('input image cannot '\n      'be smaller than the requested crop size')\n  st_y = (height - crop_height) // 2\n  st_x = (width - crop_width) // 2\n  return np.squeeze(img[st_y : st_y + crop_height, st_x : st_x + crop_width, :])\n\n\ndef load_global_step_from_checkpoint_dir(checkpoint_dir):\n  \"\"\"Loads  the global step from the checkpoint directory.\n\n  Args:\n    checkpoint_dir: string, path to the checkpoint directory.\n\n  Returns:\n    int, the global step of the latest checkpoint or 0 if none was found.\n  \"\"\"\n  try:\n    checkpoint_reader = tf.train.NewCheckpointReader(\n        tf.train.latest_checkpoint(checkpoint_dir))\n    return checkpoint_reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)\n  except:\n    return 0\n\n\ndef model_vars(prefix):\n  \"\"\"Return trainable variables matching a prefix.\n\n  Args:\n    prefix: string, the prefix variable names must match.\n\n  Returns:\n    a tuple (match, others) of TF variables, 'match' contains the matched\n     variables and 'others' contains the remaining variables.\n  \"\"\"\n  match, no_match = [], []\n  for x in tf.trainable_variables():\n    if x.name.startswith(prefix):\n      match.append(x)\n    else:\n      no_match.append(x)\n  return match, no_match\n\n\ndef to_png(x):\n  \"\"\"Convert a 3D tensor to png.\n\n  Args:\n    x: Tensor, 01C formatted input image.\n\n  Returns:\n    Tensor, 1D string representing the image in png format.\n  \"\"\"\n  with tf.Graph().as_default():\n    with tf.Session() as sess_temp:\n      x = tf.constant(x)\n      y = tf.image.encode_png(\n          tf.cast(\n              tf.clip_by_value(tf.round(127.5 + 127.5 * x), 0, 255), tf.uint8),\n          compression=9)\n      return sess_temp.run(y)\n\n\ndef images_to_grid(images):\n  \"\"\"Converts a grid of images (5D tensor) to a single image.\n\n  Args:\n    images: 5D tensor (count_y, count_x, height, width, colors), grid of images.\n\n  Returns:\n    a 3D tensor image of shape (count_y * height, count_x * width, colors).\n  \"\"\"\n  ny, nx, h, w, c = images.shape\n  images = images.transpose(0, 2, 1, 3, 4)\n  images = images.reshape([ny * h, nx * w, c])\n  return images\n\n\ndef save_images(image, output_dir, cur_nimg):\n  \"\"\"Saves images to disk.\n\n  Saves a file called 'name.png' containing the latest samples from the\n   generator and a file called 'name_123.png' where 123 is the KiB of trained\n   images.\n\n  Args:\n    image: 3D numpy array (height, width, colors), the image to save.\n    output_dir: string, the directory where to save the image.\n    cur_nimg: int, current number of images seen by training.\n\n  Returns:\n    None\n  \"\"\"\n  for name in ('name.png', 'name_%06d.png' % (cur_nimg >> 10)):\n    with tf.gfile.Open(os.path.join(output_dir, name), 'wb') as f:\n      f.write(image)\n\n\nclass HookReport(tf.train.SessionRunHook):\n  \"\"\"Custom reporting hook.\n\n  Register your tensor scalars with HookReport.log_tensor(my_tensor, 'my_name').\n  This hook will report their average values over report period argument\n  provided to the constructed. The values are printed in the order the tensors\n  were registered.\n\n  Attributes:\n    step: int, the current global step.\n    active: bool, whether logging is active or disabled.\n  \"\"\"\n  _REPORT_KEY = 'report'\n  _TENSOR_NAMES = {}\n\n  def __init__(self, period, batch_size):\n    self.step = 0\n    self.active = True\n    self._period = period // batch_size\n    self._batch_size = batch_size\n    self._sums = np.array([])\n    self._count = 0\n    self._nimgs_per_cycle = 0\n    self._step_ratio = 0\n    self._start = time.time()\n    self._nimgs = 0\n    self._batch_size = batch_size\n\n  def disable(self):\n    parent = self\n\n    class Disabler(object):\n\n      def __enter__(self):\n        parent.active = False\n        return parent\n\n      def __exit__(self, exc_type, exc_val, exc_tb):\n        parent.active = True\n\n    return Disabler()\n\n  def begin(self):\n    self.active = True\n    self._count = 0\n    self._nimgs_per_cycle = 0\n    self._start = time.time()\n\n  def before_run(self, run_context):\n    if not self.active:\n      return\n    del run_context\n    fetches = tf.get_collection(self._REPORT_KEY)\n    return tf.train.SessionRunArgs(fetches)\n\n  def after_run(self, run_context, run_values):\n    if not self.active:\n      return\n    del run_context\n    results = run_values.results\n    # Note: sometimes the returned step is incorrect (off by one) for some\n    # unknown reason.\n    self.step = results[-1] + 1\n    self._count += 1\n    self._nimgs_per_cycle += self._batch_size\n    self._nimgs += self._batch_size\n\n    if not self._sums.size:\n      self._sums = np.array(results[:-1], 'd')\n    else:\n      self._sums += np.array(results[:-1], 'd')\n\n    if self.step // self._period != self._step_ratio:\n      fetches = tf.get_collection(self._REPORT_KEY)[:-1]\n      stats = '  '.join('%s=% .2f' % (self._TENSOR_NAMES[tensor],\n                                      value / self._count)\n                        for tensor, value in zip(fetches, self._sums))\n      stop = time.time()\n      tf.logging.info('step=%d, kimg=%d  %s  [%.2f img/s]' %\n                      (self.step, ((self.step * self._batch_size) >> 10),\n                       stats, self._nimgs_per_cycle / (stop - self._start)))\n      self._step_ratio = self.step // self._period\n      self._start = stop\n      self._sums *= 0\n      self._count = 0\n      self._nimgs_per_cycle = 0\n\n  def end(self, session=None):\n    del session\n\n  @classmethod\n  def log_tensor(cls, tensor, name):\n    \"\"\"Adds a tensor to be reported by the hook.\n\n    Args:\n      tensor: `tensor scalar`, a value to report.\n      name: string, the name to give the value in the report.\n\n    Returns:\n      None.\n    \"\"\"\n    cls._TENSOR_NAMES[tensor] = name\n    tf.add_to_collection(cls._REPORT_KEY, tensor)\n"
  }
]