Full Code of google/retiming for AI

main 0e7dbce941b8 cached
40 files
149.7 KB
37.2k tokens
158 symbols
1 requests
Download .txt
Repository: google/retiming
Branch: main
Commit: 0e7dbce941b8
Files: 40
Total size: 149.7 KB

Directory structure:
gitextract_wurthk5z/

├── LICENSE
├── README.md
├── data/
│   ├── __init__.py
│   ├── kpuv_dataset.py
│   └── layered_video_dataset.py
├── datasets/
│   ├── download_data.sh
│   ├── iuv_crop2full.py
│   └── prepare_iuv.sh
├── docs/
│   ├── contributing.md
│   └── data.md
├── environment.yml
├── models/
│   ├── __init__.py
│   ├── kp2uv_model.py
│   ├── lnr_model.py
│   └── networks.py
├── options/
│   ├── __init__.py
│   ├── base_options.py
│   ├── test_options.py
│   └── train_options.py
├── requirements.txt
├── run_kp2uv.py
├── scripts/
│   ├── download_kp2uv_model.sh
│   ├── run_cartwheel.sh
│   ├── run_reflection.sh
│   ├── run_splash.sh
│   └── run_trampoline.sh
├── test.py
├── third_party/
│   ├── __init__.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── base_dataset.py
│   │   ├── fast_data_loader.py
│   │   └── image_folder.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── base_model.py
│   │   └── networks.py
│   └── util/
│       ├── __init__.py
│       ├── html.py
│       ├── util.py
│       └── visualizer.py
└── train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================

                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# Layered Neural Rendering in PyTorch

This repository contains training code for the examples in the SIGGRAPH Asia 2020 paper "[Layered Neural Rendering for Retiming People in Video](https://retiming.github.io/)."

<img src='./img/teaser.gif' height="160px"/>

This is not an officially supported Google product.


## Prerequisites
- Linux
- Python 3.6+
- NVIDIA GPU + CUDA CuDNN

## Installation
This code has been tested with PyTorch 1.4 and Python 3.8.

