master 5f5226f00835 cached
16 files
144.2 KB
36.6k tokens
130 symbols
1 requests
Download .txt
Repository: google/neural_rerendering_in_the_wild
Branch: master
Commit: 5f5226f00835
Files: 16
Total size: 144.2 KB

Directory structure:
gitextract_k0balaqs/

├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── data.py
├── dataset_utils.py
├── evaluate_quantitative_metrics.py
├── layers.py
├── losses.py
├── networks.py
├── neural_rerendering.py
├── options.py
├── pretrain_appearance.py
├── segment_dataset.py
├── staged_model.py
├── style_loss.py
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: CONTRIBUTING.md
================================================
# How to Contribute

We'd love to accept your patches and contributions to this project. There are
just a few small guidelines you need to follow.

## Contributor License Agreement

Contributions to this project must be accompanied by a Contributor License
Agreement. You (or your employer) retain the copyright to your contribution;
this simply gives us permission to use and redistribute your contributions as
part of the project. Head over to <https://cla.developers.google.com/> to see
your current agreements on file or to sign a new one.

You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.

## Code reviews

All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.

## Community Guidelines

This project follows
[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).


================================================
FILE: LICENSE
================================================

                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

================================================
FILE: README.md
================================================
# Neural Rerendering in the Wild
Moustafa Meshry<sup>1</sup>,
[Dan B Goldman](http://www.danbgoldman.com/)<sup>2</sup>,
[Sameh Khamis](http://www.samehkhamis.com/)<sup>2</sup>,
[Hugues Hoppe](http://hhoppe.com/)<sup>2</sup>,
Rohit Pandey<sup>2</sup>,
[Noah Snavely](http://www.cs.cornell.edu/~snavely/)<sup>2</sup>,
[Ricardo Martin-Brualla](http://www.ricardomartinbrualla.com/)<sup>2</sup>.

<sup>1</sup>University of Maryland, College Park &nbsp;&nbsp;&nbsp;&nbsp; <sup>2</sup>Google Inc.

To appear at CVPR 2019 (Oral). <br><br>


<figure class="image">
  <img align="center" src="imgs/teaser_with_caption.jpg" width="500px">
</figure>

<!--- ![Teaser figure](https://github.com/MoustafaMeshry/neural_rerendering_in_the_wild/blob/master/imgs/teaser_with_caption.jpg?raw=true | width=450) --->

We will provide Tensorflow implementation and pretrained models for our paper soon.

[**Paper**](https://arxiv.org/abs/1904.04290) | [**Video**](https://www.youtube.com/watch?v=E1crWQn_kmY) | [**Code**](https://github.com/MoustafaMeshry/neural_rerendering_in_the_wild) | [**Project page**](https://moustafameshry.github.io/neural_rerendering_in_the_wild/)

### Abstract

We explore total scene capture — recording, modeling, and rerendering a scene under varying appearance such as season and time of day.
Starting from internet photos of a tourist landmark, we apply traditional 3D reconstruction to register the photos and approximate the scene as a point cloud.
For each photo, we render the scene points into a deep framebuffer,
and train a neural network to learn the mapping of these initial renderings to the actual photos.
This rerendering network also takes as input a latent appearance vector and a semantic mask indicating the location of transient objects like pedestrians.
The model is evaluated on several datasets of publicly available images spanning a broad range of illumination conditions.
We create short videos demonstrating realistic manipulation of the image viewpoint, appearance, and semantic labeling.
We also compare results with prior work on scene reconstruction from internet photos.

### Video
[![Supplementary material video](https://img.youtube.com/vi/E1crWQn_kmY/0.jpg)](https://www.youtube.com/watch?v=E1crWQn_kmY)


### Appearance variation

We capture the appearance of the original images in the left column, and rerender several viewpoints under them. The last column is a detail of the previous one. The top row shows the renderings part of the input to the rerenderer, that exhibit artifacts like incomplete features in the statue, and an inconsistent mix of day and night appearances. Note the hallucinated twilight scene in the sky using the last appearance. Image credits: Flickr users William Warby, Neil Rickards, Rafael Jimenez, acme401 (Creative Commons).

<figure class="image">
  <img src="imgs/app_variation.jpg" width="900px">
</figure>

### Appearance interpolation
Frames from a synthesized camera path that smoothly transitions from the photo on the left to the photo on the right by smoothly interpolating both viewpoint and the latent appearance vectors. Please see the supplementary video. Photo Credits: Allie Caulfield, Tahbepet, Till Westermayer, Elliott Brown (Creative Commons).
<figure class="image">
  <img src="imgs/app_interpolation.jpg" width="900px">
</figure>

### Acknowledgements
We thank Gregory Blascovich for his help in conducting the user study, and Johannes Schönberger and True Price for their help generating datasets.

### Run and train instructions

Staged-training consists of three stages:

-   Pretraining the appearance network.
-   Training the rendering network while fixing the weights for the appearance
    network.
-   Finetuning both the appearance and the rendering networks.

### Aligned dataset preprocessing

#### Manual preparation

*   Set a path to a base_dir that contains the source code:

```
base_dir=//to/neural_rendering
mkdir $base_dir
cd $base_dir
```

*   We assume the following format for an aligned dataset:
    * Each training image contains 3 file with the following nameing format:
        * real image: %04d_reference.png
        * render color: %04d_color.png
        * render depth: %04d_depth.png
*   Set dataset name: e.g.
```
dataset_name='trevi3k'  # set to any name
```
*   Split the dataset into train and validation sets in two subdirectories:
    *   $base_dir/datasets/$dataset_name/train
    *   $base_dir/datasets/$dataset_name/val
*   Download the DeepLab semantic segmentation model trained on the ADE20K
    dataset from this link:
    http://download.tensorflow.org/models/deeplabv3_xception_ade20k_train_2018_05_29.tar.gz
*   Unzip the downloaded file to: $base_dir/deeplabv3_xception_ade20k_train
*   Download this [file](https://github.com/MoustafaMeshry/vgg_loss/blob/master/vgg16.py) for an implementation of a vgg-based perceptual loss.
*   Download trained weights for the vgg network as instructed in this link: https://github.com/machrisaa/tensorflow-vgg
*   Save the vgg weights to $base_dir/vgg16_weights/vgg16.npy


#### Data preprocessing

*   Run the preprocessing pipeline which consists of:
    *   Filtering out sparse renders.
    *   Semantic segmentation of ground truth images.
    *   Exporting the dataset to tfrecord format.

```
# Run locally
python tools/dataset_utils.py \
--dataset_name=$dataset_name \
--dataset_parent_dir=$base_dir/datasets/$dataset_name \
--output_dir=$base_dir/datasets/$dataset_name \
--xception_frozen_graph_path=$base_dir/deeplabv3_xception_ade20k_train/frozen_inference_graph.pb \
--alsologtostderr
```

### Pretraining the appearance encoder network

```
# Run locally
python pretrain_appearance.py \
  --dataset_name=$dataset_name \
  --train_dir=$base_dir/train_models/$dataset_name-app_pretrain \
  --imageset_dir=$base_dir/datasets/$dataset_name/train \
  --train_resolution=512 \
  --metadata_output_dir=$base_dir/datasets/$dataset_name
```

### Training the rerendering network with a fixed appearance encoder

Set the dataset_parent_dir variable below to point to the directory containing
the generated TFRecords.

```
# Run locally:
dataset_parent_dir=$base_dir/datasets/$dataset_name
train_dir=$base_dir/train_models/$dataset_name-staged-fixed_appearance
load_pretrained_app_encoder=true
appearance_pretrain_dir=$base_dir/train_models/$dataset_name-app_pretrain
load_from_another_ckpt=false
fixed_appearance_train_dir=''
train_app_encoder=false

python neural_rerendering.py \
--dataset_name=$dataset_name \
--dataset_parent_dir=$dataset_parent_dir \
--train_dir=$train_dir \
--load_pretrained_app_encoder=$load_pretrained_app_encoder \
--appearance_pretrain_dir=$appearance_pretrain_dir \
--train_app_encoder=$train_app_encoder \
--load_from_another_ckpt=$load_from_another_ckpt \
--fixed_appearance_train_dir=$fixed_appearance_train_dir \
--total_kimg=4000
```

### Finetuning the rerendering network and the appearance encoder

Set the fixed_appearance_train_dir to the train directory from the previous
step.

```
# Run locally:
dataset_parent_dir=$base_dir/datasets/$dataset_name
train_dir=$base_dir/train_models/$dataset_name-staged-finetune_appearance
load_pretrained_app_encoder=false
appearance_pretrain_dir=''
load_from_another_ckpt=true
fixed_appearance_train_dir=$base_dir/train_models/$dataset_name-staged-fixed_appearance
train_app_encoder=true

python neural_rerendering.py \
--dataset_name=$dataset_name \
--dataset_parent_dir=$dataset_parent_dir \
--train_dir=$train_dir \
--load_pretrained_app_encoder=$load_pretrained_app_encoder \
--appearance_pretrain_dir=$appearance_pretrain_dir \
--train_app_encoder=$train_app_encoder \
--load_from_another_ckpt=$load_from_another_ckpt \
--fixed_appearance_train_dir=$fixed_appearance_train_dir \
--total_kimg=4000
```


### Evaluate model on validation set

```
experiment_title=$dataset_name-staged-finetune_appearance
local_train_dir=$base_dir/train_models/$experiment_title
dataset_parent_dir=$base_dir/datasets/$dataset_name
val_set_out_dir=$local_train_dir/val_set_output

# Run the model on validation set
echo "Evaluating the validation set"
python neural_rerendering.py \
      --train_dir=$local_train_dir \
      --dataset_name=$dataset_name \
      --dataset_parent_dir=$dataset_parent_dir \
      --run_mode='eval_subset' \
      --virtual_seq_name='val' \
      --output_validation_dir=$val_set_out_dir \
      --logtostderr
# Evaluate quantitative metrics
python evaluate_quantitative_metrics.py \
      --val_set_out_dir=$val_set_out_dir \
      --experiment_title=$experiment_title \
      --logtostderr
```


================================================
FILE: data.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from options import FLAGS as opts
import functools
import glob
import numpy as np
import os.path as osp
import random
import tensorflow as tf


def provide_data(dataset_name='', parent_dir='', batch_size=8, subset=None,
                 max_examples=None, crop_flag=False, crop_size=256, seeds=None,
                 use_appearance=True, shuffle=128):
  # Parsing function for each tfrecord example.
  record_parse_fn = functools.partial(
      _parser_rendered_dataset, crop_flag=crop_flag, crop_size=crop_size,
      use_alpha=opts.use_alpha, use_depth=opts.use_depth,
      use_semantics=opts.use_semantic, seeds=seeds,
      use_appearance=use_appearance)

  input_dict_var = multi_input_fn_record(
      record_parse_fn, parent_dir, dataset_name, batch_size,
      subset=subset, max_examples=max_examples, shuffle=shuffle)
  return input_dict_var


def _parser_rendered_dataset(
    serialized_example, crop_flag, crop_size, seeds, use_alpha, use_depth,
    use_semantics, use_appearance):
  """
  Parses a single tf.Example into a features dictionary with input tensors.
  """
  # Structure of features_dict need to match the dictionary structure that was
  # serialized to a tf.Example
  features_dict = {'height': tf.FixedLenFeature([], tf.int64),
                   'width': tf.FixedLenFeature([], tf.int64),
                   'rendered': tf.FixedLenFeature([], tf.string),
                   'depth': tf.FixedLenFeature([], tf.string),
                   'real': tf.FixedLenFeature([], tf.string),
                   'seg': tf.FixedLenFeature([], tf.string)}
  features = tf.parse_single_example(serialized_example, features=features_dict)
  height = tf.cast(features['height'], tf.int32)
  width = tf.cast(features['width'], tf.int32)

  # Parse the rendered image.
  rendered = tf.decode_raw(features['rendered'], tf.uint8)
  rendered = tf.cast(rendered, tf.float32) * (2.0 / 255) - 1.0
  rendered = tf.reshape(rendered, [height, width, 4])
  if not use_alpha:
    rendered = tf.slice(rendered, [0, 0, 0], [height, width, 3])
  conditional_input = rendered

  # Parse the depth image.
  if use_depth:
    depth = tf.decode_raw(features['depth'], tf.uint16)
    depth = tf.reshape(depth, [height, width, 1])
    depth = tf.cast(depth, tf.float32) * (2.0 / 255) - 1.0
    conditional_input = tf.concat([conditional_input, depth], axis=-1)

  # Parse the semantic map.
  if use_semantics:
    seg_img = tf.decode_raw(features['seg'], tf.uint8)
    seg_img = tf.reshape(seg_img, [height, width, 3])
    seg_img = tf.cast(seg_img, tf.float32) * (2.0 / 255) - 1
    conditional_input = tf.concat([conditional_input, seg_img], axis=-1)

  # Verify that the parsed input has the correct number of channels.
  assert conditional_input.shape[-1] == opts.deep_buffer_nc, ('num channels '
      'in the parsed input doesn\'t match num input channels specified in '
      'opts.deep_buffer_nc!')

  # Parse the ground truth image.
  real = tf.decode_raw(features['real'], tf.uint8)
  real = tf.cast(real, tf.float32) * (2.0 / 255) - 1.0
  real = tf.reshape(real, [height, width, 3])

  # Parse the appearance image (if any).
  appearance_input = []
  if use_appearance:
    # Concatenate the deep buffer to the real image.
    appearance_input = tf.concat([real, conditional_input], axis=-1)
    # Verify that the parsed input has the correct number of channels.
    assert appearance_input.shape[-1] == opts.appearance_nc, ('num channels '
        'in the parsed appearance input doesn\'t match num input channels '
        'specified in opts.appearance_nc!')

  # Crop conditional_input and real images, but keep the appearance input
  # uncropped (learn a one-to-many mapping from appearance to output)
  if crop_flag:
    assert crop_size is not None, 'crop_size is not provided!'
    if isinstance(crop_size, int):
      crop_size = [crop_size, crop_size]
    assert len(crop_size) == 2, 'crop_size is either an int or a 2-tuple!'

    # Central crop
    if seeds is not None and len(seeds) <= 1:
      conditional_input = tf.image.resize_image_with_crop_or_pad(
          conditional_input, crop_size[0], crop_size[1])
      real = tf.image.resize_image_with_crop_or_pad(real, crop_size[0],
                                                    crop_size[1])
    else:
      if not seeds:  # random crops
        seed = random.randint(0, (1 << 31) - 1)
      else:  # fixed crops
        seed_idx = random.randint(0, len(seeds) - 1)
        seed = seeds[seed_idx]
      conditional_input = tf.random_crop(
          conditional_input, crop_size + [opts.deep_buffer_nc], seed=seed)
      real = tf.random_crop(real, crop_size + [3], seed=seed)

  features = {'conditional_input': conditional_input,
              'expected_output': real,
              'peek_input': appearance_input}
  return features


def multi_input_fn_record(
    record_parse_fn, parent_dir, tfrecord_basename, batch_size, subset=None,
    max_examples=None, shuffle=128):
  """Creates a Dataset pipeline for tfrecord files.

  Returns:
    Dataset iterator.
  """
  subset_suffix = '*_%s.tfrecord' % subset if subset else '*.tfrecord'
  input_pattern = osp.join(parent_dir, tfrecord_basename + subset_suffix)
  filenames = sorted(glob.glob(input_pattern))
  assert len(filenames) > 0, ('Error! input pattern "%s" didn\'t match any '
                              'files' % input_pattern)
  dataset = tf.data.TFRecordDataset(filenames)
  if shuffle == 0:  # keep input deterministic
    # use one thread to get deterministic results
    dataset = dataset.map(record_parse_fn, num_parallel_calls=None)
  else:
    dataset = dataset.repeat()  # Repeat indefinitely.
    dataset = dataset.map(record_parse_fn,
                          num_parallel_calls=max(4, batch_size // 4))
    if opts.training_pipeline == 'drit':
      dataset1 = dataset.shuffle(shuffle)
      dataset2 = dataset.shuffle(shuffle)
      paired_dataset = tf.data.Dataset.zip((dataset1, dataset2))

      def _join_paired_dataset(features_a, features_b):
        features_a['conditional_input_2'] = features_b['conditional_input']
        features_a['expected_output_2'] = features_b['expected_output']
        return features_a

      joined_dataset = paired_dataset.map(_join_paired_dataset)
      dataset = joined_dataset
    else:
      dataset = dataset.shuffle(shuffle)
  if max_examples is not None:
    dataset = dataset.take(max_examples)
  dataset = dataset.batch(batch_size)
  if shuffle > 0:  # input is not deterministic
    dataset = dataset.prefetch(4)  # Prefetch a few batches.
  return dataset.make_one_shot_iterator().get_next()


================================================
FILE: dataset_utils.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from PIL import Image
from absl import app
from absl import flags
from options import FLAGS as opts
import cv2
import data
import functools
import glob
import numpy as np
import os
import os.path as osp
import shutil
import six
import tensorflow as tf
import segment_dataset as segment_utils
import utils

FLAGS = flags.FLAGS
flags.DEFINE_string('output_dir', None, 'Directory to save exported tfrecords.')
flags.DEFINE_string('xception_frozen_graph_path', None,
                    'Path to the deeplab xception model frozen graph')


class AlignedRenderedDataset(object):
  def __init__(self, rendered_filepattern, use_semantic_map=True):
    """
    Args:
      rendered_filepattern: string, path filepattern to 3D rendered images (
        assumes filenames are '/path/to/dataset/%d_color.png')
      use_semantic_map: bool, include semantic maps. in the TFRecord
    """
    self.filenames = sorted(glob.glob(rendered_filepattern))
    assert len(self.filenames) > 0, ('input %s didn\'t match any files!' %
                                     rendered_filepattern)
    self.iter_idx = 0
    self.use_semantic_map = use_semantic_map

  def __iter__(self):
    return self

  def __next__(self):
    return self.next()

  def next(self):
    if self.iter_idx < len(self.filenames):
      rendered_img_name = self.filenames[self.iter_idx]
      basename = rendered_img_name[:-9]  # remove the 'color.png' suffix
      ref_img_name = basename + 'reference.png'
      depth_img_name = basename + 'depth.png'
      # Read the 3D rendered image
      img_rendered = cv2.imread(rendered_img_name, cv2.IMREAD_UNCHANGED)
      # Change BGR (default cv2 format) to RGB
      img_rendered = img_rendered[:, :, [2,1,0,3]]  # it has a 4th alpha channel
      # Read the depth image
      img_depth = cv2.imread(depth_img_name, cv2.IMREAD_UNCHANGED)
      # Workaround as some depth images are read with a different data type!
      img_depth = img_depth.astype(np.uint16)
      # Read reference image if exists, otherwise replace with a zero image.
      if osp.exists(ref_img_name):
        img_ref = cv2.imread(ref_img_name)
        img_ref = img_ref[:, :, ::-1]  # Change BGR to RGB format.
      else:  # use a dummy 3-channel zero image as a placeholder
        print('Warning: no reference image found! Using a dummy placeholder!')
        img_height, img_width = img_depth.shape
        img_ref = np.zeros((img_height, img_width, 3), dtype=np.uint8)

      if self.use_semantic_map:
        semantic_seg_img_name = basename + 'seg_rgb.png'
        img_seg = cv2.imread(semantic_seg_img_name)
        img_seg = img_seg[:, :, ::-1]  # Change from BGR to RGB
        if img_seg.shape[0] == 512 and img_seg.shape[1] == 512:
          img_ref = utils.get_central_crop(img_ref)
          img_rendered = utils.get_central_crop(img_rendered)
          img_depth = utils.get_central_crop(img_depth)

      img_shape = img_depth.shape
      assert img_seg.shape == (img_shape + (3,)), 'error in seg image %s %s' % (
        basename, str(img_seg.shape))
      assert img_ref.shape == (img_shape + (3,)), 'error in ref image %s %s' % (
        basename, str(img_ref.shape))
      assert img_rendered.shape == (img_shape + (4,)), ('error in rendered '
        'image %s %s' % (basename, str(img_rendered.shape)))
      assert len(img_depth.shape) == 2, 'error in depth image %s %s' % (
        basename, str(img_depth.shape))

      raw_example = dict()
      raw_example['height'] = img_ref.shape[0]
      raw_example['width'] = img_ref.shape[1]
      raw_example['rendered'] = img_rendered.tostring()
      raw_example['depth'] = img_depth.tostring()
      raw_example['real'] = img_ref.tostring()
      if self.use_semantic_map:
        raw_example['seg'] = img_seg.tostring()
      self.iter_idx += 1
      return raw_example
    else:
      raise StopIteration()


def filter_out_sparse_renders(dataset_dir, splits, ratio_threshold=0.15):
  print('Filtering %s' % dataset_dir)
  if splits is None:
    imgs_dirs = [dataset_dir]
  else:
    imgs_dirs = [osp.join(dataset_dir, split) for split in splits]
  
  filtered_images = []
  total_images = 0
  sum_density = 0
  for cur_dir in imgs_dirs:
    filtered_dir = osp.join(cur_dir, 'sparse_renders')
    if not osp.exists(filtered_dir):
      os.makedirs(filtered_dir)
    imgs_file_pattern = osp.join(cur_dir, '*_color.png')
    images_path = sorted(glob.glob(imgs_file_pattern))
    print('Processing %d files' % len(images_path))
    total_images += len(images_path)
    for ii, img_path in enumerate(images_path):
      img = np.array(Image.open(img_path))
      aggregate = np.squeeze(np.sum(img, axis=2))
      height, width = aggregate.shape
      mask = aggregate > 0
      density = np.sum(mask) * 1. / (height * width)
      sum_density += density
      if density <= ratio_threshold:
        parent, basename = osp.split(img_path)
        basename = basename[:-10]  # remove the '_color.png' suffix
        srcs = sorted(glob.glob(osp.join(parent, basename + '_*')))
        dest = unicode(filtered_dir + '/.')
        for src in srcs:
          shutil.move(src, dest)
        filtered_images.append(basename)
        print('filtered fie %d: %s with a desnity of %.3f' % (ii, basename,
                                                              density))
    print('Filtered %d/%d images' % (len(filtered_images), total_images))
    print('Mean desnity = %.4f' % (sum_density / total_images))


def _to_example(dictionary):
  """Helper: build tf.Example from (string -> int/float/str list) dictionary."""
  features = {}
  for (k, v) in six.iteritems(dictionary):
    if isinstance(v, six.integer_types):
      features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=[v]))
    elif isinstance(v, float):
      features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=[v]))
    elif isinstance(v, six.string_types):
      features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[v]))
    elif isinstance(v, bytes):
      features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[v]))
    else:
      raise ValueError("Value for %s is not a recognized type; v: %s type: %s" %
                       (k, str(v[0]), str(type(v[0]))))

  return tf.train.Example(features=tf.train.Features(feature=features))


def _generate_tfrecord_dataset(generator,
                              output_name,
                              output_dir):
  """Convert a dataset into TFRecord format."""
  output_filename = os.path.join(output_dir, output_name)
  output_file = os.path.join(output_dir, output_filename)
  tf.logging.info("Writing TFRecords to file %s", output_file)
  writer = tf.python_io.TFRecordWriter(output_file)

  counter = 0
  for case in generator:
    if counter % 100 == 0:
      print('Generating case %d for %s.' % (counter, output_name))
    counter += 1
    example = _to_example(case)
    writer.write(example.SerializeToString())

  writer.close()
  return output_file


def export_aligned_dataset_to_tfrecord(
    dataset_dir, output_dir, output_basename, splits,
    xception_frozen_graph_path):

  # Step 1: filter out sparse renders
  filter_out_sparse_renders(dataset_dir, splits, 0.15)

  # Step 2: generate semantic segmentation masks
  segment_utils.segment_and_color_dataset(
      dataset_dir, xception_frozen_graph_path, splits)

  # Step 3: export dataset to TFRecord
  if splits is None:
    input_filepattern = osp.join(dataset_dir, '*_color.png')
    dataset_iter = AlignedRenderedDataset(input_filepattern)
    output_name = output_basename + '.tfrecord'
    _generate_tfrecord_dataset(dataset_iter, output_name, output_dir)
  else:
    for split in splits:
      input_filepattern = osp.join(dataset_dir, split, '*_color.png')
      dataset_iter = AlignedRenderedDataset(input_filepattern)
      output_name = '%s_%s.tfrecord' % (output_basename, split)
      _generate_tfrecord_dataset(dataset_iter, output_name, output_dir)


def main(argv):
  # Read input flags
  dataset_name = opts.dataset_name
  dataset_parent_dir = opts.dataset_parent_dir
  output_dir = FLAGS.output_dir
  xception_frozen_graph_path = FLAGS.xception_frozen_graph_path
  splits = ['train', 'val']
  # Run the preprocessing pipeline
  export_aligned_dataset_to_tfrecord(
    dataset_parent_dir, output_dir, dataset_name, splits,
    xception_frozen_graph_path)


if __name__ == '__main__':
  app.run(main)


================================================
FILE: evaluate_quantitative_metrics.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from PIL import Image
from absl import app
from absl import flags
import functools
import glob
import numpy as np
import os
import os.path as osp
import skimage.measure
import tensorflow as tf
import utils

FLAGS = flags.FLAGS
flags.DEFINE_string('val_set_out_dir', None,
                    'Output directory with concatenated fake and real images.')
flags.DEFINE_string('experiment_title', 'experiment',
                    'Name for the experiment to evaluate')


def _extract_real_and_fake_from_concatenated_output(val_set_out_dir):
      out_dir = osp.join(val_set_out_dir, 'fake')
      gt_dir = osp.join(val_set_out_dir, 'real')
      if not osp.exists(out_dir):
        os.makedirs(out_dir)
      if not osp.exists(gt_dir):
        os.makedirs(gt_dir)
      imgs_pattern = osp.join(val_set_out_dir, '*.png')
      imgs_paths = sorted(glob.glob(imgs_pattern))
      print('Separating %d images in %s.' % (len(imgs_paths), val_set_out_dir))
      for img_path in imgs_paths:
        basename = osp.basename(img_path)[:-4]  # remove the '.png' suffix
        img = np.array(Image.open(img_path))
        img_res = 512
        fake = img[:, -2*img_res:-img_res, :]
        real = img[:, -img_res:, :]
        fake_path = osp.join(out_dir, '%s_fake.png' % basename)
        real_path = osp.join(gt_dir, '%s_real.png' % basename)
        Image.fromarray(fake).save(fake_path)
        Image.fromarray(real).save(real_path)


def compute_l1_loss_metric(image_set1_paths, image_set2_paths):
  assert len(image_set1_paths) == len(image_set2_paths)
  assert len(image_set1_paths) > 0
  print('Evaluating L1 loss for %d pairs' % len(image_set1_paths))

  total_loss = 0.
  for ii, (img1_path, img2_path) in enumerate(zip(image_set1_paths,
                                                  image_set2_paths)):
    img1_in_ar = np.array(Image.open(img1_path), dtype=np.float32)
    img1_in_ar = utils.crop_to_multiple(img1_in_ar)

    img2_in_ar = np.array(Image.open(img2_path), dtype=np.float32)
    img2_in_ar = utils.crop_to_multiple(img2_in_ar)

    loss_l1 = np.mean(np.abs(img1_in_ar - img2_in_ar))
    total_loss += loss_l1

  return total_loss / len(image_set1_paths)


def compute_psnr_loss_metric(image_set1_paths, image_set2_paths):
  assert len(image_set1_paths) == len(image_set2_paths)
  assert len(image_set1_paths) > 0
  print('Evaluating PSNR loss for %d pairs' % len(image_set1_paths))

  total_loss = 0.
  for ii, (img1_path, img2_path) in enumerate(zip(image_set1_paths,
                                                  image_set2_paths)):
    img1_in_ar = np.array(Image.open(img1_path))
    img1_in_ar = utils.crop_to_multiple(img1_in_ar)

    img2_in_ar = np.array(Image.open(img2_path))
    img2_in_ar = utils.crop_to_multiple(img2_in_ar)

    loss_psnr = skimage.measure.compare_psnr(img1_in_ar, img2_in_ar)
    total_loss += loss_psnr

  return total_loss / len(image_set1_paths)


def evaluate_experiment(val_set_out_dir, title='experiment',
                        metrics=['psnr', 'l1']):

  out_dir = osp.join(val_set_out_dir, 'fake')
  gt_dir = osp.join(val_set_out_dir, 'real')
  _extract_real_and_fake_from_concatenated_output(val_set_out_dir)
  input_pattern1 = osp.join(gt_dir, '*.png')
  input_pattern2 = osp.join(out_dir, '*.png')
  set1 = sorted(glob.glob(input_pattern1))
  set2 = sorted(glob.glob(input_pattern2))
  for metric in metrics:
    if metric == 'l1':
      mean_loss = compute_l1_loss_metric(set1, set2)
    elif metric == 'psnr':
      mean_loss = compute_psnr_loss_metric(set1, set2)
    print('*** mean %s loss for %s = %f' % (metric, title, mean_loss))


def main(argv):
  evaluate_experiment(FLAGS.val_set_out_dir, title=FLAGS.experiment_title,
                      metrics=['psnr', 'l1'])


if __name__ == '__main__':
  app.run(main)


================================================
FILE: layers.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import numpy as np
import tensorflow as tf


class LayerInstanceNorm(object):

  def __init__(self, scope_suffix='instance_norm'):
    curr_scope = tf.get_variable_scope().name
    self._scope = curr_scope + '/' + scope_suffix

  def __call__(self, x):
    with tf.variable_scope(self._scope, reuse=tf.AUTO_REUSE):
      return tf.contrib.layers.instance_norm(
        x, epsilon=1e-05, center=True, scale=True)


def layer_norm(x, scope='layer_norm'):
  return tf.contrib.layers.layer_norm(x, center=True, scale=True)


def pixel_norm(x):
  """Pixel normalization.

  Args:
    x: 4D image tensor in B01C format.

  Returns:
    4D tensor with pixel normalized channels.
  """
  return x * tf.rsqrt(tf.reduce_mean(tf.square(x), [-1], keepdims=True) + 1e-8)


def global_avg_pooling(x):
  return tf.reduce_mean(x, axis=[1, 2], keepdims=True)


class FullyConnected(object):

  def __init__(self, n_out_units, scope_suffix='FC'):
    weight_init = tf.random_normal_initializer(mean=0., stddev=0.02)
    weight_regularizer = tf.contrib.layers.l2_regularizer(scale=0.0001)

    curr_scope = tf.get_variable_scope().name
    self._scope = curr_scope + '/' + scope_suffix
    self.fc_layer = functools.partial(
      tf.layers.dense, units=n_out_units, kernel_initializer=weight_init,
      kernel_regularizer=weight_regularizer, use_bias=True)

  def __call__(self, x):
    with tf.variable_scope(self._scope, reuse=tf.AUTO_REUSE):
      return self.fc_layer(x)


def init_he_scale(shape, slope=1.0):
  """He neural network random normal scaling for initialization.

  Args:
    shape: list of the dimensions of the tensor.
    slope: float, slope of the ReLu following the layer.

  Returns:
    a float, He's standard deviation.
  """
  fan_in = np.prod(shape[:-1])
  return np.sqrt(2. / ((1. + slope**2) * fan_in))


class LayerConv(object):
  """Convolution layer with support for equalized learning."""

  def __init__(self,
               name,
               w,
               n,
               stride,
               padding='SAME',
               use_scaling=False,
               relu_slope=1.):
    """Layer constructor.

    Args:
      name: string, layer name.
      w: int or 2-tuple, width of the convolution kernel.
      n: 2-tuple of ints, input and output channel depths.
      stride: int or 2-tuple, stride for the convolution kernel.
      padding: string, the padding method. {SAME, VALID, REFLECT}.
      use_scaling: bool, whether to use weight norm and scaling.
      relu_slope: float, the slope of the ReLu following the layer.
    """
    assert padding in ['SAME', 'VALID', 'REFLECT'], 'Error: unsupported padding'
    self._padding = padding
    with tf.variable_scope(name):
      if isinstance(stride, int):
        stride = [1, stride, stride, 1]
      else:
        assert len(stride) == 0, "stride is either an int or a 2-tuple"
        stride = [1, stride[0], stride[1], 1]
      if isinstance(w, int):
        w = [w, w]
      self.w = w
      shape = [w[0], w[1], n[0], n[1]]
      init_scale, pre_scale = init_he_scale(shape, relu_slope), 1.
      if use_scaling:
        init_scale, pre_scale = pre_scale, init_scale
      self._stride = stride
      self._pre_scale = pre_scale
      self._weight = tf.get_variable(
          'weight',
          shape=shape,
          initializer=tf.random_normal_initializer(stddev=init_scale))
      self._bias = tf.get_variable(
          'bias', shape=[n[1]], initializer=tf.zeros_initializer)

  def __call__(self, x):
    """Apply layer to tensor x."""
    if self._padding != 'REFLECT':
      padding = self._padding
    else:
      padding = 'VALID'
      pad_top = self.w[0] // 2
      pad_left = self.w[1] // 2
      if (self.w[0] - self._stride[1]) % 2 == 0:
        pad_bottom = pad_top
      else:
        pad_bottom = self.w[0] - self._stride[1] - pad_top
      if (self.w[1] - self._stride[2]) % 2 == 0:
        pad_right = pad_left
      else:
        pad_right = self.w[1] - self._stride[2] - pad_left
      x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right],
                     [0, 0]], mode='REFLECT')
    y = tf.nn.conv2d(x, self._weight, strides=self._stride, padding=padding)
    return self._pre_scale * y + self._bias


class LayerTransposedConv(object):
  """Convolution layer with support for equalized learning."""

  def __init__(self,
               name,
               w,
               n,
               stride,
               padding='SAME',
               use_scaling=False,
               relu_slope=1.):
    """Layer constructor.

    Args:
      name: string, layer name.
      w: int or 2-tuple, width of the convolution kernel.
      n: 2-tuple int, [n_in_channels, n_out_channels]
      stride: int or 2-tuple, stride for the convolution kernel.
      padding: string, the padding method {SAME, VALID, REFLECT}.
      use_scaling: bool, whether to use weight norm and scaling.
      relu_slope: float, the slope of the ReLu following the layer.
    """
    assert padding in ['SAME'], 'Error: unsupported padding for transposed conv'
    if isinstance(stride, int):
      stride = [1, stride, stride, 1]
    else:
      assert len(stride) == 2, "stride is either an int or a 2-tuple"
      stride = [1, stride[0], stride[1], 1]
    if isinstance(w, int):
      w = [w, w]
    self.padding = padding
    self.nc_in, self.nc_out = n
    self.stride = stride
    with tf.variable_scope(name):
      kernel_shape = [w[0], w[1], self.nc_out, self.nc_in]
      init_scale, pre_scale = init_he_scale(kernel_shape, relu_slope), 1.
      if use_scaling:
        init_scale, pre_scale = pre_scale, init_scale
      self._pre_scale = pre_scale
      self._weight = tf.get_variable(
          'weight',
          shape=kernel_shape,
          initializer=tf.random_normal_initializer(stddev=init_scale))
      self._bias = tf.get_variable(
          'bias', shape=[self.nc_out], initializer=tf.zeros_initializer)

  def __call__(self, x):
    """Apply layer to tensor x."""
    x_shape = x.get_shape().as_list()
    batch_size = tf.shape(x)[0]
    stride_x, stride_y = self.stride[1], self.stride[2]
    output_shape = tf.stack([
      batch_size, x_shape[1] * stride_x, x_shape[2] * stride_y, self.nc_out])
    y = tf.nn.conv2d_transpose(
      x, filter=self._weight, output_shape=output_shape, strides=self.stride,
      padding=self.padding)
    return self._pre_scale * y + self._bias


class ResBlock(object):
  def __init__(self,
               name,
               nc,
               norm_layer_constructor,
               activation,
               padding='SAME',
               use_scaling=False,
               relu_slope=1.):
    """Layer constructor."""
    self.name = name
    conv2d = functools.partial(
        LayerConv, w=3, n=[nc, nc], stride=1, padding=padding,
        use_scaling=use_scaling, relu_slope=relu_slope)
    self.blocks = []
    with tf.variable_scope(self.name):
      with tf.variable_scope('res0'):
        self.blocks.append(
          LayerPipe([
            conv2d('res0_conv'),
            norm_layer_constructor('res0_norm'),
            activation
          ])
        )
      with tf.variable_scope('res1'):
        self.blocks.append(
          LayerPipe([
            conv2d('res1_conv'),
            norm_layer_constructor('res1_norm')
          ])
        )

  def __call__(self, x_init):
    """Apply layer to tensor x."""
    x = x_init
    for f in self.blocks:
      x = f(x)
    return x + x_init


class BasicBlock(object):
  def __init__(self,
               name,
               n,
               activation=functools.partial(tf.nn.leaky_relu, alpha=0.2),
               padding='SAME',
               use_scaling=True,
               relu_slope=1.):
    """Layer constructor."""
    self.name = name
    conv2d = functools.partial(
        LayerConv, stride=1, padding=padding,
        use_scaling=use_scaling, relu_slope=relu_slope)
    avg_pool = functools.partial(downscale, n=2)
    nc_in, nc_out = n  # n is a 2-tuple
    with tf.variable_scope(self.name):
      self.path1_blocks = []
      with tf.variable_scope('bb_path1'):
        self.path1_blocks.append(
          LayerPipe([
            activation,
            conv2d('bb_conv0', w=3, n=[nc_in, nc_out]),
            activation,
            conv2d('bb_conv1', w=3, n=[nc_out, nc_out]),
            downscale
          ])
        )

      self.path2_blocks = []
      with tf.variable_scope('bb_path2'):
        self.path2_blocks.append(
          LayerPipe([
            downscale,
            conv2d('path2_conv', w=1, n=[nc_in, nc_out])
          ])
        )

  def __call__(self, x_init):
    """Apply layer to tensor x."""
    x1 = x_init
    x2 = x_init
    for f in self.path1_blocks:
      x1 = f(x1)
    for f in self.path2_blocks:
      x2 = f(x2)
    return x1 + x2


class LayerDense(object):
  """Dense layer with a non-linearity."""

  def __init__(self, name, n, use_scaling=False, relu_slope=1.):
    """Layer constructor.

    Args:
      name: string, layer name.
      n: 2-tuple of ints, input and output widths.
      use_scaling: bool, whether to use weight norm and scaling.
      relu_slope: float, the slope of the ReLu following the layer.
    """
    with tf.variable_scope(name):
      init_scale, pre_scale = init_he_scale(n, relu_slope), 1.
      if use_scaling:
        init_scale, pre_scale = pre_scale, init_scale
      self._pre_scale = pre_scale
      self._weight = tf.get_variable(
          'weight',
          shape=n,
          initializer=tf.random_normal_initializer(stddev=init_scale))
      self._bias = tf.get_variable(
          'bias', shape=[n[1]], initializer=tf.zeros_initializer)

  def __call__(self, x):
    """Apply layer to tensor x."""
    return self._pre_scale * tf.matmul(x, self._weight) + self._bias


class LayerPipe(object):
  """Pipe a sequence of functions."""

  def __init__(self, functions):
    """Layer constructor.

    Args:
      functions: list, functions to pipe.
    """
    self._functions = tuple(functions)

  def __call__(self, x, **kwargs):
    """Apply pipe to tensor x and return result."""
    del kwargs
    for f in self._functions:
      x = f(x)
    return x


def downscale(x, n=2):
  """Box downscaling.

  Args:
    x: 4D image tensor.
    n: integer scale (must be a power of 2).

  Returns:
    4D tensor of images down scaled by a factor n.
  """
  if n == 1:
    return x
  return tf.nn.avg_pool(x, [1, n, n, 1], [1, n, n, 1], 'VALID')


def upscale(x, n):
  """Box upscaling (also called nearest neighbors).

  Args:
    x: 4D image tensor in B01C format.
    n: integer scale (must be a power of 2).

  Returns:
    4D tensor of images up scaled by a factor n.
  """
  if n == 1:
    return x
  x_shape = tf.shape(x)
  height, width = x_shape[1], x_shape[2]
  return tf.image.resize_nearest_neighbor(x, [n * height, n * width])


def tile_and_concatenate(x, z, n_z):
  z = tf.reshape(z, shape=[-1, 1, 1, n_z])
  z = tf.tile(z, [1, tf.shape(x)[1], tf.shape(x)[2], 1])
  x = tf.concat([x, z], axis=-1)
  return x


def minibatch_mean_variance(x):
  """Computes the variance average.

  This is used by the discriminator as a form of batch discrimination.

  Args:
    x: nD tensor for which to compute variance average.

  Returns:
    a scalar, the mean variance of variable x.
  """
  mean = tf.reduce_mean(x, 0, keepdims=True)
  vals = tf.sqrt(tf.reduce_mean(tf.squared_difference(x, mean), 0) + 1e-8)
  vals = tf.reduce_mean(vals)
  return vals


def scalar_concat(x, scalar):
  """Concatenate a scalar to a 4D tensor as an extra channel.

  Args:
    x: 4D image tensor in B01C format.
    scalar: a scalar to concatenate to the tensor.

  Returns:
    a 4D tensor with one extra channel containing the value scalar at
     every position.
  """
  s = tf.shape(x)
  return tf.concat([x, tf.ones([s[0], s[1], s[2], 1]) * scalar], axis=3)


================================================
FILE: losses.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from options import FLAGS as opts
import layers
import os.path as osp
import tensorflow as tf
import vgg16


def gradient_penalty_loss(y_xy, xy, iwass_target=1, iwass_lambda=10):
  grad = tf.gradients(tf.reduce_sum(y_xy), [xy])[0]
  grad_norm = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1, 2, 3]) + 1e-8)
  loss_gp = tf.reduce_mean(
      tf.square(grad_norm - iwass_target)) * iwass_lambda / iwass_target**2
  return loss_gp


