Repository: google/retiming
Branch: main
Commit: 0e7dbce941b8
Files: 40
Total size: 149.7 KB
Directory structure:
gitextract_wurthk5z/
├── LICENSE
├── README.md
├── data/
│ ├── __init__.py
│ ├── kpuv_dataset.py
│ └── layered_video_dataset.py
├── datasets/
│ ├── download_data.sh
│ ├── iuv_crop2full.py
│ └── prepare_iuv.sh
├── docs/
│ ├── contributing.md
│ └── data.md
├── environment.yml
├── models/
│ ├── __init__.py
│ ├── kp2uv_model.py
│ ├── lnr_model.py
│ └── networks.py
├── options/
│ ├── __init__.py
│ ├── base_options.py
│ ├── test_options.py
│ └── train_options.py
├── requirements.txt
├── run_kp2uv.py
├── scripts/
│ ├── download_kp2uv_model.sh
│ ├── run_cartwheel.sh
│ ├── run_reflection.sh
│ ├── run_splash.sh
│ └── run_trampoline.sh
├── test.py
├── third_party/
│ ├── __init__.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── base_dataset.py
│ │ ├── fast_data_loader.py
│ │ └── image_folder.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ └── networks.py
│ └── util/
│ ├── __init__.py
│ ├── html.py
│ ├── util.py
│ └── visualizer.py
└── train.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# Layered Neural Rendering in PyTorch
This repository contains training code for the examples in the SIGGRAPH Asia 2020 paper "[Layered Neural Rendering for Retiming People in Video](https://retiming.github.io/)."
This is not an officially supported Google product.
## Prerequisites
- Linux
- Python 3.6+
- NVIDIA GPU + CUDA CuDNN
## Installation
This code has been tested with PyTorch 1.4 and Python 3.8.
- Install [PyTorch](http://pytorch.org) 1.4 and other dependencies.
- For pip users, please type the command `pip install -r requirements.txt`.
- For Conda users, you can create a new Conda environment using `conda env create -f environment.yml`.
## Data Processing
- Download the data for a video used in our paper (e.g. "reflection"):
```bash
bash ./datasets/download_data.sh reflection
```
- Or alternatively, download all the data by specifying `all`.
- Download the pretrained keypoint-to-UV model weights:
```bash
bash ./scripts/download_kp2uv_model.sh
```
The pretrained model will be saved at `./checkpoints/kp2uv/latest_net_Kp2uv.pth`.
- Generate the UV maps from the keypoints:
```bash
bash datasets/prepare_iuv.sh ./datasets/reflection
```
## Training
- To train a model on a video (e.g. "reflection"), run:
```bash
python train.py --name reflection --dataroot ./datasets/reflection --gpu_ids 0,1
```
- To view training results and loss plots, visit the URL http://localhost:8097.
Intermediate results are also at `./checkpoints/reflection/web/index.html`.
You can find more scripts in the `scripts` directory, e.g. `run_${VIDEO}.sh` which combines data processing, training, and saving layer results for a video.
**Note**:
- It is recommended to use >=2 GPUs, each with >=16GB memory.
- The training script first trains the low-resolution model for `--num_epochs` at `--batch_size`, and then trains the upsampling module for `--num_epochs_upsample` at `--batch_size_upsample`.
If you do not need the upsampled result, pass `--num_epochs_upsample 0`.
- Training the upsampling module requires ~2.5x memory as the low-resolution model, so set `batch_size_upsample` accordingly.
The provided scripts set the batch sizes appropriately for 2 GPUs with 16GB memory.
- GPU memory scales linearly with the number of layers.
## Saving layer results from a trained model
- Run the trained model:
```bash
python test.py --name reflection --dataroot ./datasets/reflection --do_upsampling
```
- The results (RGBA layers, videos) will be saved to `./results/reflection/test_latest/`.
- Passing `--do_upsampling` uses the results of the upsampling module. If the upsampling module hasn't been trained (`num_epochs_upsample=0`), then remove this flag.
## Custom video
To train on your own video, you will have to preprocess the data:
1. Extract the frames, e.g.
```
mkdir ./datasets/my_video && cd ./datasets/my_video
mkdir rgb && ffmpeg -i video.mp4 rgb/%04d.png
```
1. Resize the video to 256x448 and save the frames in `my_video/rgb_256`, and resize the video to 512x896 and save in `my_video/rgb_512`.
1. Run [AlphaPose and Pose Tracking](https://github.com/MVIG-SJTU/AlphaPose) on the frames. Save results as `my_video/keypoints.json`
1. Create `my_video/metadata.json` following [these instructions](docs/data.md).
1. If your video has camera motion, either (1) stabilize the video, or (2) maintain the camera motion by computing homographies and saving as `my_video/homographies.txt`.
See `scripts/run_cartwheel.sh` for a training example with camera motion, and see `./datasets/cartwheel/homographies.txt` for formatting.
**Note**: Videos that are suitable for our method have the following attributes:
- Static camera or limited camera motion that can be represented with a homography.
- Limited number of people, due to GPU memory limitations. We tested up to 7 people and 7 layers.
Multiple people can be grouped onto the same layer, though they cannot be individually retimed.
- People that move relative to the background (static people will be absorbed into the background layer).
- We tested a video length of up to 200 frames (~7 seconds).
## Citation
If you use this code for your research, please cite the following paper:
```
@inproceedings{lu2020,
title={Layered Neural Rendering for Retiming People in Video},
author={Lu, Erika and Cole, Forrester and Dekel, Tali and Xie, Weidi and Zisserman, Andrew and Salesin, David and Freeman, William T and Rubinstein, Michael},
booktitle={SIGGRAPH Asia},
year={2020}
}
```
## Acknowledgments
This code is based on [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
================================================
FILE: data/__init__.py
================================================
================================================
FILE: data/kpuv_dataset.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from third_party.data.base_dataset import BaseDataset
from PIL import Image, ImageDraw
import json
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import os
import torchvision.transforms as transforms
class KpuvDataset(BaseDataset):
"""A dataset class for keypoint data.
It assumes that the directory specified by 'dataroot' contains the file 'keypoints.json'.
"""
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument('--inp_size', type=int, default=256, help='image size')
return parser
def __init__(self, opt):
"""Initialize this dataset class by reading in keypoints.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt)
self.inp_size = opt.inp_size
inner_crop_size = int(.75*self.inp_size)
kps = []
image_paths = []
with open(os.path.join(self.root, 'keypoints.json'), 'rb') as f:
kp_data = json.load(f)
for frame in sorted(kp_data):
for skeleton in kp_data[frame]:
id = skeleton['idx']
image_paths.append(f'{id:02d}_{frame}')
kp = np.array(skeleton['keypoints']).reshape(17, 3)
kp = self.crop_kps(kp, crop_size=self.inp_size, inner_crop_size=inner_crop_size)
kps.append(kp)
self.keypoints = kps
self.image_paths = image_paths # filenames for output UVs
# for keypoint rendering
self.cmap = plt.cm.get_cmap("hsv", 17)
self.color_seq = np.array([ 9, 14, 6, 7, 13, 16, 2, 11, 3, 5, 10, 15, 1, 8, 0, 12, 4])
self.pairs = [[0,1],[0,2],[1,3],[2,4],[5,6],[5,7],[7,9],[6,8],[8,10],[11,12],[11,13],[13,15],[12,14],[14,16],[6,12],[5,11]]
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns a dictionary that contains keypoints and path
keyoints (tensor) - - an RGB image representing a skeleton
path (str) - - an identifying filename that can be used for saving the result
"""
uv_path = self.image_paths[index] # output UV path
kps = self.keypoints[index]
# draw keypoints
kp_im = Image.new(size=(self.inp_size, self.inp_size), mode='RGB')
draw = ImageDraw.Draw(kp_im)
self.render_kps(kps, draw)
kp_im = transforms.ToTensor()(kp_im)
kp_im = 2 * kp_im - 1
return {'keypoints': kp_im, 'path': uv_path}
def __len__(self):
"""Return the total number of images."""
return len(self.image_paths)
def crop_kps(self, kps, crop_size=256, inner_crop_size=192):
"""'Crops' keypoints to fit into ['crop_size', 'crop_size'].
Parameters:
kps - - a numpy array of shape [17, 2], where the keypoint order is X, Y
crop_size - - the new size of the world, which the keypoints will be centered inside
inner_crop_size - - the box size that the keypoints will fit inside (must be <=crop_size)
Returns keypoints mapped to fit inside a box of 'inner_crop_size', centered in 'crop_size'
"""
# get coordinates of bounding box, in original image coordinates
left = kps[:, 0].min()
right = kps[:, 0].max()
top = kps[:, 1].min()
bottom = kps[:, 1].max()
# map keypoints
keypoints = kps.copy()
center = ((right + left) // 2, (bottom + top) // 2)
# first place center of bounding box at origin
keypoints[:, 0] -= center[0]
keypoints[:, 1] -= center[1]
# scale bounding box to inner_crop_size
scale = float(inner_crop_size) / max(right - left, bottom - top)
keypoints[:, :2] *= scale
# move center to crop_size//2
keypoints[:, :2] += crop_size // 2
new_kps = keypoints
return new_kps
def render_kps(self, keypoints, draw, thresh=1., min_weight=0.25):
"""Render skeleton as RGB image.
Parameters:
keypoints - - a numpy array of shape [17, 3], where the keypoint order is X, Y, score
draw - - an ImageDraw object, which the keypoints will be drawn onto
thresh - - keypoints with a confidence score below this value will have a color weighted by the score
min_weight - - minimum weighting for color (scores will be mapped to the range [min_weight, 1])
"""
# first draw keypoints
ksize = 3
for i in range(keypoints.shape[0]):
x1 = keypoints[i,0] - ksize
x2 = keypoints[i,0] + ksize
y1 = keypoints[i,1] - ksize
y2 = keypoints[i,1] + ksize
if x1 < 0 or y1 < 0 or x2 > self.inp_size or y2 > self.inp_size:
continue
color = np.array(self.cmap(self.color_seq[i]))
if keypoints.shape[1] > 2:
score = keypoints[i,2]
if score < thresh: # weight color by confidence score
# first map [0,1] -> [min_weight, 1]
alpha_weight = score * (1.-min_weight) + min_weight
color[:3] *= alpha_weight
color = (255*color).astype('uint8')
draw.rectangle([x1, y1, x2, y2], fill=tuple(color))
# now draw segments
for pair in self.pairs:
x1 = keypoints[pair[0],0]
y1 = keypoints[pair[0],1]
x2 = keypoints[pair[1],0]
y2 = keypoints[pair[1],1]
if x1 < 0 or y1 < 0 or x2 < 0 or y2 < 0 or x1 > self.inp_size or y1 > self.inp_size or x2 > self.inp_size or y2 > self.inp_size:
continue
avg_color = .5*(np.array(self.cmap(self.color_seq[pair[0]])) + np.array(self.cmap(self.color_seq[pair[1]])))
if keypoints.shape[1] > 2:
score = min(keypoints[pair[0],2], keypoints[pair[1],2])
if score < thresh:
alpha_weight = score * (1.-min_weight) + min_weight
avg_color[:3] *= alpha_weight # alpha channel weigh by score
avg_color = (255*avg_color).astype('uint8')
draw.line([x1,y1,x2,y2], fill=tuple(avg_color), width=3)
================================================
FILE: data/layered_video_dataset.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
from third_party.data.base_dataset import BaseDataset
from third_party.data.image_folder import make_dataset
from PIL import Image
import torchvision.transforms as transforms
import torch.nn.functional as F
import os
import torch
import numpy as np
import json
class LayeredVideoDataset(BaseDataset):
"""A dataset class for video layers.
It assumes that the directory specified by 'dataroot' contains metadata.json, and the directories iuv, rgb_256, and rgb_512.
The 'iuv' directory should contain directories named 01, 02, etc. for each layer, each containing per-frame UV images.
"""
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument('--height', type=int, default=256, help='image height')
parser.add_argument('--width', type=int, default=448, help='image width')
parser.add_argument('--trimap_width', type=int, default=20, help='trimap gray area width')
parser.add_argument('--use_mask_images', action='store_true', default=False, help='use custom masks')
parser.add_argument('--use_homographies', action='store_true', default=False, help='handle camera motion')
return parser
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt)
rgbdir = os.path.join(opt.dataroot, 'rgb_256')
if opt.do_upsampling:
rgbdir = os.path.join(opt.dataroot, 'rgb_512')
uvdir = os.path.join(opt.dataroot, 'iuv')
self.image_paths = sorted(make_dataset(rgbdir, opt.max_dataset_size))
n_images = len(self.image_paths)
layers = sorted(os.listdir(uvdir))
layers = [l for l in layers if l.isdigit()]
self.iuv_paths = []
for l in layers:
layer_iuv_paths = sorted(make_dataset(os.path.join(uvdir, l), n_images))
if len(layer_iuv_paths) != n_images:
print(f'UNEQUAL NUMBER OF IMAGES AND IUVs: {len(layer_iuv_paths)} and {n_images}')
self.iuv_paths.append(layer_iuv_paths)
# set up per-frame compositing order
with open(os.path.join(opt.dataroot, 'metadata.json')) as f:
metadata = json.load(f)
if 'composite_order' in metadata:
self.composite_order = metadata['composite_order']
else:
self.composite_order = [tuple(range(1, 1 + len(layers)))] * n_images
if opt.use_homographies:
self.init_homographies(os.path.join(opt.dataroot, 'homographies.txt'), n_images)
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns a dictionary that contains:
image (tensor) - - the original RGB frame to reconstruct
uv_map (tensor) - - the UV maps for all layers, concatenated channel-wise
mask (tensor) - - the trimaps for all layers, concatenated channel-wise
pids (tensor) - - the person IDs for all layers, concatenated channel-wise
image_path (str) - - image path
"""
# Read the target image.
image_path = self.image_paths[index]
target_image = self.load_and_process_image(image_path)
# Read the layer IUVs and convert to network inputs.
people_layers = [self.load_and_process_iuv(self.iuv_paths[l - 1][index], index) for l in
self.composite_order[index]]
iuv_h, iuv_w = people_layers[0][0].shape[-2:]
# Create the background layer UV from homographies.
background_layer = self.get_background_inputs(index, iuv_w, iuv_h)
uv_maps, masks, pids = zip(*([background_layer] + people_layers))
uv_maps = torch.cat(uv_maps) # [L*2, H, W]
masks = torch.stack(masks) # [L, H, W]
pids = torch.stack(pids) # [L, H, W]
if self.opt.use_mask_images:
for i in range(1, len(people_layers)):
mask_path = os.path.join(self.opt.dataroot, 'mask', f'{i:02d}', os.path.basename(image_path))
if os.path.exists(mask_path):
mask = Image.open(mask_path).convert('L').resize((masks.shape[-1], masks.shape[-2]))
mask = transforms.ToTensor()(mask) * 2 - 1
masks[i] = mask
transform_params = self.get_params(do_jitter=self.opt.phase=='train')
pids = self.apply_transform(pids, transform_params, 'nearest')
masks = self.apply_transform(masks, transform_params, 'bilinear')
uv_maps = self.apply_transform(uv_maps, transform_params, 'nearest')
image_transform_params = transform_params
if self.opt.do_upsampling:
image_transform_params = { p: transform_params[p] * 2 for p in transform_params}
target_image = self.apply_transform(target_image, image_transform_params, 'bilinear')
return {'image': target_image, 'uv_map': uv_maps, 'mask': masks, 'pids': pids, 'image_path': image_path}
def __len__(self):
"""Return the total number of images."""
return len(self.image_paths)
def get_params(self, do_jitter=False, jitter_rate=0.75):
"""Get transformation parameters."""
if do_jitter:
if np.random.uniform() > jitter_rate or self.opt.do_upsampling:
scale = 1.
else:
scale = np.random.uniform(1, 1.25)
jitter_size = (scale * np.array([self.opt.height, self.opt.width])).astype(np.int)
start1 = np.random.randint(jitter_size[0] - self.opt.height + 1)
start2 = np.random.randint(jitter_size[1] - self.opt.width + 1)
else:
jitter_size = np.array([self.opt.height, self.opt.width])
start1 = 0
start2 = 0
crop_pos = np.array([start1, start2])
crop_size = np.array([self.opt.height, self.opt.width])
return {'jitter size': jitter_size, 'crop pos': crop_pos, 'crop size': crop_size}
def apply_transform(self, data, params, interp_mode='bilinear'):
"""Apply the transform to the data tensor."""
tensor_size = params['jitter size'].tolist()
crop_pos = params['crop pos']
crop_size = params['crop size']
data = F.interpolate(data.unsqueeze(0), size=tensor_size, mode=interp_mode).squeeze(0)
data = data[:, crop_pos[0]:crop_pos[0] + crop_size[0], crop_pos[1]:crop_pos[1] + crop_size[1]]
return data
def init_homographies(self, homography_path, n_images):
"""Read homography file and set up homography data."""
with open(homography_path) as f:
h_data = f.readlines()
h_scale = h_data[0].rstrip().split(' ')
self.h_scale_x = int(h_scale[1])
self.h_scale_y = int(h_scale[2])
h_bounds = h_data[1].rstrip().split(' ')
self.h_bounds_x = [float(h_bounds[1]), float(h_bounds[2])]
self.h_bounds_y = [float(h_bounds[3]), float(h_bounds[4])]
homographies = h_data[2:2 + n_images]
homographies = [torch.from_numpy(np.array(line.rstrip().split(' ')).astype(np.float32).reshape(3, 3)) for line
in
homographies]
self.homographies = homographies
def load_and_process_image(self, im_path):
"""Read image file and return as tensor in range [-1, 1]."""
image = Image.open(im_path).convert('RGB')
image = transforms.ToTensor()(image)
image = 2 * image - 1
return image
def load_and_process_iuv(self, iuv_path, i):
"""Read IUV file and convert to network inputs."""
iuv_map = Image.open(iuv_path).convert('RGBA')
iuv_map = transforms.ToTensor()(iuv_map)
uv_map, mask, pids = self.iuv2input(iuv_map, i)
return uv_map, mask, pids
def iuv2input(self, iuv, index):
"""Create network inputs from IUV.
Parameters:
iuv - - a tensor of shape [4, H, W], where the channels are: body part ID, U, V, person ID.
index - - index of iuv
Returns:
uv (tensor) - - a UV map for a single layer, ready to pass to grid sampler (values in range [-1,1])
mask (tensor) - - the corresponding mask
person_id (tensor) - - the person IDs
grid sampler indexes into texture map of size tile_width x tile_width*n_textures
"""
# Extract body part and person IDs.
part_id = (iuv[0] * 255 / 10).round()
part_id[part_id > 24] = 24
part_id_mask = (part_id > 0).float()
person_id = (255 - 255 * iuv[-1]).round() # person ID is saved as 255 - person_id
person_id *= part_id_mask # background id is 0
maxId = self.opt.n_textures // 24
person_id[person_id>maxId] = maxId
# Convert body part ID to texture map ID.
# Essentially, each of the 24 body parts for each person, plus the background have their own texture 'tile'
# The tiles are concatenated horizontally to create the texture map.
tex_id = part_id + part_id_mask * 24 * (person_id - 1)
uv = iuv[1:3]
# Convert the per-body-part UVs to UVs that correspond to the full texture map.
uv[0] += tex_id
# Get the mask.
bg_mask = (tex_id == 0).float()
mask = 1.0 - bg_mask
mask = mask * 2 - 1 # make 1 the foreground and -1 the background mask
mask = self.mask2trimap(mask)
# Composite background UV behind person UV.
h, w = iuv.shape[1:]
bg_uv = self.get_background_uv(index, w, h)
uv = bg_mask * bg_uv + (1 - bg_mask) * uv
# Map to [-1, 1] range.
uv[0] /= self.opt.n_textures
uv = uv * 2 - 1
uv = torch.clamp(uv, -1, 1)
return uv, mask, person_id
def get_background_inputs(self, index, w, h):
"""Return data for background layer at 'index'."""
uv = self.get_background_uv(index, w, h)
# normalize to correct range, of full texture atlas
uv[0] /= self.opt.n_textures
uv = uv * 2 - 1 # [0,1] -> [-1,1]
uv = torch.clamp(uv, -1, 1)
mask = -torch.ones(*uv.shape[1:])
pids = torch.zeros(*uv.shape[1:])
return uv, mask, pids
def get_background_uv(self, index, w, h):
"""Return background layer UVs at 'index' (output range [0, 1])."""
ramp_u = torch.linspace(0, 1, steps=w).unsqueeze(0).repeat(h, 1)
ramp_v = torch.linspace(0, 1, steps=h).unsqueeze(-1).repeat(1, w)
ramp = torch.stack([ramp_u, ramp_v], 0)
if hasattr(self, 'homographies'):
# scale to [0, orig width/height]
ramp[0] *= self.h_scale_x
ramp[1] *= self.h_scale_y
# apply homography
ramp = ramp.reshape(2, -1) # [2, H, W]
H = self.homographies[index]
[xt, yt] = self.transform2h(ramp[0], ramp[1], torch.inverse(H))
# scale from world to [0,1]
xt -= self.h_bounds_x[0]
xt /= (self.h_bounds_x[1] - self.h_bounds_x[0])
yt -= self.h_bounds_y[0]
yt /= (self.h_bounds_y[1] - self.h_bounds_y[0])
# restore shape
ramp = torch.stack([xt.reshape(h, w), yt.reshape(h, w)], 0)
return ramp
def transform2h(self, x, y, m):
"""Applies 2d homogeneous transformation."""
A = torch.matmul(m, torch.stack([x, y, torch.ones(len(x))]))
xt = A[0, :] / A[2, :]
yt = A[1, :] / A[2, :]
return xt, yt
def mask2trimap(self, mask):
"""Convert binary mask to trimap with values in [-1, 0, 1]."""
fg_mask = (mask > 0).float()
bg_mask = (mask < 0).float()
trimap_width = getattr(self.opt, 'trimap_width', 20)
trimap_width *= bg_mask.shape[-1] / self.opt.width
trimap_width = int(trimap_width)
bg_mask = cv2.erode(bg_mask.numpy(), kernel=np.ones((trimap_width, trimap_width)), iterations=1)
bg_mask = torch.from_numpy(bg_mask)
mask = fg_mask - bg_mask
return mask
================================================
FILE: datasets/download_data.sh
================================================
#!/bin/bash
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
NAME=$1
if [[ $NAME != "cartwheel" && $NAME != "reflection" && $NAME != "splash" && $NAME != "trampoline" && $NAME != "all" ]]; then
echo "Available videos are: cartwheel, reflection, splash, trampoline"
exit 1
fi
if [[ $NAME == "all" ]]; then
declare -a NAMES=("cartwheel" "reflection" "splash" "trampoline")
else
declare -a NAMES=($NAME)
fi
for NAME in "${NAMES[@]}"
do
echo "Specified [$NAME]"
URL=https://www.robots.ox.ac.uk/~erika/retiming/data/$NAME.zip
ZIP_FILE=./datasets/$NAME.zip
TARGET_DIR=./datasets/$NAME/
wget -N $URL -O $ZIP_FILE
mkdir $TARGET_DIR
unzip $ZIP_FILE -d ./datasets/
rm $ZIP_FILE
done
================================================
FILE: datasets/iuv_crop2full.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert UV crops to full UV maps."""
import os
import sys
import json
from PIL import Image
import numpy as np
def place_crop(crop, image, center_x, center_y):
"""Place the crop in the image at the specified location."""
im_height, im_width = image.shape[:2]
crop_height, crop_width = crop.shape[:2]
left = center_x - crop_width // 2
right = left + crop_width
top = center_y - crop_height // 2
bottom = top + crop_height
adjusted_crop = crop # remove regions of crop that go beyond image bounds
if left < 0:
adjusted_crop = adjusted_crop[:, -left:]
if right > im_width:
adjusted_crop = adjusted_crop[:, :(im_width - right)]
if top < 0:
adjusted_crop = adjusted_crop[-top:]
if bottom > im_height:
adjusted_crop = adjusted_crop[:(im_height - bottom)]
crop_mask = (adjusted_crop > 0).astype(crop.dtype).sum(-1, keepdims=True)
image[max(0, top):min(im_height, bottom), max(0, left):min(im_width, right)] *= (1 - crop_mask)
image[max(0, top):min(im_height, bottom), max(0, left):min(im_width, right)] += adjusted_crop
return image
def crop2full(keypoints_path, metadata_path, uvdir, outdir):
"""Create each frame's layer UVs from predicted UV crops"""
with open(keypoints_path) as f:
kp_data = json.load(f)
# Get all people ids
people_ids = set()
for frame in kp_data:
for skeleton in kp_data[frame]:
people_ids.add(skeleton['idx'])
people_ids = sorted(list(people_ids))
with open(metadata_path) as f:
metadata = json.load(f)
orig_size = np.array(metadata['alphapose_input_size'][::-1])
out_size = np.array(metadata['size_LR'][::-1])
if 'people_layers' in metadata:
people_layers = metadata['people_layers']
else:
people_layers = [[pid] for pid in people_ids]
# Create output directories.
for layer_i in range(1, 1 + len(people_layers)):
os.makedirs(os.path.join(outdir, f'{layer_i:02d}'), exist_ok=True)
print(f'Writing UVs to {outdir}')
for frame in sorted(kp_data):
for layer_i, layer in enumerate(people_layers, 1):
out_path = os.path.join(outdir, f'{layer_i:02d}', frame)
sys.stdout.flush()
sys.stdout.write('processing frame %s\r' % out_path)
uv_map = np.zeros([out_size[0], out_size[1], 4])
for person_id in layer:
matches = [p for p in kp_data[frame] if p['idx'] == person_id]
if len(matches) == 0: # person doesn't appear in this frame
continue
skeleton = matches[0]
kps = np.array(skeleton['keypoints']).reshape(17, 3)
# Get kps bounding box.
left = kps[:, 0].min()
right = kps[:, 0].max()
top = kps[:, 1].min()
bottom = kps[:, 1].max()
height = bottom - top
width = right - left
orig_crop_size = max(height, width)
orig_center_x = (left + right) // 2
orig_center_y = (top + bottom) // 2
# read predicted uv map
uv_crop_path = os.path.join(uvdir, f'{person_id:02d}_{os.path.basename(out_path)[:-4]}_output_uv.png')
if os.path.exists(uv_crop_path):
uv_crop = np.array(Image.open(uv_crop_path))
else:
uv_crop = np.zeros([256, 256, 3])
# add person ID channel
person_mask = (uv_crop[..., 0:1] > 0).astype('uint8')
person_ids = (255 - person_id) * person_mask
uv_crop = np.concatenate([uv_crop, person_ids], -1)
# scale crop to desired output size
# 256 is the crop size, 192 is the inner crop size
out_crop_size = orig_crop_size * 256./192 * out_size / orig_size
out_crop_size = out_crop_size.astype(np.int)
uv_crop = uv_crop.astype(np.uint8)
uv_crop = np.array(Image.fromarray(uv_crop).resize((out_crop_size[1], out_crop_size[0]), resample=Image.NEAREST))
# scale center coordinate accordingly
out_center_x = (orig_center_x * out_size[1] / orig_size[1]).astype(np.int)
out_center_y = (orig_center_y * out_size[0] / orig_size[0]).astype(np.int)
# Place UV crop in full UV map and save.
uv_map = place_crop(uv_crop, uv_map, out_center_x, out_center_y)
uv_map = Image.fromarray(uv_map.astype('uint8'))
uv_map.save(out_path)
if __name__ == "__main__":
import argparse
arguments = argparse.ArgumentParser()
arguments.add_argument('--dataroot', type=str)
opt = arguments.parse_args()
keypoints_path = os.path.join(opt.dataroot, 'keypoints.json')
metadata_path = os.path.join(opt.dataroot, 'metadata.json')
uvdir = os.path.join(opt.dataroot, 'kp2uv/test_latest/images')
outdir = os.path.join(opt.dataroot, 'iuv')
crop2full(keypoints_path, metadata_path, uvdir, outdir)
================================================
FILE: datasets/prepare_iuv.sh
================================================
#!/bin/bash
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
DATA_PATH=$1
# Predict UVs from keypoints and save the crops.
python run_kp2uv.py --model kp2uv --dataroot $DATA_PATH --results_dir $DATA_PATH
# Convert the cropped UVs to full UV maps.
python datasets/iuv_crop2full.py --dataroot $DATA_PATH
================================================
FILE: docs/contributing.md
================================================
# How to Contribute
We'd love to accept your patches and contributions to this project. There are
just a few small guidelines you need to follow.
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement (CLA). You (or your employer) retain the copyright to your
contribution; this simply gives us permission to use and redistribute your
contributions as part of the project. Head over to
to see your current agreements on file or
to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Code reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
## Community Guidelines
This project follows
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
================================================
FILE: docs/data.md
================================================
### Data
The data directory for a video is structured as follows:
```
video_name/
|-- rgb_256/
| |-- 0001.png, etc.
|-- rgb_512/
| |-- 0001.png, etc.
|-- mask/ (optional)
|-- |-- 01, etc.
|-- |-- |-- 0001.png, etc.
|-- keypoints.json
|-- metadata.json
|-- homographies.txt (optional)
```
- `metadata.json` contains a dictionary:
```
'alphapose_input_size': [width, height] # size of frames input to AlphaPose
'size_LR': [width, height] # size of low-resolution frames (multiple of 16; height should be 256)
'n_textures': int # number of texture maps required, calculated by 24*num_people + 1
'composite_order': [[1, 2, 3], [1, 3, 2], ... ] # optional per-frame back-to-front layer compositing order
```
- `keypoints.json` is in the format output by the [AlphaPose Pose Tracker](https://github.com/MVIG-SJTU/AlphaPose).
See [here](https://github.com/MVIG-SJTU/AlphaPose/tree/master/trackers/PoseFlow) for details.
================================================
FILE: environment.yml
================================================
name: retiming
channels:
- pytorch
- defaults
dependencies:
- python=3.8
- pytorch=1.4.0
- pip:
- dominate==2.4.0
- torchvision==0.5.0
- Pillow>=6.1.0
- numpy==1.19.2
- visdom==0.1.8
- opencv-python>=4.2.0
- matplotlib
================================================
FILE: models/__init__.py
================================================
================================================
FILE: models/kp2uv_model.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from third_party.models.base_model import BaseModel
from . import networks
class Kp2uvModel(BaseModel):
"""This class implements the keypoint-to-UV model (inference only)."""
@staticmethod
def modify_commandline_options(parser, is_train=True):
parser.set_defaults(dataset_mode='kpuv')
return parser
def __init__(self, opt):
"""Initialize this model class.
Parameters:
opt -- test options
"""
BaseModel.__init__(self, opt)
self.visual_names = ['keypoints', 'output_uv']
self.model_names = ['Kp2uv']
self.netKp2uv = networks.define_kp2uv(gpu_ids=self.gpu_ids)
self.isTrain = False # only test mode supported
# Our program will automatically call to define schedulers, load networks, and print networks
def set_input(self, input):
"""Unpack input data from the dataloader.
Parameters:
input: a dictionary that contains the data itself and its metadata information.
"""
self.keypoints = input['keypoints'].to(self.device)
self.image_paths = input['path']
def forward(self):
"""Run forward pass. This will be called by ."""
output = self.netKp2uv.forward(self.keypoints)
self.output_uv = self.output2rgb(output)
def output2rgb(self, output):
"""Convert network outputs to RGB image."""
pred_id, pred_uv = output
_, pred_id_class = pred_id.max(1)
pred_id_class = pred_id_class.unsqueeze(1)
# extract UV from pred_uv (48 channels); select based on class ID
selected_uv = -1 * torch.ones(pred_uv.shape[0], 2, pred_uv.shape[2], pred_uv.shape[3], device=pred_uv.device)
for partid in range(1, 25):
mask = (pred_id_class == partid).float()
selected_uv *= (1. - mask)
selected_uv += mask * pred_uv[:, (partid - 1) * 2:(partid - 1) * 2 + 2]
pred_uv = selected_uv
rgb = torch.cat([pred_id_class.float() * 10 / 255. * 2 - 1, pred_uv], 1)
return rgb
def optimize_parameters(self):
pass
================================================
FILE: models/lnr_model.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from third_party.models.base_model import BaseModel
from . import networks
import numpy as np
import torch.nn.functional as F
class LnrModel(BaseModel):
"""This class implements the layered neural rendering model for decomposing a video into layers."""
@staticmethod
def modify_commandline_options(parser, is_train=True):
parser.set_defaults(dataset_mode='layered_video')
parser.add_argument('--texture_res', type=int, default=16, help='texture resolution')
parser.add_argument('--texture_channels', type=int, default=16, help='# channels for neural texture')
parser.add_argument('--n_textures', type=int, default=25, help='# individual texture maps, 24 per person (1 per body part) + 1 for background')
if is_train:
parser.add_argument('--lambda_alpha_l1', type=float, default=0.01, help='alpha L1 sparsity loss weight')
parser.add_argument('--lambda_alpha_l0', type=float, default=0.005, help='alpha L0 sparsity loss weight')
parser.add_argument('--alpha_l1_rolloff_epoch', type=int, default=200, help='turn off L1 alpha sparsity loss weight after this epoch')
parser.add_argument('--lambda_mask', type=float, default=50, help='layer matting loss weight')
parser.add_argument('--mask_thresh', type=float, default=0.02, help='turn off masking loss when error falls below this value')
parser.add_argument('--mask_loss_rolloff_epoch', type=int, default=-1, help='decrease masking loss after this epoch; if <0, use mask_thresh instead')
parser.add_argument('--n_epochs_upsample', type=int, default=500,
help='number of epochs to train the upsampling module')
parser.add_argument('--batch_size_upsample', type=int, default=16, help='batch size for upsampling')
parser.add_argument('--jitter_rgb', type=float, default=0.2, help='amount of jitter to add to RGB')
parser.add_argument('--jitter_epochs', type=int, default=400, help='number of epochs to jitter RGB')
parser.add_argument('--do_upsampling', action='store_true', help='whether to use upsampling module')
return parser
def __init__(self, opt):
"""Initialize this model class.
Parameters:
opt -- training/test options
"""
BaseModel.__init__(self, opt)
# specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
self.visual_names = ['target_image', 'reconstruction', 'rgba_vis', 'alpha_vis', 'input_vis']
self.model_names = ['LNR']
self.netLNR = networks.define_LNR(opt.num_filters, opt.texture_channels, opt.texture_res, opt.n_textures, gpu_ids=self.gpu_ids)
self.do_upsampling = opt.do_upsampling
if self.isTrain:
self.setup_train(opt)
# Our program will automatically call to define schedulers, load networks, and print networks
def setup_train(self, opt):
"""Setup the model for training mode."""
print('setting up model')
# specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
self.loss_names = ['total', 'recon', 'alpha_reg', 'mask']
self.visual_names = ['target_image', 'reconstruction', 'rgba_vis', 'alpha_vis', 'input_vis']
self.do_upsampling = opt.do_upsampling
if not self.do_upsampling:
self.visual_names += ['mask_vis']
self.criterionLoss = torch.nn.L1Loss()
self.criterionLossMask = networks.MaskLoss().to(self.device)
self.lambda_mask = opt.lambda_mask
self.lambda_alpha_l0 = opt.lambda_alpha_l0
self.lambda_alpha_l1 = opt.lambda_alpha_l1
self.mask_loss_rolloff_epoch = opt.mask_loss_rolloff_epoch
self.jitter_rgb = opt.jitter_rgb
self.do_upsampling = opt.do_upsampling
self.optimizer = torch.optim.Adam(self.netLNR.parameters(), lr=opt.lr)
self.optimizers = [self.optimizer]
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input: a dictionary that contains the data itself and its metadata information.
"""
self.target_image = input['image'].to(self.device)
if self.isTrain and self.jitter_rgb > 0:
# add brightness jitter to rgb
self.target_image += self.jitter_rgb * torch.randn(self.target_image.shape[0], 1, 1, 1).to(self.device)
self.target_image = torch.clamp(self.target_image, -1, 1)
self.input_uv = input['uv_map'].to(self.device)
self.input_id = input['pids'].to(self.device)
self.mask = input['mask'].to(self.device)
self.image_paths = input['image_path']
def gen_crop_params(self, orig_h, orig_w, crop_size=256):
"""Generate random square cropping parameters."""
starty = np.random.randint(orig_h - crop_size + 1)
startx = np.random.randint(orig_w - crop_size + 1)
endy = starty + crop_size
endx = startx + crop_size
return starty, endy, startx, endx
def forward(self):
"""Run forward pass. This will be called by both functions and ."""
if self.do_upsampling:
input_uv_up = F.interpolate(self.input_uv, scale_factor=2, mode='bilinear')
crop_params = None
if self.isTrain:
# Take random crop to decrease memory requirement.
crop_params = self.gen_crop_params(*input_uv_up.shape[-2:])
starty, endy, startx, endx = crop_params
self.target_image = self.target_image[:, :, starty:endy, startx:endx]
outputs = self.netLNR.forward(self.input_uv, self.input_id, uv_map_upsampled=input_uv_up, crop_params=crop_params)
else:
outputs = self.netLNR(self.input_uv, self.input_id)
self.reconstruction = outputs['reconstruction'][:, :3]
self.alpha_composite = outputs['reconstruction'][:, 3]
self.output_rgba = outputs['layers']
n_layers = outputs['layers'].shape[1]
layers = outputs['layers'].clone()
layers[:, 0, -1] = 1 # Background layer's alpha is always 1
layers = torch.cat([layers[:, l] for l in range(n_layers)], -2)
self.alpha_vis = layers[:, 3:4]
self.rgba_vis = layers
self.mask_vis = torch.cat([self.mask[:, l:l+1] for l in range(n_layers)], -2)
self.input_vis = torch.cat([self.input_uv[:, 2*l:2*l+2] for l in range(n_layers)], -2)
self.input_vis = torch.cat([torch.zeros_like(self.input_vis[:, :1]), self.input_vis], 1)
def backward(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
self.loss_recon = self.criterionLoss(self.reconstruction[:, :3], self.target_image)
self.loss_total = self.loss_recon
if not self.do_upsampling:
self.loss_alpha_reg = networks.cal_alpha_reg(self.alpha_composite * .5 + .5, self.lambda_alpha_l1, self.lambda_alpha_l0)
alpha_layers = self.output_rgba[:, :, 3]
self.loss_mask = self.lambda_mask * self.criterionLossMask(alpha_layers, self.mask)
self.loss_total += self.loss_alpha_reg + self.loss_mask
else:
self.loss_mask = 0.
self.loss_alph_reg = 0.
self.loss_total.backward()
def optimize_parameters(self):
"""Update network weights; it will be called in every training iteration."""
self.forward()
self.optimizer.zero_grad()
self.backward()
self.optimizer.step()
def update_lambdas(self, epoch):
"""Update loss weights based on current epochs and losses."""
if epoch == self.opt.alpha_l1_rolloff_epoch:
self.lambda_alpha_l1 = 0
if self.mask_loss_rolloff_epoch >= 0:
if epoch == 2*self.mask_loss_rolloff_epoch:
self.lambda_mask = 0
elif epoch > self.opt.epoch_count:
if self.loss_mask < self.opt.mask_thresh * self.opt.lambda_mask:
self.mask_loss_rolloff_epoch = epoch
self.lambda_mask *= .1
if epoch == self.opt.jitter_epochs:
self.jitter_rgb = 0
def transfer_detail(self):
"""Transfer detail to layers."""
residual = self.target_image - self.reconstruction
transmission_comp = torch.zeros_like(self.target_image[:, 0:1])
rgba_detail = self.output_rgba
n_layers = self.output_rgba.shape[1]
for i in range(n_layers - 1, 0, -1): # Don't do detail transfer for background layer, due to ghosting effects.
transmission_i = 1. - transmission_comp
rgba_detail[:, i, :3] += transmission_i * residual
alpha_i = self.output_rgba[:, i, 3:4] * .5 + .5
transmission_comp = alpha_i + (1. - alpha_i) * transmission_comp
self.rgba = torch.clamp(rgba_detail, -1, 1)
def get_results(self):
"""Return results. This is different from get_current_visuals, which gets visuals for monitoring training.
Returns a dictionary:
original - - original frame
recon - - reconstruction
rgba_l* - - RGBA for each layer
mask_l* - - mask for each layer
"""
self.transfer_detail()
# Split layers
results = {'reconstruction': self.reconstruction, 'original': self.target_image}
n_layers = self.rgba.shape[1]
for i in range(n_layers):
results[f'mask_l{i}'] = self.mask[:, i:i+1]
results[f'rgba_l{i}'] = self.rgba[:, i]
if i == 0:
results[f'rgba_l{i}'][:, -1:] = 1.
return results
def freeze_basenet(self):
"""Freeze all parameters except for the upsampling module."""
net = self.netLNR
if isinstance(net, torch.nn.DataParallel):
net = net.module
self.set_requires_grad([net.encoder, net.decoder, net.final_rgba], False)
net.texture.requires_grad = False
================================================
FILE: models/networks.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from third_party.models.networks import init_net
###############################################################################
# Helper Functions
###############################################################################
def define_LNR(nf=64, texture_channels=16, texture_res=16, n_textures=25, gpu_ids=[]):
"""Create a layered neural renderer.
Parameters:
nf (int) -- the number of channels in the first/last conv layers
texture_channels (int) -- the number of channels in the neural texture
texture_res (int) -- the size of each individual texture map
n_textures (int) -- the number of individual texture maps
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
Returns a layered neural rendering model.
"""
net = LayeredNeuralRenderer(nf, texture_channels, texture_res, n_textures)
return init_net(net, gpu_ids)
def define_kp2uv(nf=64, gpu_ids=[]):
"""Create a keypoint-to-UV model.
Parameters:
nf (int) -- the number of channels in the first/last conv layers
Returns a keypoint-to-UV model.
"""
net = kp2uv(nf)
return init_net(net, gpu_ids)
def cal_alpha_reg(prediction, lambda_alpha_l1, lambda_alpha_l0):
"""Calculate the alpha regularization term.
Parameters:
prediction (tensor) - - composite of predicted alpha layers
lambda_alpha_l1 (float) - - weight for the L1 regularization term
lambda_alpha_l0 (float) - - weight for the L0 regularization term
Returns the alpha regularization loss term
"""
assert prediction.max() <= 1.
assert prediction.min() >= 0.
loss = 0.
if lambda_alpha_l1 > 0:
loss += lambda_alpha_l1 * torch.mean(prediction)
if lambda_alpha_l0 > 0:
# Pseudo L0 loss using a squished sigmoid curve.
l0_prediction = (torch.sigmoid(prediction * 5.0) - 0.5) * 2.0
loss += lambda_alpha_l0 * torch.mean(l0_prediction)
return loss
##############################################################################
# Classes
##############################################################################
class MaskLoss(nn.Module):
"""Define the loss which encourages the predicted alpha matte to match the mask (trimap)."""
def __init__(self):
super(MaskLoss, self).__init__()
self.loss = nn.L1Loss(reduction='none')
def __call__(self, prediction, target):
"""Calculate loss given predicted alpha matte and trimap.
Balance positive and negative regions. Exclude 'unknown' region from loss.
Parameters:
prediction (tensor) - - predicted alpha
target (tensor) - - trimap
Returns: the computed loss
"""
mask_err = self.loss(prediction, target)
pos_mask = F.relu(target)
neg_mask = F.relu(-target)
pos_mask_loss = (pos_mask * mask_err).sum() / (1 + pos_mask.sum())
neg_mask_loss = (neg_mask * mask_err).sum() / (1 + neg_mask.sum())
loss = .5 * (pos_mask_loss + neg_mask_loss)
return loss
class ConvBlock(nn.Module):
"""Helper module consisting of a convolution, optional normalization and activation, with padding='same'."""
def __init__(self, conv, in_channels, out_channels, ksize=4, stride=1, dil=1, norm=None, activation='relu'):
"""Create a conv block.
Parameters:
conv (convolutional layer) - - the type of conv layer, e.g. Conv2d, ConvTranspose2d
in_channels (int) - - the number of input channels
in_channels (int) - - the number of output channels
ksize (int) - - the kernel size
stride (int) - - stride
dil (int) - - dilation
norm (norm layer) - - the type of normalization layer, e.g. BatchNorm2d, InstanceNorm2d
activation (str) -- the type of activation: relu | leaky | tanh | none
"""
super(ConvBlock, self).__init__()
self.k = ksize
self.s = stride
self.d = dil
self.conv = conv(in_channels, out_channels, ksize, stride=stride, dilation=dil)
if norm is not None:
self.norm = norm(out_channels)
else:
self.norm = None
if activation == 'leaky':
self.activation = nn.LeakyReLU(0.2)
elif activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'tanh':
self.activation = nn.Tanh()
else:
self.activation = None
def forward(self, x):
"""Forward pass. Compute necessary padding and cropping because pytorch doesn't have pad=same."""
height, width = x.shape[-2:]
if isinstance(self.conv, nn.modules.ConvTranspose2d):
desired_height = height * self.s
desired_width = width * self.s
pady = 0
padx = 0
else:
# o = [i + 2*p - k - (k-1)*(d-1)]/s + 1
# padding = .5 * (stride * (output-1) + (k-1)(d-1) + k - input)
desired_height = height // self.s
desired_width = width // self.s
pady = .5 * (self.s * (desired_height - 1) + (self.k - 1) * (self.d - 1) + self.k - height)
padx = .5 * (self.s * (desired_width - 1) + (self.k - 1) * (self.d - 1) + self.k - width)
x = F.pad(x, [int(np.floor(padx)), int(np.ceil(padx)), int(np.floor(pady)), int(np.ceil(pady))])
x = self.conv(x)
if x.shape[-2] != desired_height or x.shape[-1] != desired_width:
cropy = x.shape[-2] - desired_height
cropx = x.shape[-1] - desired_width
x = x[:, :, int(np.floor(cropy / 2.)):-int(np.ceil(cropy / 2.)),
int(np.floor(cropx / 2.)):-int(np.ceil(cropx / 2.))]
if self.norm:
x = self.norm(x)
if self.activation:
x = self.activation(x)
return x
class ResBlock(nn.Module):
"""Define a residual block."""
def __init__(self, channels, ksize=4, stride=1, dil=1, norm=None, activation='relu'):
"""Initialize the residual block, which consists of 2 conv blocks with a skip connection."""
super(ResBlock, self).__init__()
self.convblock1 = ConvBlock(nn.Conv2d, channels, channels, ksize=ksize, stride=stride, dil=dil, norm=norm,
activation=activation)
self.convblock2 = ConvBlock(nn.Conv2d, channels, channels, ksize=ksize, stride=stride, dil=dil, norm=norm,
activation=None)
def forward(self, x):
identity = x
x = self.convblock1(x)
x = self.convblock2(x)
x += identity
return x
class kp2uv(nn.Module):
"""UNet architecture for converting keypoint image to UV map.
Same person UV map format as described in https://arxiv.org/pdf/1802.00434.pdf.
"""
def __init__(self, nf=64):
super(kp2uv, self).__init__(),
self.encoder = nn.ModuleList([
ConvBlock(nn.Conv2d, 3, nf, ksize=4, stride=2),
ConvBlock(nn.Conv2d, nf, nf * 2, ksize=4, stride=2, norm=nn.InstanceNorm2d, activation='leaky'),
ConvBlock(nn.Conv2d, nf * 2, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d, activation='leaky'),
ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d, activation='leaky'),
ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d, activation='leaky'),
ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=3, stride=1, norm=nn.InstanceNorm2d, activation='leaky'),
ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=3, stride=1, norm=nn.InstanceNorm2d, activation='leaky')])
self.decoder = nn.ModuleList([
ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d),
ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d),
ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 2, ksize=4, stride=2, norm=nn.InstanceNorm2d),
ConvBlock(nn.ConvTranspose2d, nf * 2 * 2, nf, ksize=4, stride=2, norm=nn.InstanceNorm2d),
ConvBlock(nn.ConvTranspose2d, nf * 2, nf, ksize=4, stride=2, norm=nn.InstanceNorm2d)])
# head to predict body part class (25 classes - 24 body parts, 1 background.)
self.id_pred = ConvBlock(nn.Conv2d, nf + 3, 25, ksize=3, stride=1, activation='none')
# head to predict UV coordinates for every body part class
self.uv_pred = ConvBlock(nn.Conv2d, nf + 3, 2 * 24, ksize=3, stride=1, activation='tanh')
def forward(self, x):
"""Forward pass through UNet, handling skip connections.
Parameters:
x (tensor) - - rendered keypoint image, shape [B, 3, H, W]
Returns:
x_id (tensor): part id class probabilities
x_uv (tensor): uv coordinates for each part id
"""
skips = [x]
for i, layer in enumerate(self.encoder):
x = layer(x)
if i < 5:
skips.append(x)
for layer in self.decoder:
x = torch.cat((x, skips.pop()), 1)
x = layer(x)
x = torch.cat((x, skips.pop()), 1)
x_id = self.id_pred(x)
x_uv = self.uv_pred(x)
return x_id, x_uv
class LayeredNeuralRenderer(nn.Module):
"""Layered Neural Rendering model for video decomposition.
Consists of neural texture, UNet, upsampling module.
"""
def __init__(self, nf=64, texture_channels=16, texture_res=16, n_textures=25):
super(LayeredNeuralRenderer, self).__init__(),
"""Initialize layered neural renderer.
Parameters:
nf (int) -- the number of channels in the first/last conv layers
texture_channels (int) -- the number of channels in the neural texture
texture_res (int) -- the size of each individual texture map
n_textures (int) -- the number of individual texture maps
"""
# Neural texture is implemented as 'n_textures' concatenated horizontally
self.texture = nn.Parameter(torch.randn(1, texture_channels, texture_res, n_textures * texture_res))
# Define UNet
self.encoder = nn.ModuleList([
ConvBlock(nn.Conv2d, texture_channels + 1, nf, ksize=4, stride=2),
ConvBlock(nn.Conv2d, nf, nf * 2, ksize=4, stride=2, norm=nn.BatchNorm2d, activation='leaky'),
ConvBlock(nn.Conv2d, nf * 2, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d, activation='leaky'),
ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d, activation='leaky'),
ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d, activation='leaky'),
ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=1, dil=2, norm=nn.BatchNorm2d, activation='leaky'),
ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=1, dil=2, norm=nn.BatchNorm2d, activation='leaky')])
self.decoder = nn.ModuleList([
ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d),
ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d),
ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 2, ksize=4, stride=2, norm=nn.BatchNorm2d),
ConvBlock(nn.ConvTranspose2d, nf * 2 * 2, nf, ksize=4, stride=2, norm=nn.BatchNorm2d),
ConvBlock(nn.ConvTranspose2d, nf * 2, nf, ksize=4, stride=2, norm=nn.BatchNorm2d)])
self.final_rgba = ConvBlock(nn.Conv2d, nf, 4, ksize=4, stride=1, activation='tanh')
# Define upsampling block, which outputs a residual
upsampling_ic = texture_channels + 5 + nf
self.upsample_block = nn.Sequential(
ConvBlock(nn.Conv2d, upsampling_ic, nf, ksize=3, stride=1, norm=nn.InstanceNorm2d),
ResBlock(nf, ksize=3, stride=1, norm=nn.InstanceNorm2d),
ResBlock(nf, ksize=3, stride=1, norm=nn.InstanceNorm2d),
ResBlock(nf, ksize=3, stride=1, norm=nn.InstanceNorm2d),
ConvBlock(nn.Conv2d, nf, 4, ksize=3, stride=1, activation='none'))
def render(self, x):
"""Pass inputs for a single layer through UNet.
Parameters:
x (tensor) - - sampled texture concatenated with person IDs
Returns RGBA for the input layer and the final feature maps.
"""
skips = [x]
for i, layer in enumerate(self.encoder):
x = layer(x)
if i < 5:
skips.append(x)
for layer in self.decoder:
x = torch.cat((x, skips.pop()), 1)
x = layer(x)
rgba = self.final_rgba(x)
return rgba, x
def forward(self, uv_map, id_layers, uv_map_upsampled=None, crop_params=None):
"""Forward pass through layered neural renderer.
Steps:
1. Sample from the neural texture using uv_map
2. Input uv_map and id_layers into UNet
2a. If doing upsampling, then pass upsampled inputs and results through upsampling module
3. Composite RGBA outputs.
Parameters:
uv_map (tensor) - - UV maps for all layers, with shape [B, (2*L), H, W]
id_layers (tensor) - - person ID for all layers, with shape [B, L, H, W]
uv_map_upsampled (tensor) - - upsampled UV maps to input to upsampling module (if None, skip upsampling)
crop_params
"""
b_sz = uv_map.shape[0]
n_layers = uv_map.shape[1] // 2
texture = self.texture.repeat(b_sz, 1, 1, 1)
composite = None
layers = []
sampled_textures = []
for i in range(n_layers):
# Get RGBA for this layer.
uv_map_i = uv_map[:, i * 2:(i + 1) * 2, ...]
uv_map_i = uv_map_i.permute(0, 2, 3, 1)
sampled_texture = F.grid_sample(texture, uv_map_i, mode='bilinear', padding_mode='zeros')
inputs = torch.cat([sampled_texture, id_layers[:, i:i + 1]], 1)
rgba, last_feat = self.render(inputs)
if uv_map_upsampled is not None:
uv_map_up_i = uv_map_upsampled[:, i * 2:(i + 1) * 2, ...]
uv_map_up_i = uv_map_up_i.permute(0, 2, 3, 1)
sampled_texture_up = F.grid_sample(texture, uv_map_up_i, mode='bilinear', padding_mode='zeros')
id_layers_up = F.interpolate(id_layers[:, i:i + 1], size=sampled_texture_up.shape[-2:],
mode='bilinear')
inputs_up = torch.cat([sampled_texture_up, id_layers_up], 1)
upsampled_size = inputs_up.shape[-2:]
rgba = F.interpolate(rgba, size=upsampled_size, mode='bilinear')
last_feat = F.interpolate(last_feat, size=upsampled_size, mode='bilinear')
if crop_params is not None:
starty, endy, startx, endx = crop_params
rgba = rgba[:, :, starty:endy, startx:endx]
last_feat = last_feat[:, :, starty:endy, startx:endx]
inputs_up = inputs_up[:, :, starty:endy, startx:endx]
rgba_residual = self.upsample_block(torch.cat((rgba, inputs_up, last_feat), 1))
rgba += .01 * rgba_residual
rgba = torch.clamp(rgba, -1, 1)
sampled_texture = sampled_texture_up
# Update the composite with this layer's RGBA output
if composite is None:
composite = rgba
else:
alpha = rgba[:, 3:4] * .5 + .5
composite = rgba * alpha + composite * (1.0 - alpha)
layers.append(rgba)
sampled_textures.append(sampled_texture)
outputs = {
'reconstruction': composite,
'layers': torch.stack(layers, 1),
'sampled texture': sampled_textures, # for debugging
}
return outputs
================================================
FILE: options/__init__.py
================================================
================================================
FILE: options/base_options.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from third_party.util import util
from third_party import models
from third_party import data
import torch
import json
class BaseOptions():
"""This class defines options used during both training and test time.
It also implements several helper functions such as parsing, printing, and saving the options.
It also gathers additional options defined in functions in both dataset class and model class.
"""
def __init__(self):
"""Reset the class; indicates the class hasn't been initialized"""
self.initialized = False
def initialize(self, parser):
"""Define the common options that are used in both training and test."""
# basic parameters
parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders rgb_256, etc)')
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
parser.add_argument('--seed', type=int, default=1, help='initial random seed')
# model parameters
parser.add_argument('--model', type=str, default='lnr', help='chooses which model to use. [lnr | kp2uv]')
parser.add_argument('--num_filters', type=int, default=64, help='# filters in the first and last conv layers')
# dataset parameters
parser.add_argument('--dataset_mode', type=str, default='layered_video', help='chooses how datasets are loaded. [layered_video | kpuv]')
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
parser.add_argument('--batch_size', type=int, default=32, help='input batch size')
parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
# additional parameters
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
self.initialized = True
return parser
def gather_options(self):
"""Initialize our parser with basic options(only once).
Add additional model-specific and dataset-specific options.
These options are defined in the function
in model and dataset classes.
"""
if not self.initialized: # check if it has been initialized
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)
# get the basic options
opt, _ = parser.parse_known_args()
# modify model-related parser options
model_name = opt.model
model_option_setter = models.get_option_setter(model_name)
parser = model_option_setter(parser, self.isTrain)
opt, _ = parser.parse_known_args() # parse again with new defaults
# modify dataset-related parser options
dataset_name = opt.dataset_mode
dataset_option_setter = data.get_option_setter(dataset_name)
parser = dataset_option_setter(parser, self.isTrain)
# save and return the parser
self.parser = parser
return parser.parse_args()
def print_options(self, opt):
"""Print and save options
It will print both current options and default values(if different).
It will save options into a text file / [checkpoints_dir] / opt.txt
"""
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
# save to the disk
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
with open(file_name, 'wt') as opt_file:
opt_file.write(message)
opt_file.write('\n')
def parse(self):
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
opt = self.gather_options()
opt.isTrain = self.isTrain # train or test
# process opt.suffix
if opt.suffix:
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
opt.name = opt.name + suffix
self.print_options(opt)
# set gpu ids
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids.append(id)
if len(opt.gpu_ids) > 0:
torch.cuda.set_device(opt.gpu_ids[0])
self.opt = opt
return self.opt
def parse_dataset_meta(self):
"""Parse options from the 'metadata.json' file in the dataroot."""
with open(os.path.join(self.opt.dataroot, 'metadata.json')) as f:
metadata = json.load(f)
self.opt.n_textures = metadata['n_textures']
self.opt.width = metadata['size_LR'][0]
self.opt.height = metadata['size_LR'][1]
return self.opt
================================================
FILE: options/test_options.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_options import BaseOptions
class TestOptions(BaseOptions):
"""This class includes test options.
It also includes shared options defined in BaseOptions.
"""
def initialize(self, parser):
parser = BaseOptions.initialize(self, parser) # define shared options
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
parser.add_argument('--num_test', type=int, default=float("inf"), help='how many test images to run')
self.isTrain = False
return parser
================================================
FILE: options/train_options.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_options import BaseOptions
class TrainOptions(BaseOptions):
"""This class includes training options.
It also includes shared options defined in BaseOptions.
"""
def initialize(self, parser):
parser = BaseOptions.initialize(self, parser)
# visdom and HTML visualization parameters
parser.add_argument('--display_freq', type=int, default=20, help='frequency of showing training results on screen (in epochs)')
parser.add_argument('--display_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
parser.add_argument('--update_html_freq', type=int, default=50, help='frequency of saving training results to html')
parser.add_argument('--print_freq', type=int, default=10, help='frequency of showing training results on console (in steps per epoch)')
parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
# network saving and loading parameters
parser.add_argument('--save_latest_freq', type=int, default=50, help='frequency of saving the latest results (in epochs)')
parser.add_argument('--save_by_epoch', action='store_true', help='whether saves model as "epoch" or "latest" (overwrites previous)')
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
# training parameters
parser.add_argument('--n_epochs', type=int, default=2000, help='number of epochs with the initial learning rate')
parser.add_argument('--n_epochs_decay', type=int, default=0, help='number of epochs to linearly decay learning rate to zero')
parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam')
parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
self.isTrain = True
return parser
================================================
FILE: requirements.txt
================================================
torch==1.4.0
torchvision>=0.5.0
dominate>=2.4.0
visdom>=0.1.8
matplotlib>=3.2.1
opencv-python>=4.2.0
================================================
FILE: run_kp2uv.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for running the keypoint-to-UV network and saving the predicted UVs.
Example:
python run_kp2uv.py --model kp2uv --dataroot ./datasets/reflection --results_dir ./datasets/reflection
It will load the pre-trained model from '--checkpoints_dir' and save the results to '--results_dir'.
See options/base_options.py and options/test_options.py for more test options.
"""
import os
from options.test_options import TestOptions
from third_party.data import create_dataset
from third_party.models import create_model
from third_party.util.visualizer import save_images
from third_party.util import html
if __name__ == '__main__':
opt = TestOptions().parse() # get test options
# hard-code some parameters
opt.name = 'kp2uv'
opt.num_threads = 0 # test code only supports num_threads = 0
opt.batch_size = 1 # test code only supports batch_size = 1
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
model = create_model(opt) # create a model given opt.model and other options
model.setup(opt) # regular setup: load and print networks; create schedulers
# create a website
web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch)) # define the website directory
print('creating web directory', web_dir)
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
for i, data in enumerate(dataset):
if i >= opt.num_test: # only apply our model to opt.num_test images.
break
model.set_input(data) # unpack data from data loader
model.test() # run inference
visuals = model.get_current_visuals() # get image results
img_path = model.get_image_paths() # get image paths
if i % 5 == 0:
print('processing (%04d)-th image... %s' % (i, img_path))
save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
webpage.save() # save the HTML
================================================
FILE: scripts/download_kp2uv_model.sh
================================================
#!/bin/bash
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
mkdir -p ./checkpoints/kp2uv
MODEL_FILE=./checkpoints/kp2uv/latest_net_Kp2uv.pth
URL=https://www.robots.ox.ac.uk/~erika/retiming/pretrained_models/kp2uv.pth
wget -N $URL -O $MODEL_FILE
================================================
FILE: scripts/run_cartwheel.sh
================================================
#!/bin/bash
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
GPUS=0,1
DATA_PATH=./datasets/cartwheel
bash datasets/prepare_iuv.sh $DATA_PATH
python train.py \
--name cartwheel \
--dataroot $DATA_PATH \
--use_homographies \
--gpu_ids $GPUS
python test.py \
--name cartwheel \
--dataroot $DATA_PATH \
--do_upsampling \
--use_homographies
================================================
FILE: scripts/run_reflection.sh
================================================
#!/bin/bash
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
GPUS=0,1
DATA_PATH=./datasets/reflection
bash datasets/prepare_iuv.sh $DATA_PATH
python train.py \
--name reflection \
--dataroot $DATA_PATH \
--gpu_ids $GPUS
python test.py \
--name reflection \
--dataroot $DATA_PATH \
--do_upsampling
================================================
FILE: scripts/run_splash.sh
================================================
#!/bin/bash
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
GPUS=0,1
DATA_PATH=./datasets/splash
bash datasets/prepare_iuv.sh $DATA_PATH
python train.py \
--name splash \
--dataroot $DATA_PATH \
--batch_size 24 \
--batch_size_upsample 12 \
--use_mask_images \
--gpu_ids $GPUS
python test.py \
--name splash \
--dataroot $DATA_PATH \
--do_upsampling
================================================
FILE: scripts/run_trampoline.sh
================================================
#!/bin/bash
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
GPUS=0,1
DATA_PATH=./datasets/trampoline
bash datasets/prepare_iuv.sh $DATA_PATH
python train.py \
--name trampoline \
--dataroot $DATA_PATH \
--batch_size 16 \
--batch_size_upsample 6 \
--gpu_ids $GPUS
python test.py \
--name trampoline \
--dataroot $DATA_PATH \
--do_upsampling
================================================
FILE: test.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script to save the full outputs of a layered neural renderer (LNR).
Once you have trained the LNR with train.py, you can use this script to save the model's final layer decomposition.
It will load a saved model from '--checkpoints_dir' and save the results to '--results_dir'.
It first creates a model and dataset given the options. It will hard-code some parameters.
It then runs inference for '--num_test' images and save results to an HTML file.
Example (You need to train models first or download pre-trained models from our website):
python test.py --dataroot ./datasets/reflection --name reflection --do_upsampling
If the upsampling module isn't trained (train.py is used with '--n_epochs_upsample 0'), remove --do_upsampling.
Use '--results_dir ' to specify the results directory.
See options/base_options.py and options/test_options.py for more test options.
"""
import os
from options.test_options import TestOptions
from third_party.data import create_dataset
from third_party.models import create_model
from third_party.util.visualizer import save_images, save_videos
from third_party.util import html
import torch
if __name__ == '__main__':
testopt = TestOptions()
testopt.parse()
opt = testopt.parse_dataset_meta()
# hard-code some parameters for test
opt.num_threads = 0 # test code only supports num_threads = 0
opt.batch_size = 1 # test code only supports batch_size = 1
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
model = create_model(opt) # create a model given opt.model and other options
model.setup(opt) # regular setup: load and print networks; create schedulers
# create a website
web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch)) # define the website directory
print('creating web directory', web_dir)
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
video_visuals = None
for i, data in enumerate(dataset):
if i >= opt.num_test: # only apply our model to opt.num_test images.
break
model.set_input(data) # unpack data from data loader
model.test() # run inference
img_path = model.get_image_paths() # get image paths
if i % 5 == 0: # save images to an HTML file
print('processing (%04d)-th image... %s' % (i, img_path))
visuals = model.get_results() # rgba, reconstruction, original, mask
if video_visuals is None:
video_visuals = visuals
else:
for k in video_visuals:
video_visuals[k] = torch.cat((video_visuals[k], visuals[k]))
rgba = { k: visuals[k] for k in visuals if 'rgba' in k }
# save RGBA layers
save_images(webpage, rgba, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
save_videos(webpage, video_visuals, width=opt.display_winsize)
webpage.save() # save the HTML of videos
================================================
FILE: third_party/__init__.py
================================================
================================================
FILE: third_party/data/__init__.py
================================================
"""This package includes all the modules related to data loading and preprocessing
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
You need to implement four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point from data loader.
-- : (optionally) add dataset-specific options and set default options.
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import importlib
import torch.utils.data
from .base_dataset import BaseDataset
from .fast_data_loader import FastDataLoader
def find_dataset_using_name(dataset_name):
"""Import the module "data/[dataset_name]_dataset.py".
In the file, the class called DatasetNameDataset() will
be instantiated. It has to be a subclass of BaseDataset,
and it is case-insensitive.
"""
dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls
if dataset is None:
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
return dataset
def get_option_setter(dataset_name):
"""Return the static method of the dataset class."""
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options
def create_dataset(opt, use_fast_loader=False):
"""Create a dataset given the option.
This function wraps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
If use_fast_loader=False, use the default pytorch dataloader. Otherwise, use FastDatasetLoader.
Example:
>>> from data import create_dataset
>>> dataset = create_dataset(opt)
"""
data_loader = CustomDatasetDataLoader(opt, use_fast_loader=use_fast_loader)
dataset = data_loader.load_data()
return dataset
class CustomDatasetDataLoader():
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
def __init__(self, opt, use_fast_loader=False):
"""Initialize this class
Step 1: create a dataset instance given the name [dataset_mode]
Step 2: create a multi-threaded data loader.
If use_fast_loader=False, use the default pytorch dataloader. Otherwise, use FastDatasetLoader.
"""
self.opt = opt
dataset_class = find_dataset_using_name(opt.dataset_mode)
self.dataset = dataset_class(opt)
print("dataset [%s] was created" % type(self.dataset).__name__)
loader = torch.utils.data.DataLoader
if use_fast_loader:
loader = FastDataLoader
self.dataloader = loader(
self.dataset,
batch_size=opt.batch_size,
shuffle=not opt.serial_batches,
num_workers=int(opt.num_threads))
def load_data(self):
return self
def __len__(self):
"""Return the number of data in the dataset"""
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
"""Return a batch of data"""
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data
================================================
FILE: third_party/data/base_dataset.py
================================================
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
"""
import random
import numpy as np
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
from abc import ABC, abstractmethod
class BaseDataset(data.Dataset, ABC):
"""This class is an abstract base class (ABC) for datasets.
To create a subclass, you need to implement the following four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point.
-- : (optionally) add dataset-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the class; save the options in the class
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
self.opt = opt
self.root = opt.dataroot
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def __len__(self):
"""Return the total number of images in the dataset."""
return 0
@abstractmethod
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns:
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
"""
pass
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.preprocess == 'resize_and_crop':
new_h = new_w = opt.load_size
elif opt.preprocess == 'scale_width_and_crop':
new_w = opt.load_size
new_h = opt.load_size * h // w
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
transform_list = []
if grayscale:
transform_list.append(transforms.Grayscale(1))
if 'resize' in opt.preprocess:
osize = [opt.load_size, opt.load_size]
transform_list.append(transforms.Resize(osize, method))
elif 'scale_width' in opt.preprocess:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
if 'crop' in opt.preprocess:
if params is None:
transform_list.append(transforms.RandomCrop(opt.crop_size))
else:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
if opt.preprocess == 'none':
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
if not opt.no_flip:
if params is None:
transform_list.append(transforms.RandomHorizontalFlip())
elif params['flip']:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
if convert:
transform_list += [transforms.ToTensor()]
if grayscale:
transform_list += [transforms.Normalize((0.5,), (0.5,))]
else:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if h == oh and w == ow:
return img
__print_size_warning(ow, oh, w, h)
return img.resize((w, h), method)
def __scale_width(img, target_size, crop_size, method=Image.BICUBIC):
ow, oh = img.size
if ow == target_size and oh >= crop_size:
return img
w = target_size
h = int(max(target_size * oh / ow, crop_size))
return img.resize((w, h), method)
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
def __print_size_warning(ow, oh, w, h):
"""Print warning information about image size(only print once)"""
if not hasattr(__print_size_warning, 'has_printed'):
print("The image size needs to be a multiple of 4. "
"The loaded image size was (%d, %d), so it was adjusted to "
"(%d, %d). This adjustment will be done to all images "
"whose sizes are not multiples of 4" % (ow, oh, w, h))
__print_size_warning.has_printed = True
================================================
FILE: third_party/data/fast_data_loader.py
================================================
""" Fixes the issue where DataLoader is slow because processes aren't reused
See https://github.com/pytorch/pytorch/issues/15849
Warning: overrides batch sampler.
"""
import torch.utils.data
class _RepeatSampler(object):
""" Sampler that repeats forever.
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)
class FastDataLoader(torch.utils.data.dataloader.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
================================================
FILE: third_party/data/image_folder.py
================================================
"""A modified image folder class
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
so that this class can load images from both current directory and its subdirectories.
"""
import torch.utils.data as data
from PIL import Image
import os
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
'.tif', '.TIF', '.tiff', '.TIFF',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir, max_dataset_size=float("inf")):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
images = sorted(images)
return images[:min(max_dataset_size, len(images))]
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)
================================================
FILE: third_party/models/__init__.py
================================================
"""This package contains modules related to objective functions, optimizations, and network architectures.
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
You need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- : unpack data from dataset and apply preprocessing.
-- : produce intermediate results.
-- : calculate loss, gradients, and update network weights.
-- : (optionally) add model-specific options and set default options.
In the function <__init__>, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
Now you can use the model class by specifying flag '--model dummy'.
See our template model class 'template_model.py' for more details.
"""
import importlib
from .base_model import BaseModel
def find_model_using_name(model_name):
"""Import the module "models/[model_name]_model.py".
In the file, the class called DatasetNameModel() will
be instantiated. It has to be a subclass of BaseModel,
and it is case-insensitive.
"""
model_filename = "models." + model_name + "_model"
modellib = importlib.import_module(model_filename)
model = None
target_model_name = model_name.replace('_', '') + 'model'
for name, cls in modellib.__dict__.items():
if name.lower() == target_model_name.lower() \
and issubclass(cls, BaseModel):
model = cls
if model is None:
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
exit(0)
return model
def get_option_setter(model_name):
"""Return the static method of the model class."""
model_class = find_model_using_name(model_name)
return model_class.modify_commandline_options
def create_model(opt):
"""Create a model given the option.
This function warps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from models import create_model
>>> model = create_model(opt)
"""
model = find_model_using_name(opt.model)
instance = model(opt)
print("model [%s] was created" % type(instance).__name__)
return instance
================================================
FILE: third_party/models/base_model.py
================================================
import os
import torch
from collections import OrderedDict
from abc import ABC, abstractmethod
from . import networks
class BaseModel(ABC):
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- : unpack data from dataset and apply preprocessing.
-- : produce intermediate results.
-- : calculate losses, gradients, and update network weights.
-- : (optionally) add model-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
When creating your custom class, you need to implement your own initialization.
In this function, you should first call
Then, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
"""
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
self.loss_names = []
self.model_names = []
self.visual_names = []
self.optimizers = []
self.image_paths = []
self.metric = 0 # used for learning rate policy 'plateau'
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): includes the data itself and its metadata information.
"""
pass
@abstractmethod
def forward(self):
"""Run forward pass; called by both functions and ."""
pass
@abstractmethod
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
pass
def setup(self, opt):
"""Load and print networks; create schedulers
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
if self.isTrain:
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
if not self.isTrain or opt.continue_train:
load_suffix = opt.epoch
self.load_networks(load_suffix)
self.print_networks(opt.verbose)
def eval(self):
"""Make models eval mode during test time"""
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.eval()
def test(self):
"""Forward function used in test time.
This function wraps function in no_grad() so we don't save intermediate steps for backprop
It also calls to produce additional visualization results
"""
with torch.no_grad():
self.forward()
self.compute_visuals()
def compute_visuals(self):
"""Calculate additional output images for visdom and HTML visualization"""
pass
def get_image_paths(self):
""" Return image paths that are used to load current data"""
return self.image_paths
def update_learning_rate(self):
"""Update learning rates for all the networks; called at the end of every epoch"""
old_lr = self.optimizers[0].param_groups[0]['lr']
for scheduler in self.schedulers:
if self.opt.lr_policy == 'plateau':
scheduler.step(self.metric)
else:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
if old_lr != lr:
print('learning rate %.7f -> %.7f' % (old_lr, lr))
def get_current_visuals(self):
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
return visual_ret
def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
return errors_ret
def save_networks(self, epoch):
"""Save all the networks to the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_net_%s.pth' % (epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.cpu().state_dict(), save_path)
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
def load_networks(self, epoch):
"""Load all the networks from the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
for name in self.model_names:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (epoch, name)
load_path = os.path.join(self.save_dir, load_filename)
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
print('loading the model from %s' % load_path)
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=str(self.device))
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
net.load_state_dict(state_dict)
def print_networks(self, verbose):
"""Print the total number of parameters in the network and (if verbose) network architecture
Parameters:
verbose (bool) -- if verbose: print the network architecture
"""
print('---------- Networks initialized -------------')
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
num_params = 0
num_trainable_params = 0
for param in net.parameters():
num_params += param.numel()
if param.requires_grad:
num_trainable_params += param.numel()
if verbose:
print(net)
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
print('[Network %s] Total number of trainable parameters : %.3f M' % (name, num_trainable_params / 1e6))
print('-----------------------------------------------')
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
================================================
FILE: third_party/models/networks.py
================================================
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
###############################################################################
# Helper Functions
###############################################################################
def get_scheduler(optimizer, opt):
"""Return a learning rate scheduler
Parameters:
optimizer -- the optimizer of the network
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
For 'linear', we keep the same learning rate for the first epochs
and linearly decay the rate to zero over the next epochs.
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
See https://pytorch.org/docs/stable/optim.html for more details.
"""
if opt.lr_policy == 'linear':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif opt.lr_policy == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
def init_net(net, gpu_ids=[]):
"""Initialize a network by registering CPU/GPU device (with multi-GPU support)
Parameters:
net (network) -- the network to be initialized
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
Return an initialized network.
"""
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
return net
================================================
FILE: third_party/util/__init__.py
================================================
"""This package includes a miscellaneous collection of useful helper functions."""
================================================
FILE: third_party/util/html.py
================================================
import dominate
from dominate.tags import meta, h3, table, tr, td, p, a, img, br, video, source
import os
class HTML:
"""This HTML class allows us to save images and write texts into a single HTML file.
It consists of functions such as (add a text header to the HTML file),
(add a row of images to the HTML file), and (save the HTML to the disk).
It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
"""
def __init__(self, web_dir, title, refresh=0):
"""Initialize the HTML classes
Parameters:
web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
with self.doc.head:
meta(http_equiv="refresh", content=str(refresh))
def get_image_dir(self):
"""Return the directory that stores images"""
return self.img_dir
def get_video_dir(self):
"""Return the directory that stores videos"""
return self.vid_dir
def add_header(self, text):
"""Insert a header to the HTML file
Parameters:
text (str) -- the header text
"""
with self.doc:
h3(text)
def add_images(self, ims, txts, links, width=400):
"""add images to the HTML file
Parameters:
ims (str list) -- a list of image paths
txts (str list) -- a list of image names shown on the website
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
"""
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
self.doc.add(self.t)
with self.t:
with tr():
for im, txt, link in zip(ims, txts, links):
with td(style="word-wrap: break-word;", halign="center", valign="top"):
with p():
with a(href=os.path.join('images', link)):
img(style="width:%dpx" % width, src=os.path.join('images', im))
br()
p(txt)
def add_videos(self, vids, txts, links, width=400):
"""add images to the HTML file
Parameters:
ims (str list) -- a list of image paths
txts (str list) -- a list of image names shown on the website
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
"""
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
self.doc.add(self.t)
with self.t:
with tr():
for vid, txt, link in zip(vids, txts, links):
with td(style="word-wrap: break-word;", halign="center", valign="top"):
with p():
with a(href=os.path.join('videos', link)):
with video(style="width:%dpx" % width, controls=True):
source(src=os.path.join('videos', vid), type="video/mp4")
br()
p(txt)
def save(self):
"""save the current content to the HMTL file"""
html_file = '%s/index.html' % self.web_dir
f = open(html_file, 'wt')
f.write(self.doc.render())
f.close()
if __name__ == '__main__': # we show an example usage here.
html = HTML('web/', 'test_html')
html.add_header('hello world')
ims, txts, links = [], [], []
for n in range(4):
ims.append('image_%d.png' % n)
txts.append('text_%d' % n)
links.append('image_%d.png' % n)
html.add_images(ims, txts, links)
html.save()
================================================
FILE: third_party/util/util.py
================================================
"""This module contains simple helper functions """
from __future__ import print_function
import torch
import numpy as np
from PIL import Image
import os
def tensor2im(input_image, imtype=np.uint8):
""""Converts a Tensor array into a numpy image array.
Parameters:
input_image (tensor) -- the input image tensor array
imtype (type) -- the desired type of the converted numpy array
"""
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
def render_png(image, background='checker'):
height, width = image.shape[:2]
if background == 'checker':
checkerboard = np.kron([[136, 120] * (width//128+1), [120, 136] * (width//128+1)] * (height//128+1), np.ones((16, 16)))
checkerboard = np.expand_dims(np.tile(checkerboard, (4, 4)), -1)
bg = checkerboard[:height, :width]
elif background == 'black':
bg = np.zeros([height, width, 1])
else:
bg = 255 * np.ones([height, width, 1])
image = image.astype(np.float32)
alpha = image[:, :, 3:] / 255
rendered_image = alpha * image[:, :, :3] + (1 - alpha) * bg
return rendered_image.astype(np.uint8)
def diagnose_network(net, name='network'):
"""Calculate and print the mean of average absolute(gradients)
Parameters:
net (torch network) -- Torch network
name (str) -- the name of the network
"""
mean = 0.0
count = 0
for param in net.parameters():
if param.grad is not None:
mean += torch.mean(torch.abs(param.grad.data))
count += 1
if count > 0:
mean = mean / count
print(name)
print(mean)
def save_image(image_numpy, image_path, aspect_ratio=1.0):
"""Save a numpy image to the disk
Parameters:
image_numpy (numpy array) -- input numpy array
image_path (str) -- the path of the image
"""
image_pil = Image.fromarray(image_numpy)
h, w, _ = image_numpy.shape
if aspect_ratio > 1.0:
image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
if aspect_ratio < 1.0:
image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
image_pil.save(image_path)
def print_numpy(x, val=True, shp=False):
"""Print the mean, min, max, median, std, and size of a numpy array
Parameters:
val (bool) -- if print the values of the numpy array
shp (bool) -- if print the shape of the numpy array
"""
x = x.astype(np.float64)
if shp:
print('shape,', x.shape)
if val:
x = x.flatten()
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
def mkdirs(paths):
"""create empty directories if they don't exist
Parameters:
paths (str list) -- a list of directory paths
"""
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
"""create a single empty directory if it didn't exist
Parameters:
path (str) -- a single directory path
"""
if not os.path.exists(path):
os.makedirs(path)
================================================
FILE: third_party/util/visualizer.py
================================================
import cv2
import numpy as np
import os
import sys
import ntpath
import time
from . import util, html
from subprocess import Popen, PIPE
if sys.version_info[0] == 2:
VisdomExceptionBase = Exception
else:
VisdomExceptionBase = ConnectionError
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
"""Save images to the disk.
Parameters:
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
image_path (str) -- the string is used to create image paths
aspect_ratio (float) -- the aspect ratio of saved images
width (int) -- the images will be resized to width x width
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
"""
image_dir = webpage.get_image_dir()
short_path = ntpath.basename(image_path[0])
name = os.path.splitext(short_path)[0]
webpage.add_header(name)
ims, txts, links = [], [], []
for label, im_data in visuals.items():
im = util.tensor2im(im_data)
image_name = '%s_%s.png' % (name, label)
save_path = os.path.join(image_dir, image_name)
util.save_image(im, save_path, aspect_ratio=aspect_ratio)
ims.append(image_name)
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=width)
def save_videos(webpage, visuals, width=256):
"""Save videos to the disk.
Parameters:
webpage (the HTML class) -- the HTML webpage class that stores these videos (see html.py for more details)
visuals (OrderedDict) -- an ordered dictionary that stores (name, video (either tensor or numpy) ) pairs
save_dir (str) -- the string is used to create video paths
aspect_ratio (float) -- the aspect ratio of saved images
width (int) -- the images will be resized to width x width
This function will save videos stored in 'visuals' to the HTML file specified by 'webpage'.
"""
video_dir = webpage.get_video_dir()
webpage.add_header('videos')
vids, txts, links = [], [], []
for label, vid_data in sorted(visuals.items()):
video_name = f'{label}.webm'
video_path = os.path.join(video_dir, video_name)
frame_height, frame_width = vid_data.shape[-2:]
video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'vp80'), 25, (frame_width, frame_height))
for i in range(vid_data.shape[0]):
frame = util.tensor2im(vid_data[i:i+1])
if frame.shape[-1] == 4:
# render png
frame = util.render_png(frame, background='checker')
frame = frame[:, :, ::-1] # RGB -> BGR
video.write(frame)
video.release()
cv2.destroyAllWindows()
print("You may see an OpenCV 'vp80 not supported' error message despite the video saving correctly. Please ignore it.")
vids.append(video_name)
txts.append(label)
links.append(video_name)
webpage.add_videos(vids, txts, links, width=width)
class Visualizer():
"""This class includes several functions that can display/save images and print/save logging information.
It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
"""
def __init__(self, opt):
"""Initialize the Visualizer class
Parameters:
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
Step 1: Cache the training/test options
Step 2: connect to a visdom server
Step 3: create an HTML object for saveing HTML filters
Step 4: create a logging file to store training losses
"""
self.opt = opt # cache the option
self.display_id = opt.display_id
self.use_html = opt.isTrain and not opt.no_html
self.win_size = opt.display_winsize
self.name = opt.name
self.port = opt.display_port
self.saved = False
if self.display_id > 0: # connect to a visdom server given and
import visdom
self.ncols = opt.display_ncols
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
if not self.vis.check_connection():
self.create_visdom_connections()
if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
self.img_dir = os.path.join(self.web_dir, 'images')
print('create web directory %s...' % self.web_dir)
util.mkdirs([self.web_dir, self.img_dir])
# create a logging file to store training losses
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
def reset(self):
"""Reset the self.saved status"""
self.saved = False
def create_visdom_connections(self):
"""If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
print('Command: %s' % cmd)
Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
def display_current_results(self, visuals, epoch, save_result):
"""Display current results on visdom; save current results to an HTML file.
Parameters:
visuals (OrderedDict) - - dictionary of images to display or save
epoch (int) - - the current epoch
save_result (bool) - - if save the current results to an HTML file
"""
if self.display_id > 0: # show images in the browser using visdom
ncols = self.ncols
if ncols > 0: # show all the images in one visdom panel
ncols = min(ncols, len(visuals))
h, w = next(iter(visuals.values())).shape[:2]
table_css = """""" % (w, h) # create a table css
# create a table of images.
title = self.name
label_html = ''
label_html_row = ''
images = []
idx = 0
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
label_html_row += '%s | ' % label
images.append(image_numpy.transpose([2, 0, 1]))
idx += 1
if idx % ncols == 0:
label_html += '%s
' % label_html_row
label_html_row = ''
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
while idx % ncols != 0:
images.append(white_image)
label_html_row += ' | '
idx += 1
if label_html_row != '':
label_html += '%s
' % label_html_row
try:
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
padding=2, opts=dict(title=title + ' images'))
label_html = '' % label_html
self.vis.text(table_css + label_html, win=self.display_id + 2,
opts=dict(title=title + ' labels'))
except VisdomExceptionBase:
self.create_visdom_connections()
else: # show each image in a separate visdom panel;
idx = 1
try:
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
win=self.display_id + idx)
idx += 1
except VisdomExceptionBase:
self.create_visdom_connections()
if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
self.saved = True
# save images to the disk
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
util.save_image(image_numpy, img_path)
# update website
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
for n in range(epoch, 0, -1):
label = list(visuals.keys())[0]
img_path = 'epoch%.3d_%s.png' % (n, label)
if not os.path.exists(os.path.join(webpage.img_dir, img_path)):
continue
webpage.add_header('epoch [%d]' % n)
ims, txts, links = [], [], []
for label, image_numpy in visuals.items():
img_path = 'epoch%.3d_%s.png' % (n, label)
ims.append(img_path)
txts.append(label)
links.append(img_path)
webpage.add_images(ims, txts, links, width=self.win_size)
webpage.save()
def plot_current_losses(self, epoch, counter_ratio, losses):
"""display the current losses on visdom display: dictionary of error labels and values
Parameters:
epoch (int) -- current epoch
counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
"""
if not hasattr(self, 'plot_data'):
self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
self.plot_data['X'].append(epoch + counter_ratio)
self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
try:
self.vis.line(
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
Y=np.array(self.plot_data['Y']),
opts={
'title': self.name + ' loss over time',
'legend': self.plot_data['legend'],
'xlabel': 'epoch',
'ylabel': 'loss'},
win=self.display_id)
except VisdomExceptionBase:
self.create_visdom_connections()
# losses: same format as |losses| of plot_current_losses
def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
"""print current losses on console; also save the losses to the disk
Parameters:
epoch (int) -- current epoch
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
t_comp (float) -- computational time per data point (normalized by batch_size)
t_data (float) -- data loading time per data point (normalized by batch_size)
"""
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)
print(message) # print the message
with open(self.log_name, "a") as log_file:
log_file.write('%s\n' % message) # save the message
================================================
FILE: train.py
================================================
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for training a layered neural renderer on a video.
You need to specify the dataset ('--dataroot') and experiment name ('--name').
Example:
python train.py --dataroot ./datasets/reflection --name reflection --gpu_ids 0,1
The script first creates a model, dataset, and visualizer given the options.
It then does standard network training. During training, it also visualizes/saves the images, prints/saves the loss
plot, and saves the model.
Use '--continue_train' to resume your previous training.
The default setting is to first train the base model, which produces the low-resolution result (256x448), and then
train the upsampling module to produce the 512x896 result. If the upsampling module is unnecessary, use
'--n_epochs_upsample 0'.
See options/base_options.py and options/train_options.py for more training options.
"""
import time
from options.train_options import TrainOptions
from third_party.data import create_dataset
from third_party.models import create_model
from third_party.util.visualizer import Visualizer
import torch
import numpy as np
def main():
trainopt = TrainOptions()
trainopt.parse()
opt = trainopt.parse_dataset_meta()
torch.manual_seed(opt.seed)
np.random.seed(opt.seed)
opt.do_upsampling = False # Train low-res network first
dataset = create_dataset(opt, use_fast_loader=True)
dataset_size = len(dataset)
print('The number of training images = %d' % dataset_size)
model = create_model(opt)
model.setup(opt) # regular setup: load and print networks; create schedulers
visualizer = Visualizer(opt)
# Train base model (produces low-resolution output)
train(model, dataset, visualizer, opt)
# Optionally train upsampling module
if opt.n_epochs_upsample > 0:
opt.do_upsampling = True
opt.batch_size = opt.batch_size_upsample
# load dataset for upsampling
dataset = create_dataset(opt, use_fast_loader=True)
dataset_size = len(dataset)
print('The number of training images = %d' % dataset_size)
# set lambdas for upsampling training
opt.lambda_mask = 0
opt.lambda_alpha_l0 = 0
opt.lambda_alpha_l1 = 0
opt.mask_loss_rolloff_epoch = -1
opt.jitter_rgb = 0
# reinit optimizers and schedulers, lambdas
model.setup_train(opt)
# freeze base model and just train upsampling module
model.freeze_basenet()
model.setup(opt)
# update epoch count to resume training
opt.epoch_count = opt.n_epochs + opt.n_epochs_decay + 1
opt.n_epochs += opt.n_epochs_upsample
train(model, dataset, visualizer, opt)
def train(model, dataset, visualizer, opt):
dataset_size = len(dataset)
total_iters = 0 # the total number of training iterations
for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): # outer loop for different epochs; we save the model by , +
epoch_start_time = time.time() # timer for entire epoch
iter_data_time = time.time() # timer for data loading per iteration
epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
model.update_lambdas(epoch)
for i, data in enumerate(dataset): # inner loop within one epoch
iter_start_time = time.time() # timer for computation per iteration
if i % opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
total_iters += opt.batch_size
epoch_iter += opt.batch_size
model.set_input(data)
model.optimize_parameters()
if i % opt.print_freq == 0: # print training losses and save logging information to the disk
losses = model.get_current_losses()
t_comp = (time.time() - iter_start_time) / opt.batch_size
visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
if opt.display_id > 0:
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
iter_data_time = time.time()
if epoch % opt.display_freq == 1: # display images on visdom and save images to a HTML file
save_result = epoch % opt.update_html_freq == 1
model.compute_visuals()
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
if epoch % opt.save_latest_freq == 0: # cache our latest model every epochs
print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
save_suffix = 'epoch_%d' % epoch if opt.save_by_epoch else 'latest'
model.save_networks(save_suffix)
model.update_learning_rate() # update learning rates at the end of every epoch.
print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))
if __name__ == '__main__':
main()