- Install [PyTorch](http://pytorch.org) 1.4 and other dependencies.
  - For pip users, please type the command `pip install -r requirements.txt`.
  - For Conda users, you can create a new Conda environment using `conda env create -f environment.yml`.

## Data Processing
- Download the data for a video used in our paper (e.g. "reflection"):
```bash
bash ./datasets/download_data.sh reflection
```
- Or alternatively, download all the data by specifying `all`.
- Download the pretrained keypoint-to-UV model weights:
```bash
bash ./scripts/download_kp2uv_model.sh
``` 
The pretrained model will be saved at `./checkpoints/kp2uv/latest_net_Kp2uv.pth`.
- Generate the UV maps from the keypoints:
```bash
bash datasets/prepare_iuv.sh ./datasets/reflection
```
## Training
- To train a model on a video (e.g. "reflection"), run:
```bash
python train.py --name reflection --dataroot ./datasets/reflection --gpu_ids 0,1
```
- To view training results and loss plots, visit the URL http://localhost:8097.
Intermediate results are also at `./checkpoints/reflection/web/index.html`.

You can find more scripts in the `scripts` directory, e.g. `run_${VIDEO}.sh` which combines data processing, training, and saving layer results for a video. 

**Note**:
- It is recommended to use >=2 GPUs, each with >=16GB memory.
- The training script first trains the low-resolution model for `--num_epochs` at `--batch_size`, and then trains the upsampling module for `--num_epochs_upsample` at `--batch_size_upsample`.
If you do not need the upsampled result, pass `--num_epochs_upsample 0`.
- Training the upsampling module requires ~2.5x memory as the low-resolution model, so set `batch_size_upsample` accordingly.
The provided scripts set the batch sizes appropriately for 2 GPUs with 16GB memory.
- GPU memory scales linearly with the number of layers.

## Saving layer results from a trained model
- Run the trained model:
```bash
python test.py --name reflection --dataroot ./datasets/reflection --do_upsampling
```
- The results (RGBA layers, videos) will be saved to `./results/reflection/test_latest/`.
- Passing `--do_upsampling` uses the results of the upsampling module. If the upsampling module hasn't been trained (`num_epochs_upsample=0`), then remove this flag.

## Custom video
To train on your own video, you will have to preprocess the data:
1. Extract the frames, e.g.
    ```
    mkdir ./datasets/my_video && cd ./datasets/my_video 
    mkdir rgb && ffmpeg -i video.mp4 rgb/%04d.png
    ```
1. Resize the video to 256x448 and save the frames in `my_video/rgb_256`, and resize the video to 512x896 and save in `my_video/rgb_512`.
1. Run [AlphaPose and Pose Tracking](https://github.com/MVIG-SJTU/AlphaPose) on the frames. Save results as `my_video/keypoints.json`
1. Create `my_video/metadata.json` following [these instructions](docs/data.md).
1. If your video has camera motion, either (1) stabilize the video, or (2) maintain the camera motion by computing homographies and saving as `my_video/homographies.txt`.
See `scripts/run_cartwheel.sh` for a training example with camera motion, and see `./datasets/cartwheel/homographies.txt` for formatting.

**Note**: Videos that are suitable for our method have the following attributes:
- Static camera or limited camera motion that can be represented with a homography.
- Limited number of people, due to GPU memory limitations. We tested up to 7 people and 7 layers.
Multiple people can be grouped onto the same layer, though they cannot be individually retimed.
- People that move relative to the background (static people will be absorbed into the background layer).
- We tested a video length of up to 200 frames (~7 seconds).

## Citation
If you use this code for your research, please cite the following paper:
```
@inproceedings{lu2020,
  title={Layered Neural Rendering for Retiming People in Video},
  author={Lu, Erika and Cole, Forrester and Dekel, Tali and Xie, Weidi and Zisserman, Andrew and Salesin, David and Freeman, William T and Rubinstein, Michael},
  booktitle={SIGGRAPH Asia},
  year={2020}
}
```

## Acknowledgments
This code is based on [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).


================================================
FILE: data/__init__.py
================================================


================================================
FILE: data/kpuv_dataset.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from third_party.data.base_dataset import BaseDataset
from PIL import Image, ImageDraw
import json
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import os
import torchvision.transforms as transforms


class KpuvDataset(BaseDataset):
    """A dataset class for keypoint data.

    It assumes that the directory specified by 'dataroot' contains the file 'keypoints.json'.
    """
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.add_argument('--inp_size', type=int, default=256, help='image size')
        return parser

    def __init__(self, opt):
        """Initialize this dataset class by reading in keypoints.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseDataset.__init__(self, opt)

        self.inp_size = opt.inp_size
        inner_crop_size = int(.75*self.inp_size)
        kps = []
        image_paths = []
        with open(os.path.join(self.root, 'keypoints.json'), 'rb') as f:
            kp_data = json.load(f)
        for frame in sorted(kp_data):
            for skeleton in kp_data[frame]:
                id = skeleton['idx']
                image_paths.append(f'{id:02d}_{frame}')
                kp = np.array(skeleton['keypoints']).reshape(17, 3)
                kp = self.crop_kps(kp, crop_size=self.inp_size, inner_crop_size=inner_crop_size)
                kps.append(kp)

        self.keypoints = kps
        self.image_paths = image_paths  # filenames for output UVs

        # for keypoint rendering
        self.cmap = plt.cm.get_cmap("hsv", 17)
        self.color_seq = np.array([ 9, 14,  6,  7, 13, 16,  2, 11,  3,  5, 10, 15,  1,  8,  0, 12,  4])
        self.pairs = [[0,1],[0,2],[1,3],[2,4],[5,6],[5,7],[7,9],[6,8],[8,10],[11,12],[11,13],[13,15],[12,14],[14,16],[6,12],[5,11]]


    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains keypoints and path
            keyoints (tensor) - - an RGB image representing a skeleton
            path (str) - - an identifying filename that can be used for saving the result
        """
        uv_path = self.image_paths[index]  # output UV path
        kps = self.keypoints[index]

        # draw keypoints
        kp_im = Image.new(size=(self.inp_size, self.inp_size), mode='RGB')
        draw = ImageDraw.Draw(kp_im)
        self.render_kps(kps, draw)
        kp_im = transforms.ToTensor()(kp_im)
        kp_im = 2 * kp_im - 1

        return {'keypoints': kp_im, 'path': uv_path}

    def __len__(self):
        """Return the total number of images."""
        return len(self.image_paths)

    def crop_kps(self, kps, crop_size=256, inner_crop_size=192):
        """'Crops' keypoints to fit into ['crop_size', 'crop_size'].

        Parameters:
            kps - - a numpy array of shape [17, 2], where the keypoint order is X, Y
            crop_size - - the new size of the world, which the keypoints will be centered inside
            inner_crop_size - - the box size that the keypoints will fit inside (must be <=crop_size)

        Returns keypoints mapped to fit inside a box of 'inner_crop_size', centered in 'crop_size'
        """
        # get coordinates of bounding box, in original image coordinates
        left = kps[:, 0].min()
        right = kps[:, 0].max()
        top = kps[:, 1].min()
        bottom = kps[:, 1].max()

        # map keypoints
        keypoints = kps.copy()
        center = ((right + left) // 2, (bottom + top) // 2)
        # first place center of bounding box at origin
        keypoints[:, 0] -= center[0]
        keypoints[:, 1] -= center[1]
        # scale bounding box to inner_crop_size
        scale = float(inner_crop_size) / max(right - left, bottom - top)
        keypoints[:, :2] *= scale
        # move center to crop_size//2
        keypoints[:, :2] += crop_size // 2
        new_kps = keypoints

        return new_kps

    def render_kps(self, keypoints, draw, thresh=1., min_weight=0.25):
        """Render skeleton as RGB image.

        Parameters:
            keypoints - - a numpy array of shape [17, 3], where the keypoint order is X, Y, score
            draw - - an ImageDraw object, which the keypoints will be drawn onto
            thresh - - keypoints with a confidence score below this value will have a color weighted by the score
            min_weight - - minimum weighting for color (scores will be mapped to the range [min_weight, 1])
        """
        # first draw keypoints
        ksize = 3
        for i in range(keypoints.shape[0]):
            x1 = keypoints[i,0] - ksize
            x2 = keypoints[i,0] + ksize
            y1 = keypoints[i,1] - ksize
            y2 = keypoints[i,1] + ksize
            if x1 < 0 or y1 < 0 or x2 > self.inp_size or y2 > self.inp_size:
                continue
            color = np.array(self.cmap(self.color_seq[i]))
            if keypoints.shape[1] > 2:
                score = keypoints[i,2]
                if score < thresh:  # weight color by confidence score
                    # first map [0,1] -> [min_weight, 1]
                    alpha_weight = score * (1.-min_weight) + min_weight
                    color[:3] *= alpha_weight
            color = (255*color).astype('uint8')
            draw.rectangle([x1, y1, x2, y2], fill=tuple(color))
        # now draw segments
        for pair in self.pairs:
            x1 = keypoints[pair[0],0]
            y1 = keypoints[pair[0],1]
            x2 = keypoints[pair[1],0]
            y2 = keypoints[pair[1],1]
            if x1 < 0 or y1 < 0 or x2 < 0 or y2 < 0 or x1 > self.inp_size or y1 > self.inp_size or x2 > self.inp_size or y2 > self.inp_size:
                continue
            avg_color = .5*(np.array(self.cmap(self.color_seq[pair[0]])) + np.array(self.cmap(self.color_seq[pair[1]])))
            if keypoints.shape[1] > 2:
                score = min(keypoints[pair[0],2], keypoints[pair[1],2])
                if score < thresh:
                    alpha_weight = score * (1.-min_weight) + min_weight
                    avg_color[:3] *= alpha_weight # alpha channel weigh by score
            avg_color = (255*avg_color).astype('uint8')
            draw.line([x1,y1,x2,y2], fill=tuple(avg_color), width=3)

================================================
FILE: data/layered_video_dataset.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cv2
from third_party.data.base_dataset import BaseDataset
from third_party.data.image_folder import make_dataset
from PIL import Image
import torchvision.transforms as transforms
import torch.nn.functional as F
import os
import torch
import numpy as np
import json


class LayeredVideoDataset(BaseDataset):
    """A dataset class for video layers.

    It assumes that the directory specified by 'dataroot' contains metadata.json, and the directories iuv, rgb_256, and rgb_512.
    The 'iuv' directory should contain directories named 01, 02, etc. for each layer, each containing per-frame UV images.
    """
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.add_argument('--height', type=int, default=256, help='image height')
        parser.add_argument('--width', type=int, default=448, help='image width')
        parser.add_argument('--trimap_width', type=int, default=20, help='trimap gray area width')
        parser.add_argument('--use_mask_images', action='store_true', default=False, help='use custom masks')
        parser.add_argument('--use_homographies', action='store_true', default=False, help='handle camera motion')
        return parser

    def __init__(self, opt):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseDataset.__init__(self, opt)
        rgbdir = os.path.join(opt.dataroot, 'rgb_256')
        if opt.do_upsampling:
            rgbdir = os.path.join(opt.dataroot, 'rgb_512')
        uvdir = os.path.join(opt.dataroot, 'iuv')
        self.image_paths = sorted(make_dataset(rgbdir, opt.max_dataset_size))
        n_images = len(self.image_paths)
        layers = sorted(os.listdir(uvdir))
        layers = [l for l in layers if l.isdigit()]
        self.iuv_paths = []
        for l in layers:
            layer_iuv_paths = sorted(make_dataset(os.path.join(uvdir, l), n_images))
            if len(layer_iuv_paths) != n_images:
                print(f'UNEQUAL NUMBER OF IMAGES AND IUVs: {len(layer_iuv_paths)} and {n_images}')
            self.iuv_paths.append(layer_iuv_paths)

        # set up per-frame compositing order
        with open(os.path.join(opt.dataroot, 'metadata.json')) as f:
            metadata = json.load(f)
        if 'composite_order' in metadata:
            self.composite_order = metadata['composite_order']
        else:
            self.composite_order = [tuple(range(1, 1 + len(layers)))] * n_images

        if opt.use_homographies:
            self.init_homographies(os.path.join(opt.dataroot, 'homographies.txt'), n_images)

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains:
            image (tensor) - - the original RGB frame to reconstruct
            uv_map (tensor) - - the UV maps for all layers, concatenated channel-wise
            mask (tensor) - - the trimaps for all layers, concatenated channel-wise
            pids (tensor) - - the person IDs for all layers, concatenated channel-wise
            image_path (str) - - image path
        """
        # Read the target image.
        image_path = self.image_paths[index]
        target_image = self.load_and_process_image(image_path)

        # Read the layer IUVs and convert to network inputs.
        people_layers = [self.load_and_process_iuv(self.iuv_paths[l - 1][index], index) for l in
                         self.composite_order[index]]
        iuv_h, iuv_w = people_layers[0][0].shape[-2:]

        # Create the background layer UV from homographies.
        background_layer = self.get_background_inputs(index, iuv_w, iuv_h)

        uv_maps, masks, pids = zip(*([background_layer] + people_layers))
        uv_maps = torch.cat(uv_maps)  # [L*2, H, W]
        masks = torch.stack(masks)  # [L, H, W]
        pids = torch.stack(pids)  # [L, H, W]

        if self.opt.use_mask_images:
            for i in range(1, len(people_layers)):
                mask_path = os.path.join(self.opt.dataroot, 'mask', f'{i:02d}', os.path.basename(image_path))
                if os.path.exists(mask_path):
                    mask = Image.open(mask_path).convert('L').resize((masks.shape[-1], masks.shape[-2]))
                    mask = transforms.ToTensor()(mask) * 2 - 1
                    masks[i] = mask

        transform_params = self.get_params(do_jitter=self.opt.phase=='train')
        pids = self.apply_transform(pids, transform_params, 'nearest')
        masks = self.apply_transform(masks, transform_params, 'bilinear')
        uv_maps = self.apply_transform(uv_maps, transform_params, 'nearest')
        image_transform_params = transform_params
        if self.opt.do_upsampling:
            image_transform_params = { p: transform_params[p] * 2 for p in transform_params}
        target_image = self.apply_transform(target_image, image_transform_params, 'bilinear')

        return {'image': target_image, 'uv_map': uv_maps, 'mask': masks, 'pids': pids, 'image_path': image_path}

    def __len__(self):
        """Return the total number of images."""
        return len(self.image_paths)

    def get_params(self, do_jitter=False, jitter_rate=0.75):
        """Get transformation parameters."""
        if do_jitter:
            if np.random.uniform() > jitter_rate or self.opt.do_upsampling:
                scale = 1.
            else:
                scale = np.random.uniform(1, 1.25)
            jitter_size = (scale * np.array([self.opt.height, self.opt.width])).astype(np.int)
            start1 = np.random.randint(jitter_size[0] - self.opt.height + 1)
            start2 = np.random.randint(jitter_size[1] - self.opt.width + 1)
        else:
            jitter_size = np.array([self.opt.height, self.opt.width])
            start1 = 0
            start2 = 0
        crop_pos = np.array([start1, start2])
        crop_size = np.array([self.opt.height, self.opt.width])
        return {'jitter size': jitter_size, 'crop pos': crop_pos, 'crop size': crop_size}

    def apply_transform(self, data, params, interp_mode='bilinear'):
        """Apply the transform to the data tensor."""
        tensor_size = params['jitter size'].tolist()
        crop_pos = params['crop pos']
        crop_size = params['crop size']
        data = F.interpolate(data.unsqueeze(0), size=tensor_size, mode=interp_mode).squeeze(0)
        data = data[:, crop_pos[0]:crop_pos[0] + crop_size[0], crop_pos[1]:crop_pos[1] + crop_size[1]]
        return data

    def init_homographies(self, homography_path, n_images):
        """Read homography file and set up homography data."""
        with open(homography_path) as f:
            h_data = f.readlines()
        h_scale = h_data[0].rstrip().split(' ')
        self.h_scale_x = int(h_scale[1])
        self.h_scale_y = int(h_scale[2])
        h_bounds = h_data[1].rstrip().split(' ')
        self.h_bounds_x = [float(h_bounds[1]), float(h_bounds[2])]
        self.h_bounds_y = [float(h_bounds[3]), float(h_bounds[4])]
        homographies = h_data[2:2 + n_images]
        homographies = [torch.from_numpy(np.array(line.rstrip().split(' ')).astype(np.float32).reshape(3, 3)) for line
                        in
                        homographies]
        self.homographies = homographies

    def load_and_process_image(self, im_path):
        """Read image file and return as tensor in range [-1, 1]."""
        image = Image.open(im_path).convert('RGB')
        image = transforms.ToTensor()(image)
        image = 2 * image - 1
        return image

    def load_and_process_iuv(self, iuv_path, i):
        """Read IUV file and convert to network inputs."""
        iuv_map = Image.open(iuv_path).convert('RGBA')
        iuv_map = transforms.ToTensor()(iuv_map)
        uv_map, mask, pids = self.iuv2input(iuv_map, i)
        return uv_map, mask, pids

    def iuv2input(self, iuv, index):
        """Create network inputs from IUV.
        Parameters:
            iuv - - a tensor of shape [4, H, W], where the channels are: body part ID, U, V, person ID.
            index - - index of iuv

        Returns:
            uv (tensor) - - a UV map for a single layer, ready to pass to grid sampler (values in range [-1,1])
            mask (tensor) - - the corresponding mask
            person_id (tensor) - - the person IDs

        grid sampler indexes into texture map of size tile_width x tile_width*n_textures
        """
        # Extract body part and person IDs.
        part_id = (iuv[0] * 255 / 10).round()
        part_id[part_id > 24] = 24
        part_id_mask = (part_id > 0).float()
        person_id = (255 - 255 * iuv[-1]).round()  # person ID is saved as 255 - person_id
        person_id *= part_id_mask  # background id is 0
        maxId = self.opt.n_textures // 24
        person_id[person_id>maxId] = maxId

        # Convert body part ID to texture map ID.
        # Essentially, each of the 24 body parts for each person, plus the background have their own texture 'tile'
        # The tiles are concatenated horizontally to create the texture map.
        tex_id = part_id + part_id_mask * 24 * (person_id - 1)

        uv = iuv[1:3]
        # Convert the per-body-part UVs to UVs that correspond to the full texture map.
        uv[0] += tex_id

        # Get the mask.
        bg_mask = (tex_id == 0).float()
        mask = 1.0 - bg_mask
        mask = mask * 2 - 1  # make 1 the foreground and -1 the background mask
        mask = self.mask2trimap(mask)

        # Composite background UV behind person UV.
        h, w = iuv.shape[1:]
        bg_uv = self.get_background_uv(index, w, h)
        uv = bg_mask * bg_uv + (1 - bg_mask) * uv

        # Map to [-1, 1] range.
        uv[0] /= self.opt.n_textures
        uv = uv * 2 - 1
        uv = torch.clamp(uv, -1, 1)

        return uv, mask, person_id

    def get_background_inputs(self, index, w, h):
        """Return data for background layer at 'index'."""
        uv = self.get_background_uv(index, w, h)
        # normalize to correct range, of full texture atlas
        uv[0] /= self.opt.n_textures
        uv = uv * 2 - 1  # [0,1] -> [-1,1]
        uv = torch.clamp(uv, -1, 1)

        mask = -torch.ones(*uv.shape[1:])
        pids = torch.zeros(*uv.shape[1:])
        return uv, mask, pids

    def get_background_uv(self, index, w, h):
        """Return background layer UVs at 'index' (output range [0, 1])."""
        ramp_u = torch.linspace(0, 1, steps=w).unsqueeze(0).repeat(h, 1)
        ramp_v = torch.linspace(0, 1, steps=h).unsqueeze(-1).repeat(1, w)
        ramp = torch.stack([ramp_u, ramp_v], 0)
        if hasattr(self, 'homographies'):
            # scale to [0, orig width/height]
            ramp[0] *= self.h_scale_x
            ramp[1] *= self.h_scale_y
            # apply homography
            ramp = ramp.reshape(2, -1)  # [2, H, W]
            H = self.homographies[index]
            [xt, yt] = self.transform2h(ramp[0], ramp[1], torch.inverse(H))
            # scale from world to [0,1]
            xt -= self.h_bounds_x[0]
            xt /= (self.h_bounds_x[1] - self.h_bounds_x[0])
            yt -= self.h_bounds_y[0]
            yt /= (self.h_bounds_y[1] - self.h_bounds_y[0])
            # restore shape
            ramp = torch.stack([xt.reshape(h, w), yt.reshape(h, w)], 0)
        return ramp

    def transform2h(self, x, y, m):
        """Applies 2d homogeneous transformation."""
        A = torch.matmul(m, torch.stack([x, y, torch.ones(len(x))]))
        xt = A[0, :] / A[2, :]
        yt = A[1, :] / A[2, :]
        return xt, yt

    def mask2trimap(self, mask):
        """Convert binary mask to trimap with values in [-1, 0, 1]."""
        fg_mask = (mask > 0).float()
        bg_mask = (mask < 0).float()
        trimap_width = getattr(self.opt, 'trimap_width', 20)
        trimap_width *= bg_mask.shape[-1] / self.opt.width
        trimap_width = int(trimap_width)
        bg_mask = cv2.erode(bg_mask.numpy(), kernel=np.ones((trimap_width, trimap_width)), iterations=1)
        bg_mask = torch.from_numpy(bg_mask)
        mask = fg_mask - bg_mask
        return mask


================================================
FILE: datasets/download_data.sh
================================================
#!/bin/bash
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

NAME=$1

if [[ $NAME != "cartwheel" && $NAME != "reflection" && $NAME != "splash" &&  $NAME != "trampoline" && $NAME != "all" ]]; then
    echo "Available videos are: cartwheel, reflection, splash, trampoline"
    exit 1
fi

if [[ $NAME == "all" ]]; then
  declare -a NAMES=("cartwheel" "reflection" "splash" "trampoline")
else
  declare -a NAMES=($NAME)
fi

for NAME in "${NAMES[@]}"
do
  echo "Specified [$NAME]"
  URL=https://www.robots.ox.ac.uk/~erika/retiming/data/$NAME.zip
  ZIP_FILE=./datasets/$NAME.zip
  TARGET_DIR=./datasets/$NAME/
  wget -N $URL -O $ZIP_FILE
  mkdir $TARGET_DIR
  unzip $ZIP_FILE -d ./datasets/
  rm $ZIP_FILE
done

================================================
FILE: datasets/iuv_crop2full.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Convert UV crops to full UV maps."""
import os
import sys
import json
from PIL import Image
import numpy as np


def place_crop(crop, image, center_x, center_y):
    """Place the crop in the image at the specified location."""
    im_height, im_width = image.shape[:2]
    crop_height, crop_width = crop.shape[:2]

    left = center_x - crop_width // 2
    right = left + crop_width
    top = center_y - crop_height // 2
    bottom = top + crop_height

    adjusted_crop = crop  # remove regions of crop that go beyond image bounds
    if left < 0:
        adjusted_crop = adjusted_crop[:, -left:]
    if right > im_width:
        adjusted_crop = adjusted_crop[:, :(im_width - right)]
    if top < 0:
        adjusted_crop = adjusted_crop[-top:]
    if bottom > im_height:
        adjusted_crop = adjusted_crop[:(im_height - bottom)]
    crop_mask = (adjusted_crop > 0).astype(crop.dtype).sum(-1, keepdims=True)
    image[max(0, top):min(im_height, bottom), max(0, left):min(im_width, right)] *= (1 - crop_mask)
    image[max(0, top):min(im_height, bottom), max(0, left):min(im_width, right)] += adjusted_crop

    return image

def crop2full(keypoints_path, metadata_path, uvdir, outdir):
    """Create each frame's layer UVs from predicted UV crops"""
    with open(keypoints_path) as f:
        kp_data = json.load(f)

    # Get all people ids
    people_ids = set()
    for frame in kp_data:
        for skeleton in kp_data[frame]:
            people_ids.add(skeleton['idx'])
    people_ids = sorted(list(people_ids))

    with open(metadata_path) as f:
        metadata = json.load(f)

    orig_size = np.array(metadata['alphapose_input_size'][::-1])
    out_size = np.array(metadata['size_LR'][::-1])

    if 'people_layers' in metadata:
        people_layers = metadata['people_layers']
    else:
        people_layers = [[pid] for pid in people_ids]

    # Create output directories.
    for layer_i in range(1, 1 + len(people_layers)):
        os.makedirs(os.path.join(outdir, f'{layer_i:02d}'), exist_ok=True)
    print(f'Writing UVs to {outdir}')

    for frame in sorted(kp_data):
        for layer_i, layer in enumerate(people_layers, 1):
            out_path = os.path.join(outdir, f'{layer_i:02d}', frame)
            sys.stdout.flush()
            sys.stdout.write('processing frame %s\r' % out_path)
            uv_map = np.zeros([out_size[0], out_size[1], 4])
            for person_id in layer:
                matches = [p for p in kp_data[frame] if p['idx'] == person_id]
                if len(matches) == 0:  # person doesn't appear in this frame
                    continue
                skeleton = matches[0]
                kps = np.array(skeleton['keypoints']).reshape(17, 3)
                # Get kps bounding box.
                left = kps[:, 0].min()
                right = kps[:, 0].max()
                top = kps[:, 1].min()
                bottom = kps[:, 1].max()
                height = bottom - top
                width = right - left
                orig_crop_size = max(height, width)
                orig_center_x = (left + right) // 2
                orig_center_y = (top + bottom) // 2

                # read predicted uv map
                uv_crop_path = os.path.join(uvdir, f'{person_id:02d}_{os.path.basename(out_path)[:-4]}_output_uv.png')
                if os.path.exists(uv_crop_path):
                    uv_crop = np.array(Image.open(uv_crop_path))
                else:
                    uv_crop = np.zeros([256, 256, 3])

                # add person ID channel
                person_mask = (uv_crop[..., 0:1] > 0).astype('uint8')
                person_ids = (255 - person_id) * person_mask
                uv_crop = np.concatenate([uv_crop, person_ids], -1)

                # scale crop to desired output size
                # 256 is the crop size, 192 is the inner crop size
                out_crop_size = orig_crop_size * 256./192 * out_size / orig_size
                out_crop_size = out_crop_size.astype(np.int)
                uv_crop = uv_crop.astype(np.uint8)
                uv_crop = np.array(Image.fromarray(uv_crop).resize((out_crop_size[1], out_crop_size[0]), resample=Image.NEAREST))

                # scale center coordinate accordingly
                out_center_x = (orig_center_x * out_size[1] / orig_size[1]).astype(np.int)
                out_center_y = (orig_center_y * out_size[0] / orig_size[0]).astype(np.int)

                # Place UV crop in full UV map and save.
                uv_map = place_crop(uv_crop, uv_map, out_center_x, out_center_y)
            uv_map = Image.fromarray(uv_map.astype('uint8'))
            uv_map.save(out_path)


if __name__ == "__main__":
    import argparse
    arguments = argparse.ArgumentParser()
    arguments.add_argument('--dataroot', type=str)
    opt = arguments.parse_args()

    keypoints_path = os.path.join(opt.dataroot, 'keypoints.json')
    metadata_path = os.path.join(opt.dataroot, 'metadata.json')
    uvdir = os.path.join(opt.dataroot, 'kp2uv/test_latest/images')
    outdir = os.path.join(opt.dataroot, 'iuv')
    crop2full(keypoints_path, metadata_path, uvdir, outdir)


================================================
FILE: datasets/prepare_iuv.sh
================================================
#!/bin/bash
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

DATA_PATH=$1
# Predict UVs from keypoints and save the crops.
python run_kp2uv.py --model kp2uv --dataroot $DATA_PATH --results_dir $DATA_PATH
# Convert the cropped UVs to full UV maps.
python datasets/iuv_crop2full.py --dataroot $DATA_PATH


================================================
FILE: docs/contributing.md
================================================
# How to Contribute

We'd love to accept your patches and contributions to this project. There are
just a few small guidelines you need to follow.

## Contributor License Agreement

Contributions to this project must be accompanied by a Contributor License
Agreement (CLA). You (or your employer) retain the copyright to your
contribution; this simply gives us permission to use and redistribute your
contributions as part of the project. Head over to
<https://cla.developers.google.com/> to see your current agreements on file or
to sign a new one.

You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.

## Code reviews

All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.

## Community Guidelines

This project follows
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).

================================================
FILE: docs/data.md
================================================
### Data
The data directory for a video is structured as follows:
```
video_name/
|-- rgb_256/
|   |-- 0001.png, etc.
|-- rgb_512/
|   |-- 0001.png, etc.
|-- mask/ (optional)
|-- |-- 01, etc.
|-- |-- |-- 0001.png, etc.   
|-- keypoints.json
|-- metadata.json
|-- homographies.txt (optional)
```
- `metadata.json` contains a dictionary:
```
'alphapose_input_size': [width, height]  # size of frames input to AlphaPose
'size_LR': [width, height]  # size of low-resolution frames (multiple of 16; height should be 256)
'n_textures': int  # number of texture maps required, calculated by 24*num_people + 1
'composite_order': [[1, 2, 3], [1, 3, 2], ... ]  # optional per-frame back-to-front layer compositing order
```
- `keypoints.json` is in the format output by the [AlphaPose Pose Tracker](https://github.com/MVIG-SJTU/AlphaPose).
See [here](https://github.com/MVIG-SJTU/AlphaPose/tree/master/trackers/PoseFlow) for details.

================================================
FILE: environment.yml
================================================
name: retiming
channels:
  - pytorch
  - defaults
dependencies:
- python=3.8
- pytorch=1.4.0
- pip:
  - dominate==2.4.0
  - torchvision==0.5.0
  - Pillow>=6.1.0
  - numpy==1.19.2
  - visdom==0.1.8
  - opencv-python>=4.2.0
  - matplotlib



================================================
FILE: models/__init__.py
================================================


================================================
FILE: models/kp2uv_model.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from third_party.models.base_model import BaseModel
from . import networks


class Kp2uvModel(BaseModel):
    """This class implements the keypoint-to-UV model (inference only)."""
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        parser.set_defaults(dataset_mode='kpuv')
        return parser

    def __init__(self, opt):
        """Initialize this model class.

        Parameters:
            opt -- test options
        """
        BaseModel.__init__(self, opt)
        self.visual_names = ['keypoints', 'output_uv']
        self.model_names = ['Kp2uv']
        self.netKp2uv = networks.define_kp2uv(gpu_ids=self.gpu_ids)
        self.isTrain = False  # only test mode supported

        # Our program will automatically call <model.setup> to define schedulers, load networks, and print networks

    def set_input(self, input):
        """Unpack input data from the dataloader.

        Parameters:
            input: a dictionary that contains the data itself and its metadata information.
        """
        self.keypoints = input['keypoints'].to(self.device)
        self.image_paths = input['path']

    def forward(self):
        """Run forward pass. This will be called by <test>."""
        output = self.netKp2uv.forward(self.keypoints)
        self.output_uv = self.output2rgb(output)

    def output2rgb(self, output):
        """Convert network outputs to RGB image."""
        pred_id, pred_uv = output
        _, pred_id_class = pred_id.max(1)
        pred_id_class = pred_id_class.unsqueeze(1)
        # extract UV from pred_uv (48 channels); select based on class ID
        selected_uv = -1 * torch.ones(pred_uv.shape[0], 2, pred_uv.shape[2], pred_uv.shape[3], device=pred_uv.device)
        for partid in range(1, 25):
            mask = (pred_id_class == partid).float()
            selected_uv *= (1. - mask)
            selected_uv += mask * pred_uv[:, (partid - 1) * 2:(partid - 1) * 2 + 2]
        pred_uv = selected_uv
        rgb = torch.cat([pred_id_class.float() * 10 / 255. * 2 - 1, pred_uv], 1)
        return rgb

    def optimize_parameters(self):
        pass


================================================
FILE: models/lnr_model.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from third_party.models.base_model import BaseModel
from . import networks
import numpy as np
import torch.nn.functional as F


class LnrModel(BaseModel):
    """This class implements the layered neural rendering model for decomposing a video into layers."""
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        parser.set_defaults(dataset_mode='layered_video')
        parser.add_argument('--texture_res', type=int, default=16, help='texture resolution')
        parser.add_argument('--texture_channels', type=int, default=16, help='# channels for neural texture')
        parser.add_argument('--n_textures', type=int, default=25, help='# individual texture maps, 24 per person (1 per body part) + 1 for background')
        if is_train:
            parser.add_argument('--lambda_alpha_l1', type=float, default=0.01, help='alpha L1 sparsity loss weight')
            parser.add_argument('--lambda_alpha_l0', type=float, default=0.005, help='alpha L0 sparsity loss weight')
            parser.add_argument('--alpha_l1_rolloff_epoch', type=int, default=200, help='turn off L1 alpha sparsity loss weight after this epoch')
            parser.add_argument('--lambda_mask', type=float, default=50, help='layer matting loss weight')
            parser.add_argument('--mask_thresh', type=float, default=0.02, help='turn off masking loss when error falls below this value')
            parser.add_argument('--mask_loss_rolloff_epoch', type=int, default=-1, help='decrease masking loss after this epoch; if <0, use mask_thresh instead')
            parser.add_argument('--n_epochs_upsample', type=int, default=500,
                                help='number of epochs to train the upsampling module')
            parser.add_argument('--batch_size_upsample', type=int, default=16, help='batch size for upsampling')
            parser.add_argument('--jitter_rgb', type=float, default=0.2, help='amount of jitter to add to RGB')
            parser.add_argument('--jitter_epochs', type=int, default=400, help='number of epochs to jitter RGB')
        parser.add_argument('--do_upsampling', action='store_true', help='whether to use upsampling module')

        return parser

    def __init__(self, opt):
        """Initialize this model class.

        Parameters:
            opt -- training/test options
        """
        BaseModel.__init__(self, opt)
        # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
        self.visual_names = ['target_image', 'reconstruction', 'rgba_vis', 'alpha_vis', 'input_vis']
        self.model_names = ['LNR']
        self.netLNR = networks.define_LNR(opt.num_filters, opt.texture_channels, opt.texture_res, opt.n_textures, gpu_ids=self.gpu_ids)
        self.do_upsampling = opt.do_upsampling
        if self.isTrain:
            self.setup_train(opt)

        # Our program will automatically call <model.setup> to define schedulers, load networks, and print networks

    def setup_train(self, opt):
        """Setup the model for training mode."""
        print('setting up model')
        # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
        self.loss_names = ['total', 'recon', 'alpha_reg', 'mask']
        self.visual_names = ['target_image', 'reconstruction', 'rgba_vis', 'alpha_vis', 'input_vis']
        self.do_upsampling = opt.do_upsampling
        if not self.do_upsampling:
            self.visual_names += ['mask_vis']
        self.criterionLoss = torch.nn.L1Loss()
        self.criterionLossMask = networks.MaskLoss().to(self.device)
        self.lambda_mask = opt.lambda_mask
        self.lambda_alpha_l0 = opt.lambda_alpha_l0
        self.lambda_alpha_l1 = opt.lambda_alpha_l1
        self.mask_loss_rolloff_epoch = opt.mask_loss_rolloff_epoch
        self.jitter_rgb = opt.jitter_rgb
        self.do_upsampling = opt.do_upsampling
        self.optimizer = torch.optim.Adam(self.netLNR.parameters(), lr=opt.lr)
        self.optimizers = [self.optimizer]

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input: a dictionary that contains the data itself and its metadata information.
        """
        self.target_image = input['image'].to(self.device)
        if self.isTrain and self.jitter_rgb > 0:
            # add brightness jitter to rgb
            self.target_image += self.jitter_rgb * torch.randn(self.target_image.shape[0], 1, 1, 1).to(self.device)
            self.target_image = torch.clamp(self.target_image, -1, 1)
        self.input_uv = input['uv_map'].to(self.device)
        self.input_id = input['pids'].to(self.device)
        self.mask = input['mask'].to(self.device)
        self.image_paths = input['image_path']

    def gen_crop_params(self, orig_h, orig_w, crop_size=256):
        """Generate random square cropping parameters."""
        starty = np.random.randint(orig_h - crop_size + 1)
        startx = np.random.randint(orig_w - crop_size + 1)
        endy = starty + crop_size
        endx = startx + crop_size
        return starty, endy, startx, endx

    def forward(self):
        """Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
        if self.do_upsampling:
            input_uv_up = F.interpolate(self.input_uv, scale_factor=2, mode='bilinear')
            crop_params = None
            if self.isTrain:
                # Take random crop to decrease memory requirement.
                crop_params = self.gen_crop_params(*input_uv_up.shape[-2:])
                starty, endy, startx, endx = crop_params
                self.target_image = self.target_image[:, :, starty:endy, startx:endx]
            outputs = self.netLNR.forward(self.input_uv, self.input_id, uv_map_upsampled=input_uv_up, crop_params=crop_params)
        else:
            outputs = self.netLNR(self.input_uv, self.input_id)
        self.reconstruction = outputs['reconstruction'][:, :3]
        self.alpha_composite = outputs['reconstruction'][:, 3]
        self.output_rgba = outputs['layers']
        n_layers = outputs['layers'].shape[1]
        layers = outputs['layers'].clone()
        layers[:, 0, -1] = 1  # Background layer's alpha is always 1
        layers = torch.cat([layers[:, l] for l in range(n_layers)], -2)
        self.alpha_vis = layers[:, 3:4]
        self.rgba_vis = layers
        self.mask_vis = torch.cat([self.mask[:, l:l+1] for l in range(n_layers)], -2)
        self.input_vis = torch.cat([self.input_uv[:, 2*l:2*l+2] for l in range(n_layers)], -2)
        self.input_vis = torch.cat([torch.zeros_like(self.input_vis[:, :1]), self.input_vis], 1)

    def backward(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        self.loss_recon = self.criterionLoss(self.reconstruction[:, :3], self.target_image)
        self.loss_total = self.loss_recon
        if not self.do_upsampling:
            self.loss_alpha_reg = networks.cal_alpha_reg(self.alpha_composite * .5 + .5, self.lambda_alpha_l1, self.lambda_alpha_l0)
            alpha_layers = self.output_rgba[:, :, 3]
            self.loss_mask = self.lambda_mask * self.criterionLossMask(alpha_layers, self.mask)
            self.loss_total += self.loss_alpha_reg + self.loss_mask
        else:
            self.loss_mask = 0.
            self.loss_alph_reg = 0.
        self.loss_total.backward()

    def optimize_parameters(self):
        """Update network weights; it will be called in every training iteration."""
        self.forward()
        self.optimizer.zero_grad()
        self.backward()
        self.optimizer.step()

    def update_lambdas(self, epoch):
        """Update loss weights based on current epochs and losses."""
        if epoch == self.opt.alpha_l1_rolloff_epoch:
            self.lambda_alpha_l1 = 0
        if self.mask_loss_rolloff_epoch >= 0:
            if epoch == 2*self.mask_loss_rolloff_epoch:
                self.lambda_mask = 0
        elif epoch > self.opt.epoch_count:
            if self.loss_mask < self.opt.mask_thresh * self.opt.lambda_mask:
                self.mask_loss_rolloff_epoch = epoch
                self.lambda_mask *= .1
        if epoch == self.opt.jitter_epochs:
            self.jitter_rgb = 0

    def transfer_detail(self):
        """Transfer detail to layers."""
        residual = self.target_image - self.reconstruction
        transmission_comp = torch.zeros_like(self.target_image[:, 0:1])
        rgba_detail = self.output_rgba
        n_layers = self.output_rgba.shape[1]
        for i in range(n_layers - 1, 0, -1):  # Don't do detail transfer for background layer, due to ghosting effects.
            transmission_i = 1. - transmission_comp
            rgba_detail[:, i, :3] += transmission_i * residual
            alpha_i = self.output_rgba[:, i, 3:4] * .5 + .5
            transmission_comp = alpha_i + (1. - alpha_i) * transmission_comp
        self.rgba = torch.clamp(rgba_detail, -1, 1)

    def get_results(self):
        """Return results. This is different from get_current_visuals, which gets visuals for monitoring training.

        Returns a dictionary:
            original - - original frame
            recon - - reconstruction
            rgba_l* - - RGBA for each layer
            mask_l* - - mask for each layer
        """
        self.transfer_detail()
        # Split layers
        results = {'reconstruction': self.reconstruction, 'original': self.target_image}
        n_layers = self.rgba.shape[1]
        for i in range(n_layers):
            results[f'mask_l{i}'] = self.mask[:, i:i+1]
            results[f'rgba_l{i}'] = self.rgba[:, i]
            if i == 0:
                results[f'rgba_l{i}'][:, -1:] = 1.
        return results

    def freeze_basenet(self):
        """Freeze all parameters except for the upsampling module."""
        net = self.netLNR
        if isinstance(net, torch.nn.DataParallel):
            net = net.module
        self.set_requires_grad([net.encoder, net.decoder, net.final_rgba], False)
        net.texture.requires_grad = False

================================================
FILE: models/networks.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from third_party.models.networks import init_net


###############################################################################
# Helper Functions
###############################################################################
def define_LNR(nf=64, texture_channels=16, texture_res=16, n_textures=25, gpu_ids=[]):
    """Create a layered neural renderer.

    Parameters:
        nf (int) -- the number of channels in the first/last conv layers
        texture_channels (int) -- the number of channels in the neural texture
        texture_res (int) -- the size of each individual texture map
        n_textures (int) -- the number of individual texture maps
        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

    Returns a layered neural rendering model.
    """
    net = LayeredNeuralRenderer(nf, texture_channels, texture_res, n_textures)
    return init_net(net, gpu_ids)


def define_kp2uv(nf=64, gpu_ids=[]):
    """Create a keypoint-to-UV model.

    Parameters:
        nf (int) -- the number of channels in the first/last conv layers

    Returns a keypoint-to-UV model.
    """
    net = kp2uv(nf)
    return init_net(net, gpu_ids)


def cal_alpha_reg(prediction, lambda_alpha_l1, lambda_alpha_l0):
    """Calculate the alpha regularization term.

    Parameters:
        prediction (tensor) - - composite of predicted alpha layers
        lambda_alpha_l1 (float) - - weight for the L1 regularization term
        lambda_alpha_l0 (float) - - weight for the L0 regularization term
    Returns the alpha regularization loss term
    """
    assert prediction.max() <= 1.
    assert prediction.min() >= 0.
    loss = 0.
    if lambda_alpha_l1 > 0:
        loss += lambda_alpha_l1 * torch.mean(prediction)
    if lambda_alpha_l0 > 0:
        # Pseudo L0 loss using a squished sigmoid curve.
        l0_prediction = (torch.sigmoid(prediction * 5.0) - 0.5) * 2.0
        loss += lambda_alpha_l0 * torch.mean(l0_prediction)
    return loss


##############################################################################
# Classes
##############################################################################
class MaskLoss(nn.Module):
    """Define the loss which encourages the predicted alpha matte to match the mask (trimap)."""

    def __init__(self):
        super(MaskLoss, self).__init__()
        self.loss = nn.L1Loss(reduction='none')

    def __call__(self, prediction, target):
        """Calculate loss given predicted alpha matte and trimap.

        Balance positive and negative regions. Exclude 'unknown' region from loss.

        Parameters:
            prediction (tensor) - - predicted alpha
            target (tensor) - - trimap

        Returns: the computed loss
        """
        mask_err = self.loss(prediction, target)
        pos_mask = F.relu(target)
        neg_mask = F.relu(-target)
        pos_mask_loss = (pos_mask * mask_err).sum() / (1 + pos_mask.sum())
        neg_mask_loss = (neg_mask * mask_err).sum() / (1 + neg_mask.sum())
        loss = .5 * (pos_mask_loss + neg_mask_loss)
        return loss


class ConvBlock(nn.Module):
    """Helper module consisting of a convolution, optional normalization and activation, with padding='same'."""

    def __init__(self, conv, in_channels, out_channels, ksize=4, stride=1, dil=1, norm=None, activation='relu'):
        """Create a conv block.

        Parameters:
            conv (convolutional layer) - - the type of conv layer, e.g. Conv2d, ConvTranspose2d
            in_channels (int) - - the number of input channels
            in_channels (int) - - the number of output channels
            ksize (int) - - the kernel size
            stride (int) - - stride
            dil (int) - - dilation
            norm (norm layer) - - the type of normalization layer, e.g. BatchNorm2d, InstanceNorm2d
            activation (str)  -- the type of activation: relu | leaky | tanh | none
        """
        super(ConvBlock, self).__init__()
        self.k = ksize
        self.s = stride
        self.d = dil
        self.conv = conv(in_channels, out_channels, ksize, stride=stride, dilation=dil)

        if norm is not None:
            self.norm = norm(out_channels)
        else:
            self.norm = None

        if activation == 'leaky':
            self.activation = nn.LeakyReLU(0.2)
        elif activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        else:
            self.activation = None

    def forward(self, x):
        """Forward pass. Compute necessary padding and cropping because pytorch doesn't have pad=same."""
        height, width = x.shape[-2:]
        if isinstance(self.conv, nn.modules.ConvTranspose2d):
            desired_height = height * self.s
            desired_width = width * self.s
            pady = 0
            padx = 0
        else:
            # o = [i + 2*p - k - (k-1)*(d-1)]/s + 1
            # padding = .5 * (stride * (output-1) + (k-1)(d-1) + k - input)
            desired_height = height // self.s
            desired_width = width // self.s
            pady = .5 * (self.s * (desired_height - 1) + (self.k - 1) * (self.d - 1) + self.k - height)
            padx = .5 * (self.s * (desired_width - 1) + (self.k - 1) * (self.d - 1) + self.k - width)
        x = F.pad(x, [int(np.floor(padx)), int(np.ceil(padx)), int(np.floor(pady)), int(np.ceil(pady))])
        x = self.conv(x)
        if x.shape[-2] != desired_height or x.shape[-1] != desired_width:
            cropy = x.shape[-2] - desired_height
            cropx = x.shape[-1] - desired_width
            x = x[:, :, int(np.floor(cropy / 2.)):-int(np.ceil(cropy / 2.)),
                int(np.floor(cropx / 2.)):-int(np.ceil(cropx / 2.))]
        if self.norm:
            x = self.norm(x)
        if self.activation:
            x = self.activation(x)
        return x


class ResBlock(nn.Module):
    """Define a residual block."""

    def __init__(self, channels, ksize=4, stride=1, dil=1, norm=None, activation='relu'):
        """Initialize the residual block, which consists of 2 conv blocks with a skip connection."""
        super(ResBlock, self).__init__()
        self.convblock1 = ConvBlock(nn.Conv2d, channels, channels, ksize=ksize, stride=stride, dil=dil, norm=norm,
                                    activation=activation)
        self.convblock2 = ConvBlock(nn.Conv2d, channels, channels, ksize=ksize, stride=stride, dil=dil, norm=norm,
                                    activation=None)

    def forward(self, x):
        identity = x
        x = self.convblock1(x)
        x = self.convblock2(x)
        x += identity
        return x


class kp2uv(nn.Module):
    """UNet architecture for converting keypoint image to UV map.

    Same person UV map format as described in https://arxiv.org/pdf/1802.00434.pdf.
    """

    def __init__(self, nf=64):
        super(kp2uv, self).__init__(),
        self.encoder = nn.ModuleList([
            ConvBlock(nn.Conv2d, 3, nf, ksize=4, stride=2),
            ConvBlock(nn.Conv2d, nf, nf * 2, ksize=4, stride=2, norm=nn.InstanceNorm2d, activation='leaky'),
            ConvBlock(nn.Conv2d, nf * 2, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d, activation='leaky'),
            ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d, activation='leaky'),
            ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d, activation='leaky'),
            ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=3, stride=1, norm=nn.InstanceNorm2d, activation='leaky'),
            ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=3, stride=1, norm=nn.InstanceNorm2d, activation='leaky')])

        self.decoder = nn.ModuleList([
            ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d),
            ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d),
            ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 2, ksize=4, stride=2, norm=nn.InstanceNorm2d),
            ConvBlock(nn.ConvTranspose2d, nf * 2 * 2, nf, ksize=4, stride=2, norm=nn.InstanceNorm2d),
            ConvBlock(nn.ConvTranspose2d, nf * 2, nf, ksize=4, stride=2, norm=nn.InstanceNorm2d)])

        # head to predict body part class (25 classes - 24 body parts, 1 background.)
        self.id_pred = ConvBlock(nn.Conv2d, nf + 3, 25, ksize=3, stride=1, activation='none')
        # head to predict UV coordinates for every body part class
        self.uv_pred = ConvBlock(nn.Conv2d, nf + 3, 2 * 24, ksize=3, stride=1, activation='tanh')

    def forward(self, x):
        """Forward pass through UNet, handling skip connections.
        Parameters:
            x (tensor) - - rendered keypoint image, shape [B, 3, H, W]

        Returns:
            x_id (tensor): part id class probabilities
            x_uv (tensor): uv coordinates for each part id
        """
        skips = [x]
        for i, layer in enumerate(self.encoder):
            x = layer(x)
            if i < 5:
                skips.append(x)
        for layer in self.decoder:
            x = torch.cat((x, skips.pop()), 1)
            x = layer(x)
        x = torch.cat((x, skips.pop()), 1)
        x_id = self.id_pred(x)
        x_uv = self.uv_pred(x)
        return x_id, x_uv