def KL_loss(mean, logvar):
  loss = 0.5 * tf.reduce_sum(tf.square(mean) + tf.exp(logvar) - 1. - logvar,
                             axis=-1)
  return tf.reduce_sum(loss)  # just to match DRIT implementation


def l2_regularize(x):
  return tf.reduce_mean(tf.square(x))


def L1_loss(x, y):
  return tf.reduce_mean(tf.abs(x - y))


class PerceptualLoss:
  def __init__(self, x, y, image_shape, layers, w_layers, w_act=0.1):
    """
    Builds vgg16 network and computes the perceptual loss.
    """
    assert len(image_shape) == 3 and image_shape[-1] == 3
    assert osp.exists(opts.vgg16_path), 'Cannot find %s' % opts.vgg16_path

    self.w_act = w_act
    self.vgg_layers = layers
    self.w_layers = w_layers
    batch_shape = [None] + image_shape  # [None, H, W, 3]

    vgg_net = vgg16.Vgg16(opts.vgg16_path)
    self.x_acts = vgg_net.get_vgg_activations(x, layers)
    self.y_acts = vgg_net.get_vgg_activations(y, layers)
    loss = 0
    for w, act1, act2 in zip(self.w_layers, self.x_acts, self.y_acts):
      loss += w * tf.reduce_mean(tf.square(self.w_act * (act1 - act2)))
    self.loss = loss

  def __call__(self):
    return self.loss


