Repository: google/retiming Branch: main Commit: 0e7dbce941b8 Files: 40 Total size: 149.7 KB Directory structure: gitextract_wurthk5z/ ├── LICENSE ├── README.md ├── data/ │ ├── __init__.py │ ├── kpuv_dataset.py │ └── layered_video_dataset.py ├── datasets/ │ ├── download_data.sh │ ├── iuv_crop2full.py │ └── prepare_iuv.sh ├── docs/ │ ├── contributing.md │ └── data.md ├── environment.yml ├── models/ │ ├── __init__.py │ ├── kp2uv_model.py │ ├── lnr_model.py │ └── networks.py ├── options/ │ ├── __init__.py │ ├── base_options.py │ ├── test_options.py │ └── train_options.py ├── requirements.txt ├── run_kp2uv.py ├── scripts/ │ ├── download_kp2uv_model.sh │ ├── run_cartwheel.sh │ ├── run_reflection.sh │ ├── run_splash.sh │ └── run_trampoline.sh ├── test.py ├── third_party/ │ ├── __init__.py │ ├── data/ │ │ ├── __init__.py │ │ ├── base_dataset.py │ │ ├── fast_data_loader.py │ │ └── image_folder.py │ ├── models/ │ │ ├── __init__.py │ │ ├── base_model.py │ │ └── networks.py │ └── util/ │ ├── __init__.py │ ├── html.py │ ├── util.py │ └── visualizer.py └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # Layered Neural Rendering in PyTorch This repository contains training code for the examples in the SIGGRAPH Asia 2020 paper "[Layered Neural Rendering for Retiming People in Video](https://retiming.github.io/)." This is not an officially supported Google product. ## Prerequisites - Linux - Python 3.6+ - NVIDIA GPU + CUDA CuDNN ## Installation This code has been tested with PyTorch 1.4 and Python 3.8. - Install [PyTorch](http://pytorch.org) 1.4 and other dependencies. - For pip users, please type the command `pip install -r requirements.txt`. - For Conda users, you can create a new Conda environment using `conda env create -f environment.yml`. ## Data Processing - Download the data for a video used in our paper (e.g. "reflection"): ```bash bash ./datasets/download_data.sh reflection ``` - Or alternatively, download all the data by specifying `all`. - Download the pretrained keypoint-to-UV model weights: ```bash bash ./scripts/download_kp2uv_model.sh ``` The pretrained model will be saved at `./checkpoints/kp2uv/latest_net_Kp2uv.pth`. - Generate the UV maps from the keypoints: ```bash bash datasets/prepare_iuv.sh ./datasets/reflection ``` ## Training - To train a model on a video (e.g. "reflection"), run: ```bash python train.py --name reflection --dataroot ./datasets/reflection --gpu_ids 0,1 ``` - To view training results and loss plots, visit the URL http://localhost:8097. Intermediate results are also at `./checkpoints/reflection/web/index.html`. You can find more scripts in the `scripts` directory, e.g. `run_${VIDEO}.sh` which combines data processing, training, and saving layer results for a video. **Note**: - It is recommended to use >=2 GPUs, each with >=16GB memory. - The training script first trains the low-resolution model for `--num_epochs` at `--batch_size`, and then trains the upsampling module for `--num_epochs_upsample` at `--batch_size_upsample`. If you do not need the upsampled result, pass `--num_epochs_upsample 0`. - Training the upsampling module requires ~2.5x memory as the low-resolution model, so set `batch_size_upsample` accordingly. The provided scripts set the batch sizes appropriately for 2 GPUs with 16GB memory. - GPU memory scales linearly with the number of layers. ## Saving layer results from a trained model - Run the trained model: ```bash python test.py --name reflection --dataroot ./datasets/reflection --do_upsampling ``` - The results (RGBA layers, videos) will be saved to `./results/reflection/test_latest/`. - Passing `--do_upsampling` uses the results of the upsampling module. If the upsampling module hasn't been trained (`num_epochs_upsample=0`), then remove this flag. ## Custom video To train on your own video, you will have to preprocess the data: 1. Extract the frames, e.g. ``` mkdir ./datasets/my_video && cd ./datasets/my_video mkdir rgb && ffmpeg -i video.mp4 rgb/%04d.png ``` 1. Resize the video to 256x448 and save the frames in `my_video/rgb_256`, and resize the video to 512x896 and save in `my_video/rgb_512`. 1. Run [AlphaPose and Pose Tracking](https://github.com/MVIG-SJTU/AlphaPose) on the frames. Save results as `my_video/keypoints.json` 1. Create `my_video/metadata.json` following [these instructions](docs/data.md). 1. If your video has camera motion, either (1) stabilize the video, or (2) maintain the camera motion by computing homographies and saving as `my_video/homographies.txt`. See `scripts/run_cartwheel.sh` for a training example with camera motion, and see `./datasets/cartwheel/homographies.txt` for formatting. **Note**: Videos that are suitable for our method have the following attributes: - Static camera or limited camera motion that can be represented with a homography. - Limited number of people, due to GPU memory limitations. We tested up to 7 people and 7 layers. Multiple people can be grouped onto the same layer, though they cannot be individually retimed. - People that move relative to the background (static people will be absorbed into the background layer). - We tested a video length of up to 200 frames (~7 seconds). ## Citation If you use this code for your research, please cite the following paper: ``` @inproceedings{lu2020, title={Layered Neural Rendering for Retiming People in Video}, author={Lu, Erika and Cole, Forrester and Dekel, Tali and Xie, Weidi and Zisserman, Andrew and Salesin, David and Freeman, William T and Rubinstein, Michael}, booktitle={SIGGRAPH Asia}, year={2020} } ``` ## Acknowledgments This code is based on [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). ================================================ FILE: data/__init__.py ================================================ ================================================ FILE: data/kpuv_dataset.py ================================================ # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from third_party.data.base_dataset import BaseDataset from PIL import Image, ImageDraw import json import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import os import torchvision.transforms as transforms class KpuvDataset(BaseDataset): """A dataset class for keypoint data. It assumes that the directory specified by 'dataroot' contains the file 'keypoints.json'. """ @staticmethod def modify_commandline_options(parser, is_train): parser.add_argument('--inp_size', type=int, default=256, help='image size') return parser def __init__(self, opt): """Initialize this dataset class by reading in keypoints. Parameters: opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseDataset.__init__(self, opt) self.inp_size = opt.inp_size inner_crop_size = int(.75*self.inp_size) kps = [] image_paths = [] with open(os.path.join(self.root, 'keypoints.json'), 'rb') as f: kp_data = json.load(f) for frame in sorted(kp_data): for skeleton in kp_data[frame]: id = skeleton['idx'] image_paths.append(f'{id:02d}_{frame}') kp = np.array(skeleton['keypoints']).reshape(17, 3) kp = self.crop_kps(kp, crop_size=self.inp_size, inner_crop_size=inner_crop_size) kps.append(kp) self.keypoints = kps self.image_paths = image_paths # filenames for output UVs # for keypoint rendering self.cmap = plt.cm.get_cmap("hsv", 17) self.color_seq = np.array([ 9, 14, 6, 7, 13, 16, 2, 11, 3, 5, 10, 15, 1, 8, 0, 12, 4]) self.pairs = [[0,1],[0,2],[1,3],[2,4],[5,6],[5,7],[7,9],[6,8],[8,10],[11,12],[11,13],[13,15],[12,14],[14,16],[6,12],[5,11]] def __getitem__(self, index): """Return a data point and its metadata information. Parameters: index - - a random integer for data indexing Returns a dictionary that contains keypoints and path keyoints (tensor) - - an RGB image representing a skeleton path (str) - - an identifying filename that can be used for saving the result """ uv_path = self.image_paths[index] # output UV path kps = self.keypoints[index] # draw keypoints kp_im = Image.new(size=(self.inp_size, self.inp_size), mode='RGB') draw = ImageDraw.Draw(kp_im) self.render_kps(kps, draw) kp_im = transforms.ToTensor()(kp_im) kp_im = 2 * kp_im - 1 return {'keypoints': kp_im, 'path': uv_path} def __len__(self): """Return the total number of images.""" return len(self.image_paths) def crop_kps(self, kps, crop_size=256, inner_crop_size=192): """'Crops' keypoints to fit into ['crop_size', 'crop_size']. Parameters: kps - - a numpy array of shape [17, 2], where the keypoint order is X, Y crop_size - - the new size of the world, which the keypoints will be centered inside inner_crop_size - - the box size that the keypoints will fit inside (must be <=crop_size) Returns keypoints mapped to fit inside a box of 'inner_crop_size', centered in 'crop_size' """ # get coordinates of bounding box, in original image coordinates left = kps[:, 0].min() right = kps[:, 0].max() top = kps[:, 1].min() bottom = kps[:, 1].max() # map keypoints keypoints = kps.copy() center = ((right + left) // 2, (bottom + top) // 2) # first place center of bounding box at origin keypoints[:, 0] -= center[0] keypoints[:, 1] -= center[1] # scale bounding box to inner_crop_size scale = float(inner_crop_size) / max(right - left, bottom - top) keypoints[:, :2] *= scale # move center to crop_size//2 keypoints[:, :2] += crop_size // 2 new_kps = keypoints return new_kps def render_kps(self, keypoints, draw, thresh=1., min_weight=0.25): """Render skeleton as RGB image. Parameters: keypoints - - a numpy array of shape [17, 3], where the keypoint order is X, Y, score draw - - an ImageDraw object, which the keypoints will be drawn onto thresh - - keypoints with a confidence score below this value will have a color weighted by the score min_weight - - minimum weighting for color (scores will be mapped to the range [min_weight, 1]) """ # first draw keypoints ksize = 3 for i in range(keypoints.shape[0]): x1 = keypoints[i,0] - ksize x2 = keypoints[i,0] + ksize y1 = keypoints[i,1] - ksize y2 = keypoints[i,1] + ksize if x1 < 0 or y1 < 0 or x2 > self.inp_size or y2 > self.inp_size: continue color = np.array(self.cmap(self.color_seq[i])) if keypoints.shape[1] > 2: score = keypoints[i,2] if score < thresh: # weight color by confidence score # first map [0,1] -> [min_weight, 1] alpha_weight = score * (1.-min_weight) + min_weight color[:3] *= alpha_weight color = (255*color).astype('uint8') draw.rectangle([x1, y1, x2, y2], fill=tuple(color)) # now draw segments for pair in self.pairs: x1 = keypoints[pair[0],0] y1 = keypoints[pair[0],1] x2 = keypoints[pair[1],0] y2 = keypoints[pair[1],1] if x1 < 0 or y1 < 0 or x2 < 0 or y2 < 0 or x1 > self.inp_size or y1 > self.inp_size or x2 > self.inp_size or y2 > self.inp_size: continue avg_color = .5*(np.array(self.cmap(self.color_seq[pair[0]])) + np.array(self.cmap(self.color_seq[pair[1]]))) if keypoints.shape[1] > 2: score = min(keypoints[pair[0],2], keypoints[pair[1],2]) if score < thresh: alpha_weight = score * (1.-min_weight) + min_weight avg_color[:3] *= alpha_weight # alpha channel weigh by score avg_color = (255*avg_color).astype('uint8') draw.line([x1,y1,x2,y2], fill=tuple(avg_color), width=3) ================================================ FILE: data/layered_video_dataset.py ================================================ # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import cv2 from third_party.data.base_dataset import BaseDataset from third_party.data.image_folder import make_dataset from PIL import Image import torchvision.transforms as transforms import torch.nn.functional as F import os import torch import numpy as np import json class LayeredVideoDataset(BaseDataset): """A dataset class for video layers. It assumes that the directory specified by 'dataroot' contains metadata.json, and the directories iuv, rgb_256, and rgb_512. The 'iuv' directory should contain directories named 01, 02, etc. for each layer, each containing per-frame UV images. """ @staticmethod def modify_commandline_options(parser, is_train): parser.add_argument('--height', type=int, default=256, help='image height') parser.add_argument('--width', type=int, default=448, help='image width') parser.add_argument('--trimap_width', type=int, default=20, help='trimap gray area width') parser.add_argument('--use_mask_images', action='store_true', default=False, help='use custom masks') parser.add_argument('--use_homographies', action='store_true', default=False, help='handle camera motion') return parser def __init__(self, opt): """Initialize this dataset class. Parameters: opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseDataset.__init__(self, opt) rgbdir = os.path.join(opt.dataroot, 'rgb_256') if opt.do_upsampling: rgbdir = os.path.join(opt.dataroot, 'rgb_512') uvdir = os.path.join(opt.dataroot, 'iuv') self.image_paths = sorted(make_dataset(rgbdir, opt.max_dataset_size)) n_images = len(self.image_paths) layers = sorted(os.listdir(uvdir)) layers = [l for l in layers if l.isdigit()] self.iuv_paths = [] for l in layers: layer_iuv_paths = sorted(make_dataset(os.path.join(uvdir, l), n_images)) if len(layer_iuv_paths) != n_images: print(f'UNEQUAL NUMBER OF IMAGES AND IUVs: {len(layer_iuv_paths)} and {n_images}') self.iuv_paths.append(layer_iuv_paths) # set up per-frame compositing order with open(os.path.join(opt.dataroot, 'metadata.json')) as f: metadata = json.load(f) if 'composite_order' in metadata: self.composite_order = metadata['composite_order'] else: self.composite_order = [tuple(range(1, 1 + len(layers)))] * n_images if opt.use_homographies: self.init_homographies(os.path.join(opt.dataroot, 'homographies.txt'), n_images) def __getitem__(self, index): """Return a data point and its metadata information. Parameters: index - - a random integer for data indexing Returns a dictionary that contains: image (tensor) - - the original RGB frame to reconstruct uv_map (tensor) - - the UV maps for all layers, concatenated channel-wise mask (tensor) - - the trimaps for all layers, concatenated channel-wise pids (tensor) - - the person IDs for all layers, concatenated channel-wise image_path (str) - - image path """ # Read the target image. image_path = self.image_paths[index] target_image = self.load_and_process_image(image_path) # Read the layer IUVs and convert to network inputs. people_layers = [self.load_and_process_iuv(self.iuv_paths[l - 1][index], index) for l in self.composite_order[index]] iuv_h, iuv_w = people_layers[0][0].shape[-2:] # Create the background layer UV from homographies. background_layer = self.get_background_inputs(index, iuv_w, iuv_h) uv_maps, masks, pids = zip(*([background_layer] + people_layers)) uv_maps = torch.cat(uv_maps) # [L*2, H, W] masks = torch.stack(masks) # [L, H, W] pids = torch.stack(pids) # [L, H, W] if self.opt.use_mask_images: for i in range(1, len(people_layers)): mask_path = os.path.join(self.opt.dataroot, 'mask', f'{i:02d}', os.path.basename(image_path)) if os.path.exists(mask_path): mask = Image.open(mask_path).convert('L').resize((masks.shape[-1], masks.shape[-2])) mask = transforms.ToTensor()(mask) * 2 - 1 masks[i] = mask transform_params = self.get_params(do_jitter=self.opt.phase=='train') pids = self.apply_transform(pids, transform_params, 'nearest') masks = self.apply_transform(masks, transform_params, 'bilinear') uv_maps = self.apply_transform(uv_maps, transform_params, 'nearest') image_transform_params = transform_params if self.opt.do_upsampling: image_transform_params = { p: transform_params[p] * 2 for p in transform_params} target_image = self.apply_transform(target_image, image_transform_params, 'bilinear') return {'image': target_image, 'uv_map': uv_maps, 'mask': masks, 'pids': pids, 'image_path': image_path} def __len__(self): """Return the total number of images.""" return len(self.image_paths) def get_params(self, do_jitter=False, jitter_rate=0.75): """Get transformation parameters.""" if do_jitter: if np.random.uniform() > jitter_rate or self.opt.do_upsampling: scale = 1. else: scale = np.random.uniform(1, 1.25) jitter_size = (scale * np.array([self.opt.height, self.opt.width])).astype(np.int) start1 = np.random.randint(jitter_size[0] - self.opt.height + 1) start2 = np.random.randint(jitter_size[1] - self.opt.width + 1) else: jitter_size = np.array([self.opt.height, self.opt.width]) start1 = 0 start2 = 0 crop_pos = np.array([start1, start2]) crop_size = np.array([self.opt.height, self.opt.width]) return {'jitter size': jitter_size, 'crop pos': crop_pos, 'crop size': crop_size} def apply_transform(self, data, params, interp_mode='bilinear'): """Apply the transform to the data tensor.""" tensor_size = params['jitter size'].tolist() crop_pos = params['crop pos'] crop_size = params['crop size'] data = F.interpolate(data.unsqueeze(0), size=tensor_size, mode=interp_mode).squeeze(0) data = data[:, crop_pos[0]:crop_pos[0] + crop_size[0], crop_pos[1]:crop_pos[1] + crop_size[1]] return data def init_homographies(self, homography_path, n_images): """Read homography file and set up homography data.""" with open(homography_path) as f: h_data = f.readlines() h_scale = h_data[0].rstrip().split(' ') self.h_scale_x = int(h_scale[1]) self.h_scale_y = int(h_scale[2]) h_bounds = h_data[1].rstrip().split(' ') self.h_bounds_x = [float(h_bounds[1]), float(h_bounds[2])] self.h_bounds_y = [float(h_bounds[3]), float(h_bounds[4])] homographies = h_data[2:2 + n_images] homographies = [torch.from_numpy(np.array(line.rstrip().split(' ')).astype(np.float32).reshape(3, 3)) for line in homographies] self.homographies = homographies def load_and_process_image(self, im_path): """Read image file and return as tensor in range [-1, 1].""" image = Image.open(im_path).convert('RGB') image = transforms.ToTensor()(image) image = 2 * image - 1 return image def load_and_process_iuv(self, iuv_path, i): """Read IUV file and convert to network inputs.""" iuv_map = Image.open(iuv_path).convert('RGBA') iuv_map = transforms.ToTensor()(iuv_map) uv_map, mask, pids = self.iuv2input(iuv_map, i) return uv_map, mask, pids def iuv2input(self, iuv, index): """Create network inputs from IUV. Parameters: iuv - - a tensor of shape [4, H, W], where the channels are: body part ID, U, V, person ID. index - - index of iuv Returns: uv (tensor) - - a UV map for a single layer, ready to pass to grid sampler (values in range [-1,1]) mask (tensor) - - the corresponding mask person_id (tensor) - - the person IDs grid sampler indexes into texture map of size tile_width x tile_width*n_textures """ # Extract body part and person IDs. part_id = (iuv[0] * 255 / 10).round() part_id[part_id > 24] = 24 part_id_mask = (part_id > 0).float() person_id = (255 - 255 * iuv[-1]).round() # person ID is saved as 255 - person_id person_id *= part_id_mask # background id is 0 maxId = self.opt.n_textures // 24 person_id[person_id>maxId] = maxId # Convert body part ID to texture map ID. # Essentially, each of the 24 body parts for each person, plus the background have their own texture 'tile' # The tiles are concatenated horizontally to create the texture map. tex_id = part_id + part_id_mask * 24 * (person_id - 1) uv = iuv[1:3] # Convert the per-body-part UVs to UVs that correspond to the full texture map. uv[0] += tex_id # Get the mask. bg_mask = (tex_id == 0).float() mask = 1.0 - bg_mask mask = mask * 2 - 1 # make 1 the foreground and -1 the background mask mask = self.mask2trimap(mask) # Composite background UV behind person UV. h, w = iuv.shape[1:] bg_uv = self.get_background_uv(index, w, h) uv = bg_mask * bg_uv + (1 - bg_mask) * uv # Map to [-1, 1] range. uv[0] /= self.opt.n_textures uv = uv * 2 - 1 uv = torch.clamp(uv, -1, 1) return uv, mask, person_id def get_background_inputs(self, index, w, h): """Return data for background layer at 'index'.""" uv = self.get_background_uv(index, w, h) # normalize to correct range, of full texture atlas uv[0] /= self.opt.n_textures uv = uv * 2 - 1 # [0,1] -> [-1,1] uv = torch.clamp(uv, -1, 1) mask = -torch.ones(*uv.shape[1:]) pids = torch.zeros(*uv.shape[1:]) return uv, mask, pids def get_background_uv(self, index, w, h): """Return background layer UVs at 'index' (output range [0, 1]).""" ramp_u = torch.linspace(0, 1, steps=w).unsqueeze(0).repeat(h, 1) ramp_v = torch.linspace(0, 1, steps=h).unsqueeze(-1).repeat(1, w) ramp = torch.stack([ramp_u, ramp_v], 0) if hasattr(self, 'homographies'): # scale to [0, orig width/height] ramp[0] *= self.h_scale_x ramp[1] *= self.h_scale_y # apply homography ramp = ramp.reshape(2, -1) # [2, H, W] H = self.homographies[index] [xt, yt] = self.transform2h(ramp[0], ramp[1], torch.inverse(H)) # scale from world to [0,1] xt -= self.h_bounds_x[0] xt /= (self.h_bounds_x[1] - self.h_bounds_x[0]) yt -= self.h_bounds_y[0] yt /= (self.h_bounds_y[1] - self.h_bounds_y[0]) # restore shape ramp = torch.stack([xt.reshape(h, w), yt.reshape(h, w)], 0) return ramp def transform2h(self, x, y, m): """Applies 2d homogeneous transformation.""" A = torch.matmul(m, torch.stack([x, y, torch.ones(len(x))])) xt = A[0, :] / A[2, :] yt = A[1, :] / A[2, :] return xt, yt def mask2trimap(self, mask): """Convert binary mask to trimap with values in [-1, 0, 1].""" fg_mask = (mask > 0).float() bg_mask = (mask < 0).float() trimap_width = getattr(self.opt, 'trimap_width', 20) trimap_width *= bg_mask.shape[-1] / self.opt.width trimap_width = int(trimap_width) bg_mask = cv2.erode(bg_mask.numpy(), kernel=np.ones((trimap_width, trimap_width)), iterations=1) bg_mask = torch.from_numpy(bg_mask) mask = fg_mask - bg_mask return mask ================================================ FILE: datasets/download_data.sh ================================================ #!/bin/bash # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. NAME=$1 if [[ $NAME != "cartwheel" && $NAME != "reflection" && $NAME != "splash" && $NAME != "trampoline" && $NAME != "all" ]]; then echo "Available videos are: cartwheel, reflection, splash, trampoline" exit 1 fi if [[ $NAME == "all" ]]; then declare -a NAMES=("cartwheel" "reflection" "splash" "trampoline") else declare -a NAMES=($NAME) fi for NAME in "${NAMES[@]}" do echo "Specified [$NAME]" URL=https://www.robots.ox.ac.uk/~erika/retiming/data/$NAME.zip ZIP_FILE=./datasets/$NAME.zip TARGET_DIR=./datasets/$NAME/ wget -N $URL -O $ZIP_FILE mkdir $TARGET_DIR unzip $ZIP_FILE -d ./datasets/ rm $ZIP_FILE done ================================================ FILE: datasets/iuv_crop2full.py ================================================ # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Convert UV crops to full UV maps.""" import os import sys import json from PIL import Image import numpy as np def place_crop(crop, image, center_x, center_y): """Place the crop in the image at the specified location.""" im_height, im_width = image.shape[:2] crop_height, crop_width = crop.shape[:2] left = center_x - crop_width // 2 right = left + crop_width top = center_y - crop_height // 2 bottom = top + crop_height adjusted_crop = crop # remove regions of crop that go beyond image bounds if left < 0: adjusted_crop = adjusted_crop[:, -left:] if right > im_width: adjusted_crop = adjusted_crop[:, :(im_width - right)] if top < 0: adjusted_crop = adjusted_crop[-top:] if bottom > im_height: adjusted_crop = adjusted_crop[:(im_height - bottom)] crop_mask = (adjusted_crop > 0).astype(crop.dtype).sum(-1, keepdims=True) image[max(0, top):min(im_height, bottom), max(0, left):min(im_width, right)] *= (1 - crop_mask) image[max(0, top):min(im_height, bottom), max(0, left):min(im_width, right)] += adjusted_crop return image def crop2full(keypoints_path, metadata_path, uvdir, outdir): """Create each frame's layer UVs from predicted UV crops""" with open(keypoints_path) as f: kp_data = json.load(f) # Get all people ids people_ids = set() for frame in kp_data: for skeleton in kp_data[frame]: people_ids.add(skeleton['idx']) people_ids = sorted(list(people_ids)) with open(metadata_path) as f: metadata = json.load(f) orig_size = np.array(metadata['alphapose_input_size'][::-1]) out_size = np.array(metadata['size_LR'][::-1]) if 'people_layers' in metadata: people_layers = metadata['people_layers'] else: people_layers = [[pid] for pid in people_ids] # Create output directories. for layer_i in range(1, 1 + len(people_layers)): os.makedirs(os.path.join(outdir, f'{layer_i:02d}'), exist_ok=True) print(f'Writing UVs to {outdir}') for frame in sorted(kp_data): for layer_i, layer in enumerate(people_layers, 1): out_path = os.path.join(outdir, f'{layer_i:02d}', frame) sys.stdout.flush() sys.stdout.write('processing frame %s\r' % out_path) uv_map = np.zeros([out_size[0], out_size[1], 4]) for person_id in layer: matches = [p for p in kp_data[frame] if p['idx'] == person_id] if len(matches) == 0: # person doesn't appear in this frame continue skeleton = matches[0] kps = np.array(skeleton['keypoints']).reshape(17, 3) # Get kps bounding box. left = kps[:, 0].min() right = kps[:, 0].max() top = kps[:, 1].min() bottom = kps[:, 1].max() height = bottom - top width = right - left orig_crop_size = max(height, width) orig_center_x = (left + right) // 2 orig_center_y = (top + bottom) // 2 # read predicted uv map uv_crop_path = os.path.join(uvdir, f'{person_id:02d}_{os.path.basename(out_path)[:-4]}_output_uv.png') if os.path.exists(uv_crop_path): uv_crop = np.array(Image.open(uv_crop_path)) else: uv_crop = np.zeros([256, 256, 3]) # add person ID channel person_mask = (uv_crop[..., 0:1] > 0).astype('uint8') person_ids = (255 - person_id) * person_mask uv_crop = np.concatenate([uv_crop, person_ids], -1) # scale crop to desired output size # 256 is the crop size, 192 is the inner crop size out_crop_size = orig_crop_size * 256./192 * out_size / orig_size out_crop_size = out_crop_size.astype(np.int) uv_crop = uv_crop.astype(np.uint8) uv_crop = np.array(Image.fromarray(uv_crop).resize((out_crop_size[1], out_crop_size[0]), resample=Image.NEAREST)) # scale center coordinate accordingly out_center_x = (orig_center_x * out_size[1] / orig_size[1]).astype(np.int) out_center_y = (orig_center_y * out_size[0] / orig_size[0]).astype(np.int) # Place UV crop in full UV map and save. uv_map = place_crop(uv_crop, uv_map, out_center_x, out_center_y) uv_map = Image.fromarray(uv_map.astype('uint8')) uv_map.save(out_path) if __name__ == "__main__": import argparse arguments = argparse.ArgumentParser() arguments.add_argument('--dataroot', type=str) opt = arguments.parse_args() keypoints_path = os.path.join(opt.dataroot, 'keypoints.json') metadata_path = os.path.join(opt.dataroot, 'metadata.json') uvdir = os.path.join(opt.dataroot, 'kp2uv/test_latest/images') outdir = os.path.join(opt.dataroot, 'iuv') crop2full(keypoints_path, metadata_path, uvdir, outdir) ================================================ FILE: datasets/prepare_iuv.sh ================================================ #!/bin/bash # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. DATA_PATH=$1 # Predict UVs from keypoints and save the crops. python run_kp2uv.py --model kp2uv --dataroot $DATA_PATH --results_dir $DATA_PATH # Convert the cropped UVs to full UV maps. python datasets/iuv_crop2full.py --dataroot $DATA_PATH ================================================ FILE: docs/contributing.md ================================================ # How to Contribute We'd love to accept your patches and contributions to this project. There are just a few small guidelines you need to follow. ## Contributor License Agreement Contributions to this project must be accompanied by a Contributor License Agreement (CLA). You (or your employer) retain the copyright to your contribution; this simply gives us permission to use and redistribute your contributions as part of the project. Head over to to see your current agreements on file or to sign a new one. You generally only need to submit a CLA once, so if you've already submitted one (even if it was for a different project), you probably don't need to do it again. ## Code reviews All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. ## Community Guidelines This project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/). ================================================ FILE: docs/data.md ================================================ ### Data The data directory for a video is structured as follows: ``` video_name/ |-- rgb_256/ | |-- 0001.png, etc. |-- rgb_512/ | |-- 0001.png, etc. |-- mask/ (optional) |-- |-- 01, etc. |-- |-- |-- 0001.png, etc. |-- keypoints.json |-- metadata.json |-- homographies.txt (optional) ``` - `metadata.json` contains a dictionary: ``` 'alphapose_input_size': [width, height] # size of frames input to AlphaPose 'size_LR': [width, height] # size of low-resolution frames (multiple of 16; height should be 256) 'n_textures': int # number of texture maps required, calculated by 24*num_people + 1 'composite_order': [[1, 2, 3], [1, 3, 2], ... ] # optional per-frame back-to-front layer compositing order ``` - `keypoints.json` is in the format output by the [AlphaPose Pose Tracker](https://github.com/MVIG-SJTU/AlphaPose). See [here](https://github.com/MVIG-SJTU/AlphaPose/tree/master/trackers/PoseFlow) for details. ================================================ FILE: environment.yml ================================================ name: retiming channels: - pytorch - defaults dependencies: - python=3.8 - pytorch=1.4.0 - pip: - dominate==2.4.0 - torchvision==0.5.0 - Pillow>=6.1.0 - numpy==1.19.2 - visdom==0.1.8 - opencv-python>=4.2.0 - matplotlib ================================================ FILE: models/__init__.py ================================================ ================================================ FILE: models/kp2uv_model.py ================================================ # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from third_party.models.base_model import BaseModel from . import networks class Kp2uvModel(BaseModel): """This class implements the keypoint-to-UV model (inference only).""" @staticmethod def modify_commandline_options(parser, is_train=True): parser.set_defaults(dataset_mode='kpuv') return parser def __init__(self, opt): """Initialize this model class. Parameters: opt -- test options """ BaseModel.__init__(self, opt) self.visual_names = ['keypoints', 'output_uv'] self.model_names = ['Kp2uv'] self.netKp2uv = networks.define_kp2uv(gpu_ids=self.gpu_ids) self.isTrain = False # only test mode supported # Our program will automatically call to define schedulers, load networks, and print networks def set_input(self, input): """Unpack input data from the dataloader. Parameters: input: a dictionary that contains the data itself and its metadata information. """ self.keypoints = input['keypoints'].to(self.device) self.image_paths = input['path'] def forward(self): """Run forward pass. This will be called by .""" output = self.netKp2uv.forward(self.keypoints) self.output_uv = self.output2rgb(output) def output2rgb(self, output): """Convert network outputs to RGB image.""" pred_id, pred_uv = output _, pred_id_class = pred_id.max(1) pred_id_class = pred_id_class.unsqueeze(1) # extract UV from pred_uv (48 channels); select based on class ID selected_uv = -1 * torch.ones(pred_uv.shape[0], 2, pred_uv.shape[2], pred_uv.shape[3], device=pred_uv.device) for partid in range(1, 25): mask = (pred_id_class == partid).float() selected_uv *= (1. - mask) selected_uv += mask * pred_uv[:, (partid - 1) * 2:(partid - 1) * 2 + 2] pred_uv = selected_uv rgb = torch.cat([pred_id_class.float() * 10 / 255. * 2 - 1, pred_uv], 1) return rgb def optimize_parameters(self): pass ================================================ FILE: models/lnr_model.py ================================================ # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from third_party.models.base_model import BaseModel from . import networks import numpy as np import torch.nn.functional as F class LnrModel(BaseModel): """This class implements the layered neural rendering model for decomposing a video into layers.""" @staticmethod def modify_commandline_options(parser, is_train=True): parser.set_defaults(dataset_mode='layered_video') parser.add_argument('--texture_res', type=int, default=16, help='texture resolution') parser.add_argument('--texture_channels', type=int, default=16, help='# channels for neural texture') parser.add_argument('--n_textures', type=int, default=25, help='# individual texture maps, 24 per person (1 per body part) + 1 for background') if is_train: parser.add_argument('--lambda_alpha_l1', type=float, default=0.01, help='alpha L1 sparsity loss weight') parser.add_argument('--lambda_alpha_l0', type=float, default=0.005, help='alpha L0 sparsity loss weight') parser.add_argument('--alpha_l1_rolloff_epoch', type=int, default=200, help='turn off L1 alpha sparsity loss weight after this epoch') parser.add_argument('--lambda_mask', type=float, default=50, help='layer matting loss weight') parser.add_argument('--mask_thresh', type=float, default=0.02, help='turn off masking loss when error falls below this value') parser.add_argument('--mask_loss_rolloff_epoch', type=int, default=-1, help='decrease masking loss after this epoch; if <0, use mask_thresh instead') parser.add_argument('--n_epochs_upsample', type=int, default=500, help='number of epochs to train the upsampling module') parser.add_argument('--batch_size_upsample', type=int, default=16, help='batch size for upsampling') parser.add_argument('--jitter_rgb', type=float, default=0.2, help='amount of jitter to add to RGB') parser.add_argument('--jitter_epochs', type=int, default=400, help='number of epochs to jitter RGB') parser.add_argument('--do_upsampling', action='store_true', help='whether to use upsampling module') return parser def __init__(self, opt): """Initialize this model class. Parameters: opt -- training/test options """ BaseModel.__init__(self, opt) # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. self.visual_names = ['target_image', 'reconstruction', 'rgba_vis', 'alpha_vis', 'input_vis'] self.model_names = ['LNR'] self.netLNR = networks.define_LNR(opt.num_filters, opt.texture_channels, opt.texture_res, opt.n_textures, gpu_ids=self.gpu_ids) self.do_upsampling = opt.do_upsampling if self.isTrain: self.setup_train(opt) # Our program will automatically call to define schedulers, load networks, and print networks def setup_train(self, opt): """Setup the model for training mode.""" print('setting up model') # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. self.loss_names = ['total', 'recon', 'alpha_reg', 'mask'] self.visual_names = ['target_image', 'reconstruction', 'rgba_vis', 'alpha_vis', 'input_vis'] self.do_upsampling = opt.do_upsampling if not self.do_upsampling: self.visual_names += ['mask_vis'] self.criterionLoss = torch.nn.L1Loss() self.criterionLossMask = networks.MaskLoss().to(self.device) self.lambda_mask = opt.lambda_mask self.lambda_alpha_l0 = opt.lambda_alpha_l0 self.lambda_alpha_l1 = opt.lambda_alpha_l1 self.mask_loss_rolloff_epoch = opt.mask_loss_rolloff_epoch self.jitter_rgb = opt.jitter_rgb self.do_upsampling = opt.do_upsampling self.optimizer = torch.optim.Adam(self.netLNR.parameters(), lr=opt.lr) self.optimizers = [self.optimizer] def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input: a dictionary that contains the data itself and its metadata information. """ self.target_image = input['image'].to(self.device) if self.isTrain and self.jitter_rgb > 0: # add brightness jitter to rgb self.target_image += self.jitter_rgb * torch.randn(self.target_image.shape[0], 1, 1, 1).to(self.device) self.target_image = torch.clamp(self.target_image, -1, 1) self.input_uv = input['uv_map'].to(self.device) self.input_id = input['pids'].to(self.device) self.mask = input['mask'].to(self.device) self.image_paths = input['image_path'] def gen_crop_params(self, orig_h, orig_w, crop_size=256): """Generate random square cropping parameters.""" starty = np.random.randint(orig_h - crop_size + 1) startx = np.random.randint(orig_w - crop_size + 1) endy = starty + crop_size endx = startx + crop_size return starty, endy, startx, endx def forward(self): """Run forward pass. This will be called by both functions and .""" if self.do_upsampling: input_uv_up = F.interpolate(self.input_uv, scale_factor=2, mode='bilinear') crop_params = None if self.isTrain: # Take random crop to decrease memory requirement. crop_params = self.gen_crop_params(*input_uv_up.shape[-2:]) starty, endy, startx, endx = crop_params self.target_image = self.target_image[:, :, starty:endy, startx:endx] outputs = self.netLNR.forward(self.input_uv, self.input_id, uv_map_upsampled=input_uv_up, crop_params=crop_params) else: outputs = self.netLNR(self.input_uv, self.input_id) self.reconstruction = outputs['reconstruction'][:, :3] self.alpha_composite = outputs['reconstruction'][:, 3] self.output_rgba = outputs['layers'] n_layers = outputs['layers'].shape[1] layers = outputs['layers'].clone() layers[:, 0, -1] = 1 # Background layer's alpha is always 1 layers = torch.cat([layers[:, l] for l in range(n_layers)], -2) self.alpha_vis = layers[:, 3:4] self.rgba_vis = layers self.mask_vis = torch.cat([self.mask[:, l:l+1] for l in range(n_layers)], -2) self.input_vis = torch.cat([self.input_uv[:, 2*l:2*l+2] for l in range(n_layers)], -2) self.input_vis = torch.cat([torch.zeros_like(self.input_vis[:, :1]), self.input_vis], 1) def backward(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" self.loss_recon = self.criterionLoss(self.reconstruction[:, :3], self.target_image) self.loss_total = self.loss_recon if not self.do_upsampling: self.loss_alpha_reg = networks.cal_alpha_reg(self.alpha_composite * .5 + .5, self.lambda_alpha_l1, self.lambda_alpha_l0) alpha_layers = self.output_rgba[:, :, 3] self.loss_mask = self.lambda_mask * self.criterionLossMask(alpha_layers, self.mask) self.loss_total += self.loss_alpha_reg + self.loss_mask else: self.loss_mask = 0. self.loss_alph_reg = 0. self.loss_total.backward() def optimize_parameters(self): """Update network weights; it will be called in every training iteration.""" self.forward() self.optimizer.zero_grad() self.backward() self.optimizer.step() def update_lambdas(self, epoch): """Update loss weights based on current epochs and losses.""" if epoch == self.opt.alpha_l1_rolloff_epoch: self.lambda_alpha_l1 = 0 if self.mask_loss_rolloff_epoch >= 0: if epoch == 2*self.mask_loss_rolloff_epoch: self.lambda_mask = 0 elif epoch > self.opt.epoch_count: if self.loss_mask < self.opt.mask_thresh * self.opt.lambda_mask: self.mask_loss_rolloff_epoch = epoch self.lambda_mask *= .1 if epoch == self.opt.jitter_epochs: self.jitter_rgb = 0 def transfer_detail(self): """Transfer detail to layers.""" residual = self.target_image - self.reconstruction transmission_comp = torch.zeros_like(self.target_image[:, 0:1]) rgba_detail = self.output_rgba n_layers = self.output_rgba.shape[1] for i in range(n_layers - 1, 0, -1): # Don't do detail transfer for background layer, due to ghosting effects. transmission_i = 1. - transmission_comp rgba_detail[:, i, :3] += transmission_i * residual alpha_i = self.output_rgba[:, i, 3:4] * .5 + .5 transmission_comp = alpha_i + (1. - alpha_i) * transmission_comp self.rgba = torch.clamp(rgba_detail, -1, 1) def get_results(self): """Return results. This is different from get_current_visuals, which gets visuals for monitoring training. Returns a dictionary: original - - original frame recon - - reconstruction rgba_l* - - RGBA for each layer mask_l* - - mask for each layer """ self.transfer_detail() # Split layers results = {'reconstruction': self.reconstruction, 'original': self.target_image} n_layers = self.rgba.shape[1] for i in range(n_layers): results[f'mask_l{i}'] = self.mask[:, i:i+1] results[f'rgba_l{i}'] = self.rgba[:, i] if i == 0: results[f'rgba_l{i}'][:, -1:] = 1. return results def freeze_basenet(self): """Freeze all parameters except for the upsampling module.""" net = self.netLNR if isinstance(net, torch.nn.DataParallel): net = net.module self.set_requires_grad([net.encoder, net.decoder, net.final_rgba], False) net.texture.requires_grad = False ================================================ FILE: models/networks.py ================================================ # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from third_party.models.networks import init_net ############################################################################### # Helper Functions ############################################################################### def define_LNR(nf=64, texture_channels=16, texture_res=16, n_textures=25, gpu_ids=[]): """Create a layered neural renderer. Parameters: nf (int) -- the number of channels in the first/last conv layers texture_channels (int) -- the number of channels in the neural texture texture_res (int) -- the size of each individual texture map n_textures (int) -- the number of individual texture maps gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 Returns a layered neural rendering model. """ net = LayeredNeuralRenderer(nf, texture_channels, texture_res, n_textures) return init_net(net, gpu_ids) def define_kp2uv(nf=64, gpu_ids=[]): """Create a keypoint-to-UV model. Parameters: nf (int) -- the number of channels in the first/last conv layers Returns a keypoint-to-UV model. """ net = kp2uv(nf) return init_net(net, gpu_ids) def cal_alpha_reg(prediction, lambda_alpha_l1, lambda_alpha_l0): """Calculate the alpha regularization term. Parameters: prediction (tensor) - - composite of predicted alpha layers lambda_alpha_l1 (float) - - weight for the L1 regularization term lambda_alpha_l0 (float) - - weight for the L0 regularization term Returns the alpha regularization loss term """ assert prediction.max() <= 1. assert prediction.min() >= 0. loss = 0. if lambda_alpha_l1 > 0: loss += lambda_alpha_l1 * torch.mean(prediction) if lambda_alpha_l0 > 0: # Pseudo L0 loss using a squished sigmoid curve. l0_prediction = (torch.sigmoid(prediction * 5.0) - 0.5) * 2.0 loss += lambda_alpha_l0 * torch.mean(l0_prediction) return loss ############################################################################## # Classes ############################################################################## class MaskLoss(nn.Module): """Define the loss which encourages the predicted alpha matte to match the mask (trimap).""" def __init__(self): super(MaskLoss, self).__init__() self.loss = nn.L1Loss(reduction='none') def __call__(self, prediction, target): """Calculate loss given predicted alpha matte and trimap. Balance positive and negative regions. Exclude 'unknown' region from loss. Parameters: prediction (tensor) - - predicted alpha target (tensor) - - trimap Returns: the computed loss """ mask_err = self.loss(prediction, target) pos_mask = F.relu(target) neg_mask = F.relu(-target) pos_mask_loss = (pos_mask * mask_err).sum() / (1 + pos_mask.sum()) neg_mask_loss = (neg_mask * mask_err).sum() / (1 + neg_mask.sum()) loss = .5 * (pos_mask_loss + neg_mask_loss) return loss class ConvBlock(nn.Module): """Helper module consisting of a convolution, optional normalization and activation, with padding='same'.""" def __init__(self, conv, in_channels, out_channels, ksize=4, stride=1, dil=1, norm=None, activation='relu'): """Create a conv block. Parameters: conv (convolutional layer) - - the type of conv layer, e.g. Conv2d, ConvTranspose2d in_channels (int) - - the number of input channels in_channels (int) - - the number of output channels ksize (int) - - the kernel size stride (int) - - stride dil (int) - - dilation norm (norm layer) - - the type of normalization layer, e.g. BatchNorm2d, InstanceNorm2d activation (str) -- the type of activation: relu | leaky | tanh | none """ super(ConvBlock, self).__init__() self.k = ksize self.s = stride self.d = dil self.conv = conv(in_channels, out_channels, ksize, stride=stride, dilation=dil) if norm is not None: self.norm = norm(out_channels) else: self.norm = None if activation == 'leaky': self.activation = nn.LeakyReLU(0.2) elif activation == 'relu': self.activation = nn.ReLU() elif activation == 'tanh': self.activation = nn.Tanh() else: self.activation = None def forward(self, x): """Forward pass. Compute necessary padding and cropping because pytorch doesn't have pad=same.""" height, width = x.shape[-2:] if isinstance(self.conv, nn.modules.ConvTranspose2d): desired_height = height * self.s desired_width = width * self.s pady = 0 padx = 0 else: # o = [i + 2*p - k - (k-1)*(d-1)]/s + 1 # padding = .5 * (stride * (output-1) + (k-1)(d-1) + k - input) desired_height = height // self.s desired_width = width // self.s pady = .5 * (self.s * (desired_height - 1) + (self.k - 1) * (self.d - 1) + self.k - height) padx = .5 * (self.s * (desired_width - 1) + (self.k - 1) * (self.d - 1) + self.k - width) x = F.pad(x, [int(np.floor(padx)), int(np.ceil(padx)), int(np.floor(pady)), int(np.ceil(pady))]) x = self.conv(x) if x.shape[-2] != desired_height or x.shape[-1] != desired_width: cropy = x.shape[-2] - desired_height cropx = x.shape[-1] - desired_width x = x[:, :, int(np.floor(cropy / 2.)):-int(np.ceil(cropy / 2.)), int(np.floor(cropx / 2.)):-int(np.ceil(cropx / 2.))] if self.norm: x = self.norm(x) if self.activation: x = self.activation(x) return x class ResBlock(nn.Module): """Define a residual block.""" def __init__(self, channels, ksize=4, stride=1, dil=1, norm=None, activation='relu'): """Initialize the residual block, which consists of 2 conv blocks with a skip connection.""" super(ResBlock, self).__init__() self.convblock1 = ConvBlock(nn.Conv2d, channels, channels, ksize=ksize, stride=stride, dil=dil, norm=norm, activation=activation) self.convblock2 = ConvBlock(nn.Conv2d, channels, channels, ksize=ksize, stride=stride, dil=dil, norm=norm, activation=None) def forward(self, x): identity = x x = self.convblock1(x) x = self.convblock2(x) x += identity return x class kp2uv(nn.Module): """UNet architecture for converting keypoint image to UV map. Same person UV map format as described in https://arxiv.org/pdf/1802.00434.pdf. """ def __init__(self, nf=64): super(kp2uv, self).__init__(), self.encoder = nn.ModuleList([ ConvBlock(nn.Conv2d, 3, nf, ksize=4, stride=2), ConvBlock(nn.Conv2d, nf, nf * 2, ksize=4, stride=2, norm=nn.InstanceNorm2d, activation='leaky'), ConvBlock(nn.Conv2d, nf * 2, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d, activation='leaky'), ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d, activation='leaky'), ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d, activation='leaky'), ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=3, stride=1, norm=nn.InstanceNorm2d, activation='leaky'), ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=3, stride=1, norm=nn.InstanceNorm2d, activation='leaky')]) self.decoder = nn.ModuleList([ ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d), ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d), ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 2, ksize=4, stride=2, norm=nn.InstanceNorm2d), ConvBlock(nn.ConvTranspose2d, nf * 2 * 2, nf, ksize=4, stride=2, norm=nn.InstanceNorm2d), ConvBlock(nn.ConvTranspose2d, nf * 2, nf, ksize=4, stride=2, norm=nn.InstanceNorm2d)]) # head to predict body part class (25 classes - 24 body parts, 1 background.) self.id_pred = ConvBlock(nn.Conv2d, nf + 3, 25, ksize=3, stride=1, activation='none') # head to predict UV coordinates for every body part class self.uv_pred = ConvBlock(nn.Conv2d, nf + 3, 2 * 24, ksize=3, stride=1, activation='tanh') def forward(self, x): """Forward pass through UNet, handling skip connections. Parameters: x (tensor) - - rendered keypoint image, shape [B, 3, H, W] Returns: x_id (tensor): part id class probabilities x_uv (tensor): uv coordinates for each part id """ skips = [x] for i, layer in enumerate(self.encoder): x = layer(x) if i < 5: skips.append(x) for layer in self.decoder: x = torch.cat((x, skips.pop()), 1) x = layer(x) x = torch.cat((x, skips.pop()), 1) x_id = self.id_pred(x) x_uv = self.uv_pred(x) return x_id, x_uv class LayeredNeuralRenderer(nn.Module): """Layered Neural Rendering model for video decomposition. Consists of neural texture, UNet, upsampling module. """ def __init__(self, nf=64, texture_channels=16, texture_res=16, n_textures=25): super(LayeredNeuralRenderer, self).__init__(), """Initialize layered neural renderer. Parameters: nf (int) -- the number of channels in the first/last conv layers texture_channels (int) -- the number of channels in the neural texture texture_res (int) -- the size of each individual texture map n_textures (int) -- the number of individual texture maps """ # Neural texture is implemented as 'n_textures' concatenated horizontally self.texture = nn.Parameter(torch.randn(1, texture_channels, texture_res, n_textures * texture_res)) # Define UNet self.encoder = nn.ModuleList([ ConvBlock(nn.Conv2d, texture_channels + 1, nf, ksize=4, stride=2), ConvBlock(nn.Conv2d, nf, nf * 2, ksize=4, stride=2, norm=nn.BatchNorm2d, activation='leaky'), ConvBlock(nn.Conv2d, nf * 2, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d, activation='leaky'), ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d, activation='leaky'), ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d, activation='leaky'), ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=1, dil=2, norm=nn.BatchNorm2d, activation='leaky'), ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=1, dil=2, norm=nn.BatchNorm2d, activation='leaky')]) self.decoder = nn.ModuleList([ ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d), ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d), ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 2, ksize=4, stride=2, norm=nn.BatchNorm2d), ConvBlock(nn.ConvTranspose2d, nf * 2 * 2, nf, ksize=4, stride=2, norm=nn.BatchNorm2d), ConvBlock(nn.ConvTranspose2d, nf * 2, nf, ksize=4, stride=2, norm=nn.BatchNorm2d)]) self.final_rgba = ConvBlock(nn.Conv2d, nf, 4, ksize=4, stride=1, activation='tanh') # Define upsampling block, which outputs a residual upsampling_ic = texture_channels + 5 + nf self.upsample_block = nn.Sequential( ConvBlock(nn.Conv2d, upsampling_ic, nf, ksize=3, stride=1, norm=nn.InstanceNorm2d), ResBlock(nf, ksize=3, stride=1, norm=nn.InstanceNorm2d), ResBlock(nf, ksize=3, stride=1, norm=nn.InstanceNorm2d), ResBlock(nf, ksize=3, stride=1, norm=nn.InstanceNorm2d), ConvBlock(nn.Conv2d, nf, 4, ksize=3, stride=1, activation='none')) def render(self, x): """Pass inputs for a single layer through UNet. Parameters: x (tensor) - - sampled texture concatenated with person IDs Returns RGBA for the input layer and the final feature maps. """ skips = [x] for i, layer in enumerate(self.encoder): x = layer(x) if i < 5: skips.append(x) for layer in self.decoder: x = torch.cat((x, skips.pop()), 1) x = layer(x) rgba = self.final_rgba(x) return rgba, x def forward(self, uv_map, id_layers, uv_map_upsampled=None, crop_params=None): """Forward pass through layered neural renderer. Steps: 1. Sample from the neural texture using uv_map 2. Input uv_map and id_layers into UNet 2a. If doing upsampling, then pass upsampled inputs and results through upsampling module 3. Composite RGBA outputs. Parameters: uv_map (tensor) - - UV maps for all layers, with shape [B, (2*L), H, W] id_layers (tensor) - - person ID for all layers, with shape [B, L, H, W] uv_map_upsampled (tensor) - - upsampled UV maps to input to upsampling module (if None, skip upsampling) crop_params """ b_sz = uv_map.shape[0] n_layers = uv_map.shape[1] // 2 texture = self.texture.repeat(b_sz, 1, 1, 1) composite = None layers = [] sampled_textures = [] for i in range(n_layers): # Get RGBA for this layer. uv_map_i = uv_map[:, i * 2:(i + 1) * 2, ...] uv_map_i = uv_map_i.permute(0, 2, 3, 1) sampled_texture = F.grid_sample(texture, uv_map_i, mode='bilinear', padding_mode='zeros') inputs = torch.cat([sampled_texture, id_layers[:, i:i + 1]], 1) rgba, last_feat = self.render(inputs) if uv_map_upsampled is not None: uv_map_up_i = uv_map_upsampled[:, i * 2:(i + 1) * 2, ...] uv_map_up_i = uv_map_up_i.permute(0, 2, 3, 1) sampled_texture_up = F.grid_sample(texture, uv_map_up_i, mode='bilinear', padding_mode='zeros') id_layers_up = F.interpolate(id_layers[:, i:i + 1], size=sampled_texture_up.shape[-2:], mode='bilinear') inputs_up = torch.cat([sampled_texture_up, id_layers_up], 1) upsampled_size = inputs_up.shape[-2:] rgba = F.interpolate(rgba, size=upsampled_size, mode='bilinear') last_feat = F.interpolate(last_feat, size=upsampled_size, mode='bilinear') if crop_params is not None: starty, endy, startx, endx = crop_params rgba = rgba[:, :, starty:endy, startx:endx] last_feat = last_feat[:, :, starty:endy, startx:endx] inputs_up = inputs_up[:, :, starty:endy, startx:endx] rgba_residual = self.upsample_block(torch.cat((rgba, inputs_up, last_feat), 1)) rgba += .01 * rgba_residual rgba = torch.clamp(rgba, -1, 1) sampled_texture = sampled_texture_up # Update the composite with this layer's RGBA output if composite is None: composite = rgba else: alpha = rgba[:, 3:4] * .5 + .5 composite = rgba * alpha + composite * (1.0 - alpha) layers.append(rgba) sampled_textures.append(sampled_texture) outputs = { 'reconstruction': composite, 'layers': torch.stack(layers, 1), 'sampled texture': sampled_textures, # for debugging } return outputs ================================================ FILE: options/__init__.py ================================================ ================================================ FILE: options/base_options.py ================================================ # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os from third_party.util import util from third_party import models from third_party import data import torch import json class BaseOptions(): """This class defines options used during both training and test time. It also implements several helper functions such as parsing, printing, and saving the options. It also gathers additional options defined in functions in both dataset class and model class. """ def __init__(self): """Reset the class; indicates the class hasn't been initialized""" self.initialized = False def initialize(self, parser): """Define the common options that are used in both training and test.""" # basic parameters parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders rgb_256, etc)') parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') parser.add_argument('--seed', type=int, default=1, help='initial random seed') # model parameters parser.add_argument('--model', type=str, default='lnr', help='chooses which model to use. [lnr | kp2uv]') parser.add_argument('--num_filters', type=int, default=64, help='# filters in the first and last conv layers') # dataset parameters parser.add_argument('--dataset_mode', type=str, default='layered_video', help='chooses how datasets are loaded. [layered_video | kpuv]') parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') parser.add_argument('--batch_size', type=int, default=32, help='input batch size') parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') # additional parameters parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') self.initialized = True return parser def gather_options(self): """Initialize our parser with basic options(only once). Add additional model-specific and dataset-specific options. These options are defined in the function in model and dataset classes. """ if not self.initialized: # check if it has been initialized parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = self.initialize(parser) # get the basic options opt, _ = parser.parse_known_args() # modify model-related parser options model_name = opt.model model_option_setter = models.get_option_setter(model_name) parser = model_option_setter(parser, self.isTrain) opt, _ = parser.parse_known_args() # parse again with new defaults # modify dataset-related parser options dataset_name = opt.dataset_mode dataset_option_setter = data.get_option_setter(dataset_name) parser = dataset_option_setter(parser, self.isTrain) # save and return the parser self.parser = parser return parser.parse_args() def print_options(self, opt): """Print and save options It will print both current options and default values(if different). It will save options into a text file / [checkpoints_dir] / opt.txt """ message = '' message += '----------------- Options ---------------\n' for k, v in sorted(vars(opt).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' print(message) # save to the disk expr_dir = os.path.join(opt.checkpoints_dir, opt.name) util.mkdirs(expr_dir) file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) with open(file_name, 'wt') as opt_file: opt_file.write(message) opt_file.write('\n') def parse(self): """Parse our options, create checkpoints directory suffix, and set up gpu device.""" opt = self.gather_options() opt.isTrain = self.isTrain # train or test # process opt.suffix if opt.suffix: suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' opt.name = opt.name + suffix self.print_options(opt) # set gpu ids str_ids = opt.gpu_ids.split(',') opt.gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >= 0: opt.gpu_ids.append(id) if len(opt.gpu_ids) > 0: torch.cuda.set_device(opt.gpu_ids[0]) self.opt = opt return self.opt def parse_dataset_meta(self): """Parse options from the 'metadata.json' file in the dataroot.""" with open(os.path.join(self.opt.dataroot, 'metadata.json')) as f: metadata = json.load(f) self.opt.n_textures = metadata['n_textures'] self.opt.width = metadata['size_LR'][0] self.opt.height = metadata['size_LR'][1] return self.opt ================================================ FILE: options/test_options.py ================================================ # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .base_options import BaseOptions class TestOptions(BaseOptions): """This class includes test options. It also includes shared options defined in BaseOptions. """ def initialize(self, parser): parser = BaseOptions.initialize(self, parser) # define shared options parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') parser.add_argument('--num_test', type=int, default=float("inf"), help='how many test images to run') self.isTrain = False return parser ================================================ FILE: options/train_options.py ================================================ # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .base_options import BaseOptions class TrainOptions(BaseOptions): """This class includes training options. It also includes shared options defined in BaseOptions. """ def initialize(self, parser): parser = BaseOptions.initialize(self, parser) # visdom and HTML visualization parameters parser.add_argument('--display_freq', type=int, default=20, help='frequency of showing training results on screen (in epochs)') parser.add_argument('--display_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') parser.add_argument('--update_html_freq', type=int, default=50, help='frequency of saving training results to html') parser.add_argument('--print_freq', type=int, default=10, help='frequency of showing training results on console (in steps per epoch)') parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') # network saving and loading parameters parser.add_argument('--save_latest_freq', type=int, default=50, help='frequency of saving the latest results (in epochs)') parser.add_argument('--save_by_epoch', action='store_true', help='whether saves model as "epoch" or "latest" (overwrites previous)') parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') # training parameters parser.add_argument('--n_epochs', type=int, default=2000, help='number of epochs with the initial learning rate') parser.add_argument('--n_epochs_decay', type=int, default=0, help='number of epochs to linearly decay learning rate to zero') parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam') parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') self.isTrain = True return parser ================================================ FILE: requirements.txt ================================================ torch==1.4.0 torchvision>=0.5.0 dominate>=2.4.0 visdom>=0.1.8 matplotlib>=3.2.1 opencv-python>=4.2.0 ================================================ FILE: run_kp2uv.py ================================================ # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Script for running the keypoint-to-UV network and saving the predicted UVs. Example: python run_kp2uv.py --model kp2uv --dataroot ./datasets/reflection --results_dir ./datasets/reflection It will load the pre-trained model from '--checkpoints_dir' and save the results to '--results_dir'. See options/base_options.py and options/test_options.py for more test options. """ import os from options.test_options import TestOptions from third_party.data import create_dataset from third_party.models import create_model from third_party.util.visualizer import save_images from third_party.util import html if __name__ == '__main__': opt = TestOptions().parse() # get test options # hard-code some parameters opt.name = 'kp2uv' opt.num_threads = 0 # test code only supports num_threads = 0 opt.batch_size = 1 # test code only supports batch_size = 1 opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options model = create_model(opt) # create a model given opt.model and other options model.setup(opt) # regular setup: load and print networks; create schedulers # create a website web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch)) # define the website directory print('creating web directory', web_dir) webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) for i, data in enumerate(dataset): if i >= opt.num_test: # only apply our model to opt.num_test images. break model.set_input(data) # unpack data from data loader model.test() # run inference visuals = model.get_current_visuals() # get image results img_path = model.get_image_paths() # get image paths if i % 5 == 0: print('processing (%04d)-th image... %s' % (i, img_path)) save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) webpage.save() # save the HTML ================================================ FILE: scripts/download_kp2uv_model.sh ================================================ #!/bin/bash # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. mkdir -p ./checkpoints/kp2uv MODEL_FILE=./checkpoints/kp2uv/latest_net_Kp2uv.pth URL=https://www.robots.ox.ac.uk/~erika/retiming/pretrained_models/kp2uv.pth wget -N $URL -O $MODEL_FILE ================================================ FILE: scripts/run_cartwheel.sh ================================================ #!/bin/bash # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. GPUS=0,1 DATA_PATH=./datasets/cartwheel bash datasets/prepare_iuv.sh $DATA_PATH python train.py \ --name cartwheel \ --dataroot $DATA_PATH \ --use_homographies \ --gpu_ids $GPUS python test.py \ --name cartwheel \ --dataroot $DATA_PATH \ --do_upsampling \ --use_homographies ================================================ FILE: scripts/run_reflection.sh ================================================ #!/bin/bash # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. GPUS=0,1 DATA_PATH=./datasets/reflection bash datasets/prepare_iuv.sh $DATA_PATH python train.py \ --name reflection \ --dataroot $DATA_PATH \ --gpu_ids $GPUS python test.py \ --name reflection \ --dataroot $DATA_PATH \ --do_upsampling ================================================ FILE: scripts/run_splash.sh ================================================ #!/bin/bash # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. GPUS=0,1 DATA_PATH=./datasets/splash bash datasets/prepare_iuv.sh $DATA_PATH python train.py \ --name splash \ --dataroot $DATA_PATH \ --batch_size 24 \ --batch_size_upsample 12 \ --use_mask_images \ --gpu_ids $GPUS python test.py \ --name splash \ --dataroot $DATA_PATH \ --do_upsampling ================================================ FILE: scripts/run_trampoline.sh ================================================ #!/bin/bash # # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. GPUS=0,1 DATA_PATH=./datasets/trampoline bash datasets/prepare_iuv.sh $DATA_PATH python train.py \ --name trampoline \ --dataroot $DATA_PATH \ --batch_size 16 \ --batch_size_upsample 6 \ --gpu_ids $GPUS python test.py \ --name trampoline \ --dataroot $DATA_PATH \ --do_upsampling ================================================ FILE: test.py ================================================ # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Script to save the full outputs of a layered neural renderer (LNR). Once you have trained the LNR with train.py, you can use this script to save the model's final layer decomposition. It will load a saved model from '--checkpoints_dir' and save the results to '--results_dir'. It first creates a model and dataset given the options. It will hard-code some parameters. It then runs inference for '--num_test' images and save results to an HTML file. Example (You need to train models first or download pre-trained models from our website): python test.py --dataroot ./datasets/reflection --name reflection --do_upsampling If the upsampling module isn't trained (train.py is used with '--n_epochs_upsample 0'), remove --do_upsampling. Use '--results_dir ' to specify the results directory. See options/base_options.py and options/test_options.py for more test options. """ import os from options.test_options import TestOptions from third_party.data import create_dataset from third_party.models import create_model from third_party.util.visualizer import save_images, save_videos from third_party.util import html import torch if __name__ == '__main__': testopt = TestOptions() testopt.parse() opt = testopt.parse_dataset_meta() # hard-code some parameters for test opt.num_threads = 0 # test code only supports num_threads = 0 opt.batch_size = 1 # test code only supports batch_size = 1 opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options model = create_model(opt) # create a model given opt.model and other options model.setup(opt) # regular setup: load and print networks; create schedulers # create a website web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch)) # define the website directory print('creating web directory', web_dir) webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) video_visuals = None for i, data in enumerate(dataset): if i >= opt.num_test: # only apply our model to opt.num_test images. break model.set_input(data) # unpack data from data loader model.test() # run inference img_path = model.get_image_paths() # get image paths if i % 5 == 0: # save images to an HTML file print('processing (%04d)-th image... %s' % (i, img_path)) visuals = model.get_results() # rgba, reconstruction, original, mask if video_visuals is None: video_visuals = visuals else: for k in video_visuals: video_visuals[k] = torch.cat((video_visuals[k], visuals[k])) rgba = { k: visuals[k] for k in visuals if 'rgba' in k } # save RGBA layers save_images(webpage, rgba, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) save_videos(webpage, video_visuals, width=opt.display_winsize) webpage.save() # save the HTML of videos ================================================ FILE: third_party/__init__.py ================================================ ================================================ FILE: third_party/data/__init__.py ================================================ """This package includes all the modules related to data loading and preprocessing To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. You need to implement four functions: -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). -- <__len__>: return the size of dataset. -- <__getitem__>: get a data point from data loader. -- : (optionally) add dataset-specific options and set default options. Now you can use the dataset class by specifying flag '--dataset_mode dummy'. See our template dataset class 'template_dataset.py' for more details. """ import importlib import torch.utils.data from .base_dataset import BaseDataset from .fast_data_loader import FastDataLoader def find_dataset_using_name(dataset_name): """Import the module "data/[dataset_name]_dataset.py". In the file, the class called DatasetNameDataset() will be instantiated. It has to be a subclass of BaseDataset, and it is case-insensitive. """ dataset_filename = "data." + dataset_name + "_dataset" datasetlib = importlib.import_module(dataset_filename) dataset = None target_dataset_name = dataset_name.replace('_', '') + 'dataset' for name, cls in datasetlib.__dict__.items(): if name.lower() == target_dataset_name.lower() \ and issubclass(cls, BaseDataset): dataset = cls if dataset is None: raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) return dataset def get_option_setter(dataset_name): """Return the static method of the dataset class.""" dataset_class = find_dataset_using_name(dataset_name) return dataset_class.modify_commandline_options def create_dataset(opt, use_fast_loader=False): """Create a dataset given the option. This function wraps the class CustomDatasetDataLoader. This is the main interface between this package and 'train.py'/'test.py' If use_fast_loader=False, use the default pytorch dataloader. Otherwise, use FastDatasetLoader. Example: >>> from data import create_dataset >>> dataset = create_dataset(opt) """ data_loader = CustomDatasetDataLoader(opt, use_fast_loader=use_fast_loader) dataset = data_loader.load_data() return dataset class CustomDatasetDataLoader(): """Wrapper class of Dataset class that performs multi-threaded data loading""" def __init__(self, opt, use_fast_loader=False): """Initialize this class Step 1: create a dataset instance given the name [dataset_mode] Step 2: create a multi-threaded data loader. If use_fast_loader=False, use the default pytorch dataloader. Otherwise, use FastDatasetLoader. """ self.opt = opt dataset_class = find_dataset_using_name(opt.dataset_mode) self.dataset = dataset_class(opt) print("dataset [%s] was created" % type(self.dataset).__name__) loader = torch.utils.data.DataLoader if use_fast_loader: loader = FastDataLoader self.dataloader = loader( self.dataset, batch_size=opt.batch_size, shuffle=not opt.serial_batches, num_workers=int(opt.num_threads)) def load_data(self): return self def __len__(self): """Return the number of data in the dataset""" return min(len(self.dataset), self.opt.max_dataset_size) def __iter__(self): """Return a batch of data""" for i, data in enumerate(self.dataloader): if i * self.opt.batch_size >= self.opt.max_dataset_size: break yield data ================================================ FILE: third_party/data/base_dataset.py ================================================ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. """ import random import numpy as np import torch.utils.data as data from PIL import Image import torchvision.transforms as transforms from abc import ABC, abstractmethod class BaseDataset(data.Dataset, ABC): """This class is an abstract base class (ABC) for datasets. To create a subclass, you need to implement the following four functions: -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). -- <__len__>: return the size of dataset. -- <__getitem__>: get a data point. -- : (optionally) add dataset-specific options and set default options. """ def __init__(self, opt): """Initialize the class; save the options in the class Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ self.opt = opt self.root = opt.dataroot @staticmethod def modify_commandline_options(parser, is_train): """Add new dataset-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. """ return parser @abstractmethod def __len__(self): """Return the total number of images in the dataset.""" return 0 @abstractmethod def __getitem__(self, index): """Return a data point and its metadata information. Parameters: index - - a random integer for data indexing Returns: a dictionary of data with their names. It ususally contains the data itself and its metadata information. """ pass def get_params(opt, size): w, h = size new_h = h new_w = w if opt.preprocess == 'resize_and_crop': new_h = new_w = opt.load_size elif opt.preprocess == 'scale_width_and_crop': new_w = opt.load_size new_h = opt.load_size * h // w x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) flip = random.random() > 0.5 return {'crop_pos': (x, y), 'flip': flip} def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): transform_list = [] if grayscale: transform_list.append(transforms.Grayscale(1)) if 'resize' in opt.preprocess: osize = [opt.load_size, opt.load_size] transform_list.append(transforms.Resize(osize, method)) elif 'scale_width' in opt.preprocess: transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method))) if 'crop' in opt.preprocess: if params is None: transform_list.append(transforms.RandomCrop(opt.crop_size)) else: transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) if opt.preprocess == 'none': transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) if not opt.no_flip: if params is None: transform_list.append(transforms.RandomHorizontalFlip()) elif params['flip']: transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) if convert: transform_list += [transforms.ToTensor()] if grayscale: transform_list += [transforms.Normalize((0.5,), (0.5,))] else: transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) def __make_power_2(img, base, method=Image.BICUBIC): ow, oh = img.size h = int(round(oh / base) * base) w = int(round(ow / base) * base) if h == oh and w == ow: return img __print_size_warning(ow, oh, w, h) return img.resize((w, h), method) def __scale_width(img, target_size, crop_size, method=Image.BICUBIC): ow, oh = img.size if ow == target_size and oh >= crop_size: return img w = target_size h = int(max(target_size * oh / ow, crop_size)) return img.resize((w, h), method) def __crop(img, pos, size): ow, oh = img.size x1, y1 = pos tw = th = size if (ow > tw or oh > th): return img.crop((x1, y1, x1 + tw, y1 + th)) return img def __flip(img, flip): if flip: return img.transpose(Image.FLIP_LEFT_RIGHT) return img def __print_size_warning(ow, oh, w, h): """Print warning information about image size(only print once)""" if not hasattr(__print_size_warning, 'has_printed'): print("The image size needs to be a multiple of 4. " "The loaded image size was (%d, %d), so it was adjusted to " "(%d, %d). This adjustment will be done to all images " "whose sizes are not multiples of 4" % (ow, oh, w, h)) __print_size_warning.has_printed = True ================================================ FILE: third_party/data/fast_data_loader.py ================================================ """ Fixes the issue where DataLoader is slow because processes aren't reused See https://github.com/pytorch/pytorch/issues/15849 Warning: overrides batch sampler. """ import torch.utils.data class _RepeatSampler(object): """ Sampler that repeats forever. Args: sampler (Sampler) """ def __init__(self, sampler): self.sampler = sampler def __iter__(self): while True: yield from iter(self.sampler) class FastDataLoader(torch.utils.data.dataloader.DataLoader): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) self.iterator = super().__iter__() def __len__(self): return len(self.batch_sampler.sampler) def __iter__(self): for i in range(len(self)): yield next(self.iterator) ================================================ FILE: third_party/data/image_folder.py ================================================ """A modified image folder class We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) so that this class can load images from both current directory and its subdirectories. """ import torch.utils.data as data from PIL import Image import os IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif', '.TIF', '.tiff', '.TIFF', ] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) def make_dataset(dir, max_dataset_size=float("inf")): images = [] assert os.path.isdir(dir), '%s is not a valid directory' % dir for root, _, fnames in sorted(os.walk(dir)): for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname) images.append(path) images = sorted(images) return images[:min(max_dataset_size, len(images))] def default_loader(path): return Image.open(path).convert('RGB') class ImageFolder(data.Dataset): def __init__(self, root, transform=None, return_paths=False, loader=default_loader): imgs = make_dataset(root) if len(imgs) == 0: raise(RuntimeError("Found 0 images in: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) self.root = root self.imgs = imgs self.transform = transform self.return_paths = return_paths self.loader = loader def __getitem__(self, index): path = self.imgs[index] img = self.loader(path) if self.transform is not None: img = self.transform(img) if self.return_paths: return img, path else: return img def __len__(self): return len(self.imgs) ================================================ FILE: third_party/models/__init__.py ================================================ """This package contains modules related to objective functions, optimizations, and network architectures. To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. You need to implement the following five functions: -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). -- : unpack data from dataset and apply preprocessing. -- : produce intermediate results. -- : calculate loss, gradients, and update network weights. -- : (optionally) add model-specific options and set default options. In the function <__init__>, you need to define four lists: -- self.loss_names (str list): specify the training losses that you want to plot and save. -- self.model_names (str list): define networks used in our training. -- self.visual_names (str list): specify the images that you want to display and save. -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. Now you can use the model class by specifying flag '--model dummy'. See our template model class 'template_model.py' for more details. """ import importlib from .base_model import BaseModel def find_model_using_name(model_name): """Import the module "models/[model_name]_model.py". In the file, the class called DatasetNameModel() will be instantiated. It has to be a subclass of BaseModel, and it is case-insensitive. """ model_filename = "models." + model_name + "_model" modellib = importlib.import_module(model_filename) model = None target_model_name = model_name.replace('_', '') + 'model' for name, cls in modellib.__dict__.items(): if name.lower() == target_model_name.lower() \ and issubclass(cls, BaseModel): model = cls if model is None: print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) exit(0) return model def get_option_setter(model_name): """Return the static method of the model class.""" model_class = find_model_using_name(model_name) return model_class.modify_commandline_options def create_model(opt): """Create a model given the option. This function warps the class CustomDatasetDataLoader. This is the main interface between this package and 'train.py'/'test.py' Example: >>> from models import create_model >>> model = create_model(opt) """ model = find_model_using_name(opt.model) instance = model(opt) print("model [%s] was created" % type(instance).__name__) return instance ================================================ FILE: third_party/models/base_model.py ================================================ import os import torch from collections import OrderedDict from abc import ABC, abstractmethod from . import networks class BaseModel(ABC): """This class is an abstract base class (ABC) for models. To create a subclass, you need to implement the following five functions: -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). -- : unpack data from dataset and apply preprocessing. -- : produce intermediate results. -- : calculate losses, gradients, and update network weights. -- : (optionally) add model-specific options and set default options. """ def __init__(self, opt): """Initialize the BaseModel class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions When creating your custom class, you need to implement your own initialization. In this function, you should first call Then, you need to define four lists: -- self.loss_names (str list): specify the training losses that you want to plot and save. -- self.model_names (str list): define networks used in our training. -- self.visual_names (str list): specify the images that you want to display and save. -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. """ self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir self.loss_names = [] self.model_names = [] self.visual_names = [] self.optimizers = [] self.image_paths = [] self.metric = 0 # used for learning rate policy 'plateau' @staticmethod def modify_commandline_options(parser, is_train): """Add new model-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. """ return parser @abstractmethod def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): includes the data itself and its metadata information. """ pass @abstractmethod def forward(self): """Run forward pass; called by both functions and .""" pass @abstractmethod def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" pass def setup(self, opt): """Load and print networks; create schedulers Parameters: opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions """ if self.isTrain: self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] if not self.isTrain or opt.continue_train: load_suffix = opt.epoch self.load_networks(load_suffix) self.print_networks(opt.verbose) def eval(self): """Make models eval mode during test time""" for name in self.model_names: if isinstance(name, str): net = getattr(self, 'net' + name) net.eval() def test(self): """Forward function used in test time. This function wraps function in no_grad() so we don't save intermediate steps for backprop It also calls to produce additional visualization results """ with torch.no_grad(): self.forward() self.compute_visuals() def compute_visuals(self): """Calculate additional output images for visdom and HTML visualization""" pass def get_image_paths(self): """ Return image paths that are used to load current data""" return self.image_paths def update_learning_rate(self): """Update learning rates for all the networks; called at the end of every epoch""" old_lr = self.optimizers[0].param_groups[0]['lr'] for scheduler in self.schedulers: if self.opt.lr_policy == 'plateau': scheduler.step(self.metric) else: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr'] if old_lr != lr: print('learning rate %.7f -> %.7f' % (old_lr, lr)) def get_current_visuals(self): """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" visual_ret = OrderedDict() for name in self.visual_names: if isinstance(name, str): visual_ret[name] = getattr(self, name) return visual_ret def get_current_losses(self): """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" errors_ret = OrderedDict() for name in self.loss_names: if isinstance(name, str): errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number return errors_ret def save_networks(self, epoch): """Save all the networks to the disk. Parameters: epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) """ for name in self.model_names: if isinstance(name, str): save_filename = '%s_net_%s.pth' % (epoch, name) save_path = os.path.join(self.save_dir, save_filename) net = getattr(self, 'net' + name) if len(self.gpu_ids) > 0 and torch.cuda.is_available(): torch.save(net.module.cpu().state_dict(), save_path) net.cuda(self.gpu_ids[0]) else: torch.save(net.cpu().state_dict(), save_path) def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" key = keys[i] if i + 1 == len(keys): # at the end, pointing to a parameter/buffer if module.__class__.__name__.startswith('InstanceNorm') and \ (key == 'running_mean' or key == 'running_var'): if getattr(module, key) is None: state_dict.pop('.'.join(keys)) if module.__class__.__name__.startswith('InstanceNorm') and \ (key == 'num_batches_tracked'): state_dict.pop('.'.join(keys)) else: self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) def load_networks(self, epoch): """Load all the networks from the disk. Parameters: epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) """ for name in self.model_names: if isinstance(name, str): load_filename = '%s_net_%s.pth' % (epoch, name) load_path = os.path.join(self.save_dir, load_filename) net = getattr(self, 'net' + name) if isinstance(net, torch.nn.DataParallel): net = net.module print('loading the model from %s' % load_path) # if you are using PyTorch newer than 0.4 (e.g., built from # GitHub source), you can remove str() on self.device state_dict = torch.load(load_path, map_location=str(self.device)) if hasattr(state_dict, '_metadata'): del state_dict._metadata # patch InstanceNorm checkpoints prior to 0.4 for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) net.load_state_dict(state_dict) def print_networks(self, verbose): """Print the total number of parameters in the network and (if verbose) network architecture Parameters: verbose (bool) -- if verbose: print the network architecture """ print('---------- Networks initialized -------------') for name in self.model_names: if isinstance(name, str): net = getattr(self, 'net' + name) num_params = 0 num_trainable_params = 0 for param in net.parameters(): num_params += param.numel() if param.requires_grad: num_trainable_params += param.numel() if verbose: print(net) print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) print('[Network %s] Total number of trainable parameters : %.3f M' % (name, num_trainable_params / 1e6)) print('-----------------------------------------------') def set_requires_grad(self, nets, requires_grad=False): """Set requies_grad=Fasle for all the networks to avoid unnecessary computations Parameters: nets (network list) -- a list of networks requires_grad (bool) -- whether the networks require gradients or not """ if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad ================================================ FILE: third_party/models/networks.py ================================================ import torch import torch.nn as nn from torch.optim import lr_scheduler ############################################################################### # Helper Functions ############################################################################### def get_scheduler(optimizer, opt): """Return a learning rate scheduler Parameters: optimizer -- the optimizer of the network opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine For 'linear', we keep the same learning rate for the first epochs and linearly decay the rate to zero over the next epochs. For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. See https://pytorch.org/docs/stable/optim.html for more details. """ if opt.lr_policy == 'linear': def lambda_rule(epoch): lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) return lr_l scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) elif opt.lr_policy == 'step': scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) elif opt.lr_policy == 'plateau': scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) elif opt.lr_policy == 'cosine': scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) else: return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) return scheduler def init_net(net, gpu_ids=[]): """Initialize a network by registering CPU/GPU device (with multi-GPU support) Parameters: net (network) -- the network to be initialized gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 Return an initialized network. """ if len(gpu_ids) > 0: assert (torch.cuda.is_available()) net.to(gpu_ids[0]) net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs return net ================================================ FILE: third_party/util/__init__.py ================================================ """This package includes a miscellaneous collection of useful helper functions.""" ================================================ FILE: third_party/util/html.py ================================================ import dominate from dominate.tags import meta, h3, table, tr, td, p, a, img, br, video, source import os class HTML: """This HTML class allows us to save images and write texts into a single HTML file. It consists of functions such as (add a text header to the HTML file), (add a row of images to the HTML file), and (save the HTML to the disk). It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. """ def __init__(self, web_dir, title, refresh=0): """Initialize the HTML classes Parameters: web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: with self.doc.head: meta(http_equiv="refresh", content=str(refresh)) def get_image_dir(self): """Return the directory that stores images""" return self.img_dir def get_video_dir(self): """Return the directory that stores videos""" return self.vid_dir def add_header(self, text): """Insert a header to the HTML file Parameters: text (str) -- the header text """ with self.doc: h3(text) def add_images(self, ims, txts, links, width=400): """add images to the HTML file Parameters: ims (str list) -- a list of image paths txts (str list) -- a list of image names shown on the website links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page """ self.t = table(border=1, style="table-layout: fixed;") # Insert a table self.doc.add(self.t) with self.t: with tr(): for im, txt, link in zip(ims, txts, links): with td(style="word-wrap: break-word;", halign="center", valign="top"): with p(): with a(href=os.path.join('images', link)): img(style="width:%dpx" % width, src=os.path.join('images', im)) br() p(txt) def add_videos(self, vids, txts, links, width=400): """add images to the HTML file Parameters: ims (str list) -- a list of image paths txts (str list) -- a list of image names shown on the website links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page """ self.t = table(border=1, style="table-layout: fixed;") # Insert a table self.doc.add(self.t) with self.t: with tr(): for vid, txt, link in zip(vids, txts, links): with td(style="word-wrap: break-word;", halign="center", valign="top"): with p(): with a(href=os.path.join('videos', link)): with video(style="width:%dpx" % width, controls=True): source(src=os.path.join('videos', vid), type="video/mp4") br() p(txt) def save(self): """save the current content to the HMTL file""" html_file = '%s/index.html' % self.web_dir f = open(html_file, 'wt') f.write(self.doc.render()) f.close() if __name__ == '__main__': # we show an example usage here. html = HTML('web/', 'test_html') html.add_header('hello world') ims, txts, links = [], [], [] for n in range(4): ims.append('image_%d.png' % n) txts.append('text_%d' % n) links.append('image_%d.png' % n) html.add_images(ims, txts, links) html.save() ================================================ FILE: third_party/util/util.py ================================================ """This module contains simple helper functions """ from __future__ import print_function import torch import numpy as np from PIL import Image import os def tensor2im(input_image, imtype=np.uint8): """"Converts a Tensor array into a numpy image array. Parameters: input_image (tensor) -- the input image tensor array imtype (type) -- the desired type of the converted numpy array """ if not isinstance(input_image, np.ndarray): if isinstance(input_image, torch.Tensor): # get the data from a variable image_tensor = input_image.data else: return input_image image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array if image_numpy.shape[0] == 1: # grayscale to RGB image_numpy = np.tile(image_numpy, (3, 1, 1)) image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling else: # if it is a numpy array, do nothing image_numpy = input_image return image_numpy.astype(imtype) def render_png(image, background='checker'): height, width = image.shape[:2] if background == 'checker': checkerboard = np.kron([[136, 120] * (width//128+1), [120, 136] * (width//128+1)] * (height//128+1), np.ones((16, 16))) checkerboard = np.expand_dims(np.tile(checkerboard, (4, 4)), -1) bg = checkerboard[:height, :width] elif background == 'black': bg = np.zeros([height, width, 1]) else: bg = 255 * np.ones([height, width, 1]) image = image.astype(np.float32) alpha = image[:, :, 3:] / 255 rendered_image = alpha * image[:, :, :3] + (1 - alpha) * bg return rendered_image.astype(np.uint8) def diagnose_network(net, name='network'): """Calculate and print the mean of average absolute(gradients) Parameters: net (torch network) -- Torch network name (str) -- the name of the network """ mean = 0.0 count = 0 for param in net.parameters(): if param.grad is not None: mean += torch.mean(torch.abs(param.grad.data)) count += 1 if count > 0: mean = mean / count print(name) print(mean) def save_image(image_numpy, image_path, aspect_ratio=1.0): """Save a numpy image to the disk Parameters: image_numpy (numpy array) -- input numpy array image_path (str) -- the path of the image """ image_pil = Image.fromarray(image_numpy) h, w, _ = image_numpy.shape if aspect_ratio > 1.0: image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) if aspect_ratio < 1.0: image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) image_pil.save(image_path) def print_numpy(x, val=True, shp=False): """Print the mean, min, max, median, std, and size of a numpy array Parameters: val (bool) -- if print the values of the numpy array shp (bool) -- if print the shape of the numpy array """ x = x.astype(np.float64) if shp: print('shape,', x.shape) if val: x = x.flatten() print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) def mkdirs(paths): """create empty directories if they don't exist Parameters: paths (str list) -- a list of directory paths """ if isinstance(paths, list) and not isinstance(paths, str): for path in paths: mkdir(path) else: mkdir(paths) def mkdir(path): """create a single empty directory if it didn't exist Parameters: path (str) -- a single directory path """ if not os.path.exists(path): os.makedirs(path) ================================================ FILE: third_party/util/visualizer.py ================================================ import cv2 import numpy as np import os import sys import ntpath import time from . import util, html from subprocess import Popen, PIPE if sys.version_info[0] == 2: VisdomExceptionBase = Exception else: VisdomExceptionBase = ConnectionError def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): """Save images to the disk. Parameters: webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs image_path (str) -- the string is used to create image paths aspect_ratio (float) -- the aspect ratio of saved images width (int) -- the images will be resized to width x width This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. """ image_dir = webpage.get_image_dir() short_path = ntpath.basename(image_path[0]) name = os.path.splitext(short_path)[0] webpage.add_header(name) ims, txts, links = [], [], [] for label, im_data in visuals.items(): im = util.tensor2im(im_data) image_name = '%s_%s.png' % (name, label) save_path = os.path.join(image_dir, image_name) util.save_image(im, save_path, aspect_ratio=aspect_ratio) ims.append(image_name) txts.append(label) links.append(image_name) webpage.add_images(ims, txts, links, width=width) def save_videos(webpage, visuals, width=256): """Save videos to the disk. Parameters: webpage (the HTML class) -- the HTML webpage class that stores these videos (see html.py for more details) visuals (OrderedDict) -- an ordered dictionary that stores (name, video (either tensor or numpy) ) pairs save_dir (str) -- the string is used to create video paths aspect_ratio (float) -- the aspect ratio of saved images width (int) -- the images will be resized to width x width This function will save videos stored in 'visuals' to the HTML file specified by 'webpage'. """ video_dir = webpage.get_video_dir() webpage.add_header('videos') vids, txts, links = [], [], [] for label, vid_data in sorted(visuals.items()): video_name = f'{label}.webm' video_path = os.path.join(video_dir, video_name) frame_height, frame_width = vid_data.shape[-2:] video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'vp80'), 25, (frame_width, frame_height)) for i in range(vid_data.shape[0]): frame = util.tensor2im(vid_data[i:i+1]) if frame.shape[-1] == 4: # render png frame = util.render_png(frame, background='checker') frame = frame[:, :, ::-1] # RGB -> BGR video.write(frame) video.release() cv2.destroyAllWindows() print("You may see an OpenCV 'vp80 not supported' error message despite the video saving correctly. Please ignore it.") vids.append(video_name) txts.append(label) links.append(video_name) webpage.add_videos(vids, txts, links, width=width) class Visualizer(): """This class includes several functions that can display/save images and print/save logging information. It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. """ def __init__(self, opt): """Initialize the Visualizer class Parameters: opt -- stores all the experiment flags; needs to be a subclass of BaseOptions Step 1: Cache the training/test options Step 2: connect to a visdom server Step 3: create an HTML object for saveing HTML filters Step 4: create a logging file to store training losses """ self.opt = opt # cache the option self.display_id = opt.display_id self.use_html = opt.isTrain and not opt.no_html self.win_size = opt.display_winsize self.name = opt.name self.port = opt.display_port self.saved = False if self.display_id > 0: # connect to a visdom server given and import visdom self.ncols = opt.display_ncols self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) if not self.vis.check_connection(): self.create_visdom_connections() if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') self.img_dir = os.path.join(self.web_dir, 'images') print('create web directory %s...' % self.web_dir) util.mkdirs([self.web_dir, self.img_dir]) # create a logging file to store training losses self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') with open(self.log_name, "a") as log_file: now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) def reset(self): """Reset the self.saved status""" self.saved = False def create_visdom_connections(self): """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port print('\n\nCould not connect to Visdom server. \n Trying to start a server....') print('Command: %s' % cmd) Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) def display_current_results(self, visuals, epoch, save_result): """Display current results on visdom; save current results to an HTML file. Parameters: visuals (OrderedDict) - - dictionary of images to display or save epoch (int) - - the current epoch save_result (bool) - - if save the current results to an HTML file """ if self.display_id > 0: # show images in the browser using visdom ncols = self.ncols if ncols > 0: # show all the images in one visdom panel ncols = min(ncols, len(visuals)) h, w = next(iter(visuals.values())).shape[:2] table_css = """""" % (w, h) # create a table css # create a table of images. title = self.name label_html = '' label_html_row = '' images = [] idx = 0 for label, image in visuals.items(): image_numpy = util.tensor2im(image) label_html_row += '%s' % label images.append(image_numpy.transpose([2, 0, 1])) idx += 1 if idx % ncols == 0: label_html += '%s' % label_html_row label_html_row = '' white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 while idx % ncols != 0: images.append(white_image) label_html_row += '' idx += 1 if label_html_row != '': label_html += '%s' % label_html_row try: self.vis.images(images, nrow=ncols, win=self.display_id + 1, padding=2, opts=dict(title=title + ' images')) label_html = '%s
' % label_html self.vis.text(table_css + label_html, win=self.display_id + 2, opts=dict(title=title + ' labels')) except VisdomExceptionBase: self.create_visdom_connections() else: # show each image in a separate visdom panel; idx = 1 try: for label, image in visuals.items(): image_numpy = util.tensor2im(image) self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), win=self.display_id + idx) idx += 1 except VisdomExceptionBase: self.create_visdom_connections() if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. self.saved = True # save images to the disk for label, image in visuals.items(): image_numpy = util.tensor2im(image) img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) util.save_image(image_numpy, img_path) # update website webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) for n in range(epoch, 0, -1): label = list(visuals.keys())[0] img_path = 'epoch%.3d_%s.png' % (n, label) if not os.path.exists(os.path.join(webpage.img_dir, img_path)): continue webpage.add_header('epoch [%d]' % n) ims, txts, links = [], [], [] for label, image_numpy in visuals.items(): img_path = 'epoch%.3d_%s.png' % (n, label) ims.append(img_path) txts.append(label) links.append(img_path) webpage.add_images(ims, txts, links, width=self.win_size) webpage.save() def plot_current_losses(self, epoch, counter_ratio, losses): """display the current losses on visdom display: dictionary of error labels and values Parameters: epoch (int) -- current epoch counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 losses (OrderedDict) -- training losses stored in the format of (name, float) pairs """ if not hasattr(self, 'plot_data'): self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} self.plot_data['X'].append(epoch + counter_ratio) self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) try: self.vis.line( X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), Y=np.array(self.plot_data['Y']), opts={ 'title': self.name + ' loss over time', 'legend': self.plot_data['legend'], 'xlabel': 'epoch', 'ylabel': 'loss'}, win=self.display_id) except VisdomExceptionBase: self.create_visdom_connections() # losses: same format as |losses| of plot_current_losses def print_current_losses(self, epoch, iters, losses, t_comp, t_data): """print current losses on console; also save the losses to the disk Parameters: epoch (int) -- current epoch iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) losses (OrderedDict) -- training losses stored in the format of (name, float) pairs t_comp (float) -- computational time per data point (normalized by batch_size) t_data (float) -- data loading time per data point (normalized by batch_size) """ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) for k, v in losses.items(): message += '%s: %.3f ' % (k, v) print(message) # print the message with open(self.log_name, "a") as log_file: log_file.write('%s\n' % message) # save the message ================================================ FILE: train.py ================================================ # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Script for training a layered neural renderer on a video. You need to specify the dataset ('--dataroot') and experiment name ('--name'). Example: python train.py --dataroot ./datasets/reflection --name reflection --gpu_ids 0,1 The script first creates a model, dataset, and visualizer given the options. It then does standard network training. During training, it also visualizes/saves the images, prints/saves the loss plot, and saves the model. Use '--continue_train' to resume your previous training. The default setting is to first train the base model, which produces the low-resolution result (256x448), and then train the upsampling module to produce the 512x896 result. If the upsampling module is unnecessary, use '--n_epochs_upsample 0'. See options/base_options.py and options/train_options.py for more training options. """ import time from options.train_options import TrainOptions from third_party.data import create_dataset from third_party.models import create_model from third_party.util.visualizer import Visualizer import torch import numpy as np def main(): trainopt = TrainOptions() trainopt.parse() opt = trainopt.parse_dataset_meta() torch.manual_seed(opt.seed) np.random.seed(opt.seed) opt.do_upsampling = False # Train low-res network first dataset = create_dataset(opt, use_fast_loader=True) dataset_size = len(dataset) print('The number of training images = %d' % dataset_size) model = create_model(opt) model.setup(opt) # regular setup: load and print networks; create schedulers visualizer = Visualizer(opt) # Train base model (produces low-resolution output) train(model, dataset, visualizer, opt) # Optionally train upsampling module if opt.n_epochs_upsample > 0: opt.do_upsampling = True opt.batch_size = opt.batch_size_upsample # load dataset for upsampling dataset = create_dataset(opt, use_fast_loader=True) dataset_size = len(dataset) print('The number of training images = %d' % dataset_size) # set lambdas for upsampling training opt.lambda_mask = 0 opt.lambda_alpha_l0 = 0 opt.lambda_alpha_l1 = 0 opt.mask_loss_rolloff_epoch = -1 opt.jitter_rgb = 0 # reinit optimizers and schedulers, lambdas model.setup_train(opt) # freeze base model and just train upsampling module model.freeze_basenet() model.setup(opt) # update epoch count to resume training opt.epoch_count = opt.n_epochs + opt.n_epochs_decay + 1 opt.n_epochs += opt.n_epochs_upsample train(model, dataset, visualizer, opt) def train(model, dataset, visualizer, opt): dataset_size = len(dataset) total_iters = 0 # the total number of training iterations for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): # outer loop for different epochs; we save the model by , + epoch_start_time = time.time() # timer for entire epoch iter_data_time = time.time() # timer for data loading per iteration epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch model.update_lambdas(epoch) for i, data in enumerate(dataset): # inner loop within one epoch iter_start_time = time.time() # timer for computation per iteration if i % opt.print_freq == 0: t_data = iter_start_time - iter_data_time total_iters += opt.batch_size epoch_iter += opt.batch_size model.set_input(data) model.optimize_parameters() if i % opt.print_freq == 0: # print training losses and save logging information to the disk losses = model.get_current_losses() t_comp = (time.time() - iter_start_time) / opt.batch_size visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) if opt.display_id > 0: visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) iter_data_time = time.time() if epoch % opt.display_freq == 1: # display images on visdom and save images to a HTML file save_result = epoch % opt.update_html_freq == 1 model.compute_visuals() visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) if epoch % opt.save_latest_freq == 0: # cache our latest model every epochs print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) save_suffix = 'epoch_%d' % epoch if opt.save_by_epoch else 'latest' model.save_networks(save_suffix) model.update_learning_rate() # update learning rates at the end of every epoch. print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) if __name__ == '__main__': main()