class LayeredNeuralRenderer(nn.Module):
    """Layered Neural Rendering model for video decomposition.

    Consists of neural texture, UNet, upsampling module.
    """

    def __init__(self, nf=64, texture_channels=16, texture_res=16, n_textures=25):
        super(LayeredNeuralRenderer, self).__init__(),
        """Initialize layered neural renderer.

        Parameters:
            nf (int) -- the number of channels in the first/last conv layers
            texture_channels (int) -- the number of channels in the neural texture
            texture_res (int) -- the size of each individual texture map
            n_textures (int) -- the number of individual texture maps
        """
        # Neural texture is implemented as 'n_textures' concatenated horizontally
        self.texture = nn.Parameter(torch.randn(1, texture_channels, texture_res, n_textures * texture_res))

        # Define UNet
        self.encoder = nn.ModuleList([
            ConvBlock(nn.Conv2d, texture_channels + 1, nf, ksize=4, stride=2),
            ConvBlock(nn.Conv2d, nf, nf * 2, ksize=4, stride=2, norm=nn.BatchNorm2d, activation='leaky'),
            ConvBlock(nn.Conv2d, nf * 2, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d, activation='leaky'),
            ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d, activation='leaky'),
            ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d, activation='leaky'),
            ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=1, dil=2, norm=nn.BatchNorm2d, activation='leaky'),
            ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=1, dil=2, norm=nn.BatchNorm2d, activation='leaky')])
        self.decoder = nn.ModuleList([
            ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d),
            ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d),
            ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 2, ksize=4, stride=2, norm=nn.BatchNorm2d),
            ConvBlock(nn.ConvTranspose2d, nf * 2 * 2, nf, ksize=4, stride=2, norm=nn.BatchNorm2d),
            ConvBlock(nn.ConvTranspose2d, nf * 2, nf, ksize=4, stride=2, norm=nn.BatchNorm2d)])
        self.final_rgba = ConvBlock(nn.Conv2d, nf, 4, ksize=4, stride=1, activation='tanh')

        # Define upsampling block, which outputs a residual
        upsampling_ic = texture_channels + 5 + nf
        self.upsample_block = nn.Sequential(
            ConvBlock(nn.Conv2d, upsampling_ic, nf, ksize=3, stride=1, norm=nn.InstanceNorm2d),
            ResBlock(nf, ksize=3, stride=1, norm=nn.InstanceNorm2d),
            ResBlock(nf, ksize=3, stride=1, norm=nn.InstanceNorm2d),
            ResBlock(nf, ksize=3, stride=1, norm=nn.InstanceNorm2d),
            ConvBlock(nn.Conv2d, nf, 4, ksize=3, stride=1, activation='none'))

    def render(self, x):
        """Pass inputs for a single layer through UNet.

        Parameters:
            x (tensor) - - sampled texture concatenated with person IDs

        Returns RGBA for the input layer and the final feature maps.
        """
        skips = [x]
        for i, layer in enumerate(self.encoder):
            x = layer(x)
            if i < 5:
                skips.append(x)
        for layer in self.decoder:
            x = torch.cat((x, skips.pop()), 1)
            x = layer(x)
        rgba = self.final_rgba(x)
        return rgba, x

    def forward(self, uv_map, id_layers, uv_map_upsampled=None, crop_params=None):
        """Forward pass through layered neural renderer.

        Steps:
        1. Sample from the neural texture using uv_map
        2. Input uv_map and id_layers into UNet
            2a. If doing upsampling, then pass upsampled inputs and results through upsampling module
        3. Composite RGBA outputs.

        Parameters:
            uv_map (tensor) - - UV maps for all layers, with shape [B, (2*L), H, W]
            id_layers (tensor) - - person ID for all layers, with shape [B, L, H, W]
            uv_map_upsampled (tensor) - - upsampled UV maps to input to upsampling module (if None, skip upsampling)
            crop_params
        """
        b_sz = uv_map.shape[0]
        n_layers = uv_map.shape[1] // 2
        texture = self.texture.repeat(b_sz, 1, 1, 1)
        composite = None
        layers = []
        sampled_textures = []
        for i in range(n_layers):
            # Get RGBA for this layer.
            uv_map_i = uv_map[:, i * 2:(i + 1) * 2, ...]
            uv_map_i = uv_map_i.permute(0, 2, 3, 1)
            sampled_texture = F.grid_sample(texture, uv_map_i, mode='bilinear', padding_mode='zeros')
            inputs = torch.cat([sampled_texture, id_layers[:, i:i + 1]], 1)
            rgba, last_feat = self.render(inputs)

            if uv_map_upsampled is not None:
                uv_map_up_i = uv_map_upsampled[:, i * 2:(i + 1) * 2, ...]
                uv_map_up_i = uv_map_up_i.permute(0, 2, 3, 1)
                sampled_texture_up = F.grid_sample(texture, uv_map_up_i, mode='bilinear', padding_mode='zeros')
                id_layers_up = F.interpolate(id_layers[:, i:i + 1], size=sampled_texture_up.shape[-2:],
                                             mode='bilinear')
                inputs_up = torch.cat([sampled_texture_up, id_layers_up], 1)
                upsampled_size = inputs_up.shape[-2:]
                rgba = F.interpolate(rgba, size=upsampled_size, mode='bilinear')
                last_feat = F.interpolate(last_feat, size=upsampled_size, mode='bilinear')
                if crop_params is not None:
                    starty, endy, startx, endx = crop_params
                    rgba = rgba[:, :, starty:endy, startx:endx]
                    last_feat = last_feat[:, :, starty:endy, startx:endx]
                    inputs_up = inputs_up[:, :, starty:endy, startx:endx]
                rgba_residual = self.upsample_block(torch.cat((rgba, inputs_up, last_feat), 1))
                rgba += .01 * rgba_residual
                rgba = torch.clamp(rgba, -1, 1)
                sampled_texture = sampled_texture_up

            # Update the composite with this layer's RGBA output
            if composite is None:
                composite = rgba
            else:
                alpha = rgba[:, 3:4] * .5 + .5
                composite = rgba * alpha + composite * (1.0 - alpha)
            layers.append(rgba)
            sampled_textures.append(sampled_texture)

        outputs = {
            'reconstruction': composite,
            'layers': torch.stack(layers, 1),
            'sampled texture': sampled_textures,  # for debugging
        }
        return outputs