def lsgan_appearance_E_loss(disc_response):
  disc_response = tf.squeeze(disc_response)
  gt_label = 0.5
  loss = tf.reduce_mean(tf.square(disc_response - gt_label))
  return loss


def lsgan_loss(disc_response, is_real):
  gt_label = 1 if is_real else 0
  disc_response = tf.squeeze(disc_response)
  # The following works for both regular and patchGAN discriminators
  loss = tf.reduce_mean(tf.square(disc_response - gt_label))
  return loss


def multiscale_discriminator_loss(Ds_responses, is_real):
  num_D = len(Ds_responses)
  loss = 0
  for i in range(num_D):
    curr_response = Ds_responses[i][-1][-1]
    loss += lsgan_loss(curr_response, is_real)
  return loss


================================================
FILE: networks.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from options import FLAGS as opts
import functools
import layers
import tensorflow as tf


class RenderingModel(object):

  def __init__(self, model_name, use_appearance=True):

    if model_name == 'pggan':
      self._model = ModelPGGAN(use_appearance)
    else:
      raise ValueError('Model %s not implemented!' % model_name)

  def __call__(self, x_in, z_app=None):
    return self._model(x_in, z_app)

  def get_appearance_encoder(self):
    return self._model._appearance_encoder

  def get_generator(self):
    return self._model._generator

  def get_content_encoder(self):
    return self._model._content_encoder


# "Progressive Growing of GANs (PGGAN)"-inspired architecture. Implementation is
# based on the implementation details in their paper, but code is not taken from
# the authors' released code.
# Main changes are:
#  - conditional GAN setup by introducting an encoder + skip connections.
#  - no progressive growing during training.
class ModelPGGAN(RenderingModel):

  def __init__(self, use_appearance=True):
    self._use_appearance = use_appearance
    self._content_encoder = None
    self._generator = GeneratorPGGAN(appearance_vec_size=opts.app_vector_size)
    if use_appearance:
      self._appearance_encoder = DRITAppearanceEncoderConcat(
          'appearance_net', opts.appearance_nc, opts.normalize_drit_Ez)
    else:
      self._appearance_encoder = None

  def __call__(self, x_in, z_app=None):
    y = self._generator(x_in, z_app)
    return y

  def get_appearance_encoder(self):
    return self._appearance_encoder

  def get_generator(self):
    return self._generator

  def get_content_encoder(self):
    return self._content_encoder


class PatchGANDiscriminator(object):

  def __init__(self, name_scope, input_nc, nf=64, n_layers=3, get_fmaps=False):
    """Constructor for a patchGAN discriminators.

    Args:
      name_scope: str - tf name scope.
      input_nc: int - number of input channels.
      nf: int - starting number of discriminator filters.
      n_layers: int - number of layers in the discriminator.
      get_fmaps: bool - return intermediate feature maps for FeatLoss.
    """
    self.get_fmaps = get_fmaps
    self.n_layers = n_layers
    kw = 4  # kernel width for convolution

    activation = functools.partial(tf.nn.leaky_relu, alpha=0.2)
    norm_layer = functools.partial(layers.LayerInstanceNorm)
    conv2d = functools.partial(layers.LayerConv, use_scaling=opts.use_scaling,
                               relu_slope=0.2)

    def minibatch_stats(x):
      return layers.scalar_concat(x, layers.minibatch_mean_variance(x))

    # Create layers.
    self.blocks = []
    with tf.variable_scope(name_scope, tf.AUTO_REUSE):
      with tf.variable_scope('block_0'):
        self.blocks.append([
            conv2d('conv0', w=kw, n=[input_nc, nf], stride=2),
            activation
        ])
      for ii_block in range(1, n_layers):
        nf_prev = nf
        nf = min(nf * 2, 512)
        with tf.variable_scope('block_%d' % ii_block):
          self.blocks.append([
              conv2d('conv%d' % ii_block, w=kw, n=[nf_prev, nf], stride=2),
              norm_layer(),
              activation
          ])
      # Add minibatch_stats (from PGGAN) and do a stride1 convolution.
      nf_prev = nf
      nf = min(nf * 2, 512)
      with tf.variable_scope('block_%d' % (n_layers + 1)):
        self.blocks.append([
            minibatch_stats,  # this is improvised by @meshry
            conv2d('conv%d' % (n_layers + 1), w=kw, n=[nf_prev + 1, nf],
                   stride=1),
            norm_layer(),
            activation
        ])
      # Get 1-channel patchGAN logits
      with tf.variable_scope('patchGAN_logits'):
        self.blocks.append([
            conv2d('conv%d' % (n_layers + 2), w=kw, n=[nf, 1], stride=1)
        ])

  def __call__(self, x, x_cond=None):
    # Concatenate extra conditioning input, if any.
    if x_cond is not None:
      x = tf.concat([x, x_cond], axis=3)

    if self.get_fmaps:
      # Dummy addition of x to D_fmaps, which will be removed before returing
      D_fmaps = [[x]]
      for i_block in range(len(self.blocks)):
        # Apply layer #0 in the current block
        block_fmaps = [self.blocks[i_block][0](D_fmaps[-1][-1])]
        # Apply the remaining layers of this block
        for i_layer in range(1, len(self.blocks[i_block])):
          block_fmaps.append(self.blocks[i_block][i_layer](block_fmaps[-1]))
        # Append the feature maps of this block to D_fmaps
        D_fmaps.append(block_fmaps)
      return D_fmaps[1:]  # exclude the input x which we added initially
    else:
      y = x
      for i_block in range(len(self.blocks)):
        for i_layer in range(len(self.blocks[i_block])):
          y = self.blocks[i_block][i_layer](y)
      return [[y]]


class MultiScaleDiscriminator(object):

  def __init__(self, name_scope, input_nc, num_scales=3, nf=64, n_layers=3,
               get_fmaps=False):
    self.get_fmaps = get_fmaps
    discs = []
    with tf.variable_scope(name_scope):
      for i in range(num_scales):
        discs.append(PatchGANDiscriminator(
            'D_scale%d' % i, input_nc, nf=nf, n_layers=n_layers,
            get_fmaps=get_fmaps))
    self.discriminators = discs

  def __call__(self, x, x_cond=None, params=None):
    del params
    if x_cond is not None:
      x = tf.concat([x, x_cond], axis=3)

    responses = []
    for ii, D in enumerate(self.discriminators):
      responses.append(D(x, x_cond=None))  # x_cond is already concatenated
      if ii != len(self.discriminators) - 1:
        x = layers.downscale(x, n=2)
    return responses


class GeneratorPGGAN(object):
  def __init__(self, appearance_vec_size=8, use_scaling=True,
               num_blocks=5, input_nc=7,
               fmap_base=8192, fmap_decay=1.0, fmap_max=512):
    """Generator model.
  
    Args:
      appearance_vec_size: int, size of the latent appearance vector.
      use_scaling: bool, whether to use weight scaling.
      resolution: int, width of the images (assumed to be square).
      input_nc: int, number of input channles.
      fmap_base: int, base number of channels.
      fmap_decay: float, decay rate of channels with respect to depth.
      fmap_max: int, max number of channels (supersedes fmap_base).
  
    Returns:
      function of the model.
    """
    def _num_filters(fmap_base, fmap_decay, fmap_max, stage):
      if opts.g_nf == 32:
        return min(int(2**(10 - stage)), fmap_max)  # nf32
      elif opts.g_nf == 64:
        return min(int(2**(11 - stage)), fmap_max)  # nf64
      else:
        raise ValueError('Currently unsupported num filters')

    nf = functools.partial(_num_filters, fmap_base, fmap_decay, fmap_max)
    self.num_blocks = num_blocks
    activation = functools.partial(tf.nn.leaky_relu, alpha=0.2)
    conv2d_stride1 = functools.partial(
        layers.LayerConv, stride=1, use_scaling=use_scaling, relu_slope=0.2)
    conv2d_rgb = functools.partial(layers.LayerConv, w=1, stride=1,
                                   use_scaling=use_scaling)
  
    # Create encoder layers.
    with tf.variable_scope('g_model_enc', tf.AUTO_REUSE):
      self.enc_stage = []
      self.from_rgb = []

      if opts.use_appearance and opts.inject_z == 'to_encoder':
        input_nc += appearance_vec_size
  
      for i in range(num_blocks, -1, -1):
        with tf.variable_scope('res_%d' % i):
          self.from_rgb.append(
              layers.LayerPipe([
                  conv2d_rgb('from_rgb', n=[input_nc, nf(i + 1)]),
                  activation,
              ])
          )
          self.enc_stage.append(
              layers.LayerPipe([
                  functools.partial(layers.downscale, n=2),
                  conv2d_stride1('conv0', w=3, n=[nf(i + 1), nf(i)]),
                  activation,
                  layers.pixel_norm,
                  conv2d_stride1('conv1', w=3, n=[nf(i), nf(i)]),
                  activation,
                  layers.pixel_norm
              ])
          )
  
    # Create decoder layers.
    with tf.variable_scope('g_model_dec', tf.AUTO_REUSE):
      self.dec_stage = []
      self.to_rgb = []
  
      nf_bottleneck = nf(0)  # num input filters at the bottleneck
      if opts.use_appearance and opts.inject_z == 'to_bottleneck':
        nf_bottleneck += appearance_vec_size

      with tf.variable_scope('res_0'):
        self.dec_stage.append(
          layers.LayerPipe([
            functools.partial(layers.upscale, n=2),
            conv2d_stride1('conv0', w=3, n=[nf_bottleneck, nf(1)]),
            activation,
            layers.pixel_norm,
            conv2d_stride1('conv1', w=3, n=[nf(1), nf(1)]),
            activation,
            layers.pixel_norm
          ])
        )
        self.to_rgb.append(conv2d_rgb('to_rgb', n=[nf(1), opts.output_nc]))
  
      multiply_factor = 2 if opts.concatenate_skip_layers else 1
      for i in range(1, num_blocks + 1):
        with tf.variable_scope('res_%d' % i):
          self.dec_stage.append(
              layers.LayerPipe([
                  functools.partial(layers.upscale, n=2),
                  conv2d_stride1('conv0', w=3,
                                 n=[multiply_factor * nf(i), nf(i + 1)]),
                  activation,
                  layers.pixel_norm,
                  conv2d_stride1('conv1', w=3, n=[nf(i + 1), nf(i + 1)]),
                  activation,
                  layers.pixel_norm
              ])
          )
          self.to_rgb.append(conv2d_rgb('to_rgb',
                                        n=[nf(i + 1), opts.output_nc]))

  def __call__(self, x, appearance_embedding=None, encoder_fmaps=None):
    """Generator function.

    Args:
      x: 2D tensor (batch, latents), the conditioning input batch of images.
      appearance_embedding: float tensor: latent appearance vector.
    Returns:
      4D tensor of images (NHWC), the generated images.
    """
    del encoder_fmaps
    enc_st_idx = 0
    if opts.use_appearance and opts.inject_z == 'to_encoder':
      x = layers.tile_and_concatenate(x, appearance_embedding,
                                      opts.app_vector_size)
    y = self.from_rgb[enc_st_idx](x)

    enc_responses = []
    for i in range(enc_st_idx, len(self.enc_stage)):
      y = self.enc_stage[i](y)
      enc_responses.insert(0, y)

    # Concatenate appearance vector to y
    if opts.use_appearance and opts.inject_z == 'to_bottleneck':
      appearance_tensor = tf.tile(appearance_embedding,
                                  [1, tf.shape(y)[1], tf.shape(y)[2], 1])
      y = tf.concat([y, appearance_tensor], axis=3)

    y_list = []
    for i in range(self.num_blocks + 1):
      if i > 0:
        y_skip = enc_responses[i]  # skip layer
        if opts.concatenate_skip_layers:
          y = tf.concat([y, y_skip], axis=3)
        else:
          y = y + y_skip
      y = self.dec_stage[i](y)
      y_list.append(y)

    return self.to_rgb[self.num_blocks](y_list[-1])


