[
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# Layered Neural Rendering in PyTorch\n\nThis repository contains training code for the examples in the SIGGRAPH Asia 2020 paper \"[Layered Neural Rendering for Retiming People in Video](https://retiming.github.io/).\"\n\n<img src='./img/teaser.gif' height=\"160px\"/>\n\nThis is not an officially supported Google product.\n\n\n## Prerequisites\n- Linux\n- Python 3.6+\n- NVIDIA GPU + CUDA CuDNN\n\n## Installation\nThis code has been tested with PyTorch 1.4 and Python 3.8.\n\n- Install [PyTorch](http://pytorch.org) 1.4 and other dependencies.\n  - For pip users, please type the command `pip install -r requirements.txt`.\n  - For Conda users, you can create a new Conda environment using `conda env create -f environment.yml`.\n\n## Data Processing\n- Download the data for a video used in our paper (e.g. \"reflection\"):\n```bash\nbash ./datasets/download_data.sh reflection\n```\n- Or alternatively, download all the data by specifying `all`.\n- Download the pretrained keypoint-to-UV model weights:\n```bash\nbash ./scripts/download_kp2uv_model.sh\n``` \nThe pretrained model will be saved at `./checkpoints/kp2uv/latest_net_Kp2uv.pth`.\n- Generate the UV maps from the keypoints:\n```bash\nbash datasets/prepare_iuv.sh ./datasets/reflection\n```\n## Training\n- To train a model on a video (e.g. \"reflection\"), run:\n```bash\npython train.py --name reflection --dataroot ./datasets/reflection --gpu_ids 0,1\n```\n- To view training results and loss plots, visit the URL http://localhost:8097.\nIntermediate results are also at `./checkpoints/reflection/web/index.html`.\n\nYou can find more scripts in the `scripts` directory, e.g. `run_${VIDEO}.sh` which combines data processing, training, and saving layer results for a video. \n\n**Note**:\n- It is recommended to use >=2 GPUs, each with >=16GB memory.\n- The training script first trains the low-resolution model for `--num_epochs` at `--batch_size`, and then trains the upsampling module for `--num_epochs_upsample` at `--batch_size_upsample`.\nIf you do not need the upsampled result, pass `--num_epochs_upsample 0`.\n- Training the upsampling module requires ~2.5x memory as the low-resolution model, so set `batch_size_upsample` accordingly.\nThe provided scripts set the batch sizes appropriately for 2 GPUs with 16GB memory.\n- GPU memory scales linearly with the number of layers.\n\n## Saving layer results from a trained model\n- Run the trained model:\n```bash\npython test.py --name reflection --dataroot ./datasets/reflection --do_upsampling\n```\n- The results (RGBA layers, videos) will be saved to `./results/reflection/test_latest/`.\n- Passing `--do_upsampling` uses the results of the upsampling module. If the upsampling module hasn't been trained (`num_epochs_upsample=0`), then remove this flag.\n\n## Custom video\nTo train on your own video, you will have to preprocess the data:\n1. Extract the frames, e.g.\n    ```\n    mkdir ./datasets/my_video && cd ./datasets/my_video \n    mkdir rgb && ffmpeg -i video.mp4 rgb/%04d.png\n    ```\n1. Resize the video to 256x448 and save the frames in `my_video/rgb_256`, and resize the video to 512x896 and save in `my_video/rgb_512`.\n1. Run [AlphaPose and Pose Tracking](https://github.com/MVIG-SJTU/AlphaPose) on the frames. Save results as `my_video/keypoints.json`\n1. Create `my_video/metadata.json` following [these instructions](docs/data.md).\n1. If your video has camera motion, either (1) stabilize the video, or (2) maintain the camera motion by computing homographies and saving as `my_video/homographies.txt`.\nSee `scripts/run_cartwheel.sh` for a training example with camera motion, and see `./datasets/cartwheel/homographies.txt` for formatting.\n\n**Note**: Videos that are suitable for our method have the following attributes:\n- Static camera or limited camera motion that can be represented with a homography.\n- Limited number of people, due to GPU memory limitations. We tested up to 7 people and 7 layers.\nMultiple people can be grouped onto the same layer, though they cannot be individually retimed.\n- People that move relative to the background (static people will be absorbed into the background layer).\n- We tested a video length of up to 200 frames (~7 seconds).\n\n## Citation\nIf you use this code for your research, please cite the following paper:\n```\n@inproceedings{lu2020,\n  title={Layered Neural Rendering for Retiming People in Video},\n  author={Lu, Erika and Cole, Forrester and Dekel, Tali and Xie, Weidi and Zisserman, Andrew and Salesin, David and Freeman, William T and Rubinstein, Michael},\n  booktitle={SIGGRAPH Asia},\n  year={2020}\n}\n```\n\n## Acknowledgments\nThis code is based on [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).\n"
  },
  {
    "path": "data/__init__.py",
    "content": ""
  },
  {
    "path": "data/kpuv_dataset.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom third_party.data.base_dataset import BaseDataset\nfrom PIL import Image, ImageDraw\nimport json\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport os\nimport torchvision.transforms as transforms\n\n\nclass KpuvDataset(BaseDataset):\n    \"\"\"A dataset class for keypoint data.\n\n    It assumes that the directory specified by 'dataroot' contains the file 'keypoints.json'.\n    \"\"\"\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        parser.add_argument('--inp_size', type=int, default=256, help='image size')\n        return parser\n\n    def __init__(self, opt):\n        \"\"\"Initialize this dataset class by reading in keypoints.\n\n        Parameters:\n            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions\n        \"\"\"\n        BaseDataset.__init__(self, opt)\n\n        self.inp_size = opt.inp_size\n        inner_crop_size = int(.75*self.inp_size)\n        kps = []\n        image_paths = []\n        with open(os.path.join(self.root, 'keypoints.json'), 'rb') as f:\n            kp_data = json.load(f)\n        for frame in sorted(kp_data):\n            for skeleton in kp_data[frame]:\n                id = skeleton['idx']\n                image_paths.append(f'{id:02d}_{frame}')\n                kp = np.array(skeleton['keypoints']).reshape(17, 3)\n                kp = self.crop_kps(kp, crop_size=self.inp_size, inner_crop_size=inner_crop_size)\n                kps.append(kp)\n\n        self.keypoints = kps\n        self.image_paths = image_paths  # filenames for output UVs\n\n        # for keypoint rendering\n        self.cmap = plt.cm.get_cmap(\"hsv\", 17)\n        self.color_seq = np.array([ 9, 14,  6,  7, 13, 16,  2, 11,  3,  5, 10, 15,  1,  8,  0, 12,  4])\n        self.pairs = [[0,1],[0,2],[1,3],[2,4],[5,6],[5,7],[7,9],[6,8],[8,10],[11,12],[11,13],[13,15],[12,14],[14,16],[6,12],[5,11]]\n\n\n    def __getitem__(self, index):\n        \"\"\"Return a data point and its metadata information.\n\n        Parameters:\n            index - - a random integer for data indexing\n\n        Returns a dictionary that contains keypoints and path\n            keyoints (tensor) - - an RGB image representing a skeleton\n            path (str) - - an identifying filename that can be used for saving the result\n        \"\"\"\n        uv_path = self.image_paths[index]  # output UV path\n        kps = self.keypoints[index]\n\n        # draw keypoints\n        kp_im = Image.new(size=(self.inp_size, self.inp_size), mode='RGB')\n        draw = ImageDraw.Draw(kp_im)\n        self.render_kps(kps, draw)\n        kp_im = transforms.ToTensor()(kp_im)\n        kp_im = 2 * kp_im - 1\n\n        return {'keypoints': kp_im, 'path': uv_path}\n\n    def __len__(self):\n        \"\"\"Return the total number of images.\"\"\"\n        return len(self.image_paths)\n\n    def crop_kps(self, kps, crop_size=256, inner_crop_size=192):\n        \"\"\"'Crops' keypoints to fit into ['crop_size', 'crop_size'].\n\n        Parameters:\n            kps - - a numpy array of shape [17, 2], where the keypoint order is X, Y\n            crop_size - - the new size of the world, which the keypoints will be centered inside\n            inner_crop_size - - the box size that the keypoints will fit inside (must be <=crop_size)\n\n        Returns keypoints mapped to fit inside a box of 'inner_crop_size', centered in 'crop_size'\n        \"\"\"\n        # get coordinates of bounding box, in original image coordinates\n        left = kps[:, 0].min()\n        right = kps[:, 0].max()\n        top = kps[:, 1].min()\n        bottom = kps[:, 1].max()\n\n        # map keypoints\n        keypoints = kps.copy()\n        center = ((right + left) // 2, (bottom + top) // 2)\n        # first place center of bounding box at origin\n        keypoints[:, 0] -= center[0]\n        keypoints[:, 1] -= center[1]\n        # scale bounding box to inner_crop_size\n        scale = float(inner_crop_size) / max(right - left, bottom - top)\n        keypoints[:, :2] *= scale\n        # move center to crop_size//2\n        keypoints[:, :2] += crop_size // 2\n        new_kps = keypoints\n\n        return new_kps\n\n    def render_kps(self, keypoints, draw, thresh=1., min_weight=0.25):\n        \"\"\"Render skeleton as RGB image.\n\n        Parameters:\n            keypoints - - a numpy array of shape [17, 3], where the keypoint order is X, Y, score\n            draw - - an ImageDraw object, which the keypoints will be drawn onto\n            thresh - - keypoints with a confidence score below this value will have a color weighted by the score\n            min_weight - - minimum weighting for color (scores will be mapped to the range [min_weight, 1])\n        \"\"\"\n        # first draw keypoints\n        ksize = 3\n        for i in range(keypoints.shape[0]):\n            x1 = keypoints[i,0] - ksize\n            x2 = keypoints[i,0] + ksize\n            y1 = keypoints[i,1] - ksize\n            y2 = keypoints[i,1] + ksize\n            if x1 < 0 or y1 < 0 or x2 > self.inp_size or y2 > self.inp_size:\n                continue\n            color = np.array(self.cmap(self.color_seq[i]))\n            if keypoints.shape[1] > 2:\n                score = keypoints[i,2]\n                if score < thresh:  # weight color by confidence score\n                    # first map [0,1] -> [min_weight, 1]\n                    alpha_weight = score * (1.-min_weight) + min_weight\n                    color[:3] *= alpha_weight\n            color = (255*color).astype('uint8')\n            draw.rectangle([x1, y1, x2, y2], fill=tuple(color))\n        # now draw segments\n        for pair in self.pairs:\n            x1 = keypoints[pair[0],0]\n            y1 = keypoints[pair[0],1]\n            x2 = keypoints[pair[1],0]\n            y2 = keypoints[pair[1],1]\n            if x1 < 0 or y1 < 0 or x2 < 0 or y2 < 0 or x1 > self.inp_size or y1 > self.inp_size or x2 > self.inp_size or y2 > self.inp_size:\n                continue\n            avg_color = .5*(np.array(self.cmap(self.color_seq[pair[0]])) + np.array(self.cmap(self.color_seq[pair[1]])))\n            if keypoints.shape[1] > 2:\n                score = min(keypoints[pair[0],2], keypoints[pair[1],2])\n                if score < thresh:\n                    alpha_weight = score * (1.-min_weight) + min_weight\n                    avg_color[:3] *= alpha_weight # alpha channel weigh by score\n            avg_color = (255*avg_color).astype('uint8')\n            draw.line([x1,y1,x2,y2], fill=tuple(avg_color), width=3)"
  },
  {
    "path": "data/layered_video_dataset.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport cv2\nfrom third_party.data.base_dataset import BaseDataset\nfrom third_party.data.image_folder import make_dataset\nfrom PIL import Image\nimport torchvision.transforms as transforms\nimport torch.nn.functional as F\nimport os\nimport torch\nimport numpy as np\nimport json\n\n\nclass LayeredVideoDataset(BaseDataset):\n    \"\"\"A dataset class for video layers.\n\n    It assumes that the directory specified by 'dataroot' contains metadata.json, and the directories iuv, rgb_256, and rgb_512.\n    The 'iuv' directory should contain directories named 01, 02, etc. for each layer, each containing per-frame UV images.\n    \"\"\"\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        parser.add_argument('--height', type=int, default=256, help='image height')\n        parser.add_argument('--width', type=int, default=448, help='image width')\n        parser.add_argument('--trimap_width', type=int, default=20, help='trimap gray area width')\n        parser.add_argument('--use_mask_images', action='store_true', default=False, help='use custom masks')\n        parser.add_argument('--use_homographies', action='store_true', default=False, help='handle camera motion')\n        return parser\n\n    def __init__(self, opt):\n        \"\"\"Initialize this dataset class.\n\n        Parameters:\n            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions\n        \"\"\"\n        BaseDataset.__init__(self, opt)\n        rgbdir = os.path.join(opt.dataroot, 'rgb_256')\n        if opt.do_upsampling:\n            rgbdir = os.path.join(opt.dataroot, 'rgb_512')\n        uvdir = os.path.join(opt.dataroot, 'iuv')\n        self.image_paths = sorted(make_dataset(rgbdir, opt.max_dataset_size))\n        n_images = len(self.image_paths)\n        layers = sorted(os.listdir(uvdir))\n        layers = [l for l in layers if l.isdigit()]\n        self.iuv_paths = []\n        for l in layers:\n            layer_iuv_paths = sorted(make_dataset(os.path.join(uvdir, l), n_images))\n            if len(layer_iuv_paths) != n_images:\n                print(f'UNEQUAL NUMBER OF IMAGES AND IUVs: {len(layer_iuv_paths)} and {n_images}')\n            self.iuv_paths.append(layer_iuv_paths)\n\n        # set up per-frame compositing order\n        with open(os.path.join(opt.dataroot, 'metadata.json')) as f:\n            metadata = json.load(f)\n        if 'composite_order' in metadata:\n            self.composite_order = metadata['composite_order']\n        else:\n            self.composite_order = [tuple(range(1, 1 + len(layers)))] * n_images\n\n        if opt.use_homographies:\n            self.init_homographies(os.path.join(opt.dataroot, 'homographies.txt'), n_images)\n\n    def __getitem__(self, index):\n        \"\"\"Return a data point and its metadata information.\n\n        Parameters:\n            index - - a random integer for data indexing\n\n        Returns a dictionary that contains:\n            image (tensor) - - the original RGB frame to reconstruct\n            uv_map (tensor) - - the UV maps for all layers, concatenated channel-wise\n            mask (tensor) - - the trimaps for all layers, concatenated channel-wise\n            pids (tensor) - - the person IDs for all layers, concatenated channel-wise\n            image_path (str) - - image path\n        \"\"\"\n        # Read the target image.\n        image_path = self.image_paths[index]\n        target_image = self.load_and_process_image(image_path)\n\n        # Read the layer IUVs and convert to network inputs.\n        people_layers = [self.load_and_process_iuv(self.iuv_paths[l - 1][index], index) for l in\n                         self.composite_order[index]]\n        iuv_h, iuv_w = people_layers[0][0].shape[-2:]\n\n        # Create the background layer UV from homographies.\n        background_layer = self.get_background_inputs(index, iuv_w, iuv_h)\n\n        uv_maps, masks, pids = zip(*([background_layer] + people_layers))\n        uv_maps = torch.cat(uv_maps)  # [L*2, H, W]\n        masks = torch.stack(masks)  # [L, H, W]\n        pids = torch.stack(pids)  # [L, H, W]\n\n        if self.opt.use_mask_images:\n            for i in range(1, len(people_layers)):\n                mask_path = os.path.join(self.opt.dataroot, 'mask', f'{i:02d}', os.path.basename(image_path))\n                if os.path.exists(mask_path):\n                    mask = Image.open(mask_path).convert('L').resize((masks.shape[-1], masks.shape[-2]))\n                    mask = transforms.ToTensor()(mask) * 2 - 1\n                    masks[i] = mask\n\n        transform_params = self.get_params(do_jitter=self.opt.phase=='train')\n        pids = self.apply_transform(pids, transform_params, 'nearest')\n        masks = self.apply_transform(masks, transform_params, 'bilinear')\n        uv_maps = self.apply_transform(uv_maps, transform_params, 'nearest')\n        image_transform_params = transform_params\n        if self.opt.do_upsampling:\n            image_transform_params = { p: transform_params[p] * 2 for p in transform_params}\n        target_image = self.apply_transform(target_image, image_transform_params, 'bilinear')\n\n        return {'image': target_image, 'uv_map': uv_maps, 'mask': masks, 'pids': pids, 'image_path': image_path}\n\n    def __len__(self):\n        \"\"\"Return the total number of images.\"\"\"\n        return len(self.image_paths)\n\n    def get_params(self, do_jitter=False, jitter_rate=0.75):\n        \"\"\"Get transformation parameters.\"\"\"\n        if do_jitter:\n            if np.random.uniform() > jitter_rate or self.opt.do_upsampling:\n                scale = 1.\n            else:\n                scale = np.random.uniform(1, 1.25)\n            jitter_size = (scale * np.array([self.opt.height, self.opt.width])).astype(np.int)\n            start1 = np.random.randint(jitter_size[0] - self.opt.height + 1)\n            start2 = np.random.randint(jitter_size[1] - self.opt.width + 1)\n        else:\n            jitter_size = np.array([self.opt.height, self.opt.width])\n            start1 = 0\n            start2 = 0\n        crop_pos = np.array([start1, start2])\n        crop_size = np.array([self.opt.height, self.opt.width])\n        return {'jitter size': jitter_size, 'crop pos': crop_pos, 'crop size': crop_size}\n\n    def apply_transform(self, data, params, interp_mode='bilinear'):\n        \"\"\"Apply the transform to the data tensor.\"\"\"\n        tensor_size = params['jitter size'].tolist()\n        crop_pos = params['crop pos']\n        crop_size = params['crop size']\n        data = F.interpolate(data.unsqueeze(0), size=tensor_size, mode=interp_mode).squeeze(0)\n        data = data[:, crop_pos[0]:crop_pos[0] + crop_size[0], crop_pos[1]:crop_pos[1] + crop_size[1]]\n        return data\n\n    def init_homographies(self, homography_path, n_images):\n        \"\"\"Read homography file and set up homography data.\"\"\"\n        with open(homography_path) as f:\n            h_data = f.readlines()\n        h_scale = h_data[0].rstrip().split(' ')\n        self.h_scale_x = int(h_scale[1])\n        self.h_scale_y = int(h_scale[2])\n        h_bounds = h_data[1].rstrip().split(' ')\n        self.h_bounds_x = [float(h_bounds[1]), float(h_bounds[2])]\n        self.h_bounds_y = [float(h_bounds[3]), float(h_bounds[4])]\n        homographies = h_data[2:2 + n_images]\n        homographies = [torch.from_numpy(np.array(line.rstrip().split(' ')).astype(np.float32).reshape(3, 3)) for line\n                        in\n                        homographies]\n        self.homographies = homographies\n\n    def load_and_process_image(self, im_path):\n        \"\"\"Read image file and return as tensor in range [-1, 1].\"\"\"\n        image = Image.open(im_path).convert('RGB')\n        image = transforms.ToTensor()(image)\n        image = 2 * image - 1\n        return image\n\n    def load_and_process_iuv(self, iuv_path, i):\n        \"\"\"Read IUV file and convert to network inputs.\"\"\"\n        iuv_map = Image.open(iuv_path).convert('RGBA')\n        iuv_map = transforms.ToTensor()(iuv_map)\n        uv_map, mask, pids = self.iuv2input(iuv_map, i)\n        return uv_map, mask, pids\n\n    def iuv2input(self, iuv, index):\n        \"\"\"Create network inputs from IUV.\n        Parameters:\n            iuv - - a tensor of shape [4, H, W], where the channels are: body part ID, U, V, person ID.\n            index - - index of iuv\n\n        Returns:\n            uv (tensor) - - a UV map for a single layer, ready to pass to grid sampler (values in range [-1,1])\n            mask (tensor) - - the corresponding mask\n            person_id (tensor) - - the person IDs\n\n        grid sampler indexes into texture map of size tile_width x tile_width*n_textures\n        \"\"\"\n        # Extract body part and person IDs.\n        part_id = (iuv[0] * 255 / 10).round()\n        part_id[part_id > 24] = 24\n        part_id_mask = (part_id > 0).float()\n        person_id = (255 - 255 * iuv[-1]).round()  # person ID is saved as 255 - person_id\n        person_id *= part_id_mask  # background id is 0\n        maxId = self.opt.n_textures // 24\n        person_id[person_id>maxId] = maxId\n\n        # Convert body part ID to texture map ID.\n        # Essentially, each of the 24 body parts for each person, plus the background have their own texture 'tile'\n        # The tiles are concatenated horizontally to create the texture map.\n        tex_id = part_id + part_id_mask * 24 * (person_id - 1)\n\n        uv = iuv[1:3]\n        # Convert the per-body-part UVs to UVs that correspond to the full texture map.\n        uv[0] += tex_id\n\n        # Get the mask.\n        bg_mask = (tex_id == 0).float()\n        mask = 1.0 - bg_mask\n        mask = mask * 2 - 1  # make 1 the foreground and -1 the background mask\n        mask = self.mask2trimap(mask)\n\n        # Composite background UV behind person UV.\n        h, w = iuv.shape[1:]\n        bg_uv = self.get_background_uv(index, w, h)\n        uv = bg_mask * bg_uv + (1 - bg_mask) * uv\n\n        # Map to [-1, 1] range.\n        uv[0] /= self.opt.n_textures\n        uv = uv * 2 - 1\n        uv = torch.clamp(uv, -1, 1)\n\n        return uv, mask, person_id\n\n    def get_background_inputs(self, index, w, h):\n        \"\"\"Return data for background layer at 'index'.\"\"\"\n        uv = self.get_background_uv(index, w, h)\n        # normalize to correct range, of full texture atlas\n        uv[0] /= self.opt.n_textures\n        uv = uv * 2 - 1  # [0,1] -> [-1,1]\n        uv = torch.clamp(uv, -1, 1)\n\n        mask = -torch.ones(*uv.shape[1:])\n        pids = torch.zeros(*uv.shape[1:])\n        return uv, mask, pids\n\n    def get_background_uv(self, index, w, h):\n        \"\"\"Return background layer UVs at 'index' (output range [0, 1]).\"\"\"\n        ramp_u = torch.linspace(0, 1, steps=w).unsqueeze(0).repeat(h, 1)\n        ramp_v = torch.linspace(0, 1, steps=h).unsqueeze(-1).repeat(1, w)\n        ramp = torch.stack([ramp_u, ramp_v], 0)\n        if hasattr(self, 'homographies'):\n            # scale to [0, orig width/height]\n            ramp[0] *= self.h_scale_x\n            ramp[1] *= self.h_scale_y\n            # apply homography\n            ramp = ramp.reshape(2, -1)  # [2, H, W]\n            H = self.homographies[index]\n            [xt, yt] = self.transform2h(ramp[0], ramp[1], torch.inverse(H))\n            # scale from world to [0,1]\n            xt -= self.h_bounds_x[0]\n            xt /= (self.h_bounds_x[1] - self.h_bounds_x[0])\n            yt -= self.h_bounds_y[0]\n            yt /= (self.h_bounds_y[1] - self.h_bounds_y[0])\n            # restore shape\n            ramp = torch.stack([xt.reshape(h, w), yt.reshape(h, w)], 0)\n        return ramp\n\n    def transform2h(self, x, y, m):\n        \"\"\"Applies 2d homogeneous transformation.\"\"\"\n        A = torch.matmul(m, torch.stack([x, y, torch.ones(len(x))]))\n        xt = A[0, :] / A[2, :]\n        yt = A[1, :] / A[2, :]\n        return xt, yt\n\n    def mask2trimap(self, mask):\n        \"\"\"Convert binary mask to trimap with values in [-1, 0, 1].\"\"\"\n        fg_mask = (mask > 0).float()\n        bg_mask = (mask < 0).float()\n        trimap_width = getattr(self.opt, 'trimap_width', 20)\n        trimap_width *= bg_mask.shape[-1] / self.opt.width\n        trimap_width = int(trimap_width)\n        bg_mask = cv2.erode(bg_mask.numpy(), kernel=np.ones((trimap_width, trimap_width)), iterations=1)\n        bg_mask = torch.from_numpy(bg_mask)\n        mask = fg_mask - bg_mask\n        return mask\n"
  },
  {
    "path": "datasets/download_data.sh",
    "content": "#!/bin/bash\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nNAME=$1\n\nif [[ $NAME != \"cartwheel\" && $NAME != \"reflection\" && $NAME != \"splash\" &&  $NAME != \"trampoline\" && $NAME != \"all\" ]]; then\n    echo \"Available videos are: cartwheel, reflection, splash, trampoline\"\n    exit 1\nfi\n\nif [[ $NAME == \"all\" ]]; then\n  declare -a NAMES=(\"cartwheel\" \"reflection\" \"splash\" \"trampoline\")\nelse\n  declare -a NAMES=($NAME)\nfi\n\nfor NAME in \"${NAMES[@]}\"\ndo\n  echo \"Specified [$NAME]\"\n  URL=https://www.robots.ox.ac.uk/~erika/retiming/data/$NAME.zip\n  ZIP_FILE=./datasets/$NAME.zip\n  TARGET_DIR=./datasets/$NAME/\n  wget -N $URL -O $ZIP_FILE\n  mkdir $TARGET_DIR\n  unzip $ZIP_FILE -d ./datasets/\n  rm $ZIP_FILE\ndone"
  },
  {
    "path": "datasets/iuv_crop2full.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Convert UV crops to full UV maps.\"\"\"\nimport os\nimport sys\nimport json\nfrom PIL import Image\nimport numpy as np\n\n\ndef place_crop(crop, image, center_x, center_y):\n    \"\"\"Place the crop in the image at the specified location.\"\"\"\n    im_height, im_width = image.shape[:2]\n    crop_height, crop_width = crop.shape[:2]\n\n    left = center_x - crop_width // 2\n    right = left + crop_width\n    top = center_y - crop_height // 2\n    bottom = top + crop_height\n\n    adjusted_crop = crop  # remove regions of crop that go beyond image bounds\n    if left < 0:\n        adjusted_crop = adjusted_crop[:, -left:]\n    if right > im_width:\n        adjusted_crop = adjusted_crop[:, :(im_width - right)]\n    if top < 0:\n        adjusted_crop = adjusted_crop[-top:]\n    if bottom > im_height:\n        adjusted_crop = adjusted_crop[:(im_height - bottom)]\n    crop_mask = (adjusted_crop > 0).astype(crop.dtype).sum(-1, keepdims=True)\n    image[max(0, top):min(im_height, bottom), max(0, left):min(im_width, right)] *= (1 - crop_mask)\n    image[max(0, top):min(im_height, bottom), max(0, left):min(im_width, right)] += adjusted_crop\n\n    return image\n\ndef crop2full(keypoints_path, metadata_path, uvdir, outdir):\n    \"\"\"Create each frame's layer UVs from predicted UV crops\"\"\"\n    with open(keypoints_path) as f:\n        kp_data = json.load(f)\n\n    # Get all people ids\n    people_ids = set()\n    for frame in kp_data:\n        for skeleton in kp_data[frame]:\n            people_ids.add(skeleton['idx'])\n    people_ids = sorted(list(people_ids))\n\n    with open(metadata_path) as f:\n        metadata = json.load(f)\n\n    orig_size = np.array(metadata['alphapose_input_size'][::-1])\n    out_size = np.array(metadata['size_LR'][::-1])\n\n    if 'people_layers' in metadata:\n        people_layers = metadata['people_layers']\n    else:\n        people_layers = [[pid] for pid in people_ids]\n\n    # Create output directories.\n    for layer_i in range(1, 1 + len(people_layers)):\n        os.makedirs(os.path.join(outdir, f'{layer_i:02d}'), exist_ok=True)\n    print(f'Writing UVs to {outdir}')\n\n    for frame in sorted(kp_data):\n        for layer_i, layer in enumerate(people_layers, 1):\n            out_path = os.path.join(outdir, f'{layer_i:02d}', frame)\n            sys.stdout.flush()\n            sys.stdout.write('processing frame %s\\r' % out_path)\n            uv_map = np.zeros([out_size[0], out_size[1], 4])\n            for person_id in layer:\n                matches = [p for p in kp_data[frame] if p['idx'] == person_id]\n                if len(matches) == 0:  # person doesn't appear in this frame\n                    continue\n                skeleton = matches[0]\n                kps = np.array(skeleton['keypoints']).reshape(17, 3)\n                # Get kps bounding box.\n                left = kps[:, 0].min()\n                right = kps[:, 0].max()\n                top = kps[:, 1].min()\n                bottom = kps[:, 1].max()\n                height = bottom - top\n                width = right - left\n                orig_crop_size = max(height, width)\n                orig_center_x = (left + right) // 2\n                orig_center_y = (top + bottom) // 2\n\n                # read predicted uv map\n                uv_crop_path = os.path.join(uvdir, f'{person_id:02d}_{os.path.basename(out_path)[:-4]}_output_uv.png')\n                if os.path.exists(uv_crop_path):\n                    uv_crop = np.array(Image.open(uv_crop_path))\n                else:\n                    uv_crop = np.zeros([256, 256, 3])\n\n                # add person ID channel\n                person_mask = (uv_crop[..., 0:1] > 0).astype('uint8')\n                person_ids = (255 - person_id) * person_mask\n                uv_crop = np.concatenate([uv_crop, person_ids], -1)\n\n                # scale crop to desired output size\n                # 256 is the crop size, 192 is the inner crop size\n                out_crop_size = orig_crop_size * 256./192 * out_size / orig_size\n                out_crop_size = out_crop_size.astype(np.int)\n                uv_crop = uv_crop.astype(np.uint8)\n                uv_crop = np.array(Image.fromarray(uv_crop).resize((out_crop_size[1], out_crop_size[0]), resample=Image.NEAREST))\n\n                # scale center coordinate accordingly\n                out_center_x = (orig_center_x * out_size[1] / orig_size[1]).astype(np.int)\n                out_center_y = (orig_center_y * out_size[0] / orig_size[0]).astype(np.int)\n\n                # Place UV crop in full UV map and save.\n                uv_map = place_crop(uv_crop, uv_map, out_center_x, out_center_y)\n            uv_map = Image.fromarray(uv_map.astype('uint8'))\n            uv_map.save(out_path)\n\n\nif __name__ == \"__main__\":\n    import argparse\n    arguments = argparse.ArgumentParser()\n    arguments.add_argument('--dataroot', type=str)\n    opt = arguments.parse_args()\n\n    keypoints_path = os.path.join(opt.dataroot, 'keypoints.json')\n    metadata_path = os.path.join(opt.dataroot, 'metadata.json')\n    uvdir = os.path.join(opt.dataroot, 'kp2uv/test_latest/images')\n    outdir = os.path.join(opt.dataroot, 'iuv')\n    crop2full(keypoints_path, metadata_path, uvdir, outdir)\n"
  },
  {
    "path": "datasets/prepare_iuv.sh",
    "content": "#!/bin/bash\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nDATA_PATH=$1\n# Predict UVs from keypoints and save the crops.\npython run_kp2uv.py --model kp2uv --dataroot $DATA_PATH --results_dir $DATA_PATH\n# Convert the cropped UVs to full UV maps.\npython datasets/iuv_crop2full.py --dataroot $DATA_PATH\n"
  },
  {
    "path": "docs/contributing.md",
    "content": "# How to Contribute\n\nWe'd love to accept your patches and contributions to this project. There are\njust a few small guidelines you need to follow.\n\n## Contributor License Agreement\n\nContributions to this project must be accompanied by a Contributor License\nAgreement (CLA). You (or your employer) retain the copyright to your\ncontribution; this simply gives us permission to use and redistribute your\ncontributions as part of the project. Head over to\n<https://cla.developers.google.com/> to see your current agreements on file or\nto sign a new one.\n\nYou generally only need to submit a CLA once, so if you've already submitted one\n(even if it was for a different project), you probably don't need to do it\nagain.\n\n## Code reviews\n\nAll submissions, including submissions by project members, require review. We\nuse GitHub pull requests for this purpose. Consult\n[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more\ninformation on using pull requests.\n\n## Community Guidelines\n\nThis project follows\n[Google's Open Source Community Guidelines](https://opensource.google/conduct/)."
  },
  {
    "path": "docs/data.md",
    "content": "### Data\nThe data directory for a video is structured as follows:\n```\nvideo_name/\n|-- rgb_256/\n|   |-- 0001.png, etc.\n|-- rgb_512/\n|   |-- 0001.png, etc.\n|-- mask/ (optional)\n|-- |-- 01, etc.\n|-- |-- |-- 0001.png, etc.   \n|-- keypoints.json\n|-- metadata.json\n|-- homographies.txt (optional)\n```\n- `metadata.json` contains a dictionary:\n```\n'alphapose_input_size': [width, height]  # size of frames input to AlphaPose\n'size_LR': [width, height]  # size of low-resolution frames (multiple of 16; height should be 256)\n'n_textures': int  # number of texture maps required, calculated by 24*num_people + 1\n'composite_order': [[1, 2, 3], [1, 3, 2], ... ]  # optional per-frame back-to-front layer compositing order\n```\n- `keypoints.json` is in the format output by the [AlphaPose Pose Tracker](https://github.com/MVIG-SJTU/AlphaPose).\nSee [here](https://github.com/MVIG-SJTU/AlphaPose/tree/master/trackers/PoseFlow) for details."
  },
  {
    "path": "environment.yml",
    "content": "name: retiming\nchannels:\n  - pytorch\n  - defaults\ndependencies:\n- python=3.8\n- pytorch=1.4.0\n- pip:\n  - dominate==2.4.0\n  - torchvision==0.5.0\n  - Pillow>=6.1.0\n  - numpy==1.19.2\n  - visdom==0.1.8\n  - opencv-python>=4.2.0\n  - matplotlib\n\n"
  },
  {
    "path": "models/__init__.py",
    "content": ""
  },
  {
    "path": "models/kp2uv_model.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom third_party.models.base_model import BaseModel\nfrom . import networks\n\n\nclass Kp2uvModel(BaseModel):\n    \"\"\"This class implements the keypoint-to-UV model (inference only).\"\"\"\n    @staticmethod\n    def modify_commandline_options(parser, is_train=True):\n        parser.set_defaults(dataset_mode='kpuv')\n        return parser\n\n    def __init__(self, opt):\n        \"\"\"Initialize this model class.\n\n        Parameters:\n            opt -- test options\n        \"\"\"\n        BaseModel.__init__(self, opt)\n        self.visual_names = ['keypoints', 'output_uv']\n        self.model_names = ['Kp2uv']\n        self.netKp2uv = networks.define_kp2uv(gpu_ids=self.gpu_ids)\n        self.isTrain = False  # only test mode supported\n\n        # Our program will automatically call <model.setup> to define schedulers, load networks, and print networks\n\n    def set_input(self, input):\n        \"\"\"Unpack input data from the dataloader.\n\n        Parameters:\n            input: a dictionary that contains the data itself and its metadata information.\n        \"\"\"\n        self.keypoints = input['keypoints'].to(self.device)\n        self.image_paths = input['path']\n\n    def forward(self):\n        \"\"\"Run forward pass. This will be called by <test>.\"\"\"\n        output = self.netKp2uv.forward(self.keypoints)\n        self.output_uv = self.output2rgb(output)\n\n    def output2rgb(self, output):\n        \"\"\"Convert network outputs to RGB image.\"\"\"\n        pred_id, pred_uv = output\n        _, pred_id_class = pred_id.max(1)\n        pred_id_class = pred_id_class.unsqueeze(1)\n        # extract UV from pred_uv (48 channels); select based on class ID\n        selected_uv = -1 * torch.ones(pred_uv.shape[0], 2, pred_uv.shape[2], pred_uv.shape[3], device=pred_uv.device)\n        for partid in range(1, 25):\n            mask = (pred_id_class == partid).float()\n            selected_uv *= (1. - mask)\n            selected_uv += mask * pred_uv[:, (partid - 1) * 2:(partid - 1) * 2 + 2]\n        pred_uv = selected_uv\n        rgb = torch.cat([pred_id_class.float() * 10 / 255. * 2 - 1, pred_uv], 1)\n        return rgb\n\n    def optimize_parameters(self):\n        pass\n"
  },
  {
    "path": "models/lnr_model.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom third_party.models.base_model import BaseModel\nfrom . import networks\nimport numpy as np\nimport torch.nn.functional as F\n\n\nclass LnrModel(BaseModel):\n    \"\"\"This class implements the layered neural rendering model for decomposing a video into layers.\"\"\"\n    @staticmethod\n    def modify_commandline_options(parser, is_train=True):\n        parser.set_defaults(dataset_mode='layered_video')\n        parser.add_argument('--texture_res', type=int, default=16, help='texture resolution')\n        parser.add_argument('--texture_channels', type=int, default=16, help='# channels for neural texture')\n        parser.add_argument('--n_textures', type=int, default=25, help='# individual texture maps, 24 per person (1 per body part) + 1 for background')\n        if is_train:\n            parser.add_argument('--lambda_alpha_l1', type=float, default=0.01, help='alpha L1 sparsity loss weight')\n            parser.add_argument('--lambda_alpha_l0', type=float, default=0.005, help='alpha L0 sparsity loss weight')\n            parser.add_argument('--alpha_l1_rolloff_epoch', type=int, default=200, help='turn off L1 alpha sparsity loss weight after this epoch')\n            parser.add_argument('--lambda_mask', type=float, default=50, help='layer matting loss weight')\n            parser.add_argument('--mask_thresh', type=float, default=0.02, help='turn off masking loss when error falls below this value')\n            parser.add_argument('--mask_loss_rolloff_epoch', type=int, default=-1, help='decrease masking loss after this epoch; if <0, use mask_thresh instead')\n            parser.add_argument('--n_epochs_upsample', type=int, default=500,\n                                help='number of epochs to train the upsampling module')\n            parser.add_argument('--batch_size_upsample', type=int, default=16, help='batch size for upsampling')\n            parser.add_argument('--jitter_rgb', type=float, default=0.2, help='amount of jitter to add to RGB')\n            parser.add_argument('--jitter_epochs', type=int, default=400, help='number of epochs to jitter RGB')\n        parser.add_argument('--do_upsampling', action='store_true', help='whether to use upsampling module')\n\n        return parser\n\n    def __init__(self, opt):\n        \"\"\"Initialize this model class.\n\n        Parameters:\n            opt -- training/test options\n        \"\"\"\n        BaseModel.__init__(self, opt)\n        # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.\n        self.visual_names = ['target_image', 'reconstruction', 'rgba_vis', 'alpha_vis', 'input_vis']\n        self.model_names = ['LNR']\n        self.netLNR = networks.define_LNR(opt.num_filters, opt.texture_channels, opt.texture_res, opt.n_textures, gpu_ids=self.gpu_ids)\n        self.do_upsampling = opt.do_upsampling\n        if self.isTrain:\n            self.setup_train(opt)\n\n        # Our program will automatically call <model.setup> to define schedulers, load networks, and print networks\n\n    def setup_train(self, opt):\n        \"\"\"Setup the model for training mode.\"\"\"\n        print('setting up model')\n        # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.\n        self.loss_names = ['total', 'recon', 'alpha_reg', 'mask']\n        self.visual_names = ['target_image', 'reconstruction', 'rgba_vis', 'alpha_vis', 'input_vis']\n        self.do_upsampling = opt.do_upsampling\n        if not self.do_upsampling:\n            self.visual_names += ['mask_vis']\n        self.criterionLoss = torch.nn.L1Loss()\n        self.criterionLossMask = networks.MaskLoss().to(self.device)\n        self.lambda_mask = opt.lambda_mask\n        self.lambda_alpha_l0 = opt.lambda_alpha_l0\n        self.lambda_alpha_l1 = opt.lambda_alpha_l1\n        self.mask_loss_rolloff_epoch = opt.mask_loss_rolloff_epoch\n        self.jitter_rgb = opt.jitter_rgb\n        self.do_upsampling = opt.do_upsampling\n        self.optimizer = torch.optim.Adam(self.netLNR.parameters(), lr=opt.lr)\n        self.optimizers = [self.optimizer]\n\n    def set_input(self, input):\n        \"\"\"Unpack input data from the dataloader and perform necessary pre-processing steps.\n\n        Parameters:\n            input: a dictionary that contains the data itself and its metadata information.\n        \"\"\"\n        self.target_image = input['image'].to(self.device)\n        if self.isTrain and self.jitter_rgb > 0:\n            # add brightness jitter to rgb\n            self.target_image += self.jitter_rgb * torch.randn(self.target_image.shape[0], 1, 1, 1).to(self.device)\n            self.target_image = torch.clamp(self.target_image, -1, 1)\n        self.input_uv = input['uv_map'].to(self.device)\n        self.input_id = input['pids'].to(self.device)\n        self.mask = input['mask'].to(self.device)\n        self.image_paths = input['image_path']\n\n    def gen_crop_params(self, orig_h, orig_w, crop_size=256):\n        \"\"\"Generate random square cropping parameters.\"\"\"\n        starty = np.random.randint(orig_h - crop_size + 1)\n        startx = np.random.randint(orig_w - crop_size + 1)\n        endy = starty + crop_size\n        endx = startx + crop_size\n        return starty, endy, startx, endx\n\n    def forward(self):\n        \"\"\"Run forward pass. This will be called by both functions <optimize_parameters> and <test>.\"\"\"\n        if self.do_upsampling:\n            input_uv_up = F.interpolate(self.input_uv, scale_factor=2, mode='bilinear')\n            crop_params = None\n            if self.isTrain:\n                # Take random crop to decrease memory requirement.\n                crop_params = self.gen_crop_params(*input_uv_up.shape[-2:])\n                starty, endy, startx, endx = crop_params\n                self.target_image = self.target_image[:, :, starty:endy, startx:endx]\n            outputs = self.netLNR.forward(self.input_uv, self.input_id, uv_map_upsampled=input_uv_up, crop_params=crop_params)\n        else:\n            outputs = self.netLNR(self.input_uv, self.input_id)\n        self.reconstruction = outputs['reconstruction'][:, :3]\n        self.alpha_composite = outputs['reconstruction'][:, 3]\n        self.output_rgba = outputs['layers']\n        n_layers = outputs['layers'].shape[1]\n        layers = outputs['layers'].clone()\n        layers[:, 0, -1] = 1  # Background layer's alpha is always 1\n        layers = torch.cat([layers[:, l] for l in range(n_layers)], -2)\n        self.alpha_vis = layers[:, 3:4]\n        self.rgba_vis = layers\n        self.mask_vis = torch.cat([self.mask[:, l:l+1] for l in range(n_layers)], -2)\n        self.input_vis = torch.cat([self.input_uv[:, 2*l:2*l+2] for l in range(n_layers)], -2)\n        self.input_vis = torch.cat([torch.zeros_like(self.input_vis[:, :1]), self.input_vis], 1)\n\n    def backward(self):\n        \"\"\"Calculate losses, gradients, and update network weights; called in every training iteration\"\"\"\n        self.loss_recon = self.criterionLoss(self.reconstruction[:, :3], self.target_image)\n        self.loss_total = self.loss_recon\n        if not self.do_upsampling:\n            self.loss_alpha_reg = networks.cal_alpha_reg(self.alpha_composite * .5 + .5, self.lambda_alpha_l1, self.lambda_alpha_l0)\n            alpha_layers = self.output_rgba[:, :, 3]\n            self.loss_mask = self.lambda_mask * self.criterionLossMask(alpha_layers, self.mask)\n            self.loss_total += self.loss_alpha_reg + self.loss_mask\n        else:\n            self.loss_mask = 0.\n            self.loss_alph_reg = 0.\n        self.loss_total.backward()\n\n    def optimize_parameters(self):\n        \"\"\"Update network weights; it will be called in every training iteration.\"\"\"\n        self.forward()\n        self.optimizer.zero_grad()\n        self.backward()\n        self.optimizer.step()\n\n    def update_lambdas(self, epoch):\n        \"\"\"Update loss weights based on current epochs and losses.\"\"\"\n        if epoch == self.opt.alpha_l1_rolloff_epoch:\n            self.lambda_alpha_l1 = 0\n        if self.mask_loss_rolloff_epoch >= 0:\n            if epoch == 2*self.mask_loss_rolloff_epoch:\n                self.lambda_mask = 0\n        elif epoch > self.opt.epoch_count:\n            if self.loss_mask < self.opt.mask_thresh * self.opt.lambda_mask:\n                self.mask_loss_rolloff_epoch = epoch\n                self.lambda_mask *= .1\n        if epoch == self.opt.jitter_epochs:\n            self.jitter_rgb = 0\n\n    def transfer_detail(self):\n        \"\"\"Transfer detail to layers.\"\"\"\n        residual = self.target_image - self.reconstruction\n        transmission_comp = torch.zeros_like(self.target_image[:, 0:1])\n        rgba_detail = self.output_rgba\n        n_layers = self.output_rgba.shape[1]\n        for i in range(n_layers - 1, 0, -1):  # Don't do detail transfer for background layer, due to ghosting effects.\n            transmission_i = 1. - transmission_comp\n            rgba_detail[:, i, :3] += transmission_i * residual\n            alpha_i = self.output_rgba[:, i, 3:4] * .5 + .5\n            transmission_comp = alpha_i + (1. - alpha_i) * transmission_comp\n        self.rgba = torch.clamp(rgba_detail, -1, 1)\n\n    def get_results(self):\n        \"\"\"Return results. This is different from get_current_visuals, which gets visuals for monitoring training.\n\n        Returns a dictionary:\n            original - - original frame\n            recon - - reconstruction\n            rgba_l* - - RGBA for each layer\n            mask_l* - - mask for each layer\n        \"\"\"\n        self.transfer_detail()\n        # Split layers\n        results = {'reconstruction': self.reconstruction, 'original': self.target_image}\n        n_layers = self.rgba.shape[1]\n        for i in range(n_layers):\n            results[f'mask_l{i}'] = self.mask[:, i:i+1]\n            results[f'rgba_l{i}'] = self.rgba[:, i]\n            if i == 0:\n                results[f'rgba_l{i}'][:, -1:] = 1.\n        return results\n\n    def freeze_basenet(self):\n        \"\"\"Freeze all parameters except for the upsampling module.\"\"\"\n        net = self.netLNR\n        if isinstance(net, torch.nn.DataParallel):\n            net = net.module\n        self.set_requires_grad([net.encoder, net.decoder, net.final_rgba], False)\n        net.texture.requires_grad = False"
  },
  {
    "path": "models/networks.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom third_party.models.networks import init_net\n\n\n###############################################################################\n# Helper Functions\n###############################################################################\ndef define_LNR(nf=64, texture_channels=16, texture_res=16, n_textures=25, gpu_ids=[]):\n    \"\"\"Create a layered neural renderer.\n\n    Parameters:\n        nf (int) -- the number of channels in the first/last conv layers\n        texture_channels (int) -- the number of channels in the neural texture\n        texture_res (int) -- the size of each individual texture map\n        n_textures (int) -- the number of individual texture maps\n        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2\n\n    Returns a layered neural rendering model.\n    \"\"\"\n    net = LayeredNeuralRenderer(nf, texture_channels, texture_res, n_textures)\n    return init_net(net, gpu_ids)\n\n\ndef define_kp2uv(nf=64, gpu_ids=[]):\n    \"\"\"Create a keypoint-to-UV model.\n\n    Parameters:\n        nf (int) -- the number of channels in the first/last conv layers\n\n    Returns a keypoint-to-UV model.\n    \"\"\"\n    net = kp2uv(nf)\n    return init_net(net, gpu_ids)\n\n\ndef cal_alpha_reg(prediction, lambda_alpha_l1, lambda_alpha_l0):\n    \"\"\"Calculate the alpha regularization term.\n\n    Parameters:\n        prediction (tensor) - - composite of predicted alpha layers\n        lambda_alpha_l1 (float) - - weight for the L1 regularization term\n        lambda_alpha_l0 (float) - - weight for the L0 regularization term\n    Returns the alpha regularization loss term\n    \"\"\"\n    assert prediction.max() <= 1.\n    assert prediction.min() >= 0.\n    loss = 0.\n    if lambda_alpha_l1 > 0:\n        loss += lambda_alpha_l1 * torch.mean(prediction)\n    if lambda_alpha_l0 > 0:\n        # Pseudo L0 loss using a squished sigmoid curve.\n        l0_prediction = (torch.sigmoid(prediction * 5.0) - 0.5) * 2.0\n        loss += lambda_alpha_l0 * torch.mean(l0_prediction)\n    return loss\n\n\n##############################################################################\n# Classes\n##############################################################################\nclass MaskLoss(nn.Module):\n    \"\"\"Define the loss which encourages the predicted alpha matte to match the mask (trimap).\"\"\"\n\n    def __init__(self):\n        super(MaskLoss, self).__init__()\n        self.loss = nn.L1Loss(reduction='none')\n\n    def __call__(self, prediction, target):\n        \"\"\"Calculate loss given predicted alpha matte and trimap.\n\n        Balance positive and negative regions. Exclude 'unknown' region from loss.\n\n        Parameters:\n            prediction (tensor) - - predicted alpha\n            target (tensor) - - trimap\n\n        Returns: the computed loss\n        \"\"\"\n        mask_err = self.loss(prediction, target)\n        pos_mask = F.relu(target)\n        neg_mask = F.relu(-target)\n        pos_mask_loss = (pos_mask * mask_err).sum() / (1 + pos_mask.sum())\n        neg_mask_loss = (neg_mask * mask_err).sum() / (1 + neg_mask.sum())\n        loss = .5 * (pos_mask_loss + neg_mask_loss)\n        return loss\n\n\nclass ConvBlock(nn.Module):\n    \"\"\"Helper module consisting of a convolution, optional normalization and activation, with padding='same'.\"\"\"\n\n    def __init__(self, conv, in_channels, out_channels, ksize=4, stride=1, dil=1, norm=None, activation='relu'):\n        \"\"\"Create a conv block.\n\n        Parameters:\n            conv (convolutional layer) - - the type of conv layer, e.g. Conv2d, ConvTranspose2d\n            in_channels (int) - - the number of input channels\n            in_channels (int) - - the number of output channels\n            ksize (int) - - the kernel size\n            stride (int) - - stride\n            dil (int) - - dilation\n            norm (norm layer) - - the type of normalization layer, e.g. BatchNorm2d, InstanceNorm2d\n            activation (str)  -- the type of activation: relu | leaky | tanh | none\n        \"\"\"\n        super(ConvBlock, self).__init__()\n        self.k = ksize\n        self.s = stride\n        self.d = dil\n        self.conv = conv(in_channels, out_channels, ksize, stride=stride, dilation=dil)\n\n        if norm is not None:\n            self.norm = norm(out_channels)\n        else:\n            self.norm = None\n\n        if activation == 'leaky':\n            self.activation = nn.LeakyReLU(0.2)\n        elif activation == 'relu':\n            self.activation = nn.ReLU()\n        elif activation == 'tanh':\n            self.activation = nn.Tanh()\n        else:\n            self.activation = None\n\n    def forward(self, x):\n        \"\"\"Forward pass. Compute necessary padding and cropping because pytorch doesn't have pad=same.\"\"\"\n        height, width = x.shape[-2:]\n        if isinstance(self.conv, nn.modules.ConvTranspose2d):\n            desired_height = height * self.s\n            desired_width = width * self.s\n            pady = 0\n            padx = 0\n        else:\n            # o = [i + 2*p - k - (k-1)*(d-1)]/s + 1\n            # padding = .5 * (stride * (output-1) + (k-1)(d-1) + k - input)\n            desired_height = height // self.s\n            desired_width = width // self.s\n            pady = .5 * (self.s * (desired_height - 1) + (self.k - 1) * (self.d - 1) + self.k - height)\n            padx = .5 * (self.s * (desired_width - 1) + (self.k - 1) * (self.d - 1) + self.k - width)\n        x = F.pad(x, [int(np.floor(padx)), int(np.ceil(padx)), int(np.floor(pady)), int(np.ceil(pady))])\n        x = self.conv(x)\n        if x.shape[-2] != desired_height or x.shape[-1] != desired_width:\n            cropy = x.shape[-2] - desired_height\n            cropx = x.shape[-1] - desired_width\n            x = x[:, :, int(np.floor(cropy / 2.)):-int(np.ceil(cropy / 2.)),\n                int(np.floor(cropx / 2.)):-int(np.ceil(cropx / 2.))]\n        if self.norm:\n            x = self.norm(x)\n        if self.activation:\n            x = self.activation(x)\n        return x\n\n\nclass ResBlock(nn.Module):\n    \"\"\"Define a residual block.\"\"\"\n\n    def __init__(self, channels, ksize=4, stride=1, dil=1, norm=None, activation='relu'):\n        \"\"\"Initialize the residual block, which consists of 2 conv blocks with a skip connection.\"\"\"\n        super(ResBlock, self).__init__()\n        self.convblock1 = ConvBlock(nn.Conv2d, channels, channels, ksize=ksize, stride=stride, dil=dil, norm=norm,\n                                    activation=activation)\n        self.convblock2 = ConvBlock(nn.Conv2d, channels, channels, ksize=ksize, stride=stride, dil=dil, norm=norm,\n                                    activation=None)\n\n    def forward(self, x):\n        identity = x\n        x = self.convblock1(x)\n        x = self.convblock2(x)\n        x += identity\n        return x\n\n\nclass kp2uv(nn.Module):\n    \"\"\"UNet architecture for converting keypoint image to UV map.\n\n    Same person UV map format as described in https://arxiv.org/pdf/1802.00434.pdf.\n    \"\"\"\n\n    def __init__(self, nf=64):\n        super(kp2uv, self).__init__(),\n        self.encoder = nn.ModuleList([\n            ConvBlock(nn.Conv2d, 3, nf, ksize=4, stride=2),\n            ConvBlock(nn.Conv2d, nf, nf * 2, ksize=4, stride=2, norm=nn.InstanceNorm2d, activation='leaky'),\n            ConvBlock(nn.Conv2d, nf * 2, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d, activation='leaky'),\n            ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d, activation='leaky'),\n            ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d, activation='leaky'),\n            ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=3, stride=1, norm=nn.InstanceNorm2d, activation='leaky'),\n            ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=3, stride=1, norm=nn.InstanceNorm2d, activation='leaky')])\n\n        self.decoder = nn.ModuleList([\n            ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d),\n            ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 4, ksize=4, stride=2, norm=nn.InstanceNorm2d),\n            ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 2, ksize=4, stride=2, norm=nn.InstanceNorm2d),\n            ConvBlock(nn.ConvTranspose2d, nf * 2 * 2, nf, ksize=4, stride=2, norm=nn.InstanceNorm2d),\n            ConvBlock(nn.ConvTranspose2d, nf * 2, nf, ksize=4, stride=2, norm=nn.InstanceNorm2d)])\n\n        # head to predict body part class (25 classes - 24 body parts, 1 background.)\n        self.id_pred = ConvBlock(nn.Conv2d, nf + 3, 25, ksize=3, stride=1, activation='none')\n        # head to predict UV coordinates for every body part class\n        self.uv_pred = ConvBlock(nn.Conv2d, nf + 3, 2 * 24, ksize=3, stride=1, activation='tanh')\n\n    def forward(self, x):\n        \"\"\"Forward pass through UNet, handling skip connections.\n        Parameters:\n            x (tensor) - - rendered keypoint image, shape [B, 3, H, W]\n\n        Returns:\n            x_id (tensor): part id class probabilities\n            x_uv (tensor): uv coordinates for each part id\n        \"\"\"\n        skips = [x]\n        for i, layer in enumerate(self.encoder):\n            x = layer(x)\n            if i < 5:\n                skips.append(x)\n        for layer in self.decoder:\n            x = torch.cat((x, skips.pop()), 1)\n            x = layer(x)\n        x = torch.cat((x, skips.pop()), 1)\n        x_id = self.id_pred(x)\n        x_uv = self.uv_pred(x)\n        return x_id, x_uv\n\n\nclass LayeredNeuralRenderer(nn.Module):\n    \"\"\"Layered Neural Rendering model for video decomposition.\n\n    Consists of neural texture, UNet, upsampling module.\n    \"\"\"\n\n    def __init__(self, nf=64, texture_channels=16, texture_res=16, n_textures=25):\n        super(LayeredNeuralRenderer, self).__init__(),\n        \"\"\"Initialize layered neural renderer.\n\n        Parameters:\n            nf (int) -- the number of channels in the first/last conv layers\n            texture_channels (int) -- the number of channels in the neural texture\n            texture_res (int) -- the size of each individual texture map\n            n_textures (int) -- the number of individual texture maps\n        \"\"\"\n        # Neural texture is implemented as 'n_textures' concatenated horizontally\n        self.texture = nn.Parameter(torch.randn(1, texture_channels, texture_res, n_textures * texture_res))\n\n        # Define UNet\n        self.encoder = nn.ModuleList([\n            ConvBlock(nn.Conv2d, texture_channels + 1, nf, ksize=4, stride=2),\n            ConvBlock(nn.Conv2d, nf, nf * 2, ksize=4, stride=2, norm=nn.BatchNorm2d, activation='leaky'),\n            ConvBlock(nn.Conv2d, nf * 2, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d, activation='leaky'),\n            ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d, activation='leaky'),\n            ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d, activation='leaky'),\n            ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=1, dil=2, norm=nn.BatchNorm2d, activation='leaky'),\n            ConvBlock(nn.Conv2d, nf * 4, nf * 4, ksize=4, stride=1, dil=2, norm=nn.BatchNorm2d, activation='leaky')])\n        self.decoder = nn.ModuleList([\n            ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d),\n            ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 4, ksize=4, stride=2, norm=nn.BatchNorm2d),\n            ConvBlock(nn.ConvTranspose2d, nf * 4 * 2, nf * 2, ksize=4, stride=2, norm=nn.BatchNorm2d),\n            ConvBlock(nn.ConvTranspose2d, nf * 2 * 2, nf, ksize=4, stride=2, norm=nn.BatchNorm2d),\n            ConvBlock(nn.ConvTranspose2d, nf * 2, nf, ksize=4, stride=2, norm=nn.BatchNorm2d)])\n        self.final_rgba = ConvBlock(nn.Conv2d, nf, 4, ksize=4, stride=1, activation='tanh')\n\n        # Define upsampling block, which outputs a residual\n        upsampling_ic = texture_channels + 5 + nf\n        self.upsample_block = nn.Sequential(\n            ConvBlock(nn.Conv2d, upsampling_ic, nf, ksize=3, stride=1, norm=nn.InstanceNorm2d),\n            ResBlock(nf, ksize=3, stride=1, norm=nn.InstanceNorm2d),\n            ResBlock(nf, ksize=3, stride=1, norm=nn.InstanceNorm2d),\n            ResBlock(nf, ksize=3, stride=1, norm=nn.InstanceNorm2d),\n            ConvBlock(nn.Conv2d, nf, 4, ksize=3, stride=1, activation='none'))\n\n    def render(self, x):\n        \"\"\"Pass inputs for a single layer through UNet.\n\n        Parameters:\n            x (tensor) - - sampled texture concatenated with person IDs\n\n        Returns RGBA for the input layer and the final feature maps.\n        \"\"\"\n        skips = [x]\n        for i, layer in enumerate(self.encoder):\n            x = layer(x)\n            if i < 5:\n                skips.append(x)\n        for layer in self.decoder:\n            x = torch.cat((x, skips.pop()), 1)\n            x = layer(x)\n        rgba = self.final_rgba(x)\n        return rgba, x\n\n    def forward(self, uv_map, id_layers, uv_map_upsampled=None, crop_params=None):\n        \"\"\"Forward pass through layered neural renderer.\n\n        Steps:\n        1. Sample from the neural texture using uv_map\n        2. Input uv_map and id_layers into UNet\n            2a. If doing upsampling, then pass upsampled inputs and results through upsampling module\n        3. Composite RGBA outputs.\n\n        Parameters:\n            uv_map (tensor) - - UV maps for all layers, with shape [B, (2*L), H, W]\n            id_layers (tensor) - - person ID for all layers, with shape [B, L, H, W]\n            uv_map_upsampled (tensor) - - upsampled UV maps to input to upsampling module (if None, skip upsampling)\n            crop_params\n        \"\"\"\n        b_sz = uv_map.shape[0]\n        n_layers = uv_map.shape[1] // 2\n        texture = self.texture.repeat(b_sz, 1, 1, 1)\n        composite = None\n        layers = []\n        sampled_textures = []\n        for i in range(n_layers):\n            # Get RGBA for this layer.\n            uv_map_i = uv_map[:, i * 2:(i + 1) * 2, ...]\n            uv_map_i = uv_map_i.permute(0, 2, 3, 1)\n            sampled_texture = F.grid_sample(texture, uv_map_i, mode='bilinear', padding_mode='zeros')\n            inputs = torch.cat([sampled_texture, id_layers[:, i:i + 1]], 1)\n            rgba, last_feat = self.render(inputs)\n\n            if uv_map_upsampled is not None:\n                uv_map_up_i = uv_map_upsampled[:, i * 2:(i + 1) * 2, ...]\n                uv_map_up_i = uv_map_up_i.permute(0, 2, 3, 1)\n                sampled_texture_up = F.grid_sample(texture, uv_map_up_i, mode='bilinear', padding_mode='zeros')\n                id_layers_up = F.interpolate(id_layers[:, i:i + 1], size=sampled_texture_up.shape[-2:],\n                                             mode='bilinear')\n                inputs_up = torch.cat([sampled_texture_up, id_layers_up], 1)\n                upsampled_size = inputs_up.shape[-2:]\n                rgba = F.interpolate(rgba, size=upsampled_size, mode='bilinear')\n                last_feat = F.interpolate(last_feat, size=upsampled_size, mode='bilinear')\n                if crop_params is not None:\n                    starty, endy, startx, endx = crop_params\n                    rgba = rgba[:, :, starty:endy, startx:endx]\n                    last_feat = last_feat[:, :, starty:endy, startx:endx]\n                    inputs_up = inputs_up[:, :, starty:endy, startx:endx]\n                rgba_residual = self.upsample_block(torch.cat((rgba, inputs_up, last_feat), 1))\n                rgba += .01 * rgba_residual\n                rgba = torch.clamp(rgba, -1, 1)\n                sampled_texture = sampled_texture_up\n\n            # Update the composite with this layer's RGBA output\n            if composite is None:\n                composite = rgba\n            else:\n                alpha = rgba[:, 3:4] * .5 + .5\n                composite = rgba * alpha + composite * (1.0 - alpha)\n            layers.append(rgba)\n            sampled_textures.append(sampled_texture)\n\n        outputs = {\n            'reconstruction': composite,\n            'layers': torch.stack(layers, 1),\n            'sampled texture': sampled_textures,  # for debugging\n        }\n        return outputs\n"
  },
  {
    "path": "options/__init__.py",
    "content": ""
  },
  {
    "path": "options/base_options.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport os\nfrom third_party.util import util\nfrom third_party import models\nfrom third_party import data\nimport torch\nimport json\n\n\nclass BaseOptions():\n    \"\"\"This class defines options used during both training and test time.\n\n    It also implements several helper functions such as parsing, printing, and saving the options.\n    It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.\n    \"\"\"\n\n    def __init__(self):\n        \"\"\"Reset the class; indicates the class hasn't been initialized\"\"\"\n        self.initialized = False\n\n    def initialize(self, parser):\n        \"\"\"Define the common options that are used in both training and test.\"\"\"\n        # basic parameters\n        parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders rgb_256, etc)')\n        parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n        parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')\n        parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n        parser.add_argument('--seed', type=int, default=1, help='initial random seed')\n        # model parameters\n        parser.add_argument('--model', type=str, default='lnr', help='chooses which model to use. [lnr | kp2uv]')\n        parser.add_argument('--num_filters', type=int, default=64, help='# filters in the first and last conv layers')\n        # dataset parameters\n        parser.add_argument('--dataset_mode', type=str, default='layered_video', help='chooses how datasets are loaded. [layered_video | kpuv]')\n        parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n        parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')\n        parser.add_argument('--batch_size', type=int, default=32, help='input batch size')\n        parser.add_argument('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n        parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n        # additional parameters\n        parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n        parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')\n        parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n        self.initialized = True\n        return parser\n\n    def gather_options(self):\n        \"\"\"Initialize our parser with basic options(only once).\n        Add additional model-specific and dataset-specific options.\n        These options are defined in the <modify_commandline_options> function\n        in model and dataset classes.\n        \"\"\"\n        if not self.initialized:  # check if it has been initialized\n            parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n            parser = self.initialize(parser)\n\n        # get the basic options\n        opt, _ = parser.parse_known_args()\n\n        # modify model-related parser options\n        model_name = opt.model\n        model_option_setter = models.get_option_setter(model_name)\n        parser = model_option_setter(parser, self.isTrain)\n        opt, _ = parser.parse_known_args()  # parse again with new defaults\n\n        # modify dataset-related parser options\n        dataset_name = opt.dataset_mode\n        dataset_option_setter = data.get_option_setter(dataset_name)\n        parser = dataset_option_setter(parser, self.isTrain)\n\n        # save and return the parser\n        self.parser = parser\n        return parser.parse_args()\n\n    def print_options(self, opt):\n        \"\"\"Print and save options\n\n        It will print both current options and default values(if different).\n        It will save options into a text file / [checkpoints_dir] / opt.txt\n        \"\"\"\n        message = ''\n        message += '----------------- Options ---------------\\n'\n        for k, v in sorted(vars(opt).items()):\n            comment = ''\n            default = self.parser.get_default(k)\n            if v != default:\n                comment = '\\t[default: %s]' % str(default)\n            message += '{:>25}: {:<30}{}\\n'.format(str(k), str(v), comment)\n        message += '----------------- End -------------------'\n        print(message)\n\n        # save to the disk\n        expr_dir = os.path.join(opt.checkpoints_dir, opt.name)\n        util.mkdirs(expr_dir)\n        file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))\n        with open(file_name, 'wt') as opt_file:\n            opt_file.write(message)\n            opt_file.write('\\n')\n\n    def parse(self):\n        \"\"\"Parse our options, create checkpoints directory suffix, and set up gpu device.\"\"\"\n        opt = self.gather_options()\n        opt.isTrain = self.isTrain   # train or test\n\n        # process opt.suffix\n        if opt.suffix:\n            suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''\n            opt.name = opt.name + suffix\n\n        self.print_options(opt)\n\n        # set gpu ids\n        str_ids = opt.gpu_ids.split(',')\n        opt.gpu_ids = []\n        for str_id in str_ids:\n            id = int(str_id)\n            if id >= 0:\n                opt.gpu_ids.append(id)\n        if len(opt.gpu_ids) > 0:\n            torch.cuda.set_device(opt.gpu_ids[0])\n\n        self.opt = opt\n        return self.opt\n\n    def parse_dataset_meta(self):\n        \"\"\"Parse options from the 'metadata.json' file in the dataroot.\"\"\"\n        with open(os.path.join(self.opt.dataroot, 'metadata.json')) as f:\n            metadata = json.load(f)\n        self.opt.n_textures = metadata['n_textures']\n        self.opt.width = metadata['size_LR'][0]\n        self.opt.height = metadata['size_LR'][1]\n        return self.opt"
  },
  {
    "path": "options/test_options.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base_options import BaseOptions\n\n\nclass TestOptions(BaseOptions):\n    \"\"\"This class includes test options.\n\n    It also includes shared options defined in BaseOptions.\n    \"\"\"\n\n    def initialize(self, parser):\n        parser = BaseOptions.initialize(self, parser)  # define shared options\n        parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')\n        parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n        parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')\n        parser.add_argument('--num_test', type=int, default=float(\"inf\"), help='how many test images to run')\n        self.isTrain = False\n        return parser\n"
  },
  {
    "path": "options/train_options.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base_options import BaseOptions\n\n\nclass TrainOptions(BaseOptions):\n    \"\"\"This class includes training options.\n\n    It also includes shared options defined in BaseOptions.\n    \"\"\"\n\n    def initialize(self, parser):\n        parser = BaseOptions.initialize(self, parser)\n        # visdom and HTML visualization parameters\n        parser.add_argument('--display_freq', type=int, default=20, help='frequency of showing training results on screen (in epochs)')\n        parser.add_argument('--display_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n        parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')\n        parser.add_argument('--display_server', type=str, default=\"http://localhost\", help='visdom server of the web display')\n        parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n        parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')\n        parser.add_argument('--update_html_freq', type=int, default=50, help='frequency of saving training results to html')\n        parser.add_argument('--print_freq', type=int, default=10, help='frequency of showing training results on console (in steps per epoch)')\n        parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n        # network saving and loading parameters\n        parser.add_argument('--save_latest_freq', type=int, default=50, help='frequency of saving the latest results (in epochs)')\n        parser.add_argument('--save_by_epoch', action='store_true', help='whether saves model as \"epoch\" or \"latest\" (overwrites previous)')\n        parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')\n        parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')\n        parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')\n        # training parameters\n        parser.add_argument('--n_epochs', type=int, default=2000, help='number of epochs with the initial learning rate')\n        parser.add_argument('--n_epochs_decay', type=int, default=0, help='number of epochs to linearly decay learning rate to zero')\n        parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam')\n        parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')\n\n        self.isTrain = True\n        return parser\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch==1.4.0\ntorchvision>=0.5.0\ndominate>=2.4.0\nvisdom>=0.1.8\nmatplotlib>=3.2.1\nopencv-python>=4.2.0\n"
  },
  {
    "path": "run_kp2uv.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Script for running the keypoint-to-UV network and saving the predicted UVs.\n\nExample:\n    python run_kp2uv.py --model kp2uv --dataroot ./datasets/reflection --results_dir ./datasets/reflection\n\nIt will load the pre-trained model from '--checkpoints_dir' and save the results to '--results_dir'.\n\nSee options/base_options.py and options/test_options.py for more test options.\n\"\"\"\nimport os\nfrom options.test_options import TestOptions\nfrom third_party.data import create_dataset\nfrom third_party.models import create_model\nfrom third_party.util.visualizer import save_images\nfrom third_party.util import html\n\n\nif __name__ == '__main__':\n    opt = TestOptions().parse()  # get test options\n    # hard-code some parameters\n    opt.name = 'kp2uv'\n    opt.num_threads = 0   # test code only supports num_threads = 0\n    opt.batch_size = 1    # test code only supports batch_size = 1\n    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.\n    opt.display_id = -1   # no visdom display; the test code saves the results to a HTML file.\n    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options\n    model = create_model(opt)      # create a model given opt.model and other options\n    model.setup(opt)               # regular setup: load and print networks; create schedulers\n    # create a website\n    web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch))  # define the website directory\n    print('creating web directory', web_dir)\n    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))\n    for i, data in enumerate(dataset):\n        if i >= opt.num_test:  # only apply our model to opt.num_test images.\n            break\n        model.set_input(data)  # unpack data from data loader\n        model.test()           # run inference\n        visuals = model.get_current_visuals()  # get image results\n        img_path = model.get_image_paths()     # get image paths\n        if i % 5 == 0:\n            print('processing (%04d)-th image... %s' % (i, img_path))\n        save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)\n    webpage.save()  # save the HTML\n"
  },
  {
    "path": "scripts/download_kp2uv_model.sh",
    "content": "#!/bin/bash\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nmkdir -p ./checkpoints/kp2uv\nMODEL_FILE=./checkpoints/kp2uv/latest_net_Kp2uv.pth\nURL=https://www.robots.ox.ac.uk/~erika/retiming/pretrained_models/kp2uv.pth\nwget -N $URL -O $MODEL_FILE\n"
  },
  {
    "path": "scripts/run_cartwheel.sh",
    "content": "#!/bin/bash\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nGPUS=0,1\nDATA_PATH=./datasets/cartwheel\nbash datasets/prepare_iuv.sh $DATA_PATH\npython train.py \\\n  --name cartwheel \\\n  --dataroot $DATA_PATH \\\n  --use_homographies \\\n  --gpu_ids $GPUS\npython test.py \\\n  --name cartwheel \\\n  --dataroot $DATA_PATH \\\n  --do_upsampling \\\n  --use_homographies"
  },
  {
    "path": "scripts/run_reflection.sh",
    "content": "#!/bin/bash\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nGPUS=0,1\nDATA_PATH=./datasets/reflection\nbash datasets/prepare_iuv.sh $DATA_PATH\npython train.py \\\n  --name reflection \\\n  --dataroot $DATA_PATH \\\n  --gpu_ids $GPUS\npython test.py \\\n  --name reflection \\\n  --dataroot $DATA_PATH \\\n  --do_upsampling"
  },
  {
    "path": "scripts/run_splash.sh",
    "content": "#!/bin/bash\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nGPUS=0,1\nDATA_PATH=./datasets/splash\nbash datasets/prepare_iuv.sh $DATA_PATH\npython train.py \\\n  --name splash \\\n  --dataroot $DATA_PATH \\\n  --batch_size 24 \\\n  --batch_size_upsample 12 \\\n  --use_mask_images \\\n  --gpu_ids $GPUS\npython test.py \\\n  --name splash \\\n  --dataroot $DATA_PATH \\\n  --do_upsampling"
  },
  {
    "path": "scripts/run_trampoline.sh",
    "content": "#!/bin/bash\n#\n# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nGPUS=0,1\nDATA_PATH=./datasets/trampoline\nbash datasets/prepare_iuv.sh $DATA_PATH\npython train.py \\\n  --name trampoline \\\n  --dataroot $DATA_PATH \\\n  --batch_size 16 \\\n  --batch_size_upsample 6 \\\n  --gpu_ids $GPUS\npython test.py \\\n  --name trampoline \\\n  --dataroot $DATA_PATH \\\n  --do_upsampling"
  },
  {
    "path": "test.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Script to save the full outputs of a layered neural renderer (LNR).\n\nOnce you have trained the LNR with train.py, you can use this script to save the model's final layer decomposition.\nIt will load a saved model from '--checkpoints_dir' and save the results to '--results_dir'.\n\nIt first creates a model and dataset given the options. It will hard-code some parameters.\nIt then runs inference for '--num_test' images and save results to an HTML file.\n\nExample (You need to train models first or download pre-trained models from our website):\n    python test.py --dataroot ./datasets/reflection --name reflection --do_upsampling\n\n    If the upsampling module isn't trained (train.py is used with '--n_epochs_upsample 0'), remove --do_upsampling.\n    Use '--results_dir <directory_path_to_save_result>' to specify the results directory.\n\nSee options/base_options.py and options/test_options.py for more test options.\n\"\"\"\nimport os\nfrom options.test_options import TestOptions\nfrom third_party.data import create_dataset\nfrom third_party.models import create_model\nfrom third_party.util.visualizer import save_images, save_videos\nfrom third_party.util import html\nimport torch\n\n\nif __name__ == '__main__':\n    testopt = TestOptions()\n    testopt.parse()\n    opt = testopt.parse_dataset_meta()\n    # hard-code some parameters for test\n    opt.num_threads = 0   # test code only supports num_threads = 0\n    opt.batch_size = 1    # test code only supports batch_size = 1\n    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.\n    opt.display_id = -1   # no visdom display; the test code saves the results to a HTML file.\n    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options\n    model = create_model(opt)      # create a model given opt.model and other options\n    model.setup(opt)               # regular setup: load and print networks; create schedulers\n    # create a website\n    web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch))  # define the website directory\n    print('creating web directory', web_dir)\n    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))\n    video_visuals = None\n    for i, data in enumerate(dataset):\n        if i >= opt.num_test:  # only apply our model to opt.num_test images.\n            break\n        model.set_input(data)  # unpack data from data loader\n        model.test()           # run inference\n        img_path = model.get_image_paths()     # get image paths\n        if i % 5 == 0:  # save images to an HTML file\n            print('processing (%04d)-th image... %s' % (i, img_path))\n        visuals = model.get_results()  # rgba, reconstruction, original, mask\n        if video_visuals is None:\n            video_visuals = visuals\n        else:\n            for k in video_visuals:\n                video_visuals[k] = torch.cat((video_visuals[k], visuals[k]))\n        rgba = { k: visuals[k] for k in visuals if 'rgba' in k }\n        # save RGBA layers\n        save_images(webpage, rgba, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)\n    save_videos(webpage, video_visuals, width=opt.display_winsize)\n    webpage.save()  # save the HTML of videos"
  },
  {
    "path": "third_party/__init__.py",
    "content": ""
  },
  {
    "path": "third_party/data/__init__.py",
    "content": "\"\"\"This package includes all the modules related to data loading and preprocessing\n\n To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.\n You need to implement four functions:\n    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).\n    -- <__len__>:                       return the size of dataset.\n    -- <__getitem__>:                   get a data point from data loader.\n    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.\n\nNow you can use the dataset class by specifying flag '--dataset_mode dummy'.\nSee our template dataset class 'template_dataset.py' for more details.\n\"\"\"\nimport importlib\nimport torch.utils.data\nfrom .base_dataset import BaseDataset\nfrom .fast_data_loader import FastDataLoader\n\n\ndef find_dataset_using_name(dataset_name):\n    \"\"\"Import the module \"data/[dataset_name]_dataset.py\".\n\n    In the file, the class called DatasetNameDataset() will\n    be instantiated. It has to be a subclass of BaseDataset,\n    and it is case-insensitive.\n    \"\"\"\n    dataset_filename = \"data.\" + dataset_name + \"_dataset\"\n    datasetlib = importlib.import_module(dataset_filename)\n\n    dataset = None\n    target_dataset_name = dataset_name.replace('_', '') + 'dataset'\n    for name, cls in datasetlib.__dict__.items():\n        if name.lower() == target_dataset_name.lower() \\\n           and issubclass(cls, BaseDataset):\n            dataset = cls\n\n    if dataset is None:\n        raise NotImplementedError(\"In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase.\" % (dataset_filename, target_dataset_name))\n\n    return dataset\n\n\ndef get_option_setter(dataset_name):\n    \"\"\"Return the static method <modify_commandline_options> of the dataset class.\"\"\"\n    dataset_class = find_dataset_using_name(dataset_name)\n    return dataset_class.modify_commandline_options\n\n\ndef create_dataset(opt, use_fast_loader=False):\n    \"\"\"Create a dataset given the option.\n\n    This function wraps the class CustomDatasetDataLoader.\n        This is the main interface between this package and 'train.py'/'test.py'\n\n    If use_fast_loader=False, use the default pytorch dataloader. Otherwise, use FastDatasetLoader.\n\n    Example:\n        >>> from data import create_dataset\n        >>> dataset = create_dataset(opt)\n    \"\"\"\n    data_loader = CustomDatasetDataLoader(opt, use_fast_loader=use_fast_loader)\n    dataset = data_loader.load_data()\n    return dataset\n\n\nclass CustomDatasetDataLoader():\n    \"\"\"Wrapper class of Dataset class that performs multi-threaded data loading\"\"\"\n\n    def __init__(self, opt, use_fast_loader=False):\n        \"\"\"Initialize this class\n\n        Step 1: create a dataset instance given the name [dataset_mode]\n        Step 2: create a multi-threaded data loader.\n\n        If use_fast_loader=False, use the default pytorch dataloader. Otherwise, use FastDatasetLoader.\n        \"\"\"\n        self.opt = opt\n        dataset_class = find_dataset_using_name(opt.dataset_mode)\n        self.dataset = dataset_class(opt)\n        print(\"dataset [%s] was created\" % type(self.dataset).__name__)\n        loader = torch.utils.data.DataLoader\n        if use_fast_loader:\n            loader = FastDataLoader\n        self.dataloader = loader(\n            self.dataset,\n            batch_size=opt.batch_size,\n            shuffle=not opt.serial_batches,\n            num_workers=int(opt.num_threads))\n\n    def load_data(self):\n        return self\n\n    def __len__(self):\n        \"\"\"Return the number of data in the dataset\"\"\"\n        return min(len(self.dataset), self.opt.max_dataset_size)\n\n    def __iter__(self):\n        \"\"\"Return a batch of data\"\"\"\n        for i, data in enumerate(self.dataloader):\n            if i * self.opt.batch_size >= self.opt.max_dataset_size:\n                break\n            yield data\n"
  },
  {
    "path": "third_party/data/base_dataset.py",
    "content": "\"\"\"This module implements an abstract base class (ABC) 'BaseDataset' for datasets.\n\nIt also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.\n\"\"\"\nimport random\nimport numpy as np\nimport torch.utils.data as data\nfrom PIL import Image\nimport torchvision.transforms as transforms\nfrom abc import ABC, abstractmethod\n\n\nclass BaseDataset(data.Dataset, ABC):\n    \"\"\"This class is an abstract base class (ABC) for datasets.\n\n    To create a subclass, you need to implement the following four functions:\n    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).\n    -- <__len__>:                       return the size of dataset.\n    -- <__getitem__>:                   get a data point.\n    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.\n    \"\"\"\n\n    def __init__(self, opt):\n        \"\"\"Initialize the class; save the options in the class\n\n        Parameters:\n            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions\n        \"\"\"\n        self.opt = opt\n        self.root = opt.dataroot\n\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        \"\"\"Add new dataset-specific options, and rewrite default values for existing options.\n\n        Parameters:\n            parser          -- original option parser\n            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.\n\n        Returns:\n            the modified parser.\n        \"\"\"\n        return parser\n\n    @abstractmethod\n    def __len__(self):\n        \"\"\"Return the total number of images in the dataset.\"\"\"\n        return 0\n\n    @abstractmethod\n    def __getitem__(self, index):\n        \"\"\"Return a data point and its metadata information.\n\n        Parameters:\n            index - - a random integer for data indexing\n\n        Returns:\n            a dictionary of data with their names. It ususally contains the data itself and its metadata information.\n        \"\"\"\n        pass\n\n\ndef get_params(opt, size):\n    w, h = size\n    new_h = h\n    new_w = w\n    if opt.preprocess == 'resize_and_crop':\n        new_h = new_w = opt.load_size\n    elif opt.preprocess == 'scale_width_and_crop':\n        new_w = opt.load_size\n        new_h = opt.load_size * h // w\n\n    x = random.randint(0, np.maximum(0, new_w - opt.crop_size))\n    y = random.randint(0, np.maximum(0, new_h - opt.crop_size))\n\n    flip = random.random() > 0.5\n\n    return {'crop_pos': (x, y), 'flip': flip}\n\n\ndef get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):\n    transform_list = []\n    if grayscale:\n        transform_list.append(transforms.Grayscale(1))\n    if 'resize' in opt.preprocess:\n        osize = [opt.load_size, opt.load_size]\n        transform_list.append(transforms.Resize(osize, method))\n    elif 'scale_width' in opt.preprocess:\n        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))\n\n    if 'crop' in opt.preprocess:\n        if params is None:\n            transform_list.append(transforms.RandomCrop(opt.crop_size))\n        else:\n            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))\n\n    if opt.preprocess == 'none':\n        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))\n\n    if not opt.no_flip:\n        if params is None:\n            transform_list.append(transforms.RandomHorizontalFlip())\n        elif params['flip']:\n            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))\n\n    if convert:\n        transform_list += [transforms.ToTensor()]\n        if grayscale:\n            transform_list += [transforms.Normalize((0.5,), (0.5,))]\n        else:\n            transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n    return transforms.Compose(transform_list)\n\n\ndef __make_power_2(img, base, method=Image.BICUBIC):\n    ow, oh = img.size\n    h = int(round(oh / base) * base)\n    w = int(round(ow / base) * base)\n    if h == oh and w == ow:\n        return img\n\n    __print_size_warning(ow, oh, w, h)\n    return img.resize((w, h), method)\n\n\ndef __scale_width(img, target_size, crop_size, method=Image.BICUBIC):\n    ow, oh = img.size\n    if ow == target_size and oh >= crop_size:\n        return img\n    w = target_size\n    h = int(max(target_size * oh / ow, crop_size))\n    return img.resize((w, h), method)\n\n\ndef __crop(img, pos, size):\n    ow, oh = img.size\n    x1, y1 = pos\n    tw = th = size\n    if (ow > tw or oh > th):\n        return img.crop((x1, y1, x1 + tw, y1 + th))\n    return img\n\n\ndef __flip(img, flip):\n    if flip:\n        return img.transpose(Image.FLIP_LEFT_RIGHT)\n    return img\n\n\ndef __print_size_warning(ow, oh, w, h):\n    \"\"\"Print warning information about image size(only print once)\"\"\"\n    if not hasattr(__print_size_warning, 'has_printed'):\n        print(\"The image size needs to be a multiple of 4. \"\n              \"The loaded image size was (%d, %d), so it was adjusted to \"\n              \"(%d, %d). This adjustment will be done to all images \"\n              \"whose sizes are not multiples of 4\" % (ow, oh, w, h))\n        __print_size_warning.has_printed = True\n"
  },
  {
    "path": "third_party/data/fast_data_loader.py",
    "content": "\"\"\" Fixes the issue where DataLoader is slow because processes aren't reused\nSee https://github.com/pytorch/pytorch/issues/15849\nWarning: overrides batch sampler.\n\"\"\"\nimport torch.utils.data\n\n\nclass _RepeatSampler(object):\n    \"\"\" Sampler that repeats forever.\n\n    Args:\n        sampler (Sampler)\n    \"\"\"\n\n    def __init__(self, sampler):\n        self.sampler = sampler\n\n    def __iter__(self):\n        while True:\n            yield from iter(self.sampler)\n\n\nclass FastDataLoader(torch.utils.data.dataloader.DataLoader):\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))\n        self.iterator = super().__iter__()\n\n    def __len__(self):\n        return len(self.batch_sampler.sampler)\n\n    def __iter__(self):\n        for i in range(len(self)):\n            yield next(self.iterator)"
  },
  {
    "path": "third_party/data/image_folder.py",
    "content": "\"\"\"A modified image folder class\n\nWe modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)\nso that this class can load images from both current directory and its subdirectories.\n\"\"\"\n\nimport torch.utils.data as data\n\nfrom PIL import Image\nimport os\n\nIMG_EXTENSIONS = [\n    '.jpg', '.JPG', '.jpeg', '.JPEG',\n    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',\n    '.tif', '.TIF', '.tiff', '.TIFF',\n]\n\n\ndef is_image_file(filename):\n    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)\n\n\ndef make_dataset(dir, max_dataset_size=float(\"inf\")):\n    images = []\n    assert os.path.isdir(dir), '%s is not a valid directory' % dir\n\n    for root, _, fnames in sorted(os.walk(dir)):\n        for fname in fnames:\n            if is_image_file(fname):\n                path = os.path.join(root, fname)\n                images.append(path)\n    images = sorted(images)\n    return images[:min(max_dataset_size, len(images))]\n\n\ndef default_loader(path):\n    return Image.open(path).convert('RGB')\n\n\nclass ImageFolder(data.Dataset):\n\n    def __init__(self, root, transform=None, return_paths=False,\n                 loader=default_loader):\n        imgs = make_dataset(root)\n        if len(imgs) == 0:\n            raise(RuntimeError(\"Found 0 images in: \" + root + \"\\n\"\n                               \"Supported image extensions are: \" + \",\".join(IMG_EXTENSIONS)))\n\n        self.root = root\n        self.imgs = imgs\n        self.transform = transform\n        self.return_paths = return_paths\n        self.loader = loader\n\n    def __getitem__(self, index):\n        path = self.imgs[index]\n        img = self.loader(path)\n        if self.transform is not None:\n            img = self.transform(img)\n        if self.return_paths:\n            return img, path\n        else:\n            return img\n\n    def __len__(self):\n        return len(self.imgs)\n"
  },
  {
    "path": "third_party/models/__init__.py",
    "content": "\"\"\"This package contains modules related to objective functions, optimizations, and network architectures.\n\nTo add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.\nYou need to implement the following five functions:\n    -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).\n    -- <set_input>:                     unpack data from dataset and apply preprocessing.\n    -- <forward>:                       produce intermediate results.\n    -- <optimize_parameters>:           calculate loss, gradients, and update network weights.\n    -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.\n\nIn the function <__init__>, you need to define four lists:\n    -- self.loss_names (str list):          specify the training losses that you want to plot and save.\n    -- self.model_names (str list):         define networks used in our training.\n    -- self.visual_names (str list):        specify the images that you want to display and save.\n    -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.\n\nNow you can use the model class by specifying flag '--model dummy'.\nSee our template model class 'template_model.py' for more details.\n\"\"\"\n\nimport importlib\nfrom .base_model import BaseModel\n\n\ndef find_model_using_name(model_name):\n    \"\"\"Import the module \"models/[model_name]_model.py\".\n\n    In the file, the class called DatasetNameModel() will\n    be instantiated. It has to be a subclass of BaseModel,\n    and it is case-insensitive.\n    \"\"\"\n    model_filename = \"models.\" + model_name + \"_model\"\n    modellib = importlib.import_module(model_filename)\n    model = None\n    target_model_name = model_name.replace('_', '') + 'model'\n    for name, cls in modellib.__dict__.items():\n        if name.lower() == target_model_name.lower() \\\n           and issubclass(cls, BaseModel):\n            model = cls\n\n    if model is None:\n        print(\"In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase.\" % (model_filename, target_model_name))\n        exit(0)\n\n    return model\n\n\ndef get_option_setter(model_name):\n    \"\"\"Return the static method <modify_commandline_options> of the model class.\"\"\"\n    model_class = find_model_using_name(model_name)\n    return model_class.modify_commandline_options\n\n\ndef create_model(opt):\n    \"\"\"Create a model given the option.\n\n    This function warps the class CustomDatasetDataLoader.\n    This is the main interface between this package and 'train.py'/'test.py'\n\n    Example:\n        >>> from models import create_model\n        >>> model = create_model(opt)\n    \"\"\"\n    model = find_model_using_name(opt.model)\n    instance = model(opt)\n    print(\"model [%s] was created\" % type(instance).__name__)\n    return instance\n"
  },
  {
    "path": "third_party/models/base_model.py",
    "content": "import os\nimport torch\nfrom collections import OrderedDict\nfrom abc import ABC, abstractmethod\nfrom . import networks\n\n\nclass BaseModel(ABC):\n    \"\"\"This class is an abstract base class (ABC) for models.\n    To create a subclass, you need to implement the following five functions:\n        -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).\n        -- <set_input>:                     unpack data from dataset and apply preprocessing.\n        -- <forward>:                       produce intermediate results.\n        -- <optimize_parameters>:           calculate losses, gradients, and update network weights.\n        -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.\n    \"\"\"\n\n    def __init__(self, opt):\n        \"\"\"Initialize the BaseModel class.\n\n        Parameters:\n            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions\n\n        When creating your custom class, you need to implement your own initialization.\n        In this function, you should first call <BaseModel.__init__(self, opt)>\n        Then, you need to define four lists:\n            -- self.loss_names (str list):          specify the training losses that you want to plot and save.\n            -- self.model_names (str list):         define networks used in our training.\n            -- self.visual_names (str list):        specify the images that you want to display and save.\n            -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.\n        \"\"\"\n        self.opt = opt\n        self.gpu_ids = opt.gpu_ids\n        self.isTrain = opt.isTrain\n        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')  # get device name: CPU or GPU\n        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)  # save all the checkpoints to save_dir\n        self.loss_names = []\n        self.model_names = []\n        self.visual_names = []\n        self.optimizers = []\n        self.image_paths = []\n        self.metric = 0  # used for learning rate policy 'plateau'\n\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        \"\"\"Add new model-specific options, and rewrite default values for existing options.\n\n        Parameters:\n            parser          -- original option parser\n            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.\n\n        Returns:\n            the modified parser.\n        \"\"\"\n        return parser\n\n    @abstractmethod\n    def set_input(self, input):\n        \"\"\"Unpack input data from the dataloader and perform necessary pre-processing steps.\n\n        Parameters:\n            input (dict): includes the data itself and its metadata information.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def forward(self):\n        \"\"\"Run forward pass; called by both functions <optimize_parameters> and <test>.\"\"\"\n        pass\n\n    @abstractmethod\n    def optimize_parameters(self):\n        \"\"\"Calculate losses, gradients, and update network weights; called in every training iteration\"\"\"\n        pass\n\n    def setup(self, opt):\n        \"\"\"Load and print networks; create schedulers\n\n        Parameters:\n            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions\n        \"\"\"\n        if self.isTrain:\n            self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]\n        if not self.isTrain or opt.continue_train:\n            load_suffix = opt.epoch\n            self.load_networks(load_suffix)\n        self.print_networks(opt.verbose)\n\n    def eval(self):\n        \"\"\"Make models eval mode during test time\"\"\"\n        for name in self.model_names:\n            if isinstance(name, str):\n                net = getattr(self, 'net' + name)\n                net.eval()\n\n    def test(self):\n        \"\"\"Forward function used in test time.\n\n        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop\n        It also calls <compute_visuals> to produce additional visualization results\n        \"\"\"\n        with torch.no_grad():\n            self.forward()\n            self.compute_visuals()\n\n    def compute_visuals(self):\n        \"\"\"Calculate additional output images for visdom and HTML visualization\"\"\"\n        pass\n\n    def get_image_paths(self):\n        \"\"\" Return image paths that are used to load current data\"\"\"\n        return self.image_paths\n\n    def update_learning_rate(self):\n        \"\"\"Update learning rates for all the networks; called at the end of every epoch\"\"\"\n        old_lr = self.optimizers[0].param_groups[0]['lr']\n        for scheduler in self.schedulers:\n            if self.opt.lr_policy == 'plateau':\n                scheduler.step(self.metric)\n            else:\n                scheduler.step()\n\n        lr = self.optimizers[0].param_groups[0]['lr']\n        if old_lr != lr:\n            print('learning rate %.7f -> %.7f' % (old_lr, lr))\n\n    def get_current_visuals(self):\n        \"\"\"Return visualization images. train.py will display these images with visdom, and save the images to a HTML\"\"\"\n        visual_ret = OrderedDict()\n        for name in self.visual_names:\n            if isinstance(name, str):\n                visual_ret[name] = getattr(self, name)\n        return visual_ret\n\n    def get_current_losses(self):\n        \"\"\"Return traning losses / errors. train.py will print out these errors on console, and save them to a file\"\"\"\n        errors_ret = OrderedDict()\n        for name in self.loss_names:\n            if isinstance(name, str):\n                errors_ret[name] = float(getattr(self, 'loss_' + name))  # float(...) works for both scalar tensor and float number\n        return errors_ret\n\n    def save_networks(self, epoch):\n        \"\"\"Save all the networks to the disk.\n\n        Parameters:\n            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)\n        \"\"\"\n        for name in self.model_names:\n            if isinstance(name, str):\n                save_filename = '%s_net_%s.pth' % (epoch, name)\n                save_path = os.path.join(self.save_dir, save_filename)\n                net = getattr(self, 'net' + name)\n\n                if len(self.gpu_ids) > 0 and torch.cuda.is_available():\n                    torch.save(net.module.cpu().state_dict(), save_path)\n                    net.cuda(self.gpu_ids[0])\n                else:\n                    torch.save(net.cpu().state_dict(), save_path)\n\n    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):\n        \"\"\"Fix InstanceNorm checkpoints incompatibility (prior to 0.4)\"\"\"\n        key = keys[i]\n        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer\n            if module.__class__.__name__.startswith('InstanceNorm') and \\\n                    (key == 'running_mean' or key == 'running_var'):\n                if getattr(module, key) is None:\n                    state_dict.pop('.'.join(keys))\n            if module.__class__.__name__.startswith('InstanceNorm') and \\\n               (key == 'num_batches_tracked'):\n                state_dict.pop('.'.join(keys))\n        else:\n            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)\n\n    def load_networks(self, epoch):\n        \"\"\"Load all the networks from the disk.\n\n        Parameters:\n            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)\n        \"\"\"\n        for name in self.model_names:\n            if isinstance(name, str):\n                load_filename = '%s_net_%s.pth' % (epoch, name)\n                load_path = os.path.join(self.save_dir, load_filename)\n                net = getattr(self, 'net' + name)\n                if isinstance(net, torch.nn.DataParallel):\n                    net = net.module\n                print('loading the model from %s' % load_path)\n                # if you are using PyTorch newer than 0.4 (e.g., built from\n                # GitHub source), you can remove str() on self.device\n                state_dict = torch.load(load_path, map_location=str(self.device))\n                if hasattr(state_dict, '_metadata'):\n                    del state_dict._metadata\n\n                # patch InstanceNorm checkpoints prior to 0.4\n                for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop\n                    self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))\n                net.load_state_dict(state_dict)\n\n    def print_networks(self, verbose):\n        \"\"\"Print the total number of parameters in the network and (if verbose) network architecture\n\n        Parameters:\n            verbose (bool) -- if verbose: print the network architecture\n        \"\"\"\n        print('---------- Networks initialized -------------')\n        for name in self.model_names:\n            if isinstance(name, str):\n                net = getattr(self, 'net' + name)\n                num_params = 0\n                num_trainable_params = 0\n                for param in net.parameters():\n                    num_params += param.numel()\n                    if param.requires_grad:\n                        num_trainable_params += param.numel()\n                if verbose:\n                    print(net)\n                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))\n                print('[Network %s] Total number of trainable parameters : %.3f M' % (name, num_trainable_params / 1e6))\n        print('-----------------------------------------------')\n\n    def set_requires_grad(self, nets, requires_grad=False):\n        \"\"\"Set requies_grad=Fasle for all the networks to avoid unnecessary computations\n        Parameters:\n            nets (network list)   -- a list of networks\n            requires_grad (bool)  -- whether the networks require gradients or not\n        \"\"\"\n        if not isinstance(nets, list):\n            nets = [nets]\n        for net in nets:\n            if net is not None:\n                for param in net.parameters():\n                    param.requires_grad = requires_grad\n"
  },
  {
    "path": "third_party/models/networks.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.optim import lr_scheduler\n\n\n###############################################################################\n# Helper Functions\n###############################################################################\ndef get_scheduler(optimizer, opt):\n    \"\"\"Return a learning rate scheduler\n\n    Parameters:\n        optimizer          -- the optimizer of the network\n        opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions．　\n                              opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine\n\n    For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs\n    and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.\n    For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.\n    See https://pytorch.org/docs/stable/optim.html for more details.\n    \"\"\"\n    if opt.lr_policy == 'linear':\n        def lambda_rule(epoch):\n            lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)\n            return lr_l\n\n        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)\n    elif opt.lr_policy == 'step':\n        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)\n    elif opt.lr_policy == 'plateau':\n        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)\n    elif opt.lr_policy == 'cosine':\n        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)\n    else:\n        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)\n    return scheduler\n\n\ndef init_net(net, gpu_ids=[]):\n    \"\"\"Initialize a network by registering CPU/GPU device (with multi-GPU support)\n    Parameters:\n        net (network)      -- the network to be initialized\n        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2\n\n    Return an initialized network.\n    \"\"\"\n    if len(gpu_ids) > 0:\n        assert (torch.cuda.is_available())\n        net.to(gpu_ids[0])\n        net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs\n    return net\n"
  },
  {
    "path": "third_party/util/__init__.py",
    "content": "\"\"\"This package includes a miscellaneous collection of useful helper functions.\"\"\"\n"
  },
  {
    "path": "third_party/util/html.py",
    "content": "import dominate\nfrom dominate.tags import meta, h3, table, tr, td, p, a, img, br, video, source\nimport os\n\n\nclass HTML:\n    \"\"\"This HTML class allows us to save images and write texts into a single HTML file.\n\n     It consists of functions such as <add_header> (add a text header to the HTML file),\n     <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).\n     It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.\n    \"\"\"\n\n    def __init__(self, web_dir, title, refresh=0):\n        \"\"\"Initialize the HTML classes\n\n        Parameters:\n            web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/\n            title (str)   -- the webpage name\n            refresh (int) -- how often the website refresh itself; if 0; no refreshing\n        \"\"\"\n        self.title = title\n        self.web_dir = web_dir\n        self.img_dir = os.path.join(self.web_dir, 'images')\n        self.vid_dir = os.path.join(self.web_dir, 'videos')\n        if not os.path.exists(self.web_dir):\n            os.makedirs(self.web_dir)\n        if not os.path.exists(self.img_dir):\n            os.makedirs(self.img_dir)\n        if not os.path.exists(self.vid_dir):\n            os.makedirs(self.vid_dir)\n\n        self.doc = dominate.document(title=title)\n        if refresh > 0:\n            with self.doc.head:\n                meta(http_equiv=\"refresh\", content=str(refresh))\n\n    def get_image_dir(self):\n        \"\"\"Return the directory that stores images\"\"\"\n        return self.img_dir\n\n    def get_video_dir(self):\n        \"\"\"Return the directory that stores videos\"\"\"\n        return self.vid_dir\n\n    def add_header(self, text):\n        \"\"\"Insert a header to the HTML file\n\n        Parameters:\n            text (str) -- the header text\n        \"\"\"\n        with self.doc:\n            h3(text)\n\n    def add_images(self, ims, txts, links, width=400):\n        \"\"\"add images to the HTML file\n\n        Parameters:\n            ims (str list)   -- a list of image paths\n            txts (str list)  -- a list of image names shown on the website\n            links (str list) --  a list of hyperref links; when you click an image, it will redirect you to a new page\n        \"\"\"\n        self.t = table(border=1, style=\"table-layout: fixed;\")  # Insert a table\n        self.doc.add(self.t)\n        with self.t:\n            with tr():\n                for im, txt, link in zip(ims, txts, links):\n                    with td(style=\"word-wrap: break-word;\", halign=\"center\", valign=\"top\"):\n                        with p():\n                            with a(href=os.path.join('images', link)):\n                                img(style=\"width:%dpx\" % width, src=os.path.join('images', im))\n                            br()\n                            p(txt)\n\n    def add_videos(self, vids, txts, links, width=400):\n        \"\"\"add images to the HTML file\n\n        Parameters:\n            ims (str list)   -- a list of image paths\n            txts (str list)  -- a list of image names shown on the website\n            links (str list) --  a list of hyperref links; when you click an image, it will redirect you to a new page\n        \"\"\"\n        self.t = table(border=1, style=\"table-layout: fixed;\")  # Insert a table\n        self.doc.add(self.t)\n        with self.t:\n            with tr():\n                for vid, txt, link in zip(vids, txts, links):\n                    with td(style=\"word-wrap: break-word;\", halign=\"center\", valign=\"top\"):\n                        with p():\n                            with a(href=os.path.join('videos', link)):\n                                with video(style=\"width:%dpx\" % width, controls=True):\n                                    source(src=os.path.join('videos', vid), type=\"video/mp4\")\n                            br()\n                            p(txt)\n\n    def save(self):\n        \"\"\"save the current content to the HMTL file\"\"\"\n        html_file = '%s/index.html' % self.web_dir\n        f = open(html_file, 'wt')\n        f.write(self.doc.render())\n        f.close()\n\n\nif __name__ == '__main__':  # we show an example usage here.\n    html = HTML('web/', 'test_html')\n    html.add_header('hello world')\n\n    ims, txts, links = [], [], []\n    for n in range(4):\n        ims.append('image_%d.png' % n)\n        txts.append('text_%d' % n)\n        links.append('image_%d.png' % n)\n    html.add_images(ims, txts, links)\n    html.save()\n"
  },
  {
    "path": "third_party/util/util.py",
    "content": "\"\"\"This module contains simple helper functions \"\"\"\nfrom __future__ import print_function\nimport torch\nimport numpy as np\nfrom PIL import Image\nimport os\n\n\ndef tensor2im(input_image, imtype=np.uint8):\n    \"\"\"\"Converts a Tensor array into a numpy image array.\n\n    Parameters:\n        input_image (tensor) --  the input image tensor array\n        imtype (type)        --  the desired type of the converted numpy array\n    \"\"\"\n    if not isinstance(input_image, np.ndarray):\n        if isinstance(input_image, torch.Tensor):  # get the data from a variable\n            image_tensor = input_image.data\n        else:\n            return input_image\n        image_numpy = image_tensor[0].cpu().float().numpy()  # convert it into a numpy array\n        if image_numpy.shape[0] == 1:  # grayscale to RGB\n            image_numpy = np.tile(image_numpy, (3, 1, 1))\n        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling\n    else:  # if it is a numpy array, do nothing\n        image_numpy = input_image\n    return image_numpy.astype(imtype)\n\n\ndef render_png(image, background='checker'):\n    height, width = image.shape[:2]\n    if background == 'checker':\n        checkerboard = np.kron([[136, 120] * (width//128+1), [120, 136] * (width//128+1)] * (height//128+1), np.ones((16, 16)))\n        checkerboard = np.expand_dims(np.tile(checkerboard, (4, 4)), -1)\n        bg = checkerboard[:height, :width]\n    elif background == 'black':\n        bg = np.zeros([height, width, 1])\n    else:\n        bg = 255 * np.ones([height, width, 1])\n    image = image.astype(np.float32)\n    alpha = image[:, :, 3:] / 255\n    rendered_image = alpha * image[:, :, :3] + (1 - alpha) * bg\n    return rendered_image.astype(np.uint8)\n\n\ndef diagnose_network(net, name='network'):\n    \"\"\"Calculate and print the mean of average absolute(gradients)\n\n    Parameters:\n        net (torch network) -- Torch network\n        name (str) -- the name of the network\n    \"\"\"\n    mean = 0.0\n    count = 0\n    for param in net.parameters():\n        if param.grad is not None:\n            mean += torch.mean(torch.abs(param.grad.data))\n            count += 1\n    if count > 0:\n        mean = mean / count\n    print(name)\n    print(mean)\n\n\ndef save_image(image_numpy, image_path, aspect_ratio=1.0):\n    \"\"\"Save a numpy image to the disk\n\n    Parameters:\n        image_numpy (numpy array) -- input numpy array\n        image_path (str)          -- the path of the image\n    \"\"\"\n\n    image_pil = Image.fromarray(image_numpy)\n    h, w, _ = image_numpy.shape\n\n    if aspect_ratio > 1.0:\n        image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)\n    if aspect_ratio < 1.0:\n        image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)\n    image_pil.save(image_path)\n\n\ndef print_numpy(x, val=True, shp=False):\n    \"\"\"Print the mean, min, max, median, std, and size of a numpy array\n\n    Parameters:\n        val (bool) -- if print the values of the numpy array\n        shp (bool) -- if print the shape of the numpy array\n    \"\"\"\n    x = x.astype(np.float64)\n    if shp:\n        print('shape,', x.shape)\n    if val:\n        x = x.flatten()\n        print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (\n            np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))\n\n\ndef mkdirs(paths):\n    \"\"\"create empty directories if they don't exist\n\n    Parameters:\n        paths (str list) -- a list of directory paths\n    \"\"\"\n    if isinstance(paths, list) and not isinstance(paths, str):\n        for path in paths:\n            mkdir(path)\n    else:\n        mkdir(paths)\n\n\ndef mkdir(path):\n    \"\"\"create a single empty directory if it didn't exist\n\n    Parameters:\n        path (str) -- a single directory path\n    \"\"\"\n    if not os.path.exists(path):\n        os.makedirs(path)\n"
  },
  {
    "path": "third_party/util/visualizer.py",
    "content": "import cv2\nimport numpy as np\nimport os\nimport sys\nimport ntpath\nimport time\nfrom . import util, html\nfrom subprocess import Popen, PIPE\n\n\nif sys.version_info[0] == 2:\n    VisdomExceptionBase = Exception\nelse:\n    VisdomExceptionBase = ConnectionError\n\n\ndef save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):\n    \"\"\"Save images to the disk.\n\n    Parameters:\n        webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)\n        visuals (OrderedDict)    -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs\n        image_path (str)         -- the string is used to create image paths\n        aspect_ratio (float)     -- the aspect ratio of saved images\n        width (int)              -- the images will be resized to width x width\n\n    This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.\n    \"\"\"\n    image_dir = webpage.get_image_dir()\n    short_path = ntpath.basename(image_path[0])\n    name = os.path.splitext(short_path)[0]\n\n    webpage.add_header(name)\n    ims, txts, links = [], [], []\n\n    for label, im_data in visuals.items():\n        im = util.tensor2im(im_data)\n        image_name = '%s_%s.png' % (name, label)\n        save_path = os.path.join(image_dir, image_name)\n        util.save_image(im, save_path, aspect_ratio=aspect_ratio)\n        ims.append(image_name)\n        txts.append(label)\n        links.append(image_name)\n    webpage.add_images(ims, txts, links, width=width)\n\n\ndef save_videos(webpage, visuals, width=256):\n    \"\"\"Save videos to the disk.\n\n    Parameters:\n        webpage (the HTML class) -- the HTML webpage class that stores these videos (see html.py for more details)\n        visuals (OrderedDict)    -- an ordered dictionary that stores (name, video (either tensor or numpy) ) pairs\n        save_dir (str)           -- the string is used to create video paths\n        aspect_ratio (float)     -- the aspect ratio of saved images\n        width (int)              -- the images will be resized to width x width\n\n    This function will save videos stored in 'visuals' to the HTML file specified by 'webpage'.\n    \"\"\"\n    video_dir = webpage.get_video_dir()\n    webpage.add_header('videos')\n    vids, txts, links = [], [], []\n\n    for label, vid_data in sorted(visuals.items()):\n        video_name = f'{label}.webm'\n        video_path = os.path.join(video_dir, video_name)\n        frame_height, frame_width = vid_data.shape[-2:]\n        video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'vp80'), 25, (frame_width, frame_height))\n        for i in range(vid_data.shape[0]):\n            frame = util.tensor2im(vid_data[i:i+1])\n            if frame.shape[-1] == 4:\n                # render png\n                frame = util.render_png(frame, background='checker')\n            frame = frame[:, :, ::-1]  # RGB -> BGR\n            video.write(frame)\n        video.release()\n        cv2.destroyAllWindows()\n        print(\"You may see an OpenCV 'vp80 not supported' error message despite the video saving correctly. Please ignore it.\")\n        vids.append(video_name)\n        txts.append(label)\n        links.append(video_name)\n    webpage.add_videos(vids, txts, links, width=width)\n\n\nclass Visualizer():\n    \"\"\"This class includes several functions that can display/save images and print/save logging information.\n\n    It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.\n    \"\"\"\n\n    def __init__(self, opt):\n        \"\"\"Initialize the Visualizer class\n\n        Parameters:\n            opt -- stores all the experiment flags; needs to be a subclass of BaseOptions\n        Step 1: Cache the training/test options\n        Step 2: connect to a visdom server\n        Step 3: create an HTML object for saveing HTML filters\n        Step 4: create a logging file to store training losses\n        \"\"\"\n        self.opt = opt  # cache the option\n        self.display_id = opt.display_id\n        self.use_html = opt.isTrain and not opt.no_html\n        self.win_size = opt.display_winsize\n        self.name = opt.name\n        self.port = opt.display_port\n        self.saved = False\n        if self.display_id > 0:  # connect to a visdom server given <display_port> and <display_server>\n            import visdom\n            self.ncols = opt.display_ncols\n            self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)\n            if not self.vis.check_connection():\n                self.create_visdom_connections()\n\n        if self.use_html:  # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/\n            self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')\n            self.img_dir = os.path.join(self.web_dir, 'images')\n            print('create web directory %s...' % self.web_dir)\n            util.mkdirs([self.web_dir, self.img_dir])\n        # create a logging file to store training losses\n        self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')\n        with open(self.log_name, \"a\") as log_file:\n            now = time.strftime(\"%c\")\n            log_file.write('================ Training Loss (%s) ================\\n' % now)\n\n    def reset(self):\n        \"\"\"Reset the self.saved status\"\"\"\n        self.saved = False\n\n    def create_visdom_connections(self):\n        \"\"\"If the program could not connect to Visdom server, this function will start a new server at port < self.port > \"\"\"\n        cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port\n        print('\\n\\nCould not connect to Visdom server. \\n Trying to start a server....')\n        print('Command: %s' % cmd)\n        Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)\n\n    def display_current_results(self, visuals, epoch, save_result):\n        \"\"\"Display current results on visdom; save current results to an HTML file.\n\n        Parameters:\n            visuals (OrderedDict) - - dictionary of images to display or save\n            epoch (int) - - the current epoch\n            save_result (bool) - - if save the current results to an HTML file\n        \"\"\"\n        if self.display_id > 0:  # show images in the browser using visdom\n            ncols = self.ncols\n            if ncols > 0:        # show all the images in one visdom panel\n                ncols = min(ncols, len(visuals))\n                h, w = next(iter(visuals.values())).shape[:2]\n                table_css = \"\"\"<style>\n                        table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}\n                        table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}\n                        </style>\"\"\" % (w, h)  # create a table css\n                # create a table of images.\n                title = self.name\n                label_html = ''\n                label_html_row = ''\n                images = []\n                idx = 0\n                for label, image in visuals.items():\n                    image_numpy = util.tensor2im(image)\n                    label_html_row += '<td>%s</td>' % label\n                    images.append(image_numpy.transpose([2, 0, 1]))\n                    idx += 1\n                    if idx % ncols == 0:\n                        label_html += '<tr>%s</tr>' % label_html_row\n                        label_html_row = ''\n                white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255\n                while idx % ncols != 0:\n                    images.append(white_image)\n                    label_html_row += '<td></td>'\n                    idx += 1\n                if label_html_row != '':\n                    label_html += '<tr>%s</tr>' % label_html_row\n                try:\n                    self.vis.images(images, nrow=ncols, win=self.display_id + 1,\n                                    padding=2, opts=dict(title=title + ' images'))\n                    label_html = '<table>%s</table>' % label_html\n                    self.vis.text(table_css + label_html, win=self.display_id + 2,\n                                  opts=dict(title=title + ' labels'))\n                except VisdomExceptionBase:\n                    self.create_visdom_connections()\n\n            else:     # show each image in a separate visdom panel;\n                idx = 1\n                try:\n                    for label, image in visuals.items():\n                        image_numpy = util.tensor2im(image)\n                        self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),\n                                       win=self.display_id + idx)\n                        idx += 1\n                except VisdomExceptionBase:\n                    self.create_visdom_connections()\n\n        if self.use_html and (save_result or not self.saved):  # save images to an HTML file if they haven't been saved.\n            self.saved = True\n            # save images to the disk\n            for label, image in visuals.items():\n                image_numpy = util.tensor2im(image)\n                img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))\n                util.save_image(image_numpy, img_path)\n\n            # update website\n            webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)\n            for n in range(epoch, 0, -1):\n                label = list(visuals.keys())[0]\n                img_path = 'epoch%.3d_%s.png' % (n, label)\n                if not os.path.exists(os.path.join(webpage.img_dir, img_path)):\n                    continue\n                webpage.add_header('epoch [%d]' % n)\n                ims, txts, links = [], [], []\n\n                for label, image_numpy in visuals.items():\n                    img_path = 'epoch%.3d_%s.png' % (n, label)\n                    ims.append(img_path)\n                    txts.append(label)\n                    links.append(img_path)\n                webpage.add_images(ims, txts, links, width=self.win_size)\n            webpage.save()\n\n    def plot_current_losses(self, epoch, counter_ratio, losses):\n        \"\"\"display the current losses on visdom display: dictionary of error labels and values\n\n        Parameters:\n            epoch (int)           -- current epoch\n            counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1\n            losses (OrderedDict)  -- training losses stored in the format of (name, float) pairs\n        \"\"\"\n        if not hasattr(self, 'plot_data'):\n            self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}\n        self.plot_data['X'].append(epoch + counter_ratio)\n        self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])\n        try:\n            self.vis.line(\n                X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),\n                Y=np.array(self.plot_data['Y']),\n                opts={\n                    'title': self.name + ' loss over time',\n                    'legend': self.plot_data['legend'],\n                    'xlabel': 'epoch',\n                    'ylabel': 'loss'},\n                win=self.display_id)\n        except VisdomExceptionBase:\n            self.create_visdom_connections()\n\n    # losses: same format as |losses| of plot_current_losses\n    def print_current_losses(self, epoch, iters, losses, t_comp, t_data):\n        \"\"\"print current losses on console; also save the losses to the disk\n\n        Parameters:\n            epoch (int) -- current epoch\n            iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)\n            losses (OrderedDict) -- training losses stored in the format of (name, float) pairs\n            t_comp (float) -- computational time per data point (normalized by batch_size)\n            t_data (float) -- data loading time per data point (normalized by batch_size)\n        \"\"\"\n        message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)\n        for k, v in losses.items():\n            message += '%s: %.3f ' % (k, v)\n\n        print(message)  # print the message\n        with open(self.log_name, \"a\") as log_file:\n            log_file.write('%s\\n' % message)  # save the message\n"
  },
  {
    "path": "train.py",
    "content": "# Copyright 2020 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Script for training a layered neural renderer on a video.\n\nYou need to specify the dataset ('--dataroot') and experiment name ('--name').\n\nExample:\n    python train.py --dataroot ./datasets/reflection --name reflection --gpu_ids 0,1\n\nThe script first creates a model, dataset, and visualizer given the options.\nIt then does standard network training. During training, it also visualizes/saves the images, prints/saves the loss\nplot, and saves the model.\nUse '--continue_train' to resume your previous training.\n\nThe default setting is to first train the base model, which produces the low-resolution result (256x448), and then\ntrain the upsampling module to produce the 512x896 result. If the upsampling module is unnecessary, use\n'--n_epochs_upsample 0'.\n\nSee options/base_options.py and options/train_options.py for more training options.\n\"\"\"\nimport time\nfrom options.train_options import TrainOptions\nfrom third_party.data import create_dataset\nfrom third_party.models import create_model\nfrom third_party.util.visualizer import Visualizer\nimport torch\nimport numpy as np\n\n\ndef main():\n    trainopt = TrainOptions()\n    trainopt.parse()\n    opt = trainopt.parse_dataset_meta()\n\n    torch.manual_seed(opt.seed)\n    np.random.seed(opt.seed)\n\n    opt.do_upsampling = False  # Train low-res network first\n    dataset = create_dataset(opt, use_fast_loader=True)\n    dataset_size = len(dataset)\n    print('The number of training images = %d' % dataset_size)\n\n    model = create_model(opt)\n    model.setup(opt)  # regular setup: load and print networks; create schedulers\n    visualizer = Visualizer(opt)\n\n    # Train base model (produces low-resolution output)\n    train(model, dataset, visualizer, opt)\n\n    # Optionally train upsampling module\n    if opt.n_epochs_upsample > 0:\n        opt.do_upsampling = True\n        opt.batch_size = opt.batch_size_upsample\n        # load dataset for upsampling\n        dataset = create_dataset(opt, use_fast_loader=True)\n        dataset_size = len(dataset)\n        print('The number of training images = %d' % dataset_size)\n\n        # set lambdas for upsampling training\n        opt.lambda_mask = 0\n        opt.lambda_alpha_l0 = 0\n        opt.lambda_alpha_l1 = 0\n        opt.mask_loss_rolloff_epoch = -1\n        opt.jitter_rgb = 0\n\n        # reinit optimizers and schedulers, lambdas\n        model.setup_train(opt)\n        # freeze base model and just train upsampling module\n        model.freeze_basenet()\n        model.setup(opt)\n\n        # update epoch count to resume training\n        opt.epoch_count = opt.n_epochs + opt.n_epochs_decay + 1\n        opt.n_epochs += opt.n_epochs_upsample\n        \n        train(model, dataset, visualizer, opt)\n\n\ndef train(model, dataset, visualizer, opt):\n    dataset_size = len(dataset)\n    total_iters = 0  # the total number of training iterations\n\n    for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>\n        epoch_start_time = time.time()  # timer for entire epoch\n        iter_data_time = time.time()    # timer for data loading per iteration\n        epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch\n        model.update_lambdas(epoch)\n        for i, data in enumerate(dataset):  # inner loop within one epoch\n            iter_start_time = time.time()  # timer for computation per iteration\n            if i % opt.print_freq == 0:\n                t_data = iter_start_time - iter_data_time\n\n            total_iters += opt.batch_size\n            epoch_iter += opt.batch_size\n            model.set_input(data)\n            model.optimize_parameters()\n\n            if i % opt.print_freq == 0:  # print training losses and save logging information to the disk\n                losses = model.get_current_losses()\n                t_comp = (time.time() - iter_start_time) / opt.batch_size\n                visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)\n                if opt.display_id > 0:\n                    visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)\n\n            iter_data_time = time.time()\n\n        if epoch % opt.display_freq == 1:   # display images on visdom and save images to a HTML file\n            save_result = epoch % opt.update_html_freq == 1\n            model.compute_visuals()\n            visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)\n\n        if epoch % opt.save_latest_freq == 0:   # cache our latest model every <save_latest_freq> epochs\n            print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))\n            save_suffix = 'epoch_%d' % epoch if opt.save_by_epoch else 'latest'\n            model.save_networks(save_suffix)\n\n        model.update_learning_rate()    # update learning rates at the end of every epoch.\n        print('End of epoch %d / %d \\t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))\n\n\nif __name__ == '__main__':\n    main()"
  }
]