================================================
FILE: options/__init__.py
================================================


================================================
FILE: options/base_options.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from third_party.util import util
from third_party import models
from third_party import data
import torch
import json


class BaseOptions():
    """This class defines options used during both training and test time.

    It also implements several helper functions such as parsing, printing, and saving the options.
    It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
    """

    def __init__(self):
        """Reset the class; indicates the class hasn't been initialized"""
        self.initialized = False

    def initialize(self, parser):
        """Define the common options that are used in both training and test."""
        # basic parameters
        parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders rgb_256, etc)')
        parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
        parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
        parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
        parser.add_argument('--seed', type=int, default=1, help='initial random seed')
        # model parameters
        parser.add_argument('--model', type=str, default='lnr', help='chooses which model to use. [lnr | kp2uv]')
        parser.add_argument('--num_filters', type=int, default=64, help='# filters in the first and last conv layers')
        # dataset parameters
        parser.add_argument('--dataset_mode', type=str, default='layered_video', help='chooses how datasets are loaded. [layered_video | kpuv]')
        parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
        parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
        parser.add_argument('--batch_size', type=int, default=32, help='input batch size')
        parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
        parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
        # additional parameters
        parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
        parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
        parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
        self.initialized = True
        return parser

    def gather_options(self):
        """Initialize our parser with basic options(only once).
        Add additional model-specific and dataset-specific options.
        These options are defined in the <modify_commandline_options> function
        in model and dataset classes.
        """
        if not self.initialized:  # check if it has been initialized
            parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
            parser = self.initialize(parser)

        # get the basic options
        opt, _ = parser.parse_known_args()

        # modify model-related parser options
        model_name = opt.model
        model_option_setter = models.get_option_setter(model_name)
        parser = model_option_setter(parser, self.isTrain)
        opt, _ = parser.parse_known_args()  # parse again with new defaults

        # modify dataset-related parser options
        dataset_name = opt.dataset_mode
        dataset_option_setter = data.get_option_setter(dataset_name)
        parser = dataset_option_setter(parser, self.isTrain)

        # save and return the parser
        self.parser = parser
        return parser.parse_args()

    def print_options(self, opt):
        """Print and save options

        It will print both current options and default values(if different).
        It will save options into a text file / [checkpoints_dir] / opt.txt
        """
        message = ''
        message += '----------------- Options ---------------\n'
        for k, v in sorted(vars(opt).items()):
            comment = ''
            default = self.parser.get_default(k)
            if v != default:
                comment = '\t[default: %s]' % str(default)
            message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
        message += '----------------- End -------------------'
        print(message)

        # save to the disk
        expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
        util.mkdirs(expr_dir)
        file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
        with open(file_name, 'wt') as opt_file:
            opt_file.write(message)
            opt_file.write('\n')

    def parse(self):
        """Parse our options, create checkpoints directory suffix, and set up gpu device."""
        opt = self.gather_options()
        opt.isTrain = self.isTrain   # train or test

        # process opt.suffix
        if opt.suffix:
            suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
            opt.name = opt.name + suffix

        self.print_options(opt)

        # set gpu ids
        str_ids = opt.gpu_ids.split(',')
        opt.gpu_ids = []
        for str_id in str_ids:
            id = int(str_id)
            if id >= 0:
                opt.gpu_ids.append(id)
        if len(opt.gpu_ids) > 0:
            torch.cuda.set_device(opt.gpu_ids[0])

        self.opt = opt
        return self.opt

    def parse_dataset_meta(self):
        """Parse options from the 'metadata.json' file in the dataroot."""
        with open(os.path.join(self.opt.dataroot, 'metadata.json')) as f:
            metadata = json.load(f)
        self.opt.n_textures = metadata['n_textures']
        self.opt.width = metadata['size_LR'][0]
        self.opt.height = metadata['size_LR'][1]
        return self.opt