class DRITAppearanceEncoderConcat(object):

  def __init__(self, name_scope, input_nc, normalize_encoder):
    self.blocks = []
    activation = functools.partial(tf.nn.leaky_relu, alpha=0.2)
    conv2d = functools.partial(layers.LayerConv, use_scaling=opts.use_scaling,
                               relu_slope=0.2, padding='SAME')
    with tf.variable_scope(name_scope, tf.AUTO_REUSE):
      if normalize_encoder:
        self.blocks.append(layers.LayerPipe([
            conv2d('conv0', w=4, n=[input_nc, 64], stride=2),
            layers.BasicBlock('BB0', n=[64, 128], use_scaling=opts.use_scaling),
            layers.pixel_norm,
            layers.BasicBlock('BB1', n=[128, 192], use_scaling=opts.use_scaling),
            layers.pixel_norm,
            layers.BasicBlock('BB2', n=[192, 256], use_scaling=opts.use_scaling),
            layers.pixel_norm,
            activation,
            layers.global_avg_pooling
        ]))
      else:
        self.blocks.append(layers.LayerPipe([
            conv2d('conv0', w=4, n=[input_nc, 64], stride=2),
            layers.BasicBlock('BB0', n=[64, 128], use_scaling=opts.use_scaling),
            layers.BasicBlock('BB1', n=[128, 192], use_scaling=opts.use_scaling),
            layers.BasicBlock('BB2', n=[192, 256], use_scaling=opts.use_scaling),
            activation,
            layers.global_avg_pooling
        ]))
      # FC layers to get the mean and logvar
      self.fc_mean = layers.FullyConnected(opts.app_vector_size, 'FC_mean')
      self.fc_logvar = layers.FullyConnected(opts.app_vector_size, 'FC_logvar')

  def __call__(self, x):
    for f in self.blocks:
      x = f(x)

    mean = self.fc_mean(x)
    logvar = self.fc_logvar(x)
    # The following is an arbitrarily chosen *deterministic* latent vector
    # computation. Another option is to let z = mean, but gradients from logvar
    # will be None and will need to be removed.
    z = mean + tf.exp(0.5 * logvar)
    return z, mean, logvar


================================================
FILE: neural_rerendering.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from PIL import Image
from absl import app
from options import FLAGS as opts
import data
import datetime
import functools
import glob
import losses
import networks
import numpy as np
import options
import os.path as osp
import random
import skimage.measure
import staged_model
import tensorflow as tf
import time
import utils


def build_model_fn(use_exponential_moving_average=True):
  """Builds and returns the model function for an estimator.

  Args:
    use_exponential_moving_average: bool. If true, the exponential moving
    average will be used.

  Returns:
    function, the model_fn function typically required by an estimator.
  """
  arch_type = opts.arch_type
  use_appearance = opts.use_appearance
  def model_fn(features, labels, mode, params):
    """An estimator build_fn."""
    del labels, params
    if mode == tf.estimator.ModeKeys.TRAIN:
      step = tf.train.get_global_step()

      x_in = features['conditional_input']
      x_gt = features['expected_output']  # ground truth output
      x_app = features['peek_input']

      if opts.training_pipeline == 'staged':
        ops = staged_model.create_computation_graph(x_in, x_gt, x_app=x_app,
                                                    arch_type=opts.arch_type)
        op_increment_step = tf.assign_add(step, 1)
        train_disc_op = ops['train_disc_op']
        train_renderer_op = ops['train_renderer_op']
        train_op = tf.group(train_disc_op, train_renderer_op, op_increment_step)

        utils.HookReport.log_tensor(ops['total_loss_d'], 'total_loss_d')
        utils.HookReport.log_tensor(ops['loss_d_real'], 'loss_d_real')
        utils.HookReport.log_tensor(ops['loss_d_fake'], 'loss_d_fake')
        utils.HookReport.log_tensor(ops['total_loss_g'], 'total_loss_g')
        utils.HookReport.log_tensor(ops['loss_g_gan'], 'loss_g_gan')
        utils.HookReport.log_tensor(ops['loss_g_recon'], 'loss_g_recon')
        utils.HookReport.log_tensor(step, 'global_step')

        return tf.estimator.EstimatorSpec(
            mode=mode, loss=ops['total_loss_d'] + ops['total_loss_g'],
            train_op=train_op)
      else:
        raise NotImplementedError('%s training is not implemented.' %
                                  opts.training_pipeline)
    elif mode == tf.estimator.ModeKeys.EVAL:
      raise NotImplementedError('Eval is not implemented.')
    else:  # all below modes are for difference inference tasks.
      # Build network and initialize inference variables.
      g_func = networks.RenderingModel(arch_type, use_appearance)
      if use_appearance:
        app_func = g_func.get_appearance_encoder()
      if use_exponential_moving_average:
        ema = tf.train.ExponentialMovingAverage(decay=0.999)
        var_dict = ema.variables_to_restore()
        tf.train.init_from_checkpoint(osp.join(opts.train_dir), var_dict)

      if mode == tf.estimator.ModeKeys.PREDICT:
        x_in = features['conditional_input']
        if use_appearance:
          x_app = features['peek_input']
          x_app_embedding, _, _ = app_func(x_app)
        else:
          x_app_embedding = None
        y = g_func(x_in, x_app_embedding)
        tf.logging.info('DBG: shape of y during prediction %s.' % str(y.shape))
        return tf.estimator.EstimatorSpec(mode=mode, predictions=y)

      # 'eval_subset' mode is same as PREDICT but it concatenates the output to
      # the input render, semantic map and ground truth for easy comparison.
      elif mode == 'eval_subset':
        x_in = features['conditional_input']
        x_gt = features['expected_output']
        if use_appearance:
          x_app = features['peek_input']
          x_app_embedding, _, _ = app_func(x_app)
        else:
          x_app_embedding = None
        y = g_func(x_in, x_app_embedding)
        tf.logging.info('DBG: shape of y during prediction %s.' % str(y.shape))
        x_in_rgb = tf.slice(x_in, [0, 0, 0, 0], [-1, -1, -1, 3])
        if opts.use_semantic:
          x_in_semantic = tf.slice(x_in, [0, 0, 0, 4], [-1, -1, -1, 3])
          output_tuple = tf.concat([x_in_rgb, x_in_semantic, y, x_gt], axis=2)
        else:
          output_tuple = tf.concat([x_in_rgb, y, x_gt], axis=2)
        return tf.estimator.EstimatorSpec(mode=mode, predictions=output_tuple)

      # 'compute_appearance' mode computes and returns the latent z vector.
      elif mode == 'compute_appearance':
        assert use_appearance, 'use_appearance is set to False!'
        x_app_in = features['peek_input']
        # NOTE the following line is a temporary hack (which is
        # specially bad for inputs smaller than 512x512).
        x_app_in = tf.image.resize_image_with_crop_or_pad(x_app_in, 512, 512)
        app_embedding, _, _ = app_func(x_app_in)
        return tf.estimator.EstimatorSpec(mode=mode, predictions=app_embedding)

      # 'interpolate_appearance' mode expects an already computed latent z
      # vector as input passed a value to the dict key 'appearance_embedding'.
      elif mode == 'interpolate_appearance':
        assert use_appearance, 'use_appearance is set to False!'
        x_in = features['conditional_input']
        x_app_embedding = features['appearance_embedding']
        y = g_func(x_in, x_app_embedding)
        tf.logging.info('DBG: shape of y during prediction %s.' % str(y.shape))
        return tf.estimator.EstimatorSpec(mode=mode, predictions=y)
      else:
        raise ValueError('Unsupported mode: ' + mode)

  return model_fn


def make_sample_grid_and_save(est, dataset_name, dataset_parent_dir, grid_dims,
                              output_dir, cur_nimg):
  """Evaluate a fixed set of validation images and save output.

  Args:
    est: tf,estimator.Estimator, TF estimator to run the predictions.
    dataset_name: basename for the validation tfrecord from which to load
      validation images.
    dataset_parent_dir: path to a directory containing the validation tfrecord.
    grid_dims: 2-tuple int for the grid size (1 unit = 1 image).
    output_dir: string, where to save image samples.
    cur_nimg: int, current number of images seen by training.

  Returns:
    None.
  """
  num_examples = grid_dims[0] * grid_dims[1]
  def input_val_fn():
    dict_inp = data.provide_data(
        dataset_name=dataset_name, parent_dir=dataset_parent_dir, subset='val',
        batch_size=1, crop_flag=True, crop_size=opts.train_resolution,
        seeds=[0], max_examples=num_examples,
        use_appearance=opts.use_appearance, shuffle=0)
    x_in = dict_inp['conditional_input']
    x_gt = dict_inp['expected_output']  # ground truth output
    x_app = dict_inp['peek_input']
    return x_in, x_gt, x_app

  def est_input_val_fn():
    x_in, _, x_app = input_val_fn()
    features = {'conditional_input': x_in, 'peek_input': x_app}
    return features

  images = [x for x in est.predict(est_input_val_fn)]
  images = np.array(images, 'f')
  images = images.reshape(grid_dims + images.shape[1:])
  utils.save_images(utils.to_png(utils.images_to_grid(images)), output_dir,
                    cur_nimg)


def visualize_image_sequence(est, dataset_name, dataset_parent_dir,
                             input_sequence_name, app_base_path, output_dir):
  """Generates an image sequence as a video and stores it to disk."""
  batch_sz = opts.batch_size
  def input_seq_fn():
    dict_inp = data.provide_data(
        dataset_name=dataset_name, parent_dir=dataset_parent_dir,
        subset=input_sequence_name, batch_size=batch_sz, crop_flag=False,
        seeds=None, use_appearance=False, shuffle=0)
    x_in = dict_inp['conditional_input']
    return x_in

  # Compute appearance embedding only once and use it for all input frames.
  app_rgb_path = app_base_path + '_reference.png'
  app_rendered_path = app_base_path + '_color.png'
  app_depth_path = app_base_path + '_depth.png'
  app_sem_path = app_base_path + '_seg_rgb.png'
  x_app = _load_and_concatenate_image_channels(
      app_rgb_path, app_rendered_path, app_depth_path, app_sem_path)
  def seq_with_single_appearance_inp_fn():
    """input frames with a fixed latent appearance vector."""
    x_in_op = input_seq_fn()
    x_app_op = tf.convert_to_tensor(x_app)
    x_app_tiled_op = tf.tile(x_app_op, [tf.shape(x_in_op)[0], 1, 1, 1])
    return {'conditional_input': x_in_op,
            'peek_input': x_app_tiled_op}

  images = [x for x in est.predict(seq_with_single_appearance_inp_fn)]
  for i, gen_img in enumerate(images):
    output_file_path = osp.join(output_dir, 'out_%04d.png' % i)
    print('Saving frame #%d to %s' % (i, output_file_path))
    with tf.gfile.Open(output_file_path, 'wb') as f:
      f.write(utils.to_png(gen_img))


def train(dataset_name, dataset_parent_dir, load_pretrained_app_encoder,
          load_trained_fixed_app, save_samples_kimg=50):
  """Main training procedure.

  The trained model is saved in opts.train_dir, the function itself does not
   return anything.

  Args:
    save_samples_kimg: int, period (in KiB) to save sample images.

  Returns:
    None.
  """
  image_dir = osp.join(opts.train_dir, 'images')  # to save validation images.
  tf.gfile.MakeDirs(image_dir)
  config = tf.estimator.RunConfig(
      save_summary_steps=(1 << 10) // opts.batch_size,
      save_checkpoints_steps=(save_samples_kimg << 10) // opts.batch_size,
      keep_checkpoint_max=5,
      log_step_count_steps=1 << 30)
  model_dir = opts.train_dir
  if (opts.use_appearance and load_trained_fixed_app and
      not tf.train.latest_checkpoint(model_dir)):
    tf.logging.warning('***** Loading resume_step from %s!' %
                       opts.fixed_appearance_train_dir)
    resume_step = utils.load_global_step_from_checkpoint_dir(
        opts.fixed_appearance_train_dir)
  else:
    tf.logging.warning('***** Loading resume_step (if any) from %s!' %
                       model_dir)
    resume_step = utils.load_global_step_from_checkpoint_dir(model_dir)
  if resume_step != 0:
    tf.logging.warning('****** Resuming training at %d!' % resume_step)

  model_fn = build_model_fn()  # model function for TFEstimator.

  hooks = [utils.HookReport(1 << 12, opts.batch_size)]

  if opts.use_appearance and load_pretrained_app_encoder:
    tf.logging.warning('***** will warm-start from %s!' %
                       opts.appearance_pretrain_dir)
    ws = tf.estimator.WarmStartSettings(
        ckpt_to_initialize_from=opts.appearance_pretrain_dir,
        vars_to_warm_start='appearance_net/.*')
  elif opts.use_appearance and load_trained_fixed_app:
    tf.logging.warning('****** finetuning will warm-start from %s!' %
                       opts.fixed_appearance_train_dir)
    ws = tf.estimator.WarmStartSettings(
        ckpt_to_initialize_from=opts.fixed_appearance_train_dir,
        vars_to_warm_start='.*')
  else:
    ws = None
    tf.logging.warning('****** No warm-starting; using random initialization!')

  est = tf.estimator.Estimator(model_fn, model_dir, config, params={},
                               warm_start_from=ws)

  for next_kimg in range(opts.save_samples_kimg, opts.total_kimg + 1,
                         opts.save_samples_kimg):
    next_step = (next_kimg << 10) // opts.batch_size
    if opts.num_crops == -1:  # use random crops
      crop_seeds = None
    else:
      crop_seeds = list(100 * np.arange(opts.num_crops))
    input_train_fn = functools.partial(
        data.provide_data, dataset_name=dataset_name,
        parent_dir=dataset_parent_dir, subset='train',
        batch_size=opts.batch_size, crop_flag=True,
        crop_size=opts.train_resolution, seeds=crop_seeds,
        use_appearance=opts.use_appearance)
    est.train(input_train_fn, max_steps=next_step, hooks=hooks)
    tf.logging.info('DBG: kimg=%d, cur_step=%d' % (next_kimg, next_step))
    tf.logging.info('DBG: Saving a validation grid image %06d to %s' % (
        next_kimg, image_dir))
    make_sample_grid_and_save(est, dataset_name, dataset_parent_dir, (3, 3),
                              image_dir, next_kimg << 10)


def _build_inference_estimator(model_dir):
  model_fn = build_model_fn()
  est = tf.estimator.Estimator(model_fn, model_dir)
  return est


def evaluate_sequence(dataset_name, dataset_parent_dir, virtual_seq_name,
                      app_base_path):
  output_dir = osp.join(opts.train_dir, 'seq_output_%s' % virtual_seq_name)
  tf.gfile.MakeDirs(output_dir)
  est = _build_inference_estimator(opts.train_dir)
  visualize_image_sequence(est, dataset_name, dataset_parent_dir,
                           virtual_seq_name, app_base_path, output_dir)


def evaluate_image_set(dataset_name, dataset_parent_dir, subset_suffix,
                       output_dir=None, batch_size=6):
  if output_dir is None:
    output_dir = osp.join(opts.train_dir, 'validation_output_%s' % subset_suffix)
  tf.gfile.MakeDirs(output_dir)
  model_fn_old = build_model_fn()
  def model_fn_wrapper(features, labels, mode, params):
    del mode
    return model_fn_old(features, labels, 'eval_subset', params)
  model_dir = opts.train_dir
  est = tf.estimator.Estimator(model_fn_wrapper, model_dir)
  est_inp_fn = functools.partial(
      data.provide_data, dataset_name=dataset_name,
      parent_dir=dataset_parent_dir, subset=subset_suffix,
      batch_size=batch_size, use_appearance=opts.use_appearance, shuffle=0)

  print('Evaluating images for subset %s' % subset_suffix)
  images = [x for x in est.predict(est_inp_fn)]
  print('Evaluated %d images' % len(images))
  for i, img in enumerate(images):
    output_file_path = osp.join(output_dir, 'out_%04d.png' % i)
    print('Saving file #%d: %s' % (i, output_file_path))
    with tf.gfile.Open(output_file_path, 'wb') as f:
      f.write(utils.to_png(img))


def _load_and_concatenate_image_channels(rgb_path=None, rendered_path=None,
                                         depth_path=None, seg_path=None,
                                         size_multiple=64):
  """Prepares a single input for the network."""
  if (rgb_path is None and rendered_path is None and depth_path is None and
      seg_path is None):
    raise ValueError('At least one of the inputs has to be not None')

  channels = ()
  if rgb_path is not None:
    rgb_img = np.array(Image.open(rgb_path)).astype(np.float32)
    rgb_img = utils.crop_to_multiple(rgb_img, size_multiple)
    channels = channels + (rgb_img,)
  if rendered_path is not None:
    rendered_img = np.array(Image.open(rendered_path)).astype(np.float32)
    if not opts.use_alpha:
      rendered_img = rendered_img[:, :, :3]  # drop the alpha channel
    rendered_img = utils.crop_to_multiple(rendered_img, size_multiple)
    channels = channels + (rendered_img,)
  if depth_path is not None:
    depth_img = np.array(Image.open(depth_path))
    depth_img = depth_img.astype(np.float32)
    depth_img = utils.crop_to_multiple(depth_img[:, :, np.newaxis],
                                       size_multiple)
    channels = channels + (depth_img,)
    # depth_img = depth_img * (2.0 / 255) - 1.0
  if seg_path is not None:
    seg_img = np.array(Image.open(seg_path)).astype(np.float32)
    seg_img = utils.crop_to_multiple(seg_img, size_multiple)
    channels = channels + (seg_img,)
  # Concatenate and normalize channels
  img = np.dstack(channels)
  img = np.expand_dims(img, axis=0)
  img = img * (2.0 / 255) - 1.0
  return img