================================================
FILE: options/test_options.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_options import BaseOptions


class TestOptions(BaseOptions):
    """This class includes test options.

    It also includes shared options defined in BaseOptions.
    """

    def initialize(self, parser):
        parser = BaseOptions.initialize(self, parser)  # define shared options
        parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
        parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
        parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
        parser.add_argument('--num_test', type=int, default=float("inf"), help='how many test images to run')
        self.isTrain = False
        return parser


================================================
FILE: options/train_options.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_options import BaseOptions


class TrainOptions(BaseOptions):
    """This class includes training options.

    It also includes shared options defined in BaseOptions.
    """

    def initialize(self, parser):
        parser = BaseOptions.initialize(self, parser)
        # visdom and HTML visualization parameters
        parser.add_argument('--display_freq', type=int, default=20, help='frequency of showing training results on screen (in epochs)')
        parser.add_argument('--display_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
        parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
        parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
        parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
        parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
        parser.add_argument('--update_html_freq', type=int, default=50, help='frequency of saving training results to html')
        parser.add_argument('--print_freq', type=int, default=10, help='frequency of showing training results on console (in steps per epoch)')
        parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
        # network saving and loading parameters
        parser.add_argument('--save_latest_freq', type=int, default=50, help='frequency of saving the latest results (in epochs)')
        parser.add_argument('--save_by_epoch', action='store_true', help='whether saves model as "epoch" or "latest" (overwrites previous)')
        parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
        parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
        parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
        # training parameters
        parser.add_argument('--n_epochs', type=int, default=2000, help='number of epochs with the initial learning rate')
        parser.add_argument('--n_epochs_decay', type=int, default=0, help='number of epochs to linearly decay learning rate to zero')
        parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam')
        parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')

        self.isTrain = True
        return parser


================================================
FILE: requirements.txt
================================================
torch==1.4.0
torchvision>=0.5.0
dominate>=2.4.0
visdom>=0.1.8
matplotlib>=3.2.1
opencv-python>=4.2.0


================================================
FILE: run_kp2uv.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Script for running the keypoint-to-UV network and saving the predicted UVs.

Example:
    python run_kp2uv.py --model kp2uv --dataroot ./datasets/reflection --results_dir ./datasets/reflection

It will load the pre-trained model from '--checkpoints_dir' and save the results to '--results_dir'.

See options/base_options.py and options/test_options.py for more test options.
"""
import os
from options.test_options import TestOptions
from third_party.data import create_dataset
from third_party.models import create_model
from third_party.util.visualizer import save_images
from third_party.util import html


if __name__ == '__main__':
    opt = TestOptions().parse()  # get test options
    # hard-code some parameters
    opt.name = 'kp2uv'
    opt.num_threads = 0   # test code only supports num_threads = 0
    opt.batch_size = 1    # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.display_id = -1   # no visdom display; the test code saves the results to a HTML file.
    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
    model = create_model(opt)      # create a model given opt.model and other options
    model.setup(opt)               # regular setup: load and print networks; create schedulers
    # create a website
    web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch))  # define the website directory
    print('creating web directory', web_dir)
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
    for i, data in enumerate(dataset):
        if i >= opt.num_test:  # only apply our model to opt.num_test images.
            break
        model.set_input(data)  # unpack data from data loader
        model.test()           # run inference
        visuals = model.get_current_visuals()  # get image results
        img_path = model.get_image_paths()     # get image paths
        if i % 5 == 0:
            print('processing (%04d)-th image... %s' % (i, img_path))
        save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
    webpage.save()  # save the HTML


================================================
FILE: scripts/download_kp2uv_model.sh
================================================
#!/bin/bash
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

mkdir -p ./checkpoints/kp2uv
MODEL_FILE=./checkpoints/kp2uv/latest_net_Kp2uv.pth
URL=https://www.robots.ox.ac.uk/~erika/retiming/pretrained_models/kp2uv.pth
wget -N $URL -O $MODEL_FILE


================================================
FILE: scripts/run_cartwheel.sh
================================================
#!/bin/bash
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

GPUS=0,1
DATA_PATH=./datasets/cartwheel
bash datasets/prepare_iuv.sh $DATA_PATH
python train.py \
  --name cartwheel \
  --dataroot $DATA_PATH \
  --use_homographies \
  --gpu_ids $GPUS
python test.py \
  --name cartwheel \
  --dataroot $DATA_PATH \
  --do_upsampling \
  --use_homographies

================================================
FILE: scripts/run_reflection.sh
================================================
#!/bin/bash
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

GPUS=0,1
DATA_PATH=./datasets/reflection
bash datasets/prepare_iuv.sh $DATA_PATH
python train.py \
  --name reflection \
  --dataroot $DATA_PATH \
  --gpu_ids $GPUS
python test.py \
  --name reflection \
  --dataroot $DATA_PATH \
  --do_upsampling

================================================
FILE: scripts/run_splash.sh
================================================
#!/bin/bash
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

GPUS=0,1
DATA_PATH=./datasets/splash
bash datasets/prepare_iuv.sh $DATA_PATH
python train.py \
  --name splash \
  --dataroot $DATA_PATH \
  --batch_size 24 \
  --batch_size_upsample 12 \
  --use_mask_images \
  --gpu_ids $GPUS
python test.py \
  --name splash \
  --dataroot $DATA_PATH \
  --do_upsampling

================================================
FILE: scripts/run_trampoline.sh
================================================
#!/bin/bash
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

GPUS=0,1
DATA_PATH=./datasets/trampoline
bash datasets/prepare_iuv.sh $DATA_PATH
python train.py \
  --name trampoline \
  --dataroot $DATA_PATH \
  --batch_size 16 \
  --batch_size_upsample 6 \
  --gpu_ids $GPUS
python test.py \
  --name trampoline \
  --dataroot $DATA_PATH \
  --do_upsampling

================================================
FILE: test.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Script to save the full outputs of a layered neural renderer (LNR).

Once you have trained the LNR with train.py, you can use this script to save the model's final layer decomposition.
It will load a saved model from '--checkpoints_dir' and save the results to '--results_dir'.

It first creates a model and dataset given the options. It will hard-code some parameters.
It then runs inference for '--num_test' images and save results to an HTML file.

Example (You need to train models first or download pre-trained models from our website):
    python test.py --dataroot ./datasets/reflection --name reflection --do_upsampling

    If the upsampling module isn't trained (train.py is used with '--n_epochs_upsample 0'), remove --do_upsampling.
    Use '--results_dir <directory_path_to_save_result>' to specify the results directory.

See options/base_options.py and options/test_options.py for more test options.
"""
import os
from options.test_options import TestOptions
from third_party.data import create_dataset
from third_party.models import create_model
from third_party.util.visualizer import save_images, save_videos
from third_party.util import html
import torch


if __name__ == '__main__':
    testopt = TestOptions()
    testopt.parse()
    opt = testopt.parse_dataset_meta()
    # hard-code some parameters for test
    opt.num_threads = 0   # test code only supports num_threads = 0
    opt.batch_size = 1    # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.display_id = -1   # no visdom display; the test code saves the results to a HTML file.
    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
    model = create_model(opt)      # create a model given opt.model and other options
    model.setup(opt)               # regular setup: load and print networks; create schedulers
    # create a website
    web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch))  # define the website directory
    print('creating web directory', web_dir)
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
    video_visuals = None
    for i, data in enumerate(dataset):
        if i >= opt.num_test:  # only apply our model to opt.num_test images.
            break
        model.set_input(data)  # unpack data from data loader
        model.test()           # run inference
        img_path = model.get_image_paths()     # get image paths
        if i % 5 == 0:  # save images to an HTML file
            print('processing (%04d)-th image... %s' % (i, img_path))
        visuals = model.get_results()  # rgba, reconstruction, original, mask
        if video_visuals is None:
            video_visuals = visuals
        else:
            for k in video_visuals:
                video_visuals[k] = torch.cat((video_visuals[k], visuals[k]))
        rgba = { k: visuals[k] for k in visuals if 'rgba' in k }
        # save RGBA layers
        save_images(webpage, rgba, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
    save_videos(webpage, video_visuals, width=opt.display_winsize)
    webpage.save()  # save the HTML of videos

================================================
FILE: third_party/__init__.py
================================================


================================================
FILE: third_party/data/__init__.py
================================================
"""This package includes all the modules related to data loading and preprocessing

 To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
 You need to implement four functions:
    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
    -- <__len__>:                       return the size of dataset.
    -- <__getitem__>:                   get a data point from data loader.
    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.

Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import importlib
import torch.utils.data
from .base_dataset import BaseDataset
from .fast_data_loader import FastDataLoader


def find_dataset_using_name(dataset_name):
    """Import the module "data/[dataset_name]_dataset.py".

    In the file, the class called DatasetNameDataset() will
    be instantiated. It has to be a subclass of BaseDataset,
    and it is case-insensitive.
    """
    dataset_filename = "data." + dataset_name + "_dataset"
    datasetlib = importlib.import_module(dataset_filename)

    dataset = None
    target_dataset_name = dataset_name.replace('_', '') + 'dataset'
    for name, cls in datasetlib.__dict__.items():
        if name.lower() == target_dataset_name.lower() \
           and issubclass(cls, BaseDataset):
            dataset = cls

    if dataset is None:
        raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))

    return dataset


def get_option_setter(dataset_name):
    """Return the static method <modify_commandline_options> of the dataset class."""
    dataset_class = find_dataset_using_name(dataset_name)
    return dataset_class.modify_commandline_options


def create_dataset(opt, use_fast_loader=False):
    """Create a dataset given the option.

    This function wraps the class CustomDatasetDataLoader.
        This is the main interface between this package and 'train.py'/'test.py'

    If use_fast_loader=False, use the default pytorch dataloader. Otherwise, use FastDatasetLoader.

    Example:
        >>> from data import create_dataset
        >>> dataset = create_dataset(opt)
    """
    data_loader = CustomDatasetDataLoader(opt, use_fast_loader=use_fast_loader)
    dataset = data_loader.load_data()
    return dataset


class CustomDatasetDataLoader():
    """Wrapper class of Dataset class that performs multi-threaded data loading"""

    def __init__(self, opt, use_fast_loader=False):
        """Initialize this class

        Step 1: create a dataset instance given the name [dataset_mode]
        Step 2: create a multi-threaded data loader.

        If use_fast_loader=False, use the default pytorch dataloader. Otherwise, use FastDatasetLoader.
        """
        self.opt = opt
        dataset_class = find_dataset_using_name(opt.dataset_mode)
        self.dataset = dataset_class(opt)
        print("dataset [%s] was created" % type(self.dataset).__name__)
        loader = torch.utils.data.DataLoader
        if use_fast_loader:
            loader = FastDataLoader
        self.dataloader = loader(
            self.dataset,
            batch_size=opt.batch_size,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.num_threads))

    def load_data(self):
        return self

    def __len__(self):
        """Return the number of data in the dataset"""
        return min(len(self.dataset), self.opt.max_dataset_size)

    def __iter__(self):
        """Return a batch of data"""
        for i, data in enumerate(self.dataloader):
            if i * self.opt.batch_size >= self.opt.max_dataset_size:
                break
            yield data


================================================
FILE: third_party/data/base_dataset.py
================================================
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.

It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
"""
import random
import numpy as np
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
from abc import ABC, abstractmethod


class BaseDataset(data.Dataset, ABC):
    """This class is an abstract base class (ABC) for datasets.

    To create a subclass, you need to implement the following four functions:
    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
    -- <__len__>:                       return the size of dataset.
    -- <__getitem__>:                   get a data point.
    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.
    """

    def __init__(self, opt):
        """Initialize the class; save the options in the class

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        self.opt = opt
        self.root = opt.dataroot

    @staticmethod
    def modify_commandline_options(parser, is_train):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.
        """
        return parser

    @abstractmethod
    def __len__(self):
        """Return the total number of images in the dataset."""
        return 0

    @abstractmethod
    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns:
            a dictionary of data with their names. It ususally contains the data itself and its metadata information.
        """
        pass


def get_params(opt, size):
    w, h = size
    new_h = h
    new_w = w
    if opt.preprocess == 'resize_and_crop':
        new_h = new_w = opt.load_size
    elif opt.preprocess == 'scale_width_and_crop':
        new_w = opt.load_size
        new_h = opt.load_size * h // w

    x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
    y = random.randint(0, np.maximum(0, new_h - opt.crop_size))

    flip = random.random() > 0.5

    return {'crop_pos': (x, y), 'flip': flip}


def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
    transform_list = []
    if grayscale:
        transform_list.append(transforms.Grayscale(1))
    if 'resize' in opt.preprocess:
        osize = [opt.load_size, opt.load_size]
        transform_list.append(transforms.Resize(osize, method))
    elif 'scale_width' in opt.preprocess:
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))

    if 'crop' in opt.preprocess:
        if params is None:
            transform_list.append(transforms.RandomCrop(opt.crop_size))
        else:
            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))

    if opt.preprocess == 'none':
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))

    if not opt.no_flip:
        if params is None:
            transform_list.append(transforms.RandomHorizontalFlip())
        elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    if convert:
        transform_list += [transforms.ToTensor()]
        if grayscale:
            transform_list += [transforms.Normalize((0.5,), (0.5,))]
        else:
            transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)


def __make_power_2(img, base, method=Image.BICUBIC):
    ow, oh = img.size
    h = int(round(oh / base) * base)
    w = int(round(ow / base) * base)
    if h == oh and w == ow:
        return img

    __print_size_warning(ow, oh, w, h)
    return img.resize((w, h), method)


def __scale_width(img, target_size, crop_size, method=Image.BICUBIC):
    ow, oh = img.size
    if ow == target_size and oh >= crop_size:
        return img
    w = target_size
    h = int(max(target_size * oh / ow, crop_size))
    return img.resize((w, h), method)


def __crop(img, pos, size):
    ow, oh = img.size
    x1, y1 = pos
    tw = th = size
    if (ow > tw or oh > th):
        return img.crop((x1, y1, x1 + tw, y1 + th))
    return img


def __flip(img, flip):
    if flip:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img


def __print_size_warning(ow, oh, w, h):
    """Print warning information about image size(only print once)"""
    if not hasattr(__print_size_warning, 'has_printed'):
        print("The image size needs to be a multiple of 4. "
              "The loaded image size was (%d, %d), so it was adjusted to "
              "(%d, %d). This adjustment will be done to all images "
              "whose sizes are not multiples of 4" % (ow, oh, w, h))
        __print_size_warning.has_printed = True


================================================
FILE: third_party/data/fast_data_loader.py
================================================
""" Fixes the issue where DataLoader is slow because processes aren't reused
See https://github.com/pytorch/pytorch/issues/15849
Warning: overrides batch sampler.
"""
import torch.utils.data


class _RepeatSampler(object):
    """ Sampler that repeats forever.

    Args:
        sampler (Sampler)
    """

    def __init__(self, sampler):
        self.sampler = sampler

    def __iter__(self):
        while True:
            yield from iter(self.sampler)


class FastDataLoader(torch.utils.data.dataloader.DataLoader):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
        self.iterator = super().__iter__()

    def __len__(self):
        return len(self.batch_sampler.sampler)

    def __iter__(self):
        for i in range(len(self)):
            yield next(self.iterator)

================================================
FILE: third_party/data/image_folder.py
================================================
"""A modified image folder class

We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
so that this class can load images from both current directory and its subdirectories.
"""

import torch.utils.data as data

from PIL import Image
import os

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
    '.tif', '.TIF', '.tiff', '.TIFF',
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir, max_dataset_size=float("inf")):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)
    images = sorted(images)
    return images[:min(max_dataset_size, len(images))]


def default_loader(path):
    return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):

    def __init__(self, root, transform=None, return_paths=False,
                 loader=default_loader):
        imgs = make_dataset(root)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.transform = transform
        self.return_paths = return_paths
        self.loader = loader

    def __getitem__(self, index):
        path = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.return_paths:
            return img, path
        else:
            return img

    def __len__(self):
        return len(self.imgs)


================================================
FILE: third_party/models/__init__.py
================================================
"""This package contains modules related to objective functions, optimizations, and network architectures.

To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
You need to implement the following five functions:
    -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
    -- <set_input>:                     unpack data from dataset and apply preprocessing.
    -- <forward>:                       produce intermediate results.
    -- <optimize_parameters>:           calculate loss, gradients, and update network weights.
    -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.

In the function <__init__>, you need to define four lists:
    -- self.loss_names (str list):          specify the training losses that you want to plot and save.
    -- self.model_names (str list):         define networks used in our training.
    -- self.visual_names (str list):        specify the images that you want to display and save.
    -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.

Now you can use the model class by specifying flag '--model dummy'.
See our template model class 'template_model.py' for more details.
"""

import importlib
from .base_model import BaseModel


def find_model_using_name(model_name):
    """Import the module "models/[model_name]_model.py".

    In the file, the class called DatasetNameModel() will
    be instantiated. It has to be a subclass of BaseModel,
    and it is case-insensitive.
    """
    model_filename = "models." + model_name + "_model"
    modellib = importlib.import_module(model_filename)
    model = None
    target_model_name = model_name.replace('_', '') + 'model'
    for name, cls in modellib.__dict__.items():
        if name.lower() == target_model_name.lower() \
           and issubclass(cls, BaseModel):
            model = cls

    if model is None:
        print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
        exit(0)

    return model


def get_option_setter(model_name):
    """Return the static method <modify_commandline_options> of the model class."""
    model_class = find_model_using_name(model_name)
    return model_class.modify_commandline_options


def create_model(opt):
    """Create a model given the option.

    This function warps the class CustomDatasetDataLoader.
    This is the main interface between this package and 'train.py'/'test.py'

    Example:
        >>> from models import create_model
        >>> model = create_model(opt)
    """
    model = find_model_using_name(opt.model)
    instance = model(opt)
    print("model [%s] was created" % type(instance).__name__)
    return instance


================================================
FILE: third_party/models/base_model.py
================================================
import os
import torch
from collections import OrderedDict
from abc import ABC, abstractmethod
from . import networks


class BaseModel(ABC):
    """This class is an abstract base class (ABC) for models.
    To create a subclass, you need to implement the following five functions:
        -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
        -- <set_input>:                     unpack data from dataset and apply preprocessing.
        -- <forward>:                       produce intermediate results.
        -- <optimize_parameters>:           calculate losses, gradients, and update network weights.
        -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.
    """

    def __init__(self, opt):
        """Initialize the BaseModel class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions

        When creating your custom class, you need to implement your own initialization.
        In this function, you should first call <BaseModel.__init__(self, opt)>
        Then, you need to define four lists:
            -- self.loss_names (str list):          specify the training losses that you want to plot and save.
            -- self.model_names (str list):         define networks used in our training.
            -- self.visual_names (str list):        specify the images that you want to display and save.
            -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
        """
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')  # get device name: CPU or GPU
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)  # save all the checkpoints to save_dir
        self.loss_names = []
        self.model_names = []
        self.visual_names = []
        self.optimizers = []
        self.image_paths = []
        self.metric = 0  # used for learning rate policy 'plateau'

    @staticmethod
    def modify_commandline_options(parser, is_train):
        """Add new model-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.
        """
        return parser

    @abstractmethod
    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): includes the data itself and its metadata information.
        """
        pass

    @abstractmethod
    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        pass

    @abstractmethod
    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        pass

    def setup(self, opt):
        """Load and print networks; create schedulers

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        if self.isTrain:
            self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
        if not self.isTrain or opt.continue_train:
            load_suffix = opt.epoch
            self.load_networks(load_suffix)
        self.print_networks(opt.verbose)

    def eval(self):
        """Make models eval mode during test time"""
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                net.eval()

    def test(self):
        """Forward function used in test time.

        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
        It also calls <compute_visuals> to produce additional visualization results
        """
        with torch.no_grad():
            self.forward()
            self.compute_visuals()

    def compute_visuals(self):
        """Calculate additional output images for visdom and HTML visualization"""
        pass

    def get_image_paths(self):
        """ Return image paths that are used to load current data"""
        return self.image_paths

    def update_learning_rate(self):
        """Update learning rates for all the networks; called at the end of every epoch"""
        old_lr = self.optimizers[0].param_groups[0]['lr']
        for scheduler in self.schedulers:
            if self.opt.lr_policy == 'plateau':
                scheduler.step(self.metric)
            else:
                scheduler.step()

        lr = self.optimizers[0].param_groups[0]['lr']
        if old_lr != lr:
            print('learning rate %.7f -> %.7f' % (old_lr, lr))

    def get_current_visuals(self):
        """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
        visual_ret = OrderedDict()
        for name in self.visual_names:
            if isinstance(name, str):
                visual_ret[name] = getattr(self, name)
        return visual_ret

    def get_current_losses(self):
        """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
        errors_ret = OrderedDict()
        for name in self.loss_names:
            if isinstance(name, str):
                errors_ret[name] = float(getattr(self, 'loss_' + name))  # float(...) works for both scalar tensor and float number
        return errors_ret

    def save_networks(self, epoch):
        """Save all the networks to the disk.

        Parameters:
            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
        """
        for name in self.model_names:
            if isinstance(name, str):
                save_filename = '%s_net_%s.pth' % (epoch, name)
                save_path = os.path.join(self.save_dir, save_filename)
                net = getattr(self, 'net' + name)

                if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                    torch.save(net.module.cpu().state_dict(), save_path)
                    net.cuda(self.gpu_ids[0])
                else:
                    torch.save(net.cpu().state_dict(), save_path)

    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
        """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
        key = keys[i]
        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
            if module.__class__.__name__.startswith('InstanceNorm') and \
                    (key == 'running_mean' or key == 'running_var'):
                if getattr(module, key) is None:
                    state_dict.pop('.'.join(keys))
            if module.__class__.__name__.startswith('InstanceNorm') and \
               (key == 'num_batches_tracked'):
                state_dict.pop('.'.join(keys))
        else:
            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)

    def load_networks(self, epoch):
        """Load all the networks from the disk.

        Parameters:
            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
        """
        for name in self.model_names:
            if isinstance(name, str):
                load_filename = '%s_net_%s.pth' % (epoch, name)
                load_path = os.path.join(self.save_dir, load_filename)
                net = getattr(self, 'net' + name)
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                print('loading the model from %s' % load_path)
                # if you are using PyTorch newer than 0.4 (e.g., built from
                # GitHub source), you can remove str() on self.device
                state_dict = torch.load(load_path, map_location=str(self.device))
                if hasattr(state_dict, '_metadata'):
                    del state_dict._metadata

                # patch InstanceNorm checkpoints prior to 0.4
                for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop
                    self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
                net.load_state_dict(state_dict)

    def print_networks(self, verbose):
        """Print the total number of parameters in the network and (if verbose) network architecture

        Parameters:
            verbose (bool) -- if verbose: print the network architecture
        """
        print('---------- Networks initialized -------------')
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                num_params = 0
                num_trainable_params = 0
                for param in net.parameters():
                    num_params += param.numel()
                    if param.requires_grad:
                        num_trainable_params += param.numel()
                if verbose:
                    print(net)
                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
                print('[Network %s] Total number of trainable parameters : %.3f M' % (name, num_trainable_params / 1e6))
        print('-----------------------------------------------')

    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad


================================================
FILE: third_party/models/networks.py
================================================
import torch
import torch.nn as nn
from torch.optim import lr_scheduler


###############################################################################
# Helper Functions
###############################################################################
def get_scheduler(optimizer, opt):
    """Return a learning rate scheduler

    Parameters:
        optimizer          -- the optimizer of the network
        opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
                              opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine

    For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
    and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
    For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
    See https://pytorch.org/docs/stable/optim.html for more details.
    """
    if opt.lr_policy == 'linear':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
            return lr_l

        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif opt.lr_policy == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
    elif opt.lr_policy == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
    elif opt.lr_policy == 'cosine':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
    else:
        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
    return scheduler


def init_net(net, gpu_ids=[]):
    """Initialize a network by registering CPU/GPU device (with multi-GPU support)
    Parameters:
        net (network)      -- the network to be initialized
        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

    Return an initialized network.
    """
    if len(gpu_ids) > 0:
        assert (torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs
    return net


================================================
FILE: third_party/util/__init__.py
================================================
"""This package includes a miscellaneous collection of useful helper functions."""


================================================
FILE: third_party/util/html.py
================================================
import dominate
from dominate.tags import meta, h3, table, tr, td, p, a, img, br, video, source
import os


class HTML:
    """This HTML class allows us to save images and write texts into a single HTML file.

     It consists of functions such as <add_header> (add a text header to the HTML file),
     <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
     It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
    """

    def __init__(self, web_dir, title, refresh=0):
        """Initialize the HTML classes

        Parameters:
            web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
            title (str)   -- the webpage name
            refresh (int) -- how often the website refresh itself; if 0; no refreshing
        """
        self.title = title
        self.web_dir = web_dir
        self.img_dir = os.path.join(self.web_dir, 'images')
        self.vid_dir = os.path.join(self.web_dir, 'videos')
        if not os.path.exists(self.web_dir):
            os.makedirs(self.web_dir)
        if not os.path.exists(self.img_dir):
            os.makedirs(self.img_dir)
        if not os.path.exists(self.vid_dir):
            os.makedirs(self.vid_dir)

        self.doc = dominate.document(title=title)
        if refresh > 0:
            with self.doc.head:
                meta(http_equiv="refresh", content=str(refresh))

    def get_image_dir(self):
        """Return the directory that stores images"""
        return self.img_dir

    def get_video_dir(self):
        """Return the directory that stores videos"""
        return self.vid_dir

    def add_header(self, text):
        """Insert a header to the HTML file

        Parameters:
            text (str) -- the header text
        """
        with self.doc:
            h3(text)

    def add_images(self, ims, txts, links, width=400):
        """add images to the HTML file

        Parameters:
            ims (str list)   -- a list of image paths
            txts (str list)  -- a list of image names shown on the website
            links (str list) --  a list of hyperref links; when you click an image, it will redirect you to a new page
        """
        self.t = table(border=1, style="table-layout: fixed;")  # Insert a table
        self.doc.add(self.t)
        with self.t:
            with tr():
                for im, txt, link in zip(ims, txts, links):
                    with td(style="word-wrap: break-word;", halign="center", valign="top"):
                        with p():
                            with a(href=os.path.join('images', link)):
                                img(style="width:%dpx" % width, src=os.path.join('images', im))
                            br()
                            p(txt)

    def add_videos(self, vids, txts, links, width=400):
        """add images to the HTML file

        Parameters:
            ims (str list)   -- a list of image paths
            txts (str list)  -- a list of image names shown on the website
            links (str list) --  a list of hyperref links; when you click an image, it will redirect you to a new page
        """
        self.t = table(border=1, style="table-layout: fixed;")  # Insert a table
        self.doc.add(self.t)
        with self.t:
            with tr():
                for vid, txt, link in zip(vids, txts, links):
                    with td(style="word-wrap: break-word;", halign="center", valign="top"):
                        with p():
                            with a(href=os.path.join('videos', link)):
                                with video(style="width:%dpx" % width, controls=True):
                                    source(src=os.path.join('videos', vid), type="video/mp4")
                            br()
                            p(txt)

    def save(self):
        """save the current content to the HMTL file"""
        html_file = '%s/index.html' % self.web_dir
        f = open(html_file, 'wt')
        f.write(self.doc.render())
        f.close()


if __name__ == '__main__':  # we show an example usage here.
    html = HTML('web/', 'test_html')
    html.add_header('hello world')

    ims, txts, links = [], [], []
    for n in range(4):
        ims.append('image_%d.png' % n)
        txts.append('text_%d' % n)
        links.append('image_%d.png' % n)
    html.add_images(ims, txts, links)
    html.save()


================================================
FILE: third_party/util/util.py
================================================
"""This module contains simple helper functions """
from __future__ import print_function
import torch
import numpy as np
from PIL import Image
import os


def tensor2im(input_image, imtype=np.uint8):
    """"Converts a Tensor array into a numpy image array.

    Parameters:
        input_image (tensor) --  the input image tensor array
        imtype (type)        --  the desired type of the converted numpy array
    """
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor):  # get the data from a variable
            image_tensor = input_image.data
        else:
            return input_image
        image_numpy = image_tensor[0].cpu().float().numpy()  # convert it into a numpy array
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)


def render_png(image, background='checker'):
    height, width = image.shape[:2]
    if background == 'checker':
        checkerboard = np.kron([[136, 120] * (width//128+1), [120, 136] * (width//128+1)] * (height//128+1), np.ones((16, 16)))
        checkerboard = np.expand_dims(np.tile(checkerboard, (4, 4)), -1)
        bg = checkerboard[:height, :width]
    elif background == 'black':
        bg = np.zeros([height, width, 1])
    else:
        bg = 255 * np.ones([height, width, 1])
    image = image.astype(np.float32)
    alpha = image[:, :, 3:] / 255
    rendered_image = alpha * image[:, :, :3] + (1 - alpha) * bg
    return rendered_image.astype(np.uint8)


def diagnose_network(net, name='network'):
    """Calculate and print the mean of average absolute(gradients)

    Parameters:
        net (torch network) -- Torch network
        name (str) -- the name of the network
    """
    mean = 0.0
    count = 0
    for param in net.parameters():
        if param.grad is not None:
            mean += torch.mean(torch.abs(param.grad.data))
            count += 1
    if count > 0:
        mean = mean / count
    print(name)
    print(mean)


def save_image(image_numpy, image_path, aspect_ratio=1.0):
    """Save a numpy image to the disk

    Parameters:
        image_numpy (numpy array) -- input numpy array
        image_path (str)          -- the path of the image
    """

    image_pil = Image.fromarray(image_numpy)
    h, w, _ = image_numpy.shape

    if aspect_ratio > 1.0:
        image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
    if aspect_ratio < 1.0:
        image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
    image_pil.save(image_path)


def print_numpy(x, val=True, shp=False):
    """Print the mean, min, max, median, std, and size of a numpy array

    Parameters:
        val (bool) -- if print the values of the numpy array
        shp (bool) -- if print the shape of the numpy array
    """
    x = x.astype(np.float64)
    if shp:
        print('shape,', x.shape)
    if val:
        x = x.flatten()
        print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
            np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))


def mkdirs(paths):
    """create empty directories if they don't exist

    Parameters:
        paths (str list) -- a list of directory paths
    """
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            mkdir(path)
    else:
        mkdir(paths)


def mkdir(path):
    """create a single empty directory if it didn't exist

    Parameters:
        path (str) -- a single directory path
    """
    if not os.path.exists(path):
        os.makedirs(path)


================================================
FILE: third_party/util/visualizer.py
================================================
import cv2
import numpy as np
import os
import sys
import ntpath
import time
from . import util, html
from subprocess import Popen, PIPE


if sys.version_info[0] == 2:
    VisdomExceptionBase = Exception
else:
    VisdomExceptionBase = ConnectionError


def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
    """Save images to the disk.

    Parameters:
        webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
        visuals (OrderedDict)    -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
        image_path (str)         -- the string is used to create image paths
        aspect_ratio (float)     -- the aspect ratio of saved images
        width (int)              -- the images will be resized to width x width

    This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
    """
    image_dir = webpage.get_image_dir()
    short_path = ntpath.basename(image_path[0])
    name = os.path.splitext(short_path)[0]

    webpage.add_header(name)
    ims, txts, links = [], [], []

    for label, im_data in visuals.items():
        im = util.tensor2im(im_data)
        image_name = '%s_%s.png' % (name, label)
        save_path = os.path.join(image_dir, image_name)
        util.save_image(im, save_path, aspect_ratio=aspect_ratio)
        ims.append(image_name)
        txts.append(label)
        links.append(image_name)
    webpage.add_images(ims, txts, links, width=width)


def save_videos(webpage, visuals, width=256):
    """Save videos to the disk.

    Parameters:
        webpage (the HTML class) -- the HTML webpage class that stores these videos (see html.py for more details)
        visuals (OrderedDict)    -- an ordered dictionary that stores (name, video (either tensor or numpy) ) pairs
        save_dir (str)           -- the string is used to create video paths
        aspect_ratio (float)     -- the aspect ratio of saved images
        width (int)              -- the images will be resized to width x width

    This function will save videos stored in 'visuals' to the HTML file specified by 'webpage'.
    """
    video_dir = webpage.get_video_dir()
    webpage.add_header('videos')
    vids, txts, links = [], [], []

    for label, vid_data in sorted(visuals.items()):
        video_name = f'{label}.webm'
        video_path = os.path.join(video_dir, video_name)
        frame_height, frame_width = vid_data.shape[-2:]
        video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'vp80'), 25, (frame_width, frame_height))
        for i in range(vid_data.shape[0]):
            frame = util.tensor2im(vid_data[i:i+1])
            if frame.shape[-1] == 4:
                # render png
                frame = util.render_png(frame, background='checker')
            frame = frame[:, :, ::-1]  # RGB -> BGR
            video.write(frame)
        video.release()
        cv2.destroyAllWindows()
        print("You may see an OpenCV 'vp80 not supported' error message despite the video saving correctly. Please ignore it.")
        vids.append(video_name)
        txts.append(label)
        links.append(video_name)
    webpage.add_videos(vids, txts, links, width=width)


class Visualizer():
    """This class includes several functions that can display/save images and print/save logging information.

    It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
    """

    def __init__(self, opt):
        """Initialize the Visualizer class

        Parameters:
            opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
        Step 1: Cache the training/test options
        Step 2: connect to a visdom server
        Step 3: create an HTML object for saveing HTML filters
        Step 4: create a logging file to store training losses
        """
        self.opt = opt  # cache the option
        self.display_id = opt.display_id
        self.use_html = opt.isTrain and not opt.no_html
        self.win_size = opt.display_winsize
        self.name = opt.name
        self.port = opt.display_port
        self.saved = False
        if self.display_id > 0:  # connect to a visdom server given <display_port> and <display_server>
            import visdom
            self.ncols = opt.display_ncols
            self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
            if not self.vis.check_connection():
                self.create_visdom_connections()

        if self.use_html:  # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
            self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
            self.img_dir = os.path.join(self.web_dir, 'images')
            print('create web directory %s...' % self.web_dir)
            util.mkdirs([self.web_dir, self.img_dir])
        # create a logging file to store training losses
        self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
        with open(self.log_name, "a") as log_file:
            now = time.strftime("%c")
            log_file.write('================ Training Loss (%s) ================\n' % now)

    def reset(self):
        """Reset the self.saved status"""
        self.saved = False

    def create_visdom_connections(self):
        """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
        cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
        print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
        print('Command: %s' % cmd)
        Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)

    def display_current_results(self, visuals, epoch, save_result):
        """Display current results on visdom; save current results to an HTML file.

        Parameters:
            visuals (OrderedDict) - - dictionary of images to display or save
            epoch (int) - - the current epoch
            save_result (bool) - - if save the current results to an HTML file
        """
        if self.display_id > 0:  # show images in the browser using visdom
            ncols = self.ncols
            if ncols > 0:        # show all the images in one visdom panel
                ncols = min(ncols, len(visuals))
                h, w = next(iter(visuals.values())).shape[:2]
                table_css = """<style>
                        table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
                        table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
                        </style>""" % (w, h)  # create a table css
                # create a table of images.
                title = self.name
                label_html = ''
                label_html_row = ''
                images = []
                idx = 0
                for label, image in visuals.items():
                    image_numpy = util.tensor2im(image)
                    label_html_row += '<td>%s</td>' % label
                    images.append(image_numpy.transpose([2, 0, 1]))
                    idx += 1
                    if idx % ncols == 0:
                        label_html += '<tr>%s</tr>' % label_html_row
                        label_html_row = ''
                white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
                while idx % ncols != 0:
                    images.append(white_image)
                    label_html_row += '<td></td>'
                    idx += 1
                if label_html_row != '':
                    label_html += '<tr>%s</tr>' % label_html_row
                try:
                    self.vis.images(images, nrow=ncols, win=self.display_id + 1,
                                    padding=2, opts=dict(title=title + ' images'))
                    label_html = '<table>%s</table>' % label_html
                    self.vis.text(table_css + label_html, win=self.display_id + 2,
                                  opts=dict(title=title + ' labels'))
                except VisdomExceptionBase:
                    self.create_visdom_connections()

            else:     # show each image in a separate visdom panel;
                idx = 1
                try:
                    for label, image in visuals.items():
                        image_numpy = util.tensor2im(image)
                        self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
                                       win=self.display_id + idx)
                        idx += 1
                except VisdomExceptionBase:
                    self.create_visdom_connections()

        if self.use_html and (save_result or not self.saved):  # save images to an HTML file if they haven't been saved.
            self.saved = True
            # save images to the disk
            for label, image in visuals.items():
                image_numpy = util.tensor2im(image)
                img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
                util.save_image(image_numpy, img_path)

            # update website
            webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
            for n in range(epoch, 0, -1):
                label = list(visuals.keys())[0]
                img_path = 'epoch%.3d_%s.png' % (n, label)
                if not os.path.exists(os.path.join(webpage.img_dir, img_path)):
                    continue
                webpage.add_header('epoch [%d]' % n)
                ims, txts, links = [], [], []

                for label, image_numpy in visuals.items():
                    img_path = 'epoch%.3d_%s.png' % (n, label)
                    ims.append(img_path)
                    txts.append(label)
                    links.append(img_path)
                webpage.add_images(ims, txts, links, width=self.win_size)
            webpage.save()

    def plot_current_losses(self, epoch, counter_ratio, losses):
        """display the current losses on visdom display: dictionary of error labels and values

        Parameters:
            epoch (int)           -- current epoch
            counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
            losses (OrderedDict)  -- training losses stored in the format of (name, float) pairs
        """
        if not hasattr(self, 'plot_data'):
            self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
        self.plot_data['X'].append(epoch + counter_ratio)
        self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
        try:
            self.vis.line(
                X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
                Y=np.array(self.plot_data['Y']),
                opts={
                    'title': self.name + ' loss over time',
                    'legend': self.plot_data['legend'],
                    'xlabel': 'epoch',
                    'ylabel': 'loss'},
                win=self.display_id)
        except VisdomExceptionBase:
            self.create_visdom_connections()

    # losses: same format as |losses| of plot_current_losses
    def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
        """print current losses on console; also save the losses to the disk

        Parameters:
            epoch (int) -- current epoch
            iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
            losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
            t_comp (float) -- computational time per data point (normalized by batch_size)
            t_data (float) -- data loading time per data point (normalized by batch_size)
        """
        message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
        for k, v in losses.items():
            message += '%s: %.3f ' % (k, v)

        print(message)  # print the message
        with open(self.log_name, "a") as log_file:
            log_file.write('%s\n' % message)  # save the message


================================================
FILE: train.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Script for training a layered neural renderer on a video.

You need to specify the dataset ('--dataroot') and experiment name ('--name').

Example:
    python train.py --dataroot ./datasets/reflection --name reflection --gpu_ids 0,1

The script first creates a model, dataset, and visualizer given the options.
It then does standard network training. During training, it also visualizes/saves the images, prints/saves the loss
plot, and saves the model.
Use '--continue_train' to resume your previous training.

The default setting is to first train the base model, which produces the low-resolution result (256x448), and then
train the upsampling module to produce the 512x896 result. If the upsampling module is unnecessary, use
'--n_epochs_upsample 0'.

See options/base_options.py and options/train_options.py for more training options.
"""
import time
from options.train_options import TrainOptions
from third_party.data import create_dataset
from third_party.models import create_model
from third_party.util.visualizer import Visualizer
import torch
import numpy as np


def main():
    trainopt = TrainOptions()
    trainopt.parse()
    opt = trainopt.parse_dataset_meta()

    torch.manual_seed(opt.seed)
    np.random.seed(opt.seed)

    opt.do_upsampling = False  # Train low-res network first
    dataset = create_dataset(opt, use_fast_loader=True)
    dataset_size = len(dataset)
    print('The number of training images = %d' % dataset_size)

    model = create_model(opt)
    model.setup(opt)  # regular setup: load and print networks; create schedulers
    visualizer = Visualizer(opt)

    # Train base model (produces low-resolution output)
    train(model, dataset, visualizer, opt)

    # Optionally train upsampling module
    if opt.n_epochs_upsample > 0:
        opt.do_upsampling = True
        opt.batch_size = opt.batch_size_upsample
        # load dataset for upsampling
        dataset = create_dataset(opt, use_fast_loader=True)
        dataset_size = len(dataset)
        print('The number of training images = %d' % dataset_size)

        # set lambdas for upsampling training
        opt.lambda_mask = 0
        opt.lambda_alpha_l0 = 0
        opt.lambda_alpha_l1 = 0
        opt.mask_loss_rolloff_epoch = -1
        opt.jitter_rgb = 0

        # reinit optimizers and schedulers, lambdas
        model.setup_train(opt)
        # freeze base model and just train upsampling module
        model.freeze_basenet()
        model.setup(opt)

        # update epoch count to resume training
        opt.epoch_count = opt.n_epochs + opt.n_epochs_decay + 1
        opt.n_epochs += opt.n_epochs_upsample
        
        train(model, dataset, visualizer, opt)


def train(model, dataset, visualizer, opt):
    dataset_size = len(dataset)
    total_iters = 0  # the total number of training iterations

    for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()    # timer for data loading per iteration
        epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch
        model.update_lambdas(epoch)
        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time()  # timer for computation per iteration
            if i % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(data)
            model.optimize_parameters()

            if i % opt.print_freq == 0:  # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)

            iter_data_time = time.time()

        if epoch % opt.display_freq == 1:   # display images on visdom and save images to a HTML file
            save_result = epoch % opt.update_html_freq == 1
            model.compute_visuals()
            visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

        if epoch % opt.save_latest_freq == 0:   # cache our latest model every <save_latest_freq> epochs
            print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
            save_suffix = 'epoch_%d' % epoch if opt.save_by_epoch else 'latest'
            model.save_networks(save_suffix)

        model.update_learning_rate()    # update learning rates at the end of every epoch.
        print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))


if __name__ == '__main__':
    main()
Download .txt
gitextract_wurthk5z/

├── LICENSE
├── README.md
├── data/
│   ├── __init__.py
│   ├── kpuv_dataset.py
│   └── layered_video_dataset.py
├── datasets/
│   ├── download_data.sh
│   ├── iuv_crop2full.py
│   └── prepare_iuv.sh
├── docs/
│   ├── contributing.md
│   └── data.md
├── environment.yml
├── models/
│   ├── __init__.py
│   ├── kp2uv_model.py
│   ├── lnr_model.py
│   └── networks.py
├── options/
│   ├── __init__.py
│   ├── base_options.py
│   ├── test_options.py
│   └── train_options.py
├── requirements.txt
├── run_kp2uv.py
├── scripts/
│   ├── download_kp2uv_model.sh
│   ├── run_cartwheel.sh
│   ├── run_reflection.sh
│   ├── run_splash.sh
│   └── run_trampoline.sh
├── test.py
├── third_party/
│   ├── __init__.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── base_dataset.py
│   │   ├── fast_data_loader.py
│   │   └── image_folder.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── base_model.py
│   │   └── networks.py
│   └── util/
│       ├── __init__.py
│       ├── html.py
│       ├── util.py
│       └── visualizer.py
└── train.py
Download .txt
SYMBOL INDEX (158 symbols across 20 files)

FILE: data/kpuv_dataset.py
  class KpuvDataset (line 26) | class KpuvDataset(BaseDataset):
    method modify_commandline_options (line 32) | def modify_commandline_options(parser, is_train):
    method __init__ (line 36) | def __init__(self, opt):
    method __getitem__ (line 67) | def __getitem__(self, index):
    method __len__ (line 89) | def __len__(self):
    method crop_kps (line 93) | def crop_kps(self, kps, crop_size=256, inner_crop_size=192):
    method render_kps (line 124) | def render_kps(self, keypoints, draw, thresh=1., min_weight=0.25):

FILE: data/layered_video_dataset.py
  class LayeredVideoDataset (line 27) | class LayeredVideoDataset(BaseDataset):
    method modify_commandline_options (line 34) | def modify_commandline_options(parser, is_train):
    method __init__ (line 42) | def __init__(self, opt):
    method __getitem__ (line 75) | def __getitem__(self, index):
    method __len__ (line 124) | def __len__(self):
    method get_params (line 128) | def get_params(self, do_jitter=False, jitter_rate=0.75):
    method apply_transform (line 146) | def apply_transform(self, data, params, interp_mode='bilinear'):
    method init_homographies (line 155) | def init_homographies(self, homography_path, n_images):
    method load_and_process_image (line 171) | def load_and_process_image(self, im_path):
    method load_and_process_iuv (line 178) | def load_and_process_iuv(self, iuv_path, i):
    method iuv2input (line 185) | def iuv2input(self, iuv, index):
    method get_background_inputs (line 234) | def get_background_inputs(self, index, w, h):
    method get_background_uv (line 246) | def get_background_uv(self, index, w, h):
    method transform2h (line 268) | def transform2h(self, x, y, m):
    method mask2trimap (line 275) | def mask2trimap(self, mask):