def infer_dir(model_dir, input_dir, output_dir):
  tf.gfile.MakeDirs(output_dir)
  est = _build_inference_estimator(opts.train_dir)

  def read_image(base_path, is_appearance=False):
    if is_appearance:
      ref_img_path = base_path + '_reference.png'
    else:
      ref_img_path = None
    rendered_img_path = base_path + '_color.png'
    depth_img_path = base_path + '_depth.png'
    seg_img_path = base_path + '_seg_rgb.png'
    img = _load_and_concatenate_image_channels(
        rgb_path=ref_img_path, rendered_path=rendered_img_path,
        depth_path=depth_img_path, seg_path=seg_img_path)
    return img

  def get_inference_input_fn(base_path, app_base_path):
    x_in = read_image(base_path, False)
    x_app_in = read_image(app_base_path, True)
    def infer_input_fn():
      return {'conditional_input': x_in, 'peek_input': x_app_in}
    return infer_input_fn

  file_paths = sorted(glob.glob(osp.join(input_dir, '*_depth.png')))
  base_paths = [x[:-10] for x in file_paths]  # remove the '_depth.png' suffix
  for inp_base_path in base_paths:
    est_inp_fn = get_inference_input_fn(inp_base_path, inp_base_path)
    img = next(est.predict(est_inp_fn))
    basename = osp.basename(inp_base_path)
    output_img_path = osp.join(output_dir, basename + '_out.png')
    print('Saving generated image to %s' % output_img_path)
    with tf.gfile.Open(output_img_path, 'wb') as f:
      f.write(utils.to_png(img))


def joint_interpolation(model_dir, app_input_dir, st_app_basename,
                        end_app_basename, camera_path_dir):
  """
  Interpolates both viewpoint and appearance between two input images.
  """
  # Create output direcotry
  output_dir = osp.join(model_dir, 'joint_interpolation_out')
  tf.gfile.MakeDirs(output_dir)

  # Build estimator
  model_fn_old = build_model_fn()
  def model_fn_wrapper(features, labels, mode, params):
    del mode
    return model_fn_old(features, labels, 'interpolate_appearance', params)
  def appearance_model_fn(features, labels, mode, params):
    del mode
    return model_fn_old(features, labels, 'compute_appearance', params)
  config = tf.estimator.RunConfig(
      save_summary_steps=1000, save_checkpoints_steps=50000,
      keep_checkpoint_max=50, log_step_count_steps=1 << 30)
  model_dir = model_dir
  est = tf.estimator.Estimator(model_fn_wrapper, model_dir, config, params={})
  est_app = tf.estimator.Estimator(appearance_model_fn, model_dir, config,
                                   params={})

  # Compute appearance embeddings for the two input appearance images.
  app_inputs = []
  for app_basename in [st_app_basename, end_app_basename]:
    app_rgb_path = osp.join(app_input_dir, app_basename + '_reference.png')
    app_rendered_path = osp.join(app_input_dir, app_basename + '_color.png')
    app_depth_path = osp.join(app_input_dir, app_basename + '_depth.png')
    app_seg_path = osp.join(app_input_dir, app_basename + '_seg_rgb.png')
    app_in = _load_and_concatenate_image_channels(
        rgb_path=app_rgb_path, rendered_path=app_rendered_path,
        depth_path=app_depth_path, seg_path=app_seg_path)
    # app_inputs.append(tf.convert_to_tensor(app_in))
    app_inputs.append(app_in)

  embedding1 = next(est_app.predict(
      lambda: {'peek_input': app_inputs[0]}))
  embedding1 = np.expand_dims(embedding1, axis=0)
  embedding2 = next(est_app.predict(
      lambda: {'peek_input': app_inputs[1]}))
  embedding2 = np.expand_dims(embedding2, axis=0)

  file_paths = sorted(glob.glob(osp.join(camera_path_dir, '*_depth.png')))
  base_paths = [x[:-10] for x in file_paths]  # remove the '_depth.png' suffix

  # Compute interpolated appearance embeddings
  num_interpolations = len(base_paths)
  interpolated_embeddings = []
  delta_vec = (embedding2 - embedding1) / (num_interpolations - 1)
  for delta_iter in range(num_interpolations):
    x_app_embedding = embedding1 + delta_iter * delta_vec
    interpolated_embeddings.append(x_app_embedding)

  # Generate and save interpolated images
  for frame_idx, embedding in enumerate(interpolated_embeddings):
    # Read in input frame
    frame_render_path = osp.join(base_paths[frame_idx] + '_color.png')
    frame_depth_path = osp.join(base_paths[frame_idx] + '_depth.png')
    frame_seg_path = osp.join(base_paths[frame_idx] + '_seg_rgb.png')
    x_in = _load_and_concatenate_image_channels(
        rgb_path=None, rendered_path=frame_render_path,
        depth_path=frame_depth_path, seg_path=frame_seg_path)

    img = next(est.predict(
        lambda: {'conditional_input': tf.convert_to_tensor(x_in),
                 'appearance_embedding': tf.convert_to_tensor(embedding)}))
    output_img_name = '%s_%s_%03d.png' % (st_app_basename, end_app_basename,
                                          frame_idx)
    output_img_path = osp.join(output_dir, output_img_name)
    print('Saving interpolated image to %s' % output_img_path)
    with tf.gfile.Open(output_img_path, 'wb') as f:
      f.write(utils.to_png(img))


def interpolate_appearance(model_dir, input_dir, target_img_basename,
                           appearance_img1_basename, appearance_img2_basename):
  # Create output direcotry
  output_dir = osp.join(model_dir, 'interpolate_appearance_out')
  tf.gfile.MakeDirs(output_dir)

  # Build estimator
  model_fn_old = build_model_fn()
  def model_fn_wrapper(features, labels, mode, params):
    del mode
    return model_fn_old(features, labels, 'interpolate_appearance', params)
  def appearance_model_fn(features, labels, mode, params):
    del mode
    return model_fn_old(features, labels, 'compute_appearance', params)
  config = tf.estimator.RunConfig(
      save_summary_steps=1000, save_checkpoints_steps=50000,
      keep_checkpoint_max=50, log_step_count_steps=1 << 30)
  model_dir = model_dir
  est = tf.estimator.Estimator(model_fn_wrapper, model_dir, config, params={})
  est_app = tf.estimator.Estimator(appearance_model_fn, model_dir, config,
                                   params={})

  # Compute appearance embeddings for the two input appearance images.
  app_inputs = []
  for app_basename in [appearance_img1_basename, appearance_img2_basename]:
    app_rgb_path = osp.join(input_dir, app_basename + '_reference.png')
    app_rendered_path = osp.join(input_dir, app_basename + '_color.png')
    app_depth_path = osp.join(input_dir, app_basename + '_depth.png')
    app_seg_path = osp.join(input_dir, app_basename + '_seg_rgb.png')
    app_in = _load_and_concatenate_image_channels(
        rgb_path=app_rgb_path, rendered_path=app_rendered_path,
        depth_path=app_depth_path, seg_path=app_seg_path)
    # app_inputs.append(tf.convert_to_tensor(app_in))
    app_inputs.append(app_in)

  embedding1 = next(est_app.predict(
      lambda: {'peek_input': app_inputs[0]}))
  embedding2 = next(est_app.predict(
      lambda: {'peek_input': app_inputs[1]}))
  embedding1 = np.expand_dims(embedding1, axis=0)
  embedding2 = np.expand_dims(embedding2, axis=0)

  # Compute interpolated appearance embeddings
  num_interpolations = 10
  interpolated_embeddings = []
  delta_vec = (embedding2 - embedding1) / num_interpolations
  for delta_iter in range(num_interpolations + 1):
    x_app_embedding = embedding1 + delta_iter * delta_vec
    interpolated_embeddings.append(x_app_embedding)

  # Read in the generator input for the target image to render
  rendered_img_path = osp.join(input_dir, target_img_basename + '_color.png')
  depth_img_path = osp.join(input_dir, target_img_basename + '_depth.png')
  seg_img_path = osp.join(input_dir, target_img_basename + '_seg_rgb.png')
  x_in = _load_and_concatenate_image_channels(
      rgb_path=None, rendered_path=rendered_img_path,
      depth_path=depth_img_path, seg_path=seg_img_path)

  # Generate and save interpolated images
  for interpolate_iter, embedding in enumerate(interpolated_embeddings):
    img = next(est.predict(
        lambda: {'conditional_input': tf.convert_to_tensor(x_in),
                 'appearance_embedding': tf.convert_to_tensor(embedding)}))
    output_img_name = 'interpolate_%s_%s_%s_%03d.png' % (
        target_img_basename, appearance_img1_basename, appearance_img2_basename,
        interpolate_iter)
    output_img_path = osp.join(output_dir, output_img_name)
    print('Saving interpolated image to %s' % output_img_path)
    with tf.gfile.Open(output_img_path, 'wb') as f:
      f.write(utils.to_png(img))


def main(argv):
  del argv
  configs_str = options.list_options()
  tf.gfile.MakeDirs(opts.train_dir)
  with tf.gfile.Open(osp.join(opts.train_dir, 'configs.txt'), 'wb') as f:
    f.write(configs_str)
  tf.logging.info('Local configs\n%s' % configs_str)

  if opts.run_mode == 'train':
    dataset_name = opts.dataset_name
    dataset_parent_dir = opts.dataset_parent_dir
    load_pretrained_app_encoder = opts.load_pretrained_app_encoder
    load_trained_fixed_app = opts.load_from_another_ckpt
    batch_size = opts.batch_size
    train(dataset_name, dataset_parent_dir, load_pretrained_app_encoder,
          load_trained_fixed_app)
  elif opts.run_mode == 'eval':  # generate a camera path output sequence from TFRecord inputs.
    dataset_name = opts.dataset_name
    dataset_parent_dir = opts.dataset_parent_dir
    virtual_seq_name = opts.virtual_seq_name
    inp_app_img_base_path = opts.inp_app_img_base_path
    evaluate_sequence(dataset_name, dataset_parent_dir, virtual_seq_name,
                      inp_app_img_base_path)
  elif opts.run_mode == 'eval_subset':  # generate output for validation set (encoded as TFRecords)
    dataset_name = opts.dataset_name
    dataset_parent_dir = opts.dataset_parent_dir
    virtual_seq_name = opts.virtual_seq_name
    evaluate_image_set(dataset_name, dataset_parent_dir, virtual_seq_name,
                       opts.output_validation_dir, opts.batch_size)
  elif opts.run_mode == 'eval_dir':  # evaluate output for a directory with input images
    input_dir = opts.inference_input_path
    output_dir = opts.inference_output_dir
    model_dir = opts.train_dir
    infer_dir(model_dir, input_dir, output_dir)
  elif opts.run_mode == 'interpolate_appearance':  # interpolate appearance only between two images.
    model_dir = opts.train_dir
    input_dir = opts.inference_input_path
    target_img_basename = opts.target_img_basename
    app_img1_basename = opts.appearance_img1_basename
    app_img2_basename = opts.appearance_img2_basename
    interpolate_appearance(model_dir, input_dir, target_img_basename,
                           app_img1_basename, app_img2_basename)
  elif opts.run_mode == 'joint_interpolation':  # interpolate viewpoint and appearance between two images
    model_dir = opts.train_dir
    app_input_dir = opts.inference_input_path
    st_app_basename = opts.appearance_img1_basename
    end_app_basename = opts.appearance_img2_basename
    frames_dir = opts.frames_dir
    joint_interpolation(model_dir, app_input_dir, st_app_basename,
                        end_app_basename, frames_dir)
  else:
    raise ValueError('Unsupported --run_mode %s' % opts.run_mode)


if __name__ == '__main__':
  app.run(main)


================================================
FILE: options.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from absl import flags
import numpy as np

FLAGS = flags.FLAGS

# ------------------------------------------------------------------------------
# Train flags
# ------------------------------------------------------------------------------

# Dataset, model directory and run mode
flags.DEFINE_string('train_dir', '/tmp/nerual_rendering',
                    'Directory for model training.')
flags.DEFINE_string('dataset_name', 'sanmarco9k', 'name ID for a dataset.')
flags.DEFINE_string(
    'dataset_parent_dir', '',
    'Directory containing generated tfrecord dataset.')
flags.DEFINE_string('run_mode', 'train', "{'train', 'eval', 'infer'}")
flags.DEFINE_string('imageset_dir', None, 'Directory containing trainset '
                    'images for appearance pretraining.')
flags.DEFINE_string('metadata_output_dir', None, 'Directory to save pickled '
                    'pairwise distance matrix for appearance pretraining.')
flags.DEFINE_integer('save_samples_kimg', 50, 'kimg cycle to save sample'
                     'validation ouptut during training.')

# Network inputs/outputs
flags.DEFINE_boolean('use_depth', True, 'Add depth image to the deep buffer.')
flags.DEFINE_boolean('use_alpha', False,
                     'Add alpha channel to the deep buffer.')
flags.DEFINE_boolean('use_semantic', True,
                     'Add semantic map to the deep buffer.')
flags.DEFINE_boolean('use_appearance', True,
                     'Capture appearance from an input real image.')
flags.DEFINE_integer('deep_buffer_nc', 7,
                     'Number of input channels in the deep buffer.')
flags.DEFINE_integer('appearance_nc', 10,
                     'Number of input channels to the appearance encoder.')
flags.DEFINE_integer('output_nc', 3,
                     'Number of channels for the generated image.')

# Staged training flags
flags.DEFINE_string(
    'vgg16_path', './vgg16_weights/vgg16.npy',
    'path to a *.npy file with vgg16 pretrained weights')
flags.DEFINE_boolean('load_pretrained_app_encoder', False,
                     'Warmstart appearance encoder with pretrained weights.')
flags.DEFINE_string('appearance_pretrain_dir', '',
                    'Model dir for the pretrained appearance encoder.')
flags.DEFINE_boolean('train_app_encoder', False, 'Whether to make the weights '
                     'for the appearance encoder trainable or not.')
flags.DEFINE_boolean(
    'load_from_another_ckpt', False, 'Load weights from another trained model, '
                     'e.g load model trained with a fixed appearance encoder.')
flags.DEFINE_string('fixed_appearance_train_dir', '',
                    'Model dir for training G with a fixed appearance net.')

# -----------------------------------------------------------------------------

# More hparams
flags.DEFINE_integer('train_resolution', 256,
                     'Crop train images to this resolution.')
flags.DEFINE_float('d_lr', 0.001, 'Learning rate for the discriminator.')
flags.DEFINE_float('g_lr', 0.001, 'Learning rate for the generator.')
flags.DEFINE_float('ez_lr', 0.0001, 'Learning rate for appearance encoder.')
flags.DEFINE_integer('batch_size', 8, 'Batch size for training.')
flags.DEFINE_boolean('use_scaling', True, "use He's scaling.")
flags.DEFINE_integer('num_crops', 30, 'num crops from train images'
                     '(use -1 for random crops).')
flags.DEFINE_integer('app_vector_size', 8, 'Size of latent appearance vector.')
flags.DEFINE_integer('total_kimg', 20000,
                     'Max number (in kilo) of training images for training.')
flags.DEFINE_float('adam_beta1', 0.0, 'beta1 for adam optimizer.')
flags.DEFINE_float('adam_beta2', 0.99, 'beta2 for adam optimizer.')

# Loss weights
flags.DEFINE_float('w_loss_vgg', 0.3, 'VGG loss weight.')
flags.DEFINE_float('w_loss_feat', 10., 'Feature loss weight (from pix2pixHD).')
flags.DEFINE_float('w_loss_l1', 50., 'L1 loss weight.')
flags.DEFINE_float('w_loss_z_recon', 10., 'Z reconstruction loss weight.')
flags.DEFINE_float('w_loss_gan', 1., 'Adversarial loss weight.')
flags.DEFINE_float('w_loss_z_gan', 1., 'Z adversarial loss weight.')
flags.DEFINE_float('w_loss_kl', 0.01, 'KL divergence weight.')
flags.DEFINE_float('w_loss_l2_reg', 0.01, 'Weight for L2 regression on Z.')

# -----------------------------------------------------------------------------

# Architecture and training setup
flags.DEFINE_string('arch_type', 'pggan',
                    'Architecture type: {pggan, pix2pixhd}.')
flags.DEFINE_string('training_pipeline', 'staged',
                    'Training type type: {staged, bicycle_gan, drit}.')
flags.DEFINE_integer('g_nf', 64,
                     'num filters in the first/last layers of U-net.')
flags.DEFINE_boolean('concatenate_skip_layers', True,
                     'Use concatenation for skip connections.')

## if arch_type == 'pggan':
flags.DEFINE_integer('pggan_n_blocks', 5,
                     'Num blocks for the pggan architecture.')
## if arch_type == 'pix2pixhd':
flags.DEFINE_integer('p2p_n_downsamples', 3,
                     'Num downsamples for the pix2pixHD architecture.')
flags.DEFINE_integer('p2p_n_resblocks', 4, 'Num residual blocks at the '
                     'end/start of the pix2pixHD encoder/decoder.')
## if use_drit_pipeline:
flags.DEFINE_boolean('use_concat', True, '"concat" mode from DRIT.')
flags.DEFINE_boolean('normalize_drit_Ez', True, 'Add pixelnorm layers to the '
                     'appearance encoder.')
flags.DEFINE_boolean('concat_z_in_all_layers', True, 'Inject z at each '
                     'upsampling layer in the decoder (only for DRIT baseline)')
flags.DEFINE_string('inject_z', 'to_bottleneck', 'Method for injecting z; '
                     'one of {to_encoder, to_bottleneck}.')
flags.DEFINE_boolean('use_vgg_loss', True, 'vgg v L1 reconstruction loss.')

# ------------------------------------------------------------------------------
# Inference flags
# ------------------------------------------------------------------------------

flags.DEFINE_string('inference_input_path', '',
                    'Parent directory for input images at inference time.')
flags.DEFINE_string('inference_output_dir', '', 'Output path for inference')
flags.DEFINE_string('target_img_basename', '',
                    'basename of target image to render for interpolation')
flags.DEFINE_string('virtual_seq_name', 'full_camera_path',
                    'name for the virtual camera path suffix for the TFRecord.')
flags.DEFINE_string('inp_app_img_base_path', '',
                    'base path for the input appearance image for camera paths')

flags.DEFINE_string('appearance_img1_basename', '',
                    'basename of the first appearance image for interpolation')
flags.DEFINE_string('appearance_img2_basename', '',
                    'basename of the first appearance image for interpolation')
flags.DEFINE_list('input_basenames', [], 'input basenames for inference')
flags.DEFINE_list('input_app_basenames', [], 'input appearance basenames for '
                  'inference')
flags.DEFINE_string('frames_dir', '',
                    'Folder with input frames to a camera path')
flags.DEFINE_string('output_validation_dir', '',
                    'dataset_name for storing results in a structured folder')
flags.DEFINE_string('input_rendered', '',
                    'input rendered image name for inference')
flags.DEFINE_string('input_depth', '', 'input depth image name for inference')
flags.DEFINE_string('input_seg', '',
                    'input segmentation mask image name for inference')
flags.DEFINE_string('input_app_rgb', '',
                    'input appearance rgb image name for inference')
flags.DEFINE_string('input_app_rendered', '',
                    'input appearance rendered image name for inference')
flags.DEFINE_string('input_app_depth', '',
                    'input appearance depth image name for inference')
flags.DEFINE_string('input_app_seg', '',
                    'input appearance segmentation mask image name for'
                    'inference')
flags.DEFINE_string('output_img_name', '',
                    '[OPTIONAL] output image name for inference')

# -----------------------------------------------------------------------------
# Some validation and assertions
# -----------------------------------------------------------------------------

def validate_options():
  if FLAGS.use_drit_training:
    assert FLAGS.use_appearance, 'DRIT pipeline requires --use_appearance'
  assert not (
    FLAGS.load_pretrained_appearance_encoder and FLAGS.load_from_another_ckpt), (
      'You cannot load weights for the appearance encoder from two different '
      'checkpoints!')
  if not FLAGS.use_appearance:
    print('**Warning: setting --app_vector_size to 0 since '
          '--use_appearance=False!')
    FLAGS.set_default('app_vector_size', 0)
  
# -----------------------------------------------------------------------------
# Print all options
# -----------------------------------------------------------------------------

def list_options():
  configs = ('# Run flags/options from options.py:\n'
             '# ----------------------------------\n')
  configs += ('## Train flags:\n'
              '## ------------\n')
  configs += 'train_dir = %s\n' % FLAGS.train_dir
  configs += 'dataset_name = %s\n' % FLAGS.dataset_name
  configs += 'dataset_parent_dir = %s\n' % FLAGS.dataset_parent_dir
  configs += 'run_mode = %s\n' % FLAGS.run_mode
  configs += 'save_samples_kimg = %d\n' % FLAGS.save_samples_kimg
  configs += '\n# --------------------------------------------------------\n\n'

  configs += ('## Network inputs and outputs:\n'
              '## ---------------------------\n')
  configs += 'use_depth = %s\n' % str(FLAGS.use_depth)
  configs += 'use_alpha = %s\n' % str(FLAGS.use_alpha)
  configs += 'use_semantic = %s\n' % str(FLAGS.use_semantic)
  configs += 'use_appearance = %s\n' % str(FLAGS.use_appearance)
  configs += 'deep_buffer_nc = %d\n' % FLAGS.deep_buffer_nc
  configs += 'appearance_nc = %d\n' % FLAGS.appearance_nc
  configs += 'output_nc = %d\n' % FLAGS.output_nc
  configs += 'train_resolution = %d\n' % FLAGS.train_resolution
  configs += '\n# --------------------------------------------------------\n\n'

  configs += ('## Staged training flags:\n'
              '## ----------------------\n')
  configs += 'load_pretrained_app_encoder = %s\n' % str(
                                            FLAGS.load_pretrained_app_encoder)
  configs += 'appearance_pretrain_dir = %s\n' % FLAGS.appearance_pretrain_dir
  configs += 'train_app_encoder = %s\n' % str(FLAGS.train_app_encoder)
  configs += 'load_from_another_ckpt = %s\n' % str(FLAGS.load_from_another_ckpt)
  configs += 'fixed_appearance_train_dir = %s\n' % str(
                                            FLAGS.fixed_appearance_train_dir)
  configs += '\n# --------------------------------------------------------\n\n'

  configs += ('## More hyper-parameters:\n'
              '## ----------------------\n')
  configs += 'd_lr = %f\n' % FLAGS.d_lr
  configs += 'g_lr = %f\n' % FLAGS.g_lr
  configs += 'ez_lr = %f\n' % FLAGS.ez_lr
  configs += 'batch_size = %d\n' % FLAGS.batch_size
  configs += 'use_scaling = %s\n' % str(FLAGS.use_scaling)
  configs += 'num_crops = %d\n' % FLAGS.num_crops
  configs += 'app_vector_size = %d\n' % FLAGS.app_vector_size
  configs += 'total_kimg = %d\n' % FLAGS.total_kimg
  configs += 'adam_beta1 = %f\n' % FLAGS.adam_beta1
  configs += 'adam_beta2 = %f\n' % FLAGS.adam_beta2
  configs += '\n# --------------------------------------------------------\n\n'

  configs += ('## Loss weights:\n'
              '## -------------\n')
  configs += 'w_loss_vgg = %f\n' % FLAGS.w_loss_vgg
  configs += 'w_loss_feat = %f\n' % FLAGS.w_loss_feat
  configs += 'w_loss_l1 = %f\n' % FLAGS.w_loss_l1
  configs += 'w_loss_z_recon = %f\n' % FLAGS.w_loss_z_recon
  configs += 'w_loss_gan = %f\n' % FLAGS.w_loss_gan
  configs += 'w_loss_z_gan = %f\n' % FLAGS.w_loss_z_gan
  configs += 'w_loss_kl = %f\n' % FLAGS.w_loss_kl
  configs += 'w_loss_l2_reg = %f\n' % FLAGS.w_loss_l2_reg
  configs += '\n# --------------------------------------------------------\n\n'

  configs += ('## Architecture and training setup:\n'
              '## --------------------------------\n')
  configs += 'arch_type = %s\n' % FLAGS.arch_type
  configs += 'training_pipeline = %s\n' % FLAGS.training_pipeline
  configs += 'g_nf = %d\n' % FLAGS.g_nf
  configs += 'concatenate_skip_layers = %s\n' % str(
                                                FLAGS.concatenate_skip_layers)
  configs += 'p2p_n_downsamples = %d\n' % FLAGS.p2p_n_downsamples
  configs += 'p2p_n_resblocks = %d\n' % FLAGS.p2p_n_resblocks
  configs += 'use_concat = %s\n' % str(FLAGS.use_concat)
  configs += 'normalize_drit_Ez = %s\n' % str(FLAGS.normalize_drit_Ez)
  configs += 'inject_z = %s\n' % FLAGS.inject_z
  configs += 'concat_z_in_all_layers = %s\n' % str(FLAGS.concat_z_in_all_layers)
  configs += 'use_vgg_loss = %s\n' % str(FLAGS.use_vgg_loss)
  configs += '\n# --------------------------------------------------------\n\n'

  return configs


================================================
FILE: pretrain_appearance.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from PIL import Image
from absl import app
from absl import flags
from options import FLAGS as opts
import glob
import networks
import numpy as np
import os
import os.path as osp
import pickle
import style_loss
import tensorflow as tf
import utils


def _load_and_concatenate_image_channels(
    rgb_path=None, rendered_path=None, depth_path=None, seg_path=None,
    crop_size=512):
  if (rgb_path is None and rendered_path is None and depth_path is None and
      seg_path is None):
    raise ValueError('At least one of the inputs has to be not None')

  channels = ()
  if rgb_path is not None:
    rgb_img = np.array(Image.open(rgb_path)).astype(np.float32)
    rgb_img = utils.get_central_crop(rgb_img, crop_size, crop_size)
    channels = channels + (rgb_img,)
  if rendered_path is not None:
    rendered_img = np.array(Image.open(rendered_path)).astype(np.float32)
    rendered_img = utils.get_central_crop(rendered_img, crop_size, crop_size)
    if not opts.use_alpha:
      rendered_img = rendered_img[:,:, :3]  # drop the alpha channel
    channels = channels + (rendered_img,)
  if depth_path is not None:
    depth_img = np.array(Image.open(depth_path))
    depth_img = depth_img.astype(np.float32)
    depth_img = utils.get_central_crop(depth_img, crop_size, crop_size)
    channels = channels + (depth_img,)
  if seg_path is not None:
    seg_img = np.array(Image.open(seg_path)).astype(np.float32)
    channels = channels + (seg_img,)
  # Concatenate and normalize channels
  img = np.dstack(channels)
  img = img * (2.0 / 255) - 1.0
  return img


def read_single_appearance_input(rgb_img_path):
  base_path = rgb_img_path[:-14]  # remove the '_reference.png' suffix
  rendered_img_path = base_path + '_color.png'
  depth_img_path = base_path + '_depth.png'
  semantic_img_path = base_path + '_seg_rgb.png'
  network_input_img = _load_and_concatenate_image_channels(
      rgb_img_path, rendered_img_path, depth_img_path, semantic_img_path,
      crop_size=opts.train_resolution)
  return network_input_img


def get_triplet_input_fn(dataset_path, dist_file_path=None, k_max_nearest=5,
                         k_max_farthest=13):
  input_images_pattern = osp.join(dataset_path, '*_reference.png')
  filenames = sorted(glob.glob(input_images_pattern))
  print('DBG: obtained %d input filenames for triplet inputs' % len(filenames))
  print('DBG: Computing pairwise style distances:')
  if dist_file_path is not None and osp.exists(dist_file_path):
    print('*** Loading distance matrix from %s' % dist_file_path)
    with open(dist_file_path, 'rb') as f:
      dist_matrix = pickle.load(f)['dist_matrix']
      print('loaded a dist_matrix of shape: %s' % str(dist_matrix.shape))
  else:
    dist_matrix = style_loss.compute_pairwise_style_loss_v2(filenames)
    dist_dict = {'dist_matrix': dist_matrix}
    print('Saving distance matrix to %s' % dist_file_path)
    with open(dist_file_path, 'wb') as f:
      pickle.dump(dist_dict, f)

  # Sort neighbors for each anchor image
  num_imgs = len(dist_matrix)
  sorted_neighbors = [np.argsort(dist_matrix[ii, :]) for ii in range(num_imgs)]

  def triplet_input_fn(anchor_idx):
    # start from 1 to avoid getting the same image as its own neighbor
    positive_neighbor_idx = np.random.randint(1, k_max_nearest + 1)
    negative_neighbor_idx = num_imgs - 1 - np.random.randint(0, k_max_farthest)
    positive_img_idx = sorted_neighbors[anchor_idx][positive_neighbor_idx]
    negative_img_idx = sorted_neighbors[anchor_idx][negative_neighbor_idx]
    # Read anchor image
    anchor_rgb_path = osp.join(dataset_path, filenames[anchor_idx])
    anchor_input = read_single_appearance_input(anchor_rgb_path)
    # Read positive image
    positive_rgb_path = osp.join(dataset_path, filenames[positive_img_idx])
    positive_input = read_single_appearance_input(positive_rgb_path)
    # Read negative image
    negative_rgb_path = osp.join(dataset_path, filenames[negative_img_idx])
    negative_input = read_single_appearance_input(negative_rgb_path)
    # Return triplet
    return anchor_input, positive_input, negative_input

  return triplet_input_fn