FILE: datasets/iuv_crop2full.py
  function place_crop (line 23) | def place_crop(crop, image, center_x, center_y):
  function crop2full (line 48) | def crop2full(keypoints_path, metadata_path, uvdir, outdir):

FILE: models/kp2uv_model.py
  class Kp2uvModel (line 20) | class Kp2uvModel(BaseModel):
    method modify_commandline_options (line 23) | def modify_commandline_options(parser, is_train=True):
    method __init__ (line 27) | def __init__(self, opt):
    method set_input (line 41) | def set_input(self, input):
    method forward (line 50) | def forward(self):
    method output2rgb (line 55) | def output2rgb(self, output):
    method optimize_parameters (line 70) | def optimize_parameters(self):

FILE: models/lnr_model.py
  class LnrModel (line 22) | class LnrModel(BaseModel):
    method modify_commandline_options (line 25) | def modify_commandline_options(parser, is_train=True):
    method __init__ (line 46) | def __init__(self, opt):
    method setup_train (line 63) | def setup_train(self, opt):
    method set_input (line 83) | def set_input(self, input):
    method gen_crop_params (line 99) | def gen_crop_params(self, orig_h, orig_w, crop_size=256):
    method forward (line 107) | def forward(self):
    method backward (line 133) | def backward(self):
    method optimize_parameters (line 147) | def optimize_parameters(self):
    method update_lambdas (line 154) | def update_lambdas(self, epoch):
    method transfer_detail (line 168) | def transfer_detail(self):
    method get_results (line 181) | def get_results(self):
    method freeze_basenet (line 201) | def freeze_basenet(self):

FILE: models/networks.py
  function define_LNR (line 25) | def define_LNR(nf=64, texture_channels=16, texture_res=16, n_textures=25...
  function define_kp2uv (line 41) | def define_kp2uv(nf=64, gpu_ids=[]):
  function cal_alpha_reg (line 53) | def cal_alpha_reg(prediction, lambda_alpha_l1, lambda_alpha_l0):
  class MaskLoss (line 77) | class MaskLoss(nn.Module):
    method __init__ (line 80) | def __init__(self):
    method __call__ (line 84) | def __call__(self, prediction, target):
  class ConvBlock (line 104) | class ConvBlock(nn.Module):
    method __init__ (line 107) | def __init__(self, conv, in_channels, out_channels, ksize=4, stride=1,...
    method forward (line 140) | def forward(self, x):
  class ResBlock (line 169) | class ResBlock(nn.Module):
    method __init__ (line 172) | def __init__(self, channels, ksize=4, stride=1, dil=1, norm=None, acti...
    method forward (line 180) | def forward(self, x):
  class kp2uv (line 188) | class kp2uv(nn.Module):
    method __init__ (line 194) | def __init__(self, nf=64):
    method forward (line 217) | def forward(self, x):
  class LayeredNeuralRenderer (line 240) | class LayeredNeuralRenderer(nn.Module):
    method __init__ (line 246) | def __init__(self, nf=64, texture_channels=16, texture_res=16, n_textu...
    method render (line 285) | def render(self, x):
    method forward (line 304) | def forward(self, uv_map, id_layers, uv_map_upsampled=None, crop_param...

FILE: options/base_options.py
  class BaseOptions (line 24) | class BaseOptions():
    method __init__ (line 31) | def __init__(self):
    method initialize (line 35) | def initialize(self, parser):
    method gather_options (line 60) | def gather_options(self):
    method print_options (line 88) | def print_options(self, opt):
    method parse (line 113) | def parse(self):
    method parse_dataset_meta (line 138) | def parse_dataset_meta(self):

FILE: options/test_options.py
  class TestOptions (line 18) | class TestOptions(BaseOptions):
    method initialize (line 24) | def initialize(self, parser):

FILE: options/train_options.py
  class TrainOptions (line 18) | class TrainOptions(BaseOptions):
    method initialize (line 24) | def initialize(self, parser):

FILE: third_party/data/__init__.py
  function find_dataset_using_name (line 19) | def find_dataset_using_name(dataset_name):
  function get_option_setter (line 42) | def get_option_setter(dataset_name):
  function create_dataset (line 48) | def create_dataset(opt, use_fast_loader=False):
  class CustomDatasetDataLoader (line 65) | class CustomDatasetDataLoader():
    method __init__ (line 68) | def __init__(self, opt, use_fast_loader=False):
    method load_data (line 89) | def load_data(self):
    method __len__ (line 92) | def __len__(self):
    method __iter__ (line 96) | def __iter__(self):

FILE: third_party/data/base_dataset.py
  class BaseDataset (line 13) | class BaseDataset(data.Dataset, ABC):
    method __init__ (line 23) | def __init__(self, opt):
    method modify_commandline_options (line 33) | def modify_commandline_options(parser, is_train):
    method __len__ (line 46) | def __len__(self):
    method __getitem__ (line 51) | def __getitem__(self, index):
  function get_params (line 63) | def get_params(opt, size):
  function get_transform (line 81) | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBI...
  function __make_power_2 (line 115) | def __make_power_2(img, base, method=Image.BICUBIC):
  function __scale_width (line 126) | def __scale_width(img, target_size, crop_size, method=Image.BICUBIC):
  function __crop (line 135) | def __crop(img, pos, size):
  function __flip (line 144) | def __flip(img, flip):
  function __print_size_warning (line 150) | def __print_size_warning(ow, oh, w, h):

FILE: third_party/data/fast_data_loader.py
  class _RepeatSampler (line 8) | class _RepeatSampler(object):
    method __init__ (line 15) | def __init__(self, sampler):
    method __iter__ (line 18) | def __iter__(self):
  class FastDataLoader (line 23) | class FastDataLoader(torch.utils.data.dataloader.DataLoader):
    method __init__ (line 25) | def __init__(self, *args, **kwargs):
    method __len__ (line 30) | def __len__(self):
    method __iter__ (line 33) | def __iter__(self):

FILE: third_party/data/image_folder.py
  function is_image_file (line 19) | def is_image_file(filename):
  function make_dataset (line 23) | def make_dataset(dir, max_dataset_size=float("inf")):
  function default_loader (line 36) | def default_loader(path):
  class ImageFolder (line 40) | class ImageFolder(data.Dataset):
    method __init__ (line 42) | def __init__(self, root, transform=None, return_paths=False,
    method __getitem__ (line 55) | def __getitem__(self, index):
    method __len__ (line 65) | def __len__(self):

FILE: third_party/models/__init__.py
  function find_model_using_name (line 25) | def find_model_using_name(model_name):
  function get_option_setter (line 48) | def get_option_setter(model_name):
  function create_model (line 54) | def create_model(opt):

FILE: third_party/models/base_model.py
  class BaseModel (line 8) | class BaseModel(ABC):
    method __init__ (line 18) | def __init__(self, opt):
    method modify_commandline_options (line 45) | def modify_commandline_options(parser, is_train):
    method set_input (line 58) | def set_input(self, input):
    method forward (line 67) | def forward(self):
    method optimize_parameters (line 72) | def optimize_parameters(self):
    method setup (line 76) | def setup(self, opt):
    method eval (line 89) | def eval(self):
    method test (line 96) | def test(self):
    method compute_visuals (line 106) | def compute_visuals(self):
    method get_image_paths (line 110) | def get_image_paths(self):
    method update_learning_rate (line 114) | def update_learning_rate(self):
    method get_current_visuals (line 127) | def get_current_visuals(self):
    method get_current_losses (line 135) | def get_current_losses(self):
    method save_networks (line 143) | def save_networks(self, epoch):
    method __patch_instance_norm_state_dict (line 161) | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i...
    method load_networks (line 175) | def load_networks(self, epoch):
    method print_networks (line 200) | def print_networks(self, verbose):
    method set_requires_grad (line 222) | def set_requires_grad(self, nets, requires_grad=False):

FILE: third_party/models/networks.py
  function get_scheduler (line 9) | def get_scheduler(optimizer, opt):
  function init_net (line 39) | def init_net(net, gpu_ids=[]):

FILE: third_party/util/html.py
  class HTML (line 6) | class HTML:
    method __init__ (line 14) | def __init__(self, web_dir, title, refresh=0):
    method get_image_dir (line 38) | def get_image_dir(self):
    method get_video_dir (line 42) | def get_video_dir(self):
    method add_header (line 46) | def add_header(self, text):
    method add_images (line 55) | def add_images(self, ims, txts, links, width=400):
    method add_videos (line 75) | def add_videos(self, vids, txts, links, width=400):
    method save (line 96) | def save(self):

FILE: third_party/util/util.py
  function tensor2im (line 9) | def tensor2im(input_image, imtype=np.uint8):
  function render_png (line 30) | def render_png(image, background='checker'):
  function diagnose_network (line 46) | def diagnose_network(net, name='network'):
  function save_image (line 65) | def save_image(image_numpy, image_path, aspect_ratio=1.0):
  function print_numpy (line 83) | def print_numpy(x, val=True, shp=False):
  function mkdirs (line 99) | def mkdirs(paths):
  function mkdir (line 112) | def mkdir(path):

FILE: third_party/util/visualizer.py
  function save_images (line 17) | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
  function save_videos (line 47) | def save_videos(webpage, visuals, width=256):
  class Visualizer (line 84) | class Visualizer():
    method __init__ (line 90) | def __init__(self, opt):
    method reset (line 125) | def reset(self):
    method create_visdom_connections (line 129) | def create_visdom_connections(self):
    method display_current_results (line 136) | def display_current_results(self, visuals, epoch, save_result):
    method plot_current_losses (line 220) | def plot_current_losses(self, epoch, counter_ratio, losses):
    method print_current_losses (line 246) | def print_current_losses(self, epoch, iters, losses, t_comp, t_data):

FILE: train.py
  function main (line 42) | def main():
  function train (line 91) | def train(model, dataset, visualizer, opt):
Condensed preview — 40 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (160K chars).
[
  {
    "path": "LICENSE",
    "chars": 11358,
    "preview": "\n                                 Apache License\n                           Version 2.0, January 2004\n                  "
  },
  {
    "path": "README.md",
    "chars": 4658,
    "preview": "# Layered Neural Rendering in PyTorch\n\nThis repository contains training code for the examples in the SIGGRAPH Asia 2020"
  },
  {
    "path": "data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "data/kpuv_dataset.py",
    "chars": 6995,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "data/layered_video_dataset.py",
    "chars": 12816,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "datasets/download_data.sh",
    "chars": 1233,
    "preview": "#!/bin/bash\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may "
  },
  {
    "path": "datasets/iuv_crop2full.py",
    "chars": 5698,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "datasets/prepare_iuv.sh",
    "chars": 831,
    "preview": "#!/bin/bash\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may "
  },
  {
    "path": "docs/contributing.md",
    "chars": 1102,
    "preview": "# How to Contribute\n\nWe'd love to accept your patches and contributions to this project. There are\njust a few small guid"
  },
  {
    "path": "docs/data.md",
    "chars": 923,
    "preview": "### Data\nThe data directory for a video is structured as follows:\n```\nvideo_name/\n|-- rgb_256/\n|   |-- 0001.png, etc.\n|-"
  },
  {
    "path": "environment.yml",
    "chars": 238,
    "preview": "name: retiming\nchannels:\n  - pytorch\n  - defaults\ndependencies:\n- python=3.8\n- pytorch=1.4.0\n- pip:\n  - dominate==2.4.0\n"
  },
  {
    "path": "models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "models/kp2uv_model.py",
    "chars": 2722,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "models/lnr_model.py",
    "chars": 10862,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "models/networks.py",
    "chars": 16657,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "options/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "options/base_options.py",
    "chars": 6888,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "options/test_options.py",
    "chars": 1342,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "options/train_options.py",
    "chars": 3389,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "requirements.txt",
    "chars": 101,
    "preview": "torch==1.4.0\ntorchvision>=0.5.0\ndominate>=2.4.0\nvisdom>=0.1.8\nmatplotlib>=3.2.1\nopencv-python>=4.2.0\n"
  },
  {
    "path": "run_kp2uv.py",
    "chars": 2867,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "scripts/download_kp2uv_model.sh",
    "chars": 775,
    "preview": "#!/bin/bash\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may "
  },
  {
    "path": "scripts/run_cartwheel.sh",
    "chars": 880,
    "preview": "#!/bin/bash\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may "
  },
  {
    "path": "scripts/run_reflection.sh",
    "chars": 837,
    "preview": "#!/bin/bash\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may "
  },
  {
    "path": "scripts/run_splash.sh",
    "chars": 896,
    "preview": "#!/bin/bash\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may "
  },
  {
    "path": "scripts/run_trampoline.sh",
    "chars": 885,
    "preview": "#!/bin/bash\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may "
  },
  {
    "path": "test.py",
    "chars": 3883,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "third_party/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "third_party/data/__init__.py",
    "chars": 3968,
    "preview": "\"\"\"This package includes all the modules related to data loading and preprocessing\n\n To add a custom dataset class calle"
  },
  {
    "path": "third_party/data/base_dataset.py",
    "chars": 5400,
    "preview": "\"\"\"This module implements an abstract base class (ABC) 'BaseDataset' for datasets.\n\nIt also includes common transformati"
  },
  {
    "path": "third_party/data/fast_data_loader.py",
    "chars": 903,
    "preview": "\"\"\" Fixes the issue where DataLoader is slow because processes aren't reused\nSee https://github.com/pytorch/pytorch/issu"
  },
  {
    "path": "third_party/data/image_folder.py",
    "chars": 1913,
    "preview": "\"\"\"A modified image folder class\n\nWe modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/ma"
  },
  {
    "path": "third_party/models/__init__.py",
    "chars": 3066,
    "preview": "\"\"\"This package contains modules related to objective functions, optimizations, and network architectures.\n\nTo add a cus"
  },
  {
    "path": "third_party/models/base_model.py",
    "chars": 10446,
    "preview": "import os\nimport torch\nfrom collections import OrderedDict\nfrom abc import ABC, abstractmethod\nfrom . import networks\n\n\n"
  },
  {
    "path": "third_party/models/networks.py",
    "chars": 2250,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.optim import lr_scheduler\n\n\n##############################################"
  },
  {
    "path": "third_party/util/__init__.py",
    "chars": 83,
    "preview": "\"\"\"This package includes a miscellaneous collection of useful helper functions.\"\"\"\n"
  },
  {
    "path": "third_party/util/html.py",
    "chars": 4532,
    "preview": "import dominate\nfrom dominate.tags import meta, h3, table, tr, td, p, a, img, br, video, source\nimport os\n\n\nclass HTML:\n"
  },
  {
    "path": "third_party/util/util.py",
    "chars": 3843,
    "preview": "\"\"\"This module contains simple helper functions \"\"\"\nfrom __future__ import print_function\nimport torch\nimport numpy as n"
  },
  {
    "path": "third_party/util/visualizer.py",
    "chars": 12355,
    "preview": "import cv2\nimport numpy as np\nimport os\nimport sys\nimport ntpath\nimport time\nfrom . import util, html\nfrom subprocess im"
  },
  {
    "path": "train.py",
    "chars": 5691,
    "preview": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  }
]

About this extraction

This page contains the full source code of the google/retiming GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 40 files (149.7 KB), approximately 37.2k tokens, and a symbol index with 158 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!