def get_tf_triplet_dataset_iter(
    dataset_path, trainset_size, dist_file_path, batch_size=4,
    deterministic_flag=False, shuffle_buf_size=128, repeat_flag=True):
  # Create a dataset of anchor image indices.
  idx_dataset = tf.data.Dataset.range(trainset_size)
  # Create a mapper function from anchor idx to triplet images.
  triplet_mapper = lambda idx: tuple(tf.py_func(
      get_triplet_input_fn(dataset_path, dist_file_path), [idx],
      [tf.float32, tf.float32, tf.float32]))
  # Convert triplet to a dictionary for the estimator input format.
  triplet_to_dict_mapper = lambda anchor, pos, neg: {
      'anchor_img': anchor, 'positive_img': pos, 'negative_img': neg}
  if repeat_flag:
    idx_dataset = idx_dataset.repeat()  # Repeat indefinitely.
  if not deterministic_flag:
    idx_dataset = idx_dataset.shuffle(shuffle_buf_size)
    triplet_dataset = idx_dataset.map(
        triplet_mapper, num_parallel_calls=max(4, batch_size // 4))
    triplet_dataset = triplet_dataset.map(
        triplet_to_dict_mapper, num_parallel_calls=max(4, batch_size // 4))
  else:
    triplet_dataset = idx_dataset.map(triplet_mapper, num_parallel_calls=None)
    triplet_dataset = triplet_dataset.map(triplet_to_dict_mapper,
                                          num_parallel_calls=None)
  triplet_dataset = triplet_dataset.batch(batch_size)
  if not deterministic_flag:
    triplet_dataset = triplet_dataset.prefetch(4)  # Prefetch a few batches.
  return triplet_dataset.make_one_shot_iterator()


def build_model_fn(batch_size, lr_app_pretrain=0.0001, adam_beta1=0.0,
                   adam_beta2=0.99):
  def model_fn(features, labels, mode, params):
    del labels, params

    step = tf.train.get_global_step()
    app_func = networks.DRITAppearanceEncoderConcat(
      'appearance_net', opts.appearance_nc, opts.normalize_drit_Ez)

    if mode == tf.estimator.ModeKeys.TRAIN:
      op_increment_step = tf.assign_add(step, 1)
      with tf.name_scope('Appearance_Loss'):
        anchor_img = features['anchor_img']
        positive_img = features['positive_img']
        negative_img = features['negative_img']
        # Compute embeddings (each of shape [batch_sz, 1, 1, app_vector_sz])
        z_anchor, _, _ = app_func(anchor_img)
        z_pos, _, _ = app_func(positive_img)
        z_neg, _, _ = app_func(negative_img)
        # Squeeze into shape of [batch_sz x vec_sz]
        anchor_embedding = tf.squeeze(z_anchor, axis=[1, 2], name='z_anchor')
        positive_embedding = tf.squeeze(z_pos, axis=[1, 2])
        negative_embedding = tf.squeeze(z_neg, axis=[1, 2])
        # Compute triplet loss
        margin = 0.1
        anchor_positive_dist = tf.reduce_sum(
            tf.square(anchor_embedding - positive_embedding), axis=1)
        anchor_negative_dist = tf.reduce_sum(
            tf.square(anchor_embedding - negative_embedding), axis=1)
        triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
        triplet_loss = tf.maximum(triplet_loss, 0.)
        triplet_loss = tf.reduce_sum(triplet_loss) / batch_size
        tf.summary.scalar('appearance_triplet_loss', triplet_loss)

        # Image summaries
        anchor_rgb = tf.slice(anchor_img, [0, 0, 0, 0], [-1, -1, -1, 3])
        positive_rgb = tf.slice(positive_img, [0, 0, 0, 0], [-1, -1, -1, 3])
        negative_rgb = tf.slice(negative_img, [0, 0, 0, 0], [-1, -1, -1, 3])
        tb_vis = tf.concat([anchor_rgb, positive_rgb, negative_rgb], axis=2)
        with tf.name_scope('triplet_vis'):
          tf.summary.image('anchor-pos-neg', tb_vis)

      optimizer = tf.train.AdamOptimizer(lr_app_pretrain, adam_beta1,
                                         adam_beta2)
      optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)
      app_vars = utils.model_vars('appearance_net')[0]
      print('\n\n***************************************************')
      print('DBG: len(app_vars) = %d' % len(app_vars))
      for ii, v in enumerate(app_vars):
        print('%03d) %s' % (ii, str(v)))
      print('***************************************************\n\n')
      app_train_op = optimizer.minimize(triplet_loss, var_list=app_vars)
      return tf.estimator.EstimatorSpec(
          mode=mode, loss=triplet_loss,
          train_op=tf.group(app_train_op, op_increment_step))
    elif mode == tf.estimator.ModeKeys.PREDICT:
      imgs = features['anchor_img']
      embeddings = tf.squeeze(app_func(imgs), axis=[1, 2])
      app_vars = utils.model_vars('appearance_net')[0]
      tf.train.init_from_checkpoint(osp.join(opts.train_dir),
                                    {'appearance_net/': 'appearance_net/'})
      return tf.estimator.EstimatorSpec(mode=mode, predictions=embeddings)
    else:
      raise ValueError('Unsupported mode for the appearance model: ' + mode)

  return model_fn


def compute_dist_matrix(imageset_dir, dist_file_path, recompute_dist=False):
  if not recompute_dist and osp.exists(dist_file_path):
   print('*** Loading distance matrix from %s' % dist_file_path)
   with open(dist_file_path, 'rb') as f:
     dist_matrix = pickle.load(f)['dist_matrix']
     print('loaded a dist_matrix of shape: %s' % str(dist_matrix.shape))
     return dist_matrix
  else:
    images_paths = sorted(glob.glob(osp.join(imageset_dir, '*_reference.png')))
    dist_matrix = style_loss.compute_pairwise_style_loss_v2(images_paths)
    dist_dict = {'dist_matrix': dist_matrix}
    print('Saving distance matrix to %s' % dist_file_path)
    with open(dist_file_path, 'wb') as f:
      pickle.dump(dist_dict, f)
    return dist_matrix


def train_appearance(train_dir, imageset_dir, dist_file_path):
  batch_size = 8
  lr_app_pretrain = 0.001

  trainset_size = len(glob.glob(osp.join(imageset_dir, '*_reference.png')))
  resume_step = utils.load_global_step_from_checkpoint_dir(train_dir)
  if resume_step != 0:
    tf.logging.warning('DBG: resuming apperance pretraining at %d!' %
                       resume_step)
  model_fn = build_model_fn(batch_size, lr_app_pretrain)
  config = tf.estimator.RunConfig(
      save_summary_steps=50,
      save_checkpoints_steps=500,
      keep_checkpoint_max=5,
      log_step_count_steps=100)
  est = tf.estimator.Estimator(
      tf.contrib.estimator.replicate_model_fn(model_fn), train_dir,
      config, params={})
  # Get input function
  input_train_fn = lambda: get_tf_triplet_dataset_iter(
      imageset_dir, trainset_size, dist_file_path,
      batch_size=batch_size).get_next()
  print('Starting pretraining steps...')
  est.train(input_train_fn, steps=None, hooks=None)  # train indefinitely


def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  train_dir = opts.train_dir
  dataset_name = opts.dataset_name
  imageset_dir = opts.imageset_dir
  output_dir = opts.metadata_output_dir
  if not osp.exists(output_dir):
    os.makedirs(output_dir)
  dist_file_path = osp.join(output_dir, 'dist_%s.pckl' % dataset_name)
  compute_dist_matrix(imageset_dir, dist_file_path)
  train_appearance(train_dir, imageset_dir, dist_file_path)

if __name__ == '__main__':
  app.run(main)


================================================
FILE: segment_dataset.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Generate semantic segmentations
This module uses Xception model trained on ADE20K dataset to generate semantic
segmentation mask to any set of images.
"""

from absl import app
from absl import flags
from PIL import Image
import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import os.path as osp
import shutil
import tensorflow as tf
import utils


def get_semantic_color_coding():
  """
  assigns the 30 (actually 29) semantic colors from cityscapes semantic mapping
  to selected classes from the ADE20K150 semantic classes.
  """
  # Below are the 30 cityscape colors (one is duplicate. so total is 29 not 30)
  colors = [
    (111, 74,  0),
    ( 81,  0, 81),
    (128, 64,128),
    (244, 35,232),
    (250,170,160),
    (230,150,140),
    ( 70, 70, 70),
    (102,102,156),
    (190,153,153),
    (180,165,180),
    (150,100,100),
    (150,120, 90),
    (153,153,153),
    # (153,153,153),
    (250,170, 30),
    (220,220,  0),
    (107,142, 35),
    (152,251,152),
    ( 70,130,180),
    (220, 20, 60),
    (255,  0,  0),
    (  0,  0,142),
    (  0,  0, 70),
    (  0, 60,100),
    (  0,  0, 90),
    (  0,  0,110),
    (  0, 80,100),
    (  0,  0,230),
    (119, 11, 32),
    (  0,  0,142)]
  k_num_ade20k_classes = 150
  # initially all 150 classes are mapped to a single color (last color idx: -1)
  # Some classes are to be assigned independent colors
  # semantic classes are 1-based (1 thru 150)
  semantic_to_color_idx = -1 * np.ones(k_num_ade20k_classes + 1, dtype=int)
  semantic_to_color_idx [1] = 0    # wall
  semantic_to_color_idx [2] = 1    # building;edifice
  semantic_to_color_idx [3] = 2    # sky
  semantic_to_color_idx [105] = 3  # fountain
  semantic_to_color_idx [27] = 4   # sea
  semantic_to_color_idx [60] = 5   # stairway;staircase 
  semantic_to_color_idx [5] = 6    # tree
  semantic_to_color_idx [12] = 7   # sidewalk;pavement 
  semantic_to_color_idx [4]  = 7   # floor;flooring
  semantic_to_color_idx [7]  = 7   # road;route
  semantic_to_color_idx [13] = 8   # people
  semantic_to_color_idx [18] = 9   # plant;flora;plant;life
  semantic_to_color_idx [17] = 10  # mountain;mount
  semantic_to_color_idx [20] = 11  # chair
  semantic_to_color_idx [6] = 12   # ceiling
  semantic_to_color_idx [22] = 13  # water
  semantic_to_color_idx [35] = 14  # rock;stone
  semantic_to_color_idx [14] = 15  # earth;ground
  semantic_to_color_idx [10] = 16  # grass
  semantic_to_color_idx [70] = 17  # bench
  semantic_to_color_idx [54] = 18  # stairs;steps
  semantic_to_color_idx [101] = 19 # poster
  semantic_to_color_idx [77] = 20  # boat
  semantic_to_color_idx [85] = 21  # tower
  semantic_to_color_idx [23] = 22  # painting;picture
  semantic_to_color_idx [88] = 23  # streetlight;stree;lamp
  semantic_to_color_idx [43] = 24  # column;pillar
  semantic_to_color_idx [9] = 25   # window;windowpane
  semantic_to_color_idx [15] = 26  # door;
  semantic_to_color_idx [133] = 27 # sculpture

  semantic_to_rgb = np.array(
    [colors[col_idx][:] for col_idx in semantic_to_color_idx])
  return semantic_to_rgb


def _apply_colors(seg_images_path, save_dir, idx_to_color):
  for i, img_path in enumerate(seg_images_path):
    print('processing img #%05d / %05d: %s' % (i, len(seg_images_path),
                                               osp.split(img_path)[1]))
    seg = np.array(Image.open(img_path))
    seg_rgb = np.zeros(seg.shape + (3,), dtype=np.uint8)
    for col_idx in range(len(idx_to_color)):
      if idx_to_color[col_idx][0] != -1:
        mask = seg == col_idx
        seg_rgb[mask, :] = idx_to_color[col_idx][:]

    parent_dir, filename = osp.split(img_path)
    basename, ext = osp.splitext(filename)
    out_filename = basename + "_rgb.png"
    out_filepath = osp.join(save_dir, out_filename)
    # Save rescaled segmentation image
    Image.fromarray(seg_rgb).save(out_filepath)


# The frozen xception model only segments 512x512 images. But it would be better
# to segment the full image instead!
def segment_images(images_path, xception_frozen_graph_path, save_dir,
                   crop_height=512, crop_width=512):
  if not osp.exists(xception_frozen_graph_path):
    raise OSError('Xception frozen graph not found at %s' %
                            xception_frozen_graph_path)
  with tf.gfile.GFile(xception_frozen_graph_path, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

  with tf.Graph().as_default() as graph:
    new_input = tf.placeholder(tf.uint8, [1, crop_height, crop_width, 3],
                               name="new_input")
    tf.import_graph_def(
      graph_def,
      input_map={"ImageTensor:0": new_input},
      return_elements=None,
      name="sem_seg",
      op_dict=None,
      producer_op_list=None
    )

  corrupted_dir = osp.join(save_dir, 'corrupted')
  if not osp.exists(corrupted_dir):
    os.makedirs(corrupted_dir)
  with tf.Session(graph=graph) as sess:
    for i, img_path in enumerate(images_path):
      print('Segmenting image %05d / %05d: %s' % (i + 1, len(images_path),
                                                  img_path))
      img = np.array(Image.open(img_path))
      if len(img.shape) == 2 or img.shape[2] != 3:
        print('Warning! corrupted image %s' % img_path)
        img_base_path = img_path[:-14]  # remove the '_reference.png' suffix
        srcs = sorted(glob.glob(img_base_path + '_*'))
        dest = unicode(corrupted_dir + '/.')
        for src in srcs:
          shutil.move(src, dest)
        continue
      img = utils.get_central_crop(img, crop_height=crop_height,
                             crop_width=crop_width)
      img = np.expand_dims(img, 0)  # convert to NHWC format
      seg = sess.run("sem_seg/SemanticPredictions:0", feed_dict={
          new_input: img})
      assert np.max(seg[:]) <= 255, 'segmentation image is not of type uint8!'
      seg = np.squeeze(np.uint8(seg))  # convert to uint8 and squeeze to WxH.
      parent_dir, filename = osp.split(img_path)
      basename, ext = osp.splitext(filename)
      basename = basename[:-10]  # remove the '_reference' suffix
      seg_filename = basename + "_seg.png"
      seg_filepath = osp.join(save_dir, seg_filename)
      # Save segmentation image
      Image.fromarray(seg).save(seg_filepath)

def segment_and_color_dataset(dataset_dir, xception_frozen_graph_path,
                              splits=None, resegment_images=True):
  if splits is None:
    imgs_dirs = [dataset_dir]
  else:
    imgs_dirs = [osp.join(dataset_dir, split) for split in splits]
  
  for cur_dir in imgs_dirs:
    imgs_file_pattern = osp.join(cur_dir, '*_reference.png')
    images_path = sorted(glob.glob(imgs_file_pattern))
    if resegment_images:
      segment_images(images_path, xception_frozen_graph_path, cur_dir,
                     crop_height=512, crop_width=512)

  idx_to_col = get_semantic_color_coding()

  for cur_dir in imgs_dirs:
    save_dir = cur_dir
    seg_file_pattern = osp.join(cur_dir, '*_seg.png')
    seg_imgs_paths = sorted(glob.glob(seg_file_pattern))
    _apply_colors(seg_imgs_paths, save_dir, idx_to_col)


================================================
FILE: staged_model.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Neural re-rerendering in the wild.

Implementation of the staged training pipeline.
"""

from options import FLAGS as opts
import losses
import networks
import tensorflow as tf
import utils


def create_computation_graph(x_in, x_gt, x_app=None, arch_type='pggan',
                             use_appearance=True):
  """Create the models and the losses.

  Args:
    x_in: 4D tensor, batch of conditional input images in NHWC format.
    x_gt: 2D tensor, batch ground-truth images in NHWC format.
    x_app: 4D tensor, batch of input appearance images.

  Returns:
    Dictionary of placeholders and TF graph functions.
  """
  # ---------------------------------------------------------------------------
  # Build models/networks
  # ---------------------------------------------------------------------------

  rerenderer = networks.RenderingModel(arch_type, use_appearance)
  app_enc = rerenderer.get_appearance_encoder()
  discriminator = networks.MultiScaleDiscriminator(
      'd_model', opts.appearance_nc, num_scales=3, nf=64, n_layers=3,
      get_fmaps=False)

  # ---------------------------------------------------------------------------
  # Forward pass
  # ---------------------------------------------------------------------------

  if opts.use_appearance:
    z_app, _, _ = app_enc(x_app)
  else:
    z_app = None

  y = rerenderer(x_in, z_app)

  # ---------------------------------------------------------------------------
  # Losses
  # ---------------------------------------------------------------------------

  w_loss_gan = opts.w_loss_gan
  w_loss_recon = opts.w_loss_vgg if opts.use_vgg_loss else opts.w_loss_l1

  # compute discriminator logits
  disc_real_featmaps = discriminator(x_gt, x_in)
  disc_fake_featmaps = discriminator(y, x_in)

  # discriminator loss
  loss_d_real = losses.multiscale_discriminator_loss(disc_real_featmaps, True)
  loss_d_fake = losses.multiscale_discriminator_loss(disc_fake_featmaps, False)
  loss_d = loss_d_real + loss_d_fake

  # generator loss
  loss_g_gan = losses.multiscale_discriminator_loss(disc_fake_featmaps, True)
  if opts.use_vgg_loss:
    vgg_layers = ['conv%d_2' % i for i in range(1, 6)]  # conv1 through conv5
    vgg_layer_weights = [1./32, 1./16, 1./8, 1./4, 1.]
    vgg_loss = losses.PerceptualLoss(y, x_gt, [256, 256, 3], vgg_layers,
                                     vgg_layer_weights)  # NOTE: shouldn't hardcode image size!
    loss_g_recon = vgg_loss()
  else:
    loss_g_recon = losses.L1_loss(y, x_gt)
  loss_g = w_loss_gan * loss_g_gan + w_loss_recon * loss_g_recon

  # ---------------------------------------------------------------------------
  # Tensorboard visualizations
  # ---------------------------------------------------------------------------

  x_in_render = tf.slice(x_in, [0, 0, 0, 0], [-1, -1, -1, 3])
  if opts.use_semantic:
    x_in_semantic = tf.slice(x_in, [0, 0, 0, 4], [-1, -1, -1, 3])
    tb_visualization = tf.concat([x_in_render, x_in_semantic, y, x_gt], axis=2)
  else:
    tb_visualization = tf.concat([x_in_render, y, x_gt], axis=2)
  tf.summary.image('rendered-semantic-generated-gt tuple', tb_visualization)

  # Show input appearance images
  if opts.use_appearance:
    x_app_rgb = tf.slice(x_app, [0, 0, 0, 0], [-1, -1, -1, 3])
    x_app_sem = tf.slice(x_app, [0, 0, 0, 7], [-1, -1, -1, -1])
    tb_app_visualization = tf.concat([x_app_rgb, x_app_sem], axis=2)
    tf.summary.image('input appearance image', tb_app_visualization)

  # Loss summaries
  with tf.name_scope('Discriminator_Loss'):
    tf.summary.scalar('D_real_loss', loss_d_real)
    tf.summary.scalar('D_fake_loss', loss_d_fake)
    tf.summary.scalar('D_total_loss', loss_d)
  with tf.name_scope('Generator_Loss'):
    tf.summary.scalar('G_GAN_loss', w_loss_gan * loss_g_gan)
    tf.summary.scalar('G_reconstruction_loss', w_loss_recon * loss_g_recon)
    tf.summary.scalar('G_total_loss', loss_g)

  # ---------------------------------------------------------------------------
  # Optimizers
  # ---------------------------------------------------------------------------

  def get_optimizer(lr, loss, var_list):
    optimizer = tf.train.AdamOptimizer(lr, opts.adam_beta1, opts.adam_beta2)
    # optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)
    return optimizer.minimize(loss, var_list=var_list)

  # Training ops.
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    with tf.variable_scope('optimizers'):
      d_vars = utils.model_vars('d_model')[0]
      g_vars_all = utils.model_vars('g_model')[0]
      train_d = [get_optimizer(opts.d_lr, loss_d, d_vars)]
      train_g = [get_optimizer(opts.g_lr, loss_g, g_vars_all)]

      train_app_encoder = []
      if opts.train_app_encoder:
        lr_app = opts.ez_lr
        app_enc_vars = utils.model_vars('appearance_net')[0]
        train_app_encoder.append(get_optimizer(lr_app, loss_g, app_enc_vars))

  ema = tf.train.ExponentialMovingAverage(decay=0.999)
  with tf.control_dependencies(train_g + train_app_encoder):
    inference_vars_all = g_vars_all
    if opts.use_appearance:
      app_enc_vars = utils.model_vars('appearance_net')[0]
      inference_vars_all += app_enc_vars
    ema_op = ema.apply(inference_vars_all)

  print('***************************************************')
  print('len(g_vars_all) = %d' % len(g_vars_all))
  for ii, v in enumerate(g_vars_all):
    print('%03d) %s' % (ii, str(v)))
  print('-------------------------------------------------------')
  print('len(d_vars) = %d' % len(d_vars))
  for ii, v in enumerate(d_vars):
    print('%03d) %s' % (ii, str(v)))
  if opts.train_app_encoder:
    print('-------------------------------------------------------')
    print('len(app_enc_vars) = %d' % len(app_enc_vars))
    for ii, v in enumerate(app_enc_vars):
      print('%03d) %s' % (ii, str(v)))
  print('***************************************************\n\n')

  return {
      'train_disc_op': tf.group(train_d),
      'train_renderer_op': ema_op,
      'total_loss_d': loss_d,
      'loss_d_real': loss_d_real,
      'loss_d_fake': loss_d_fake,
      'loss_g_gan': w_loss_gan * loss_g_gan,
      'loss_g_recon': w_loss_recon * loss_g_recon,
      'total_loss_g': loss_g}


================================================
FILE: style_loss.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from PIL import Image
from options import FLAGS as opts
import data
import layers
import numpy as np
import tensorflow as tf
import utils
import vgg16


def gram_matrix(layer):
  """Computes the gram_matrix for a batch of single vgg layer
  Input:
    layer: a batch of vgg activations for a single conv layer
  Returns:
    gram: [batch_sz x num_channels x num_channels]: a batch of gram matrices
  """
  batch_size, height, width, num_channels = layer.get_shape().as_list()
  features = tf.reshape(layer, [batch_size, height * width, num_channels])
  num_elements = tf.constant(num_channels * height * width, tf.float32)
  gram = tf.matmul(features, features, adjoint_a=True) / num_elements
  return gram


def compute_gram_matrices(
    images, vgg_layers=['conv1_2', 'conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']):
  """Computes the gram matrix representation of a batch of images"""
  vgg_net = vgg16.Vgg16(opts.vgg16_path)
  vgg_acts = vgg_net.get_vgg_activations(images, vgg_layers)
  grams = [gram_matrix(layer) for layer in vgg_acts]
  return grams


def compute_pairwise_style_loss_v2(image_paths_list):
  grams_all = [None] * len(image_paths_list)
  crop_height, crop_width = opts.train_resolution, opts.train_resolution
  img_var = tf.placeholder(tf.float32, shape=[1, crop_height, crop_width, 3])
  vgg_layers = ['conv%d_2' % i for i in range(1, 6)]  # conv1 through conv5
  grams_ops = compute_gram_matrices(img_var, vgg_layers)
  with tf.Session() as sess:
    for ii, img_path in enumerate(image_paths_list):
      print('Computing gram matrices for image #%d' % (ii + 1))
      img = np.array(Image.open(img_path), dtype=np.float32)
      img = img * 2. / 255. - 1  # normalize image
      img = utils.get_central_crop(img, crop_height, crop_width)
      img = np.expand_dims(img, axis=0)
      grams_all[ii] = sess.run(grams_ops, feed_dict={img_var: img})
  print('Number of images = %d' % len(grams_all))
  print('Gram matrices per image:')
  for i in range(len(grams_all[0])):
    print('gram_matrix[%d].shape = %s' % (i, grams_all[0][i].shape))
  n_imgs = len(grams_all)
  dist_matrix = np.zeros((n_imgs, n_imgs))
  for i in range(n_imgs):
    print('Computing distances for image #%d' % i)
    for j in range(i + 1, n_imgs):
      loss_style = 0
      # Compute loss using all gram matrices from all layers
      for gram_i, gram_j in zip(grams_all[i], grams_all[j]):
        loss_style += np.mean((gram_i - gram_j) ** 2, axis=(1, 2))
      dist_matrix[i][j] = dist_matrix[j][i] = loss_style

  return dist_matrix


================================================
FILE: utils.py
================================================
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for GANs.

Basic functions such as generating sample grid, exporting to PNG, etc...
"""

import functools
import numpy as np
import os.path
import tensorflow as tf
import time


def crop_to_multiple(img, size_multiple=64):
  """ Crops the image so that its dimensions are multiples of size_multiple."""
  new_width = (img.shape[1] // size_multiple) * size_multiple
  new_height = (img.shape[0] // size_multiple) * size_multiple
  offset_x = (img.shape[1] - new_width) // 2
  offset_y = (img.shape[0] - new_height) // 2
  return img[offset_y:offset_y + new_height, offset_x:offset_x + new_width, :]


def get_central_crop(img, crop_height=512, crop_width=512):
  if len(img.shape) == 2:
    img = np.expand_dims(img, axis=2)
  assert len(img.shape) == 3, ('input image should be either a 2D or 3D matrix,'
                               ' but input was of shape %s' % str(img.shape))
  height, width, _ = img.shape
  assert height >= crop_height and width >= crop_width, ('input image cannot '
      'be smaller than the requested crop size')
  st_y = (height - crop_height) // 2
  st_x = (width - crop_width) // 2
  return np.squeeze(img[st_y : st_y + crop_height, st_x : st_x + crop_width, :])


def load_global_step_from_checkpoint_dir(checkpoint_dir):
  """Loads  the global step from the checkpoint directory.

  Args:
    checkpoint_dir: string, path to the checkpoint directory.

  Returns:
    int, the global step of the latest checkpoint or 0 if none was found.
  """
  try:
    checkpoint_reader = tf.train.NewCheckpointReader(
        tf.train.latest_checkpoint(checkpoint_dir))
    return checkpoint_reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
  except:
    return 0


def model_vars(prefix):
  """Return trainable variables matching a prefix.

  Args:
    prefix: string, the prefix variable names must match.

  Returns:
    a tuple (match, others) of TF variables, 'match' contains the matched
     variables and 'others' contains the remaining variables.
  """
  match, no_match = [], []
  for x in tf.trainable_variables():
    if x.name.startswith(prefix):
      match.append(x)
    else:
      no_match.append(x)
  return match, no_match


def to_png(x):
  """Convert a 3D tensor to png.

  Args:
    x: Tensor, 01C formatted input image.

  Returns:
    Tensor, 1D string representing the image in png format.
  """
  with tf.Graph().as_default():
    with tf.Session() as sess_temp:
      x = tf.constant(x)
      y = tf.image.encode_png(
          tf.cast(
              tf.clip_by_value(tf.round(127.5 + 127.5 * x), 0, 255), tf.uint8),
          compression=9)
      return sess_temp.run(y)


def images_to_grid(images):
  """Converts a grid of images (5D tensor) to a single image.

  Args:
    images: 5D tensor (count_y, count_x, height, width, colors), grid of images.

  Returns:
    a 3D tensor image of shape (count_y * height, count_x * width, colors).
  """
  ny, nx, h, w, c = images.shape
  images = images.transpose(0, 2, 1, 3, 4)
  images = images.reshape([ny * h, nx * w, c])
  return images


def save_images(image, output_dir, cur_nimg):
  """Saves images to disk.

  Saves a file called 'name.png' containing the latest samples from the
   generator and a file called 'name_123.png' where 123 is the KiB of trained
   images.

  Args:
    image: 3D numpy array (height, width, colors), the image to save.
    output_dir: string, the directory where to save the image.
    cur_nimg: int, current number of images seen by training.

  Returns:
    None
  """
  for name in ('name.png', 'name_%06d.png' % (cur_nimg >> 10)):
    with tf.gfile.Open(os.path.join(output_dir, name), 'wb') as f:
      f.write(image)


class HookReport(tf.train.SessionRunHook):
  """Custom reporting hook.

  Register your tensor scalars with HookReport.log_tensor(my_tensor, 'my_name').
  This hook will report their average values over report period argument
  provided to the constructed. The values are printed in the order the tensors
  were registered.

  Attributes:
    step: int, the current global step.
    active: bool, whether logging is active or disabled.
  """
  _REPORT_KEY = 'report'
  _TENSOR_NAMES = {}

  def __init__(self, period, batch_size):
    self.step = 0
    self.active = True
    self._period = period // batch_size
    self._batch_size = batch_size
    self._sums = np.array([])
    self._count = 0
    self._nimgs_per_cycle = 0
    self._step_ratio = 0
    self._start = time.time()
    self._nimgs = 0
    self._batch_size = batch_size

  def disable(self):
    parent = self

    class Disabler(object):

      def __enter__(self):
        parent.active = False
        return parent

      def __exit__(self, exc_type, exc_val, exc_tb):
        parent.active = True

    return Disabler()

  def begin(self):
    self.active = True
    self._count = 0
    self._nimgs_per_cycle = 0
    self._start = time.time()

  def before_run(self, run_context):
    if not self.active:
      return
    del run_context
    fetches = tf.get_collection(self._REPORT_KEY)
    return tf.train.SessionRunArgs(fetches)

  def after_run(self, run_context, run_values):
    if not self.active:
      return
    del run_context
    results = run_values.results
    # Note: sometimes the returned step is incorrect (off by one) for some
    # unknown reason.
    self.step = results[-1] + 1
    self._count += 1
    self._nimgs_per_cycle += self._batch_size
    self._nimgs += self._batch_size

    if not self._sums.size:
      self._sums = np.array(results[:-1], 'd')
    else:
      self._sums += np.array(results[:-1], 'd')

    if self.step // self._period != self._step_ratio:
      fetches = tf.get_collection(self._REPORT_KEY)[:-1]
      stats = '  '.join('%s=% .2f' % (self._TENSOR_NAMES[tensor],
                                      value / self._count)
                        for tensor, value in zip(fetches, self._sums))
      stop = time.time()
      tf.logging.info('step=%d, kimg=%d  %s  [%.2f img/s]' %
                      (self.step, ((self.step * self._batch_size) >> 10),
                       stats, self._nimgs_per_cycle / (stop - self._start)))
      self._step_ratio = self.step // self._period
      self._start = stop
      self._sums *= 0
      self._count = 0
      self._nimgs_per_cycle = 0

  def end(self, session=None):
    del session

  @classmethod
  def log_tensor(cls, tensor, name):
    """Adds a tensor to be reported by the hook.

    Args:
      tensor: `tensor scalar`, a value to report.
      name: string, the name to give the value in the report.

    Returns:
      None.
    """
    cls._TENSOR_NAMES[tensor] = name
    tf.add_to_collection(cls._REPORT_KEY, tensor)
Download .txt
gitextract_k0balaqs/

├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── data.py
├── dataset_utils.py
├── evaluate_quantitative_metrics.py
├── layers.py
├── losses.py
├── networks.py
├── neural_rerendering.py
├── options.py
├── pretrain_appearance.py
├── segment_dataset.py
├── staged_model.py
├── style_loss.py
└── utils.py
Download .txt
SYMBOL INDEX (130 symbols across 13 files)

FILE: data.py
  function provide_data (line 24) | def provide_data(dataset_name='', parent_dir='', batch_size=8, subset=None,
  function _parser_rendered_dataset (line 40) | def _parser_rendered_dataset(
  function multi_input_fn_record (line 130) | def multi_input_fn_record(

FILE: dataset_utils.py
  class AlignedRenderedDataset (line 38) | class AlignedRenderedDataset(object):
    method __init__ (line 39) | def __init__(self, rendered_filepattern, use_semantic_map=True):
    method __iter__ (line 52) | def __iter__(self):
    method __next__ (line 55) | def __next__(self):
    method next (line 58) | def next(self):
  function filter_out_sparse_renders (line 114) | def filter_out_sparse_renders(dataset_dir, splits, ratio_threshold=0.15):
  function _to_example (line 153) | def _to_example(dictionary):
  function _generate_tfrecord_dataset (line 172) | def _generate_tfrecord_dataset(generator,
  function export_aligned_dataset_to_tfrecord (line 193) | def export_aligned_dataset_to_tfrecord(
  function main (line 218) | def main(argv):

FILE: evaluate_quantitative_metrics.py
  function _extract_real_and_fake_from_concatenated_output (line 34) | def _extract_real_and_fake_from_concatenated_output(val_set_out_dir):
  function compute_l1_loss_metric (line 56) | def compute_l1_loss_metric(image_set1_paths, image_set2_paths):
  function compute_psnr_loss_metric (line 76) | def compute_psnr_loss_metric(image_set1_paths, image_set2_paths):
  function evaluate_experiment (line 96) | def evaluate_experiment(val_set_out_dir, title='experiment',
  function main (line 114) | def main(argv):

FILE: layers.py
  class LayerInstanceNorm (line 20) | class LayerInstanceNorm(object):
    method __init__ (line 22) | def __init__(self, scope_suffix='instance_norm'):
    method __call__ (line 26) | def __call__(self, x):
  function layer_norm (line 32) | def layer_norm(x, scope='layer_norm'):
  function pixel_norm (line 36) | def pixel_norm(x):
  function global_avg_pooling (line 48) | def global_avg_pooling(x):
  class FullyConnected (line 52) | class FullyConnected(object):
    method __init__ (line 54) | def __init__(self, n_out_units, scope_suffix='FC'):
    method __call__ (line 64) | def __call__(self, x):
  function init_he_scale (line 69) | def init_he_scale(shape, slope=1.0):
  class LayerConv (line 83) | class LayerConv(object):
    method __init__ (line 86) | def __init__(self,
    method __call__ (line 129) | def __call__(self, x):
  class LayerTransposedConv (line 151) | class LayerTransposedConv(object):
    method __init__ (line 154) | def __init__(self,
    method __call__ (line 197) | def __call__(self, x):
  class ResBlock (line 210) | class ResBlock(object):
    method __init__ (line 211) | def __init__(self,
    method __call__ (line 242) | def __call__(self, x_init):
  class BasicBlock (line 250) | class BasicBlock(object):
    method __init__ (line 251) | def __init__(self,
    method __call__ (line 287) | def __call__(self, x_init):
  class LayerDense (line 298) | class LayerDense(object):
    method __init__ (line 301) | def __init__(self, name, n, use_scaling=False, relu_slope=1.):
    method __call__ (line 322) | def __call__(self, x):
  class LayerPipe (line 327) | class LayerPipe(object):
    method __init__ (line 330) | def __init__(self, functions):
    method __call__ (line 338) | def __call__(self, x, **kwargs):
  function downscale (line 346) | def downscale(x, n=2):
  function upscale (line 361) | def upscale(x, n):
  function tile_and_concatenate (line 378) | def tile_and_concatenate(x, z, n_z):
  function minibatch_mean_variance (line 385) | def minibatch_mean_variance(x):
  function scalar_concat (line 402) | def scalar_concat(x, scalar):

FILE: losses.py
  function gradient_penalty_loss (line 22) | def gradient_penalty_loss(y_xy, xy, iwass_target=1, iwass_lambda=10):
  function KL_loss (line 30) | def KL_loss(mean, logvar):
  function l2_regularize (line 36) | def l2_regularize(x):
  function L1_loss (line 40) | def L1_loss(x, y):
  class PerceptualLoss (line 44) | class PerceptualLoss:
    method __init__ (line 45) | def __init__(self, x, y, image_shape, layers, w_layers, w_act=0.1):
    method __call__ (line 65) | def __call__(self):
  function lsgan_appearance_E_loss (line 69) | def lsgan_appearance_E_loss(disc_response):
  function lsgan_loss (line 76) | def lsgan_loss(disc_response, is_real):
  function multiscale_discriminator_loss (line 84) | def multiscale_discriminator_loss(Ds_responses, is_real):

FILE: networks.py
  class RenderingModel (line 21) | class RenderingModel(object):
    method __init__ (line 23) | def __init__(self, model_name, use_appearance=True):
    method __call__ (line 30) | def __call__(self, x_in, z_app=None):
    method get_appearance_encoder (line 33) | def get_appearance_encoder(self):
    method get_generator (line 36) | def get_generator(self):
    method get_content_encoder (line 39) | def get_content_encoder(self):
  class ModelPGGAN (line 49) | class ModelPGGAN(RenderingModel):
    method __init__ (line 51) | def __init__(self, use_appearance=True):
    method __call__ (line 61) | def __call__(self, x_in, z_app=None):
    method get_appearance_encoder (line 65) | def get_appearance_encoder(self):
    method get_generator (line 68) | def get_generator(self):
    method get_content_encoder (line 71) | def get_content_encoder(self):
  class PatchGANDiscriminator (line 75) | class PatchGANDiscriminator(object):
    method __init__ (line 77) | def __init__(self, name_scope, input_nc, nf=64, n_layers=3, get_fmaps=...
    method __call__ (line 133) | def __call__(self, x, x_cond=None):
  class MultiScaleDiscriminator (line 158) | class MultiScaleDiscriminator(object):
    method __init__ (line 160) | def __init__(self, name_scope, input_nc, num_scales=3, nf=64, n_layers=3,
    method __call__ (line 171) | def __call__(self, x, x_cond=None, params=None):
  class GeneratorPGGAN (line 184) | class GeneratorPGGAN(object):
    method __init__ (line 185) | def __init__(self, appearance_vec_size=8, use_scaling=True,
    method __call__ (line 287) | def __call__(self, x, appearance_embedding=None, encoder_fmaps=None):
  class DRITAppearanceEncoderConcat (line 328) | class DRITAppearanceEncoderConcat(object):
    method __init__ (line 330) | def __init__(self, name_scope, input_nc, normalize_encoder):
    method __call__ (line 361) | def __call__(self, x):

FILE: neural_rerendering.py
  function build_model_fn (line 35) | def build_model_fn(use_exponential_moving_average=True):
  function make_sample_grid_and_save (line 147) | def make_sample_grid_and_save(est, dataset_name, dataset_parent_dir, gri...
  function visualize_image_sequence (line 187) | def visualize_image_sequence(est, dataset_name, dataset_parent_dir,
  function train (line 222) | def train(dataset_name, dataset_parent_dir, load_pretrained_app_encoder,
  function _build_inference_estimator (line 300) | def _build_inference_estimator(model_dir):
  function evaluate_sequence (line 306) | def evaluate_sequence(dataset_name, dataset_parent_dir, virtual_seq_name,
  function evaluate_image_set (line 315) | def evaluate_image_set(dataset_name, dataset_parent_dir, subset_suffix,
  function _load_and_concatenate_image_channels (line 341) | def _load_and_concatenate_image_channels(rgb_path=None, rendered_path=None,
  function infer_dir (line 378) | def infer_dir(model_dir, input_dir, output_dir):
  function joint_interpolation (line 414) | def joint_interpolation(model_dir, app_input_dir, st_app_basename,
  function interpolate_appearance (line 491) | def interpolate_appearance(model_dir, input_dir, target_img_basename,
  function main (line 563) | def main(argv):

FILE: options.py
  function validate_options (line 175) | def validate_options():
  function list_options (line 191) | def list_options():

FILE: pretrain_appearance.py
  function _load_and_concatenate_image_channels (line 30) | def _load_and_concatenate_image_channels(
  function read_single_appearance_input (line 62) | def read_single_appearance_input(rgb_img_path):
  function get_triplet_input_fn (line 73) | def get_triplet_input_fn(dataset_path, dist_file_path=None, k_max_neares...
  function get_tf_triplet_dataset_iter (line 116) | def get_tf_triplet_dataset_iter(
  function build_model_fn (line 146) | def build_model_fn(batch_size, lr_app_pretrain=0.0001, adam_beta1=0.0,
  function compute_dist_matrix (line 214) | def compute_dist_matrix(imageset_dir, dist_file_path, recompute_dist=Fal...
  function train_appearance (line 231) | def train_appearance(train_dir, imageset_dir, dist_file_path):
  function main (line 257) | def main(argv):

FILE: segment_dataset.py
  function get_semantic_color_coding (line 33) | def get_semantic_color_coding():
  function _apply_colors (line 111) | def _apply_colors(seg_images_path, save_dir, idx_to_color):
  function segment_images (line 132) | def segment_images(images_path, xception_frozen_graph_path, save_dir,
  function segment_and_color_dataset (line 184) | def segment_and_color_dataset(dataset_dir, xception_frozen_graph_path,

FILE: staged_model.py
  function create_computation_graph (line 27) | def create_computation_graph(x_in, x_gt, x_app=None, arch_type='pggan',

FILE: style_loss.py
  function gram_matrix (line 25) | def gram_matrix(layer):
  function compute_gram_matrices (line 39) | def compute_gram_matrices(
  function compute_pairwise_style_loss_v2 (line 48) | def compute_pairwise_style_loss_v2(image_paths_list):

FILE: utils.py
  function crop_to_multiple (line 27) | def crop_to_multiple(img, size_multiple=64):
  function get_central_crop (line 36) | def get_central_crop(img, crop_height=512, crop_width=512):
  function load_global_step_from_checkpoint_dir (line 49) | def load_global_step_from_checkpoint_dir(checkpoint_dir):
  function model_vars (line 66) | def model_vars(prefix):
  function to_png (line 85) | def to_png(x):
  function images_to_grid (line 104) | def images_to_grid(images):
  function save_images (line 119) | def save_images(image, output_dir, cur_nimg):
  class HookReport (line 139) | class HookReport(tf.train.SessionRunHook):
    method __init__ (line 154) | def __init__(self, period, batch_size):
    method disable (line 167) | def disable(self):
    method begin (line 181) | def begin(self):
    method before_run (line 187) | def before_run(self, run_context):
    method after_run (line 194) | def after_run(self, run_context, run_values):
    method end (line 226) | def end(self, session=None):
    method log_tensor (line 230) | def log_tensor(cls, tensor, name):
Condensed preview — 16 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (153K chars).
[
  {
    "path": "CONTRIBUTING.md",
    "chars": 1101,
    "preview": "# How to Contribute\n\nWe'd love to accept your patches and contributions to this project. There are\njust a few small guid"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "\n                                 Apache License\n                           Version 2.0, January 2004\n                  "
  },
  {
    "path": "README.md",
    "chars": 8576,
    "preview": "# Neural Rerendering in the Wild\nMoustafa Meshry<sup>1</sup>,\n[Dan B Goldman](http://www.danbgoldman.com/)<sup>2</sup>,\n"
  },
  {
    "path": "data.py",
    "chars": 7169,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "dataset_utils.py",
    "chars": 8980,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "evaluate_quantitative_metrics.py",
    "chars": 4364,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "layers.py",
    "chars": 12473,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "losses.py",
    "chars": 2832,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "networks.py",
    "chars": 13518,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "neural_rerendering.py",
    "chars": 26943,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "options.py",
    "chars": 13725,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "pretrain_appearance.py",
    "chars": 11791,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "segment_dataset.py",
    "chars": 7662,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "staged_model.py",
    "chars": 6809,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "style_loss.py",
    "chars": 3110,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "utils.py",
    "chars": 7238,
    "preview": "# Copyright 2019 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  }
]

About this extraction

This page contains the full source code of the google/neural_rerendering_in_the_wild GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 16 files (144.2 KB), approximately 36.6k tokens, and a symbol index with 130 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!