[
  {
    "path": ".gitignore",
    "content": "*.pyc\n*.so\n*.o"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# Iterative Residual Refinement <br/> for Joint Optical Flow and Occlusion Estimation\n\n<img src=output.gif>\n\nThis repository is the PyTorch implementation of the paper:\n\n**Iterative Residual Refinement for Joint Optical Flow and Occlusion Estimation (CVPR 2019)**  \n[Junhwa Hur](https://sites.google.com/site/hurjunhwa) and [Stefan Roth](https://www.visinf.tu-darmstadt.de/team_members/sroth/sroth.en.jsp)  \nDepartment of Computer Science, TU Darmstadt  \n[[Preprint]](https://arxiv.org/pdf/1904.05290.pdf) &ensp; [[Proceeding]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Hur_Iterative_Residual_Refinement_for_Joint_Optical_Flow_and_Occlusion_Estimation_CVPR_2019_paper.pdf) &ensp; [[Supplemental]](http://openaccess.thecvf.com/content_CVPR_2019/supplemental/Hur_Iterative_Residual_Refinement_CVPR_2019_supplemental.pdf)\n\n\nPlease cite the paper below if you find our paper and source codes are useful.  \n\n    @inproceedings{Hur:2019:IRR,  \n      Author = {Junhwa Hur and Stefan Roth},  \n      Booktitle = {CVPR},  \n      Title = {Iterative Residual Refinement for Joint Optical Flow and Occlusion Estimation},  \n      Year = {2019}  \n    }\n\nContact: junhwa.hur[at]visinf.tu-darmstadt.de\n\n## Getting started\nThis code has been orginally developed under Anaconda(Python 3.6), PyTorch 0.4.1 and CUDA 8.0 on Ubuntu 16.04.\n\n1. Please install the followings:\n\n   - Anaconda\n   - PyTorch (now compatible with __PyTorch 1.5.0__)\n   - tqdm (`conda install -c conda-forge tqdm==4.40.0`)\n   - (any missing packages that the code requires)\n\n\n2. The datasets used for this projects are followings:\n    - [FlyingChairsOcc dataset](https://github.com/visinf/irr/tree/master/flyingchairsocc)\n    - [FlyingThings3D subset](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)\n    - [MPI Sintel Dataset](http://sintel.is.tue.mpg.de/downloads) + [revised occlusion GT](https://download.visinf.tu-darmstadt.de/data/flyingchairs_occ/occlusions_rev.zip)\n    - [KITTI Optical Flow 2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) and [KITTI Optical Flow 2012](http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php?benchmark=flow)\n    \n\n  \n## Training\n\nThe `scripts` folder contains training scripts of experiments demonstrated in the paper.  \nTo train the model, you can simply run the script file, e.g., `./IRR-PWC_flyingChairsOcc.sh`.  \nIn script files, please configure your own experiment directory (EXPERIMENTS_HOME) and dataset directory in your local system (e.g., SINTEL_HOME or KITTI_HOME).\n\n\n## Pretrained Models\n\nThe `saved_check_point` contains the pretrained models of *i)* baseline, *ii)* baseline + irr, and *iii)* full models.  \nAdditional pretrained models in the ablations study (Table 1 in the main paper) and their training scripts are available upon request.\n\n  \n## Inference\n\nThe scripts for testing the pre-trained models are located in `scripts/validation`.\n\n\n## Acknowledgement\n\nPortions of the source code (e.g., training pipeline, runtime, argument parser, and logger) are from [Jochen Gast](https://scholar.google.com/citations?user=tmRcFacAAAAJ&hl=en)\n\n"
  },
  {
    "path": "__init__.py",
    "content": ""
  },
  {
    "path": "augmentations.py",
    "content": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\nimport torch\nimport torch.nn as nn\nfrom utils.interpolation import Interp2, Interp2MaskBinary\nfrom utils.interpolation import Meshgrid\nimport numpy as np\n\n\ndef denormalize_coords(xx, yy, width, height):\n    \"\"\" scale indices from [-1, 1] to [0, width/height] \"\"\"\n    xx = 0.5 * (width - 1.0) * (xx.float() + 1.0)\n    yy = 0.5 * (height - 1.0) * (yy.float() + 1.0)\n    return xx, yy\n\n\ndef normalize_coords(xx, yy, width, height):\n    \"\"\" scale indices from [0, width/height] to [-1, 1] \"\"\"\n    xx = (2.0 / (width - 1.0)) * xx.float() - 1.0\n    yy = (2.0 / (height - 1.0)) * yy.float() - 1.0\n    return xx, yy\n\n\ndef apply_transform_to_params(theta0, theta_transform):\n    a1 = theta0[:, 0]\n    a2 = theta0[:, 1]\n    a3 = theta0[:, 2]\n    a4 = theta0[:, 3]\n    a5 = theta0[:, 4]\n    a6 = theta0[:, 5]\n    #\n    b1 = theta_transform[:, 0]\n    b2 = theta_transform[:, 1]\n    b3 = theta_transform[:, 2]\n    b4 = theta_transform[:, 3]\n    b5 = theta_transform[:, 4]\n    b6 = theta_transform[:, 5]\n    #\n    c1 = a1 * b1 + a4 * b2\n    c2 = a2 * b1 + a5 * b2\n    c3 = b3 + a3 * b1 + a6 * b2\n    c4 = a1 * b4 + a4 * b5\n    c5 = a2 * b4 + a5 * b5\n    c6 = b6 + a3 * b4 + a6 * b5\n    #\n    new_theta = torch.stack([c1, c2, c3, c4, c5, c6], dim=1)\n    return new_theta\n\n\nclass _IdentityParams(nn.Module):\n    def __init__(self):\n        super(_IdentityParams, self).__init__()\n        self._batch_size = 0\n        self.register_buffer(\"_o\", torch.FloatTensor())\n        self.register_buffer(\"_i\", torch.FloatTensor())\n\n    def _update(self, batch_size):\n        torch.zeros([batch_size, 1], out=self._o)\n        torch.ones([batch_size, 1], out=self._i)\n        return torch.cat([self._i, self._o, self._o, self._o, self._i, self._o], dim=1)\n\n    def forward(self, batch_size):\n        if self._batch_size != batch_size:\n            self._identity_params = self._update(batch_size)\n            self._batch_size = batch_size\n        return self._identity_params\n\n\nclass RandomMirror(nn.Module):\n    def __init__(self, vertical=True, p=0.5):\n        super(RandomMirror, self).__init__()\n        self._batch_size = 0\n        self._p = p\n        self._vertical = vertical\n        self.register_buffer(\"_mirror_probs\", torch.FloatTensor())\n\n    def update_probs(self, batch_size):\n        torch.ones([batch_size, 1], out=self._mirror_probs)\n        self._mirror_probs *= self._p\n\n    def forward(self, theta1, theta2):\n        batch_size = theta1.size(0)\n        if batch_size != self._batch_size:\n            self.update_probs(batch_size)\n            self._batch_size = batch_size\n\n        # apply random sign to a1 a2 a3 (these are the guys responsible for x)\n        sign = torch.sign(2.0 * torch.bernoulli(self._mirror_probs) - 1.0)\n        i = torch.ones_like(sign)\n        horizontal_mirror = torch.cat([sign, sign, sign, i, i, i], dim=1)\n        theta1 *= horizontal_mirror\n        theta2 *= horizontal_mirror\n\n        # apply random sign to a4 a5 a6 (these are the guys responsible for y)\n        if self._vertical:\n            sign = torch.sign(2.0 * torch.bernoulli(self._mirror_probs) - 1.0)\n            vertical_mirror = torch.cat([i, i, i, sign, sign, sign], dim=1)\n            theta1 *= vertical_mirror\n            theta2 *= vertical_mirror\n\n        return theta1, theta2\n\n\nclass RandomCrop(nn.Module):\n    \"\"\"Crops the given PIL.Image at a random location to have a region of\n    the given size. size can be a tuple (target_height, target_width)\n    or an integer, in which case the target will be of a square shape (size, size)\n    \"\"\"\n\n    def __init__(self, crop):\n        super(RandomCrop, self).__init__()\n        self._crop_size = crop\n        self.register_buffer(\"_x\", torch.LongTensor())\n        self.register_buffer(\"_y\", torch.LongTensor())\n\n    def forward(self, im1, im2, flo):\n        batch_size, _, height, width = im1.size()\n        crop_height, crop_width = self._crop_size\n\n        # check whether there is anything to do\n        if any(self._size < 1):\n            return im1, im2, flo\n\n        # get starting positions\n        self._x.random_(0, width - crop_width)\n        self._y.random_(0, height - crop_height)\n\n        im1 = im1[:, :, self._y:self._y + crop_height, self._x:self._x + crop_width]\n        im2 = im2[:, :, self._y:self._y + crop_height, self._x:self._x + crop_width]\n        flo = flo[:, :, self._y:self._y + crop_height, self._x:self._x + crop_width]\n\n\nclass RandomAffineFlow(nn.Module):\n    def __init__(self, args, addnoise=True):\n        super(RandomAffineFlow, self).__init__()\n        self._args = args\n        self._interp2 = Interp2(clamp=False)\n        self._flow_interp2 = Interp2(clamp=False)\n        self._meshgrid = Meshgrid()\n        self._identity = _IdentityParams()\n        self._random_mirror = RandomMirror()\n        self._addnoise = addnoise\n        self.register_buffer(\"_noise1\", torch.FloatTensor())\n        self.register_buffer(\"_noise2\", torch.FloatTensor())\n        self.register_buffer(\"_xbounds\", torch.FloatTensor([-1, -1, 1, 1]))\n        self.register_buffer(\"_ybounds\", torch.FloatTensor([-1, 1, -1, 1]))\n\n    def inverse_transform_coords(self, width, height, thetas, offset_x=None, offset_y=None):\n        xx, yy = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)\n\n        xx = torch.unsqueeze(xx, dim=0).float()\n        yy = torch.unsqueeze(yy, dim=0).float()\n\n        if offset_x is not None:\n            xx = xx + offset_x\n        if offset_y is not None:\n            yy = yy + offset_y\n\n        a1 = thetas[:, 0].contiguous().view(-1, 1, 1)\n        a2 = thetas[:, 1].contiguous().view(-1, 1, 1)\n        a3 = thetas[:, 2].contiguous().view(-1, 1, 1)\n        a4 = thetas[:, 3].contiguous().view(-1, 1, 1)\n        a5 = thetas[:, 4].contiguous().view(-1, 1, 1)\n        a6 = thetas[:, 5].contiguous().view(-1, 1, 1)\n\n        xx, yy = normalize_coords(xx, yy, width=width, height=height)\n        xq = a1 * xx + a2 * yy + a3\n        yq = a4 * xx + a5 * yy + a6\n        xq, yq = denormalize_coords(xq, yq, width=width, height=height)\n        return xq, yq\n\n    def transform_coords(self, width, height, thetas):\n        xx1, yy1 = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)\n        xx, yy = normalize_coords(xx1, yy1, width=width, height=height)\n\n        def _unsqueeze12(u):\n            return torch.unsqueeze(torch.unsqueeze(u, dim=1), dim=1)\n\n        a1 = _unsqueeze12(thetas[:, 0])\n        a2 = _unsqueeze12(thetas[:, 1])\n        a3 = _unsqueeze12(thetas[:, 2])\n        a4 = _unsqueeze12(thetas[:, 3])\n        a5 = _unsqueeze12(thetas[:, 4])\n        a6 = _unsqueeze12(thetas[:, 5])\n        #\n        z = a1 * a5 - a2 * a4\n        b1 = a5 / z\n        b2 = - a2 / z\n        b4 = - a4 / z\n        b5 = a1 / z\n        #\n        xhat = xx - a3\n        yhat = yy - a6\n        xq = b1 * xhat + b2 * yhat\n        yq = b4 * xhat + b5 * yhat\n\n        xq, yq = denormalize_coords(xq, yq, width=width, height=height)\n        return xq, yq\n\n    def find_invalid(self, width, height, thetas):\n        x = self._xbounds\n        y = self._ybounds\n        #\n        a1 = torch.unsqueeze(thetas[:, 0], dim=1)\n        a2 = torch.unsqueeze(thetas[:, 1], dim=1)\n        a3 = torch.unsqueeze(thetas[:, 2], dim=1)\n        a4 = torch.unsqueeze(thetas[:, 3], dim=1)\n        a5 = torch.unsqueeze(thetas[:, 4], dim=1)\n        a6 = torch.unsqueeze(thetas[:, 5], dim=1)\n        #\n        z = a1 * a5 - a2 * a4\n        b1 = a5 / z\n        b2 = - a2 / z\n        b4 = - a4 / z\n        b5 = a1 / z\n        #\n        xhat = x - a3\n        yhat = y - a6\n        xq = b1 * xhat + b2 * yhat\n        yq = b4 * xhat + b5 * yhat\n        xq, yq = denormalize_coords(xq, yq, width=width, height=height)\n        #\n        invalid = (\n                      (xq < 0) | (yq < 0) | (xq >= width) | (yq >= height)\n                  ).sum(dim=1, keepdim=True) > 0\n\n        return invalid\n\n    def apply_random_transforms_to_params(self,\n                                          theta0,\n                                          max_translate,\n                                          min_zoom, max_zoom,\n                                          min_squeeze, max_squeeze,\n                                          min_rotate, max_rotate,\n                                          validate_size=None):\n        max_translate *= 0.5\n        batch_size = theta0.size(0)\n        height, width = validate_size\n\n        # collect valid params here\n        thetas = torch.zeros_like(theta0)\n\n        zoom = theta0.new(batch_size, 1).zero_()\n        squeeze = torch.zeros_like(zoom)\n        tx = torch.zeros_like(zoom)\n        ty = torch.zeros_like(zoom)\n        phi = torch.zeros_like(zoom)\n        invalid = torch.ones_like(zoom).byte()\n\n        while invalid.sum() > 0:\n            # random sampling\n            zoom.uniform_(min_zoom, max_zoom)\n            squeeze.uniform_(min_squeeze, max_squeeze)\n            tx.uniform_(-max_translate, max_translate)\n            ty.uniform_(-max_translate, max_translate)\n            phi.uniform_(min_rotate, max_rotate)\n\n            # construct affine parameters\n            sx = zoom * squeeze\n            sy = zoom / squeeze\n            sin_phi = torch.sin(phi)\n            cos_phi = torch.cos(phi)\n            b1 = cos_phi * sx\n            b2 = sin_phi * sy\n            b3 = tx\n            b4 = - sin_phi * sx\n            b5 = cos_phi * sy\n            b6 = ty\n\n            theta_transform = torch.cat([b1, b2, b3, b4, b5, b6], dim=1)\n            theta_try = apply_transform_to_params(theta0, theta_transform)\n            thetas = invalid.float() * theta_try + (1 - invalid.float()) * thetas\n\n            # compute new invalid ones\n            invalid = self.find_invalid(width=width, height=height, thetas=thetas)\n\n        # here we should have good thetas within borders\n        return thetas\n\n    def transform_image(self, images, thetas):\n        batch_size, channels, height, width = images.size()\n        xq, yq = self.transform_coords(width=width, height=height, thetas=thetas)\n        transformed = self._interp2(images, xq, yq)\n        return transformed\n\n    def transform_flow(self, flow, theta1, theta2):\n        batch_size, channels, height, width = flow.size()\n        u = flow[:, 0, :, :]\n        v = flow[:, 1, :, :]\n\n        # inverse transform coords\n        x0, y0 = self.inverse_transform_coords(\n            width=width, height=height, thetas=theta1)\n\n        x1, y1 = self.inverse_transform_coords(\n            width=width, height=height, thetas=theta2, offset_x=u, offset_y=v)\n\n        # subtract and create new flow\n        u = x1 - x0\n        v = y1 - y0\n        new_flow = torch.stack([u, v], dim=1)\n\n        # transform coords\n        xq, yq = self.transform_coords(width=width, height=height, thetas=theta1)\n\n        # interp2\n        transformed = self._flow_interp2(new_flow, xq, yq)\n        return transformed\n\n    def forward(self, example_dict):\n        im1 = example_dict[\"input1\"]\n        im2 = example_dict[\"input2\"]\n        flo = example_dict[\"target1\"]\n\n        batch_size = im1.size(0)\n        height = im1.size(2)\n        width = im1.size(3)\n\n        # identity = no transform\n        theta0 = self._identity(batch_size)\n\n        # # global transform\n        theta1 = self.apply_random_transforms_to_params(\n            theta0,\n            max_translate=0.2,\n            min_zoom=1.0, max_zoom=1.5,\n            min_squeeze=0.86, max_squeeze=1.16,\n            min_rotate=-0.2, max_rotate=0.2,\n            validate_size=[height, width])\n\n        # relative transform\n        theta2 = self.apply_random_transforms_to_params(\n            theta1,\n            max_translate=0.015,\n            min_zoom=0.985, max_zoom=1.015,\n            min_squeeze=1.0, max_squeeze=1.0,\n            min_rotate=-0.015, max_rotate=0.015,\n            validate_size=[height, width])\n\n        # random flip images\n        theta1, theta2 = self._random_mirror(theta1, theta2)\n\n        im1 = self.transform_image(im1, theta1)\n        im2 = self.transform_image(im2, theta2)\n        flo = self.transform_flow(flo, theta1, theta2)\n\n        if self._addnoise:\n            stddev = np.random.uniform(0.0, 0.04)\n            self._noise1.resize_as_(im1)\n            self._noise2.resize_as_(im2)\n            self._noise1.normal_(std=stddev)\n            self._noise2.normal_(std=stddev)\n            im1 += self._noise1\n            im2 += self._noise2\n            im1.clamp_(0.0, 1.0)\n            im2.clamp_(0.0, 1.0)\n\n        # construct updated dictionaries\n        example_dict[\"input1\"] = im1\n        example_dict[\"input2\"] = im2\n        example_dict[\"target1\"] = flo\n\n        return example_dict\n\n\nclass RandomAffineFlowOcc(nn.Module):\n    def __init__(self, args, addnoise=True, crop=None):\n        super(RandomAffineFlowOcc, self).__init__()\n        self._args = args\n        self._interp2 = Interp2(clamp=False)\n        self._flow_interp2 = Interp2(clamp=False)\n        self._meshgrid = Meshgrid()\n        self._identity = _IdentityParams()\n        self._random_mirror = RandomMirror()\n        self._addnoise = addnoise\n        self._crop = crop\n\n        self.register_buffer(\"_noise1\", torch.FloatTensor())\n        self.register_buffer(\"_noise2\", torch.FloatTensor())\n        self.register_buffer(\"_xbounds\", torch.FloatTensor([-1, -1, 1, 1]))\n        self.register_buffer(\"_ybounds\", torch.FloatTensor([-1, 1, -1, 1]))\n        self.register_buffer(\"_x\", torch.IntTensor(1))\n        self.register_buffer(\"_y\", torch.IntTensor(1))\n\n    def inverse_transform_coords(self, width, height, thetas, offset_x=None, offset_y=None):\n        xx, yy = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)\n\n        xx = torch.unsqueeze(xx, dim=0).float()\n        yy = torch.unsqueeze(yy, dim=0).float()\n\n        if offset_x is not None:\n            xx = xx + offset_x\n        if offset_y is not None:\n            yy = yy + offset_y\n\n        a1 = thetas[:, 0].contiguous().view(-1, 1, 1)\n        a2 = thetas[:, 1].contiguous().view(-1, 1, 1)\n        a3 = thetas[:, 2].contiguous().view(-1, 1, 1)\n        a4 = thetas[:, 3].contiguous().view(-1, 1, 1)\n        a5 = thetas[:, 4].contiguous().view(-1, 1, 1)\n        a6 = thetas[:, 5].contiguous().view(-1, 1, 1)\n\n        xx, yy = normalize_coords(xx, yy, width=width, height=height)\n        xq = a1 * xx + a2 * yy + a3\n        yq = a4 * xx + a5 * yy + a6\n        xq, yq = denormalize_coords(xq, yq, width=width, height=height)\n        return xq, yq\n\n    def transform_coords(self, width, height, thetas):\n        xx1, yy1 = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)\n        xx, yy = normalize_coords(xx1, yy1, width=width, height=height)\n\n        def _unsqueeze12(u):\n            return torch.unsqueeze(torch.unsqueeze(u, dim=1), dim=1)\n\n        a1 = _unsqueeze12(thetas[:, 0])\n        a2 = _unsqueeze12(thetas[:, 1])\n        a3 = _unsqueeze12(thetas[:, 2])\n        a4 = _unsqueeze12(thetas[:, 3])\n        a5 = _unsqueeze12(thetas[:, 4])\n        a6 = _unsqueeze12(thetas[:, 5])\n        #\n        z = a1 * a5 - a2 * a4\n        b1 = a5 / z\n        b2 = - a2 / z\n        b4 = - a4 / z\n        b5 = a1 / z\n        #\n        xhat = xx - a3\n        yhat = yy - a6\n        xq = b1 * xhat + b2 * yhat\n        yq = b4 * xhat + b5 * yhat\n\n        xq, yq = denormalize_coords(xq, yq, width=width, height=height)\n        return xq, yq\n\n    def find_invalid(self, width, height, thetas):\n        x = self._xbounds\n        y = self._ybounds\n        #\n        a1 = torch.unsqueeze(thetas[:, 0], dim=1)\n        a2 = torch.unsqueeze(thetas[:, 1], dim=1)\n        a3 = torch.unsqueeze(thetas[:, 2], dim=1)\n        a4 = torch.unsqueeze(thetas[:, 3], dim=1)\n        a5 = torch.unsqueeze(thetas[:, 4], dim=1)\n        a6 = torch.unsqueeze(thetas[:, 5], dim=1)\n        #\n        z = a1 * a5 - a2 * a4\n        b1 = a5 / z\n        b2 = - a2 / z\n        b4 = - a4 / z\n        b5 = a1 / z\n        #\n        xhat = x - a3\n        yhat = y - a6\n        xq = b1 * xhat + b2 * yhat\n        yq = b4 * xhat + b5 * yhat\n        xq, yq = denormalize_coords(xq, yq, width=width, height=height)\n        #\n        invalid = (\n                      (xq < 0) | (yq < 0) | (xq >= width) | (yq >= height)\n                  ).sum(dim=1, keepdim=True) > 0\n\n        return invalid\n\n    def apply_random_transforms_to_params(self,\n                                          theta0,\n                                          max_translate,\n                                          min_zoom, max_zoom,\n                                          min_squeeze, max_squeeze,\n                                          min_rotate, max_rotate,\n                                          validate_size=None):\n        max_translate *= 0.5\n        batch_size = theta0.size(0)\n        height, width = validate_size\n\n        # collect valid params here\n        thetas = torch.zeros_like(theta0)\n\n        zoom = theta0.new(batch_size, 1).zero_()\n        squeeze = torch.zeros_like(zoom)\n        tx = torch.zeros_like(zoom)\n        ty = torch.zeros_like(zoom)\n        phi = torch.zeros_like(zoom)\n        invalid = torch.ones_like(zoom).byte()\n\n        while invalid.sum() > 0:\n            # random sampling\n            zoom.uniform_(min_zoom, max_zoom)\n            squeeze.uniform_(min_squeeze, max_squeeze)\n            tx.uniform_(-max_translate, max_translate)\n            ty.uniform_(-max_translate, max_translate)\n            phi.uniform_(min_rotate, max_rotate)\n\n            # construct affine parameters\n            sx = zoom * squeeze\n            sy = zoom / squeeze\n            sin_phi = torch.sin(phi)\n            cos_phi = torch.cos(phi)\n            b1 = cos_phi * sx\n            b2 = sin_phi * sy\n            b3 = tx\n            b4 = - sin_phi * sx\n            b5 = cos_phi * sy\n            b6 = ty\n\n            theta_transform = torch.cat([b1, b2, b3, b4, b5, b6], dim=1)\n            theta_try = apply_transform_to_params(theta0, theta_transform)\n            thetas = invalid.float() * theta_try + (1. - invalid.float()) * thetas\n\n            # compute new invalid ones\n            invalid = self.find_invalid(width=width, height=height, thetas=thetas)\n\n        # here we should have good thetas within borders\n        return thetas\n\n    def transform_image(self, images, thetas):\n        batch_size, channels, height, width = images.size()\n        xq, yq = self.transform_coords(width=width, height=height, thetas=thetas)\n        transformed = self._interp2(images, xq, yq)\n        return transformed\n\n    def transform_flow(self, flow, theta1, theta2):\n        batch_size, channels, height, width = flow.size()\n        u = flow[:, 0, :, :]\n        v = flow[:, 1, :, :]\n\n        # inverse transform coords\n        x0, y0 = self.inverse_transform_coords(\n            width=width, height=height, thetas=theta1)\n\n        x1, y1 = self.inverse_transform_coords(\n            width=width, height=height, thetas=theta2, offset_x=u, offset_y=v)\n\n        # subtract and create new flow\n        u = x1 - x0\n        v = y1 - y0\n        new_flow = torch.stack([u, v], dim=1)\n\n        # transform coords\n        xq, yq = self.transform_coords(width=width, height=height, thetas=theta1)\n\n        # interp2\n        transformed = self._flow_interp2(new_flow, xq, yq)\n        return transformed\n\n    def check_out_of_bound(self, flow, occ, batch_size):\n        _, _, height, width = flow.size()\n        u = flow[:, 0, :, :]\n        v = flow[:, 1, :, :]\n        xx, yy = self._meshgrid(width=width, height=height, device=flow.device, dtype=flow.dtype)\n        xx = torch.unsqueeze(xx, dim=0).float()\n        yy = torch.unsqueeze(yy, dim=0).float()\n        xx = xx.expand(batch_size, -1, -1) + u\n        yy = yy.expand(batch_size, -1, -1) + v\n\n        out_of_bound = ((xx < 0) | (yy < 0) | (xx >= width) | (yy >= height)).float().unsqueeze(1)\n        occ = torch.clamp(out_of_bound + occ, 0, 1)\n\n        return occ\n\n    def random_crop(self, im1, im2, flo_f, flo_b, occ1, occ2):\n\n        _, _, height, width = im1.size()\n        crop_height, crop_width = self._crop\n\n        # get starting positions\n        self._x.random_(0, width - crop_width + 1)\n        self._y.random_(0, height - crop_height + 1)\n        str_x = int(self._x)\n        str_y = int(self._y)\n        end_x = int(self._x + crop_width)\n        end_y = int(self._y + crop_height)\n\n        im1 = im1[:, :, str_y:end_y, str_x:end_x]\n        im2 = im2[:, :, str_y:end_y, str_x:end_x]\n        flo_f = flo_f[:, :, str_y:end_y, str_x:end_x]\n        flo_b = flo_b[:, :, str_y:end_y, str_x:end_x]\n        occ1 = occ1[:, :, str_y:end_y, str_x:end_x]\n        occ2 = occ2[:, :, str_y:end_y, str_x:end_x]\n\n        return im1, im2, flo_f, flo_b, occ1, occ2\n\n    def forward(self, example_dict):\n        im1 = example_dict[\"input1\"]\n        im2 = example_dict[\"input2\"]\n        flo_f = example_dict[\"target1\"]\n        flo_b = example_dict[\"target2\"]\n        occ1 = example_dict[\"target_occ1\"]\n        occ2 = example_dict[\"target_occ2\"]\n\n        batch_size = im1.size(0)\n        height = im1.size(2)\n        width = im1.size(3)\n\n        # identity = no transform\n        theta0 = self._identity(batch_size)\n\n        # # global transform\n        theta1 = self.apply_random_transforms_to_params(\n            theta0,\n            max_translate=0.2,\n            min_zoom=1.0, max_zoom=1.5,\n            min_squeeze=0.86, max_squeeze=1.16,\n            min_rotate=-0.2, max_rotate=0.2,\n            validate_size=[height, width])\n\n        # relative transform\n        theta2 = self.apply_random_transforms_to_params(\n            theta1,\n            max_translate=0.015,\n            min_zoom=0.985, max_zoom=1.015,\n            min_squeeze=1.0, max_squeeze=1.0,\n            min_rotate=-0.015, max_rotate=0.015,\n            validate_size=[height, width])\n\n        # random flip images\n        theta1, theta2 = self._random_mirror(theta1, theta2)\n\n        im1 = self.transform_image(im1, theta1)\n        im2 = self.transform_image(im2, theta2)\n        flo_f = self.transform_flow(flo_f, theta1, theta2)\n        flo_b = self.transform_flow(flo_b, theta2, theta1)\n        occ1 = self.transform_image(occ1, theta1)\n        occ2 = self.transform_image(occ2, theta2)\n\n        if self._addnoise:\n            stddev = np.random.uniform(0.0, 0.04)\n            self._noise1.resize_as_(im1)\n            self._noise2.resize_as_(im2)\n            self._noise1.normal_(std=stddev)\n            self._noise2.normal_(std=stddev)\n            im1 += self._noise1\n            im2 += self._noise2\n            im1.clamp_(0.0, 1.0)\n            im2.clamp_(0.0, 1.0)\n\n        if self._crop is not None:\n            im1, im2, flo_f, flo_b, occ1, occ2 = self.random_crop(im1, im2, flo_f, flo_b, occ1, occ2)\n\n        occ1 = self.check_out_of_bound(flo_f, occ1, batch_size)\n        occ2 = self.check_out_of_bound(flo_b, occ2, batch_size)\n\n        example_dict[\"input1\"] = im1\n        example_dict[\"input2\"] = im2\n        example_dict[\"target1\"] = flo_f\n        example_dict[\"target2\"] = flo_b\n        example_dict[\"target_occ1\"] = occ1\n        example_dict[\"target_occ2\"] = occ2\n\n        return example_dict\n\n\nclass RandomAffineFlowOccSintel(nn.Module):\n    def __init__(self, args, addnoise=True, crop=None):\n        super(RandomAffineFlowOccSintel, self).__init__()\n        self._args = args\n        self._interp2 = Interp2(clamp=False)\n        self._flow_interp2 = Interp2(clamp=False)\n        self._meshgrid = Meshgrid()\n        self._identity = _IdentityParams()\n        self._random_mirror = RandomMirror()\n        self._addnoise = addnoise\n        self._crop = crop\n\n        self.register_buffer(\"_noise1\", torch.FloatTensor())\n        self.register_buffer(\"_noise2\", torch.FloatTensor())\n        self.register_buffer(\"_xbounds\", torch.FloatTensor([-1, -1, 1, 1]))\n        self.register_buffer(\"_ybounds\", torch.FloatTensor([-1, 1, -1, 1]))\n        self.register_buffer(\"_x\", torch.IntTensor(1))\n        self.register_buffer(\"_y\", torch.IntTensor(1))\n\n    def inverse_transform_coords(self, width, height, thetas, offset_x=None, offset_y=None):\n        xx, yy = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)\n\n        xx = torch.unsqueeze(xx, dim=0).float()\n        yy = torch.unsqueeze(yy, dim=0).float()\n\n        if offset_x is not None:\n            xx = xx + offset_x\n        if offset_y is not None:\n            yy = yy + offset_y\n\n        a1 = thetas[:, 0].contiguous().view(-1, 1, 1)\n        a2 = thetas[:, 1].contiguous().view(-1, 1, 1)\n        a3 = thetas[:, 2].contiguous().view(-1, 1, 1)\n        a4 = thetas[:, 3].contiguous().view(-1, 1, 1)\n        a5 = thetas[:, 4].contiguous().view(-1, 1, 1)\n        a6 = thetas[:, 5].contiguous().view(-1, 1, 1)\n\n        xx, yy = normalize_coords(xx, yy, width=width, height=height)\n        xq = a1 * xx + a2 * yy + a3\n        yq = a4 * xx + a5 * yy + a6\n        xq, yq = denormalize_coords(xq, yq, width=width, height=height)\n        return xq, yq\n\n    def transform_coords(self, width, height, thetas):\n        xx1, yy1 = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)\n        xx, yy = normalize_coords(xx1, yy1, width=width, height=height)\n\n        def _unsqueeze12(u):\n            return torch.unsqueeze(torch.unsqueeze(u, dim=1), dim=1)\n\n        a1 = _unsqueeze12(thetas[:, 0])\n        a2 = _unsqueeze12(thetas[:, 1])\n        a3 = _unsqueeze12(thetas[:, 2])\n        a4 = _unsqueeze12(thetas[:, 3])\n        a5 = _unsqueeze12(thetas[:, 4])\n        a6 = _unsqueeze12(thetas[:, 5])\n        #\n        z = a1 * a5 - a2 * a4\n        b1 = a5 / z\n        b2 = - a2 / z\n        b4 = - a4 / z\n        b5 = a1 / z\n        #\n        xhat = xx - a3\n        yhat = yy - a6\n        xq = b1 * xhat + b2 * yhat\n        yq = b4 * xhat + b5 * yhat\n\n        xq, yq = denormalize_coords(xq, yq, width=width, height=height)\n        return xq, yq\n\n    def find_invalid(self, width, height, thetas):\n        x = self._xbounds\n        y = self._ybounds\n        #\n        a1 = torch.unsqueeze(thetas[:, 0], dim=1)\n        a2 = torch.unsqueeze(thetas[:, 1], dim=1)\n        a3 = torch.unsqueeze(thetas[:, 2], dim=1)\n        a4 = torch.unsqueeze(thetas[:, 3], dim=1)\n        a5 = torch.unsqueeze(thetas[:, 4], dim=1)\n        a6 = torch.unsqueeze(thetas[:, 5], dim=1)\n        #\n        z = a1 * a5 - a2 * a4\n        b1 = a5 / z\n        b2 = - a2 / z\n        b4 = - a4 / z\n        b5 = a1 / z\n        #\n        xhat = x - a3\n        yhat = y - a6\n        xq = b1 * xhat + b2 * yhat\n        yq = b4 * xhat + b5 * yhat\n        xq, yq = denormalize_coords(xq, yq, width=width, height=height)\n        #\n        invalid = (\n                      (xq < 0) | (yq < 0) | (xq >= width) | (yq >= height)\n                  ).sum(dim=1, keepdim=True) > 0\n\n        return invalid\n\n    def apply_random_transforms_to_params(self,\n                                          theta0,\n                                          max_translate,\n                                          min_zoom, max_zoom,\n                                          min_squeeze, max_squeeze,\n                                          min_rotate, max_rotate,\n                                          validate_size=None):\n        max_translate *= 0.5\n        batch_size = theta0.size(0)\n        height, width = validate_size\n\n        # collect valid params here\n        thetas = torch.zeros_like(theta0)\n\n        zoom = theta0.new(batch_size, 1).zero_()\n        squeeze = torch.zeros_like(zoom)\n        tx = torch.zeros_like(zoom)\n        ty = torch.zeros_like(zoom)\n        phi = torch.zeros_like(zoom)\n        invalid = torch.ones_like(zoom).byte()\n\n        while invalid.sum() > 0:\n            # random sampling\n            zoom.uniform_(min_zoom, max_zoom)\n            squeeze.uniform_(min_squeeze, max_squeeze)\n            tx.uniform_(-max_translate, max_translate)\n            ty.uniform_(-max_translate, max_translate)\n            phi.uniform_(min_rotate, max_rotate)\n\n            # construct affine parameters\n            sx = zoom * squeeze\n            sy = zoom / squeeze\n            sin_phi = torch.sin(phi)\n            cos_phi = torch.cos(phi)\n            b1 = cos_phi * sx\n            b2 = sin_phi * sy\n            b3 = tx\n            b4 = - sin_phi * sx\n            b5 = cos_phi * sy\n            b6 = ty\n\n            theta_transform = torch.cat([b1, b2, b3, b4, b5, b6], dim=1)\n            theta_try = apply_transform_to_params(theta0, theta_transform)\n            thetas = invalid.float() * theta_try + (1 - invalid.float()) * thetas\n\n            # compute new invalid ones\n            invalid = self.find_invalid(width=width, height=height, thetas=thetas)\n\n        # here we should have good thetas within borders\n        return thetas\n\n    def transform_image(self, images, thetas):\n        batch_size, channels, height, width = images.size()\n        xq, yq = self.transform_coords(width=width, height=height, thetas=thetas)\n        transformed = self._interp2(images, xq, yq)\n        return transformed\n\n    def transform_flow(self, flow, theta1, theta2):\n        batch_size, channels, height, width = flow.size()\n        u = flow[:, 0, :, :]\n        v = flow[:, 1, :, :]\n\n        # inverse transform coords\n        x0, y0 = self.inverse_transform_coords(\n            width=width, height=height, thetas=theta1)\n\n        x1, y1 = self.inverse_transform_coords(\n            width=width, height=height, thetas=theta2, offset_x=u, offset_y=v)\n\n        # subtract and create new flow\n        u = x1 - x0\n        v = y1 - y0\n        new_flow = torch.stack([u, v], dim=1)\n\n        # transform coords\n        xq, yq = self.transform_coords(width=width, height=height, thetas=theta1)\n\n        # interp2\n        transformed = self._flow_interp2(new_flow, xq, yq)\n        return transformed\n\n    def check_out_of_bound(self, flow, occ, batch_size):\n        _, _, height, width = flow.size()\n        u = flow[:, 0, :, :]\n        v = flow[:, 1, :, :]\n        xx, yy = self._meshgrid(width=width, height=height, device=flow.device, dtype=flow.dtype)\n        xx = torch.unsqueeze(xx, dim=0)\n        yy = torch.unsqueeze(yy, dim=0)\n        xx = xx.expand(batch_size, -1, -1) + u\n        yy = yy.expand(batch_size, -1, -1) + v\n\n        out_of_bound = ((xx < 0) | (yy < 0) | (xx >= width) | (yy >= height)).float().unsqueeze(1)\n        occ = torch.clamp(out_of_bound + occ, 0, 1)\n\n        return occ\n\n    def random_crop(self, im1, im2, flo_f, occ1):\n\n        _, _, height, width = im1.size()\n        crop_height, crop_width = self._crop\n\n        # get starting positions\n        self._x.random_(0, width - crop_width + 1)\n        self._y.random_(0, height - crop_height + 1)\n        str_x = int(self._x)\n        str_y = int(self._y)\n        end_x = int(self._x + crop_width)\n        end_y = int(self._y + crop_height)\n\n        im1 = im1[:, :, str_y:end_y, str_x:end_x]\n        im2 = im2[:, :, str_y:end_y, str_x:end_x]\n        flo_f = flo_f[:, :, str_y:end_y, str_x:end_x]\n        occ1 = occ1[:, :, str_y:end_y, str_x:end_x]\n\n        return im1, im2, flo_f, occ1\n\n    def forward(self, example_dict):\n        im1 = example_dict[\"input1\"]\n        im2 = example_dict[\"input2\"]\n        flo_f = example_dict[\"target1\"]\n        occ1 = example_dict[\"target_occ1\"]\n\n        batch_size = im1.size(0)\n        height = im1.size(2)\n        width = im1.size(3)\n\n        # identity = no transform\n        theta0 = self._identity(batch_size)\n\n        # # global transform\n        theta1 = self.apply_random_transforms_to_params(\n            theta0,\n            max_translate=0.2,\n            min_zoom=1.0, max_zoom=1.5,\n            min_squeeze=0.86, max_squeeze=1.16,\n            min_rotate=-0.2, max_rotate=0.2,\n            validate_size=[height, width])\n\n        # relative transform\n        theta2 = self.apply_random_transforms_to_params(\n            theta1,\n            max_translate=0.015,\n            min_zoom=0.985, max_zoom=1.015,\n            min_squeeze=1.0, max_squeeze=1.0,\n            min_rotate=-0.015, max_rotate=0.015,\n            validate_size=[height, width])\n\n        # random flip images\n        theta1, theta2 = self._random_mirror(theta1, theta2)\n\n        im1 = self.transform_image(im1, theta1)\n        im2 = self.transform_image(im2, theta2)\n        flo_f = self.transform_flow(flo_f, theta1, theta2)\n        occ1 = self.transform_image(occ1, theta1)\n\n        if self._addnoise:\n            stddev = np.random.uniform(0.0, 0.04)\n            self._noise1.resize_as_(im1)\n            self._noise2.resize_as_(im2)\n            self._noise1.normal_(std=stddev)\n            self._noise2.normal_(std=stddev)\n            im1 += self._noise1\n            im2 += self._noise2\n            im1.clamp_(0.0, 1.0)\n            im2.clamp_(0.0, 1.0)\n\n        if self._crop is not None:\n            im1, im2, flo_f, occ1 = self.random_crop(im1, im2, flo_f, occ1)\n\n        occ1 = self.check_out_of_bound(flo_f, occ1, batch_size)\n\n        example_dict[\"input1\"] = im1\n        example_dict[\"input2\"] = im2\n        example_dict[\"target1\"] = flo_f\n        example_dict[\"target_occ1\"] = occ1\n\n        return example_dict\n\n\nclass RandomAffineFlowOccKITTI(nn.Module):\n    def __init__(self, args, addnoise=True, crop=None):\n        super(RandomAffineFlowOccKITTI, self).__init__()\n        self._args = args\n        self._interp2 = Interp2(clamp=False)\n        self._flow_interp2 = Interp2MaskBinary(clamp=False)\n        self._meshgrid = Meshgrid()\n        self._identity = _IdentityParams()\n        self._random_mirror = RandomMirror(vertical=False)\n        self._addnoise = addnoise\n        self._crop = crop\n\n        self.register_buffer(\"_noise1\", torch.FloatTensor())\n        self.register_buffer(\"_noise2\", torch.FloatTensor())\n        self.register_buffer(\"_xbounds\", torch.FloatTensor([-1, -1, 1, 1]))\n        self.register_buffer(\"_ybounds\", torch.FloatTensor([-1, 1, -1, 1]))\n        self.register_buffer(\"_x\", torch.IntTensor(1))\n        self.register_buffer(\"_y\", torch.IntTensor(1))\n\n    def inverse_transform_coords(self, width, height, thetas, offset_x=None, offset_y=None):\n        xx, yy = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)\n\n        xx = torch.unsqueeze(xx, dim=0).float()\n        yy = torch.unsqueeze(yy, dim=0).float()\n\n        if offset_x is not None:\n            xx = xx + offset_x\n        if offset_y is not None:\n            yy = yy + offset_y\n\n        a1 = thetas[:, 0].contiguous().view(-1, 1, 1)\n        a2 = thetas[:, 1].contiguous().view(-1, 1, 1)\n        a3 = thetas[:, 2].contiguous().view(-1, 1, 1)\n        a4 = thetas[:, 3].contiguous().view(-1, 1, 1)\n        a5 = thetas[:, 4].contiguous().view(-1, 1, 1)\n        a6 = thetas[:, 5].contiguous().view(-1, 1, 1)\n\n        xx, yy = normalize_coords(xx, yy, width=width, height=height)\n        xq = a1 * xx + a2 * yy + a3\n        yq = a4 * xx + a5 * yy + a6\n        xq, yq = denormalize_coords(xq, yq, width=width, height=height)\n        return xq, yq\n\n    def transform_coords(self, width, height, thetas):\n        xx1, yy1 = self._meshgrid(width=width, height=height, device=thetas.device, dtype=thetas.dtype)\n        xx, yy = normalize_coords(xx1, yy1, width=width, height=height)\n\n        def _unsqueeze12(u):\n            return torch.unsqueeze(torch.unsqueeze(u, dim=1), dim=1)\n\n        a1 = _unsqueeze12(thetas[:, 0])\n        a2 = _unsqueeze12(thetas[:, 1])\n        a3 = _unsqueeze12(thetas[:, 2])\n        a4 = _unsqueeze12(thetas[:, 3])\n        a5 = _unsqueeze12(thetas[:, 4])\n        a6 = _unsqueeze12(thetas[:, 5])\n        #\n        z = a1 * a5 - a2 * a4\n        b1 = a5 / z\n        b2 = - a2 / z\n        b4 = - a4 / z\n        b5 = a1 / z\n        #\n        xhat = xx - a3\n        yhat = yy - a6\n        xq = b1 * xhat + b2 * yhat\n        yq = b4 * xhat + b5 * yhat\n\n        xq, yq = denormalize_coords(xq, yq, width=width, height=height)\n        return xq, yq\n\n    def find_invalid(self, width, height, thetas):\n        x = self._xbounds\n        y = self._ybounds\n        #\n        a1 = torch.unsqueeze(thetas[:, 0], dim=1)\n        a2 = torch.unsqueeze(thetas[:, 1], dim=1)\n        a3 = torch.unsqueeze(thetas[:, 2], dim=1)\n        a4 = torch.unsqueeze(thetas[:, 3], dim=1)\n        a5 = torch.unsqueeze(thetas[:, 4], dim=1)\n        a6 = torch.unsqueeze(thetas[:, 5], dim=1)\n        #\n        z = a1 * a5 - a2 * a4\n        b1 = a5 / z\n        b2 = - a2 / z\n        b4 = - a4 / z\n        b5 = a1 / z\n        #\n        xhat = x - a3\n        yhat = y - a6\n        xq = b1 * xhat + b2 * yhat\n        yq = b4 * xhat + b5 * yhat\n        xq, yq = denormalize_coords(xq, yq, width=width, height=height)\n        #\n        invalid = (\n                      (xq < 0) | (yq < 0) | (xq >= width) | (yq >= height)\n                  ).sum(dim=1, keepdim=True) > 0\n\n        return invalid\n\n    def apply_random_transforms_to_params(self,\n                                          theta0,\n                                          max_translate,\n                                          min_zoom, max_zoom,\n                                          min_squeeze, max_squeeze,\n                                          min_rotate, max_rotate,\n                                          validate_size=None):\n        max_translate *= 0.5\n        batch_size = theta0.size(0)\n        height, width = validate_size\n\n        # collect valid params here\n        thetas = torch.zeros_like(theta0)\n\n        zoom = theta0.new(batch_size, 1).zero_()\n        squeeze = torch.zeros_like(zoom)\n        tx = torch.zeros_like(zoom)\n        ty = torch.zeros_like(zoom)\n        phi = torch.zeros_like(zoom)\n        invalid = torch.ones_like(zoom).byte()\n\n        while invalid.sum() > 0:\n            # random sampling\n            zoom.uniform_(min_zoom, max_zoom)\n            squeeze.uniform_(min_squeeze, max_squeeze)\n            tx.uniform_(-max_translate, max_translate)\n            ty.uniform_(-max_translate, max_translate)\n            phi.uniform_(min_rotate, max_rotate)\n\n            # construct affine parameters\n            sx = zoom * squeeze\n            sy = zoom / squeeze\n            sin_phi = torch.sin(phi)\n            cos_phi = torch.cos(phi)\n            b1 = cos_phi * sx\n            b2 = sin_phi * sy\n            b3 = tx\n            b4 = - sin_phi * sx\n            b5 = cos_phi * sy\n            b6 = ty\n\n            theta_transform = torch.cat([b1, b2, b3, b4, b5, b6], dim=1)\n            theta_try = apply_transform_to_params(theta0, theta_transform)\n            thetas = invalid.float() * theta_try + (1 - invalid.float()) * thetas\n\n            # compute new invalid ones\n            invalid = self.find_invalid(width=width, height=height, thetas=thetas)\n\n        # here we should have good thetas within borders\n        return thetas\n\n    def transform_image(self, images, thetas):\n        batch_size, channels, height, width = images.size()\n        xq, yq = self.transform_coords(width=width, height=height, thetas=thetas)\n        transformed = self._interp2(images, xq, yq)\n        return transformed\n\n    def transform_flow(self, flow, theta1, theta2, valid_mask):\n        batch_size, channels, height, width = flow.size()\n        u = flow[:, 0, :, :]\n        v = flow[:, 1, :, :]\n\n        # inverse transform coords\n        x0, y0 = self.inverse_transform_coords(\n            width=width, height=height, thetas=theta1)\n\n        x1, y1 = self.inverse_transform_coords(\n            width=width, height=height, thetas=theta2, offset_x=u, offset_y=v)\n\n        # subtract and create new flow\n        u = x1 - x0\n        v = y1 - y0\n        new_flow = torch.stack([u, v], dim=1)\n\n        # transform coords\n        xq, yq = self.transform_coords(width=width, height=height, thetas=theta1)\n\n        # interp2\n        # transformed = self._interp2(new_flow, xq, yq)\n        transformed, valid_mask = self._flow_interp2(new_flow, xq, yq, valid_mask)\n        return transformed, valid_mask\n\n    def check_out_of_bound(self, flow, occ, batch_size):\n        _, _, height, width = flow.size()\n        u = flow[:, 0, :, :]\n        v = flow[:, 1, :, :]\n        xx, yy = self._meshgrid(width=width, height=height, device=flow.device, dtype=flow.dtype)\n        xx = torch.unsqueeze(xx, dim=0).float()\n        yy = torch.unsqueeze(yy, dim=0).float()\n        xx = xx.expand(batch_size, -1, -1) + u\n        yy = yy.expand(batch_size, -1, -1) + v\n\n        out_of_bound = ((xx < 0) | (yy < 0) | (xx >= width) | (yy >= height)).float().unsqueeze(1)\n        occ = torch.clamp(out_of_bound + occ, 0, 1)\n\n        return occ\n\n    def random_crop(self, im1, im2, flo_f, valid_mask):\n\n        _, _, height, width = im1.size()\n        crop_height, crop_width = self._crop\n\n        # get starting positions\n        self._x.random_(0, width - crop_width + 1)\n        self._y.random_(0, height - crop_height + 1)\n        str_x = int(self._x)\n        str_y = int(self._y)\n        end_x = int(self._x + crop_width)\n        end_y = int(self._y + crop_height)\n\n        im1 = im1[:, :, str_y:end_y, str_x:end_x]\n        im2 = im2[:, :, str_y:end_y, str_x:end_x]\n        flo_f = flo_f[:, :, str_y:end_y, str_x:end_x]\n        valid_mask = valid_mask[:, :, str_y:end_y, str_x:end_x]\n\n        return im1, im2, flo_f, valid_mask\n\n    def forward(self, example_dict):\n        im1 = example_dict[\"input1\"]\n        im2 = example_dict[\"input2\"]\n        flo_f = example_dict[\"target1\"]\n        valid_mask = example_dict[\"input_valid\"]\n\n        batch_size = im1.size(0)\n        height = im1.size(2)\n        width = im1.size(3)\n\n        # identity = no transform\n        theta0 = self._identity(batch_size)\n\n        # # global transform\n        theta1 = self.apply_random_transforms_to_params(\n            theta0,\n            max_translate=0.04,\n            min_zoom=0.98, max_zoom=1.02,\n            min_squeeze=1.0, max_squeeze=1.0,\n            min_rotate=-0.01, max_rotate=0.01,\n            validate_size=[height, width])\n\n        # relative transform\n        theta2 = self.apply_random_transforms_to_params(\n            theta1,\n            max_translate=0.005,\n            min_zoom=0.99, max_zoom=1.01,\n            min_squeeze=1.0, max_squeeze=1.0,\n            min_rotate=-0.01, max_rotate=0.01,\n            validate_size=[height, width])\n\n        # random flip images\n        theta1, theta2 = self._random_mirror(theta1, theta2)\n\n        im1 = self.transform_image(im1, theta1)\n        im2 = self.transform_image(im2, theta2)\n        flo_f, valid_mask = self.transform_flow(flo_f, theta1, theta2, valid_mask)\n\n\n        if self._addnoise:\n            stddev = np.random.uniform(0.0, 0.04)\n            self._noise1.resize_as_(im1)\n            self._noise2.resize_as_(im2)\n            self._noise1.normal_(std=stddev)\n            self._noise2.normal_(std=stddev)\n            im1 += self._noise1\n            im2 += self._noise2\n            im1.clamp_(0.0, 1.0)\n            im2.clamp_(0.0, 1.0)\n\n        if self._crop is not None:\n            im1, im2, flo_f, valid_mask = self.random_crop(im1, im2, flo_f, valid_mask)\n\n        example_dict[\"input1\"] = im1\n        example_dict[\"input2\"] = im2\n        example_dict[\"target1\"] = flo_f\n        example_dict[\"input_valid\"] = valid_mask\n\n        return example_dict\n"
  },
  {
    "path": "commandline.py",
    "content": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\nimport argparse\nimport colorama\nimport inspect\nimport os\nimport sys\nimport torch\n\nimport datasets\nimport losses\nimport models\nimport augmentations\nimport tools\nimport logger\nimport logging\nimport optim\n\n\ndef _get_type_from_arg(arg):\n    if isinstance(arg, bool):\n        return tools.str2bool\n    else:\n        return type(arg)\n\n\ndef _add_arguments_for_module(parser,\n                              module,\n                              name,\n                              default_class,\n                              add_class_argument=True,   # whether to add class choice as argument\n                              include_classes=\"*\",\n                              exclude_classes=[],\n                              exclude_params=[\"self\",\"args\"],\n                              param_defaults={},          # allows to overwrite any default param\n                              forced_default_types={},    # allows to set types for known arguments\n                              unknown_default_types={}):  # allows to set types for unknown arguments\n\n    # -------------------------------------------------------------------------\n    # Determine possible choices from class names in module, possibly apply include/exclude filters\n    # -------------------------------------------------------------------------\n    module_dict = tools.module_classes_to_dict(\n        module, include_classes=include_classes, exclude_classes=exclude_classes)\n\n    # -------------------------------------------------------------------------\n    # Parse known arguments to determine choice for argument name\n    # -------------------------------------------------------------------------\n    if add_class_argument:\n        parser.add_argument(\n            \"--%s\" % name, type=str, default=default_class, choices=module_dict.keys())\n        known_args = parser.parse_known_args(sys.argv[1:])[0]\n    else:\n        # build a temporary parser, and do not add the class as argument\n        tmp_parser = argparse.ArgumentParser()\n        tmp_parser.add_argument(\n            \"--%s\" % name, type=str, default=default_class, choices=module_dict.keys())\n        known_args = tmp_parser.parse_known_args(sys.argv[1:])[0]\n\n    class_name = vars(known_args)[name]\n\n    # -------------------------------------------------------------------------\n    # If class is None, there is no point in trying to parse further arguments\n    # -------------------------------------------------------------------------\n    if class_name is None:\n        return\n\n    # -------------------------------------------------------------------------\n    # Get constructor of that argument choice\n    # -------------------------------------------------------------------------\n    class_constructor = module_dict[class_name]\n\n    # -------------------------------------------------------------------------\n    # Determine constructor argument names and defaults\n    # -------------------------------------------------------------------------\n    try:\n        argspec = inspect.getargspec(class_constructor.__init__)\n        argspec_defaults = argspec.defaults if argspec.defaults is not None else []\n        full_args = argspec.args\n        default_args_dict = dict(zip(argspec.args[-len(argspec_defaults):], argspec_defaults))\n    except TypeError:\n        print(argspec)\n        print(argspec.defaults)\n        raise ValueError(\"unknown_default_types should be adjusted for module: '%s.py'\" % name)\n\n    # -------------------------------------------------------------------------\n    # Add sub_arguments\n    # -------------------------------------------------------------------------\n    for argname in full_args:\n\n        # ---------------------------------------------------------------------\n        # Skip\n        # ---------------------------------------------------------------------\n        if argname in exclude_params:\n            continue\n\n        # ---------------------------------------------------------------------\n        # Sub argument name\n        # ---------------------------------------------------------------------\n        sub_arg_name = \"%s_%s\" % (name, argname)\n\n        # ---------------------------------------------------------------------\n        # If a default argument is given, take that one\n        # ---------------------------------------------------------------------\n        if argname in param_defaults.keys():\n            parser.add_argument(\n                \"--%s\" % sub_arg_name,\n                type=_get_type_from_arg(param_defaults[argname]),\n                default=param_defaults[argname])\n\n        # ---------------------------------------------------------------------\n        # If a default parameter can be inferred from the module, pick that one\n        # ---------------------------------------------------------------------\n        elif argname in default_args_dict.keys():\n\n            # -----------------------------------------------------------------\n            # Check for forced default types\n            # -----------------------------------------------------------------\n            if argname in forced_default_types.keys():\n                argtype = forced_default_types[argname]\n            else:\n                argtype = _get_type_from_arg(default_args_dict[argname])\n            parser.add_argument(\n                \"--%s\" % sub_arg_name, type=argtype, default=default_args_dict[argname])\n\n        # ---------------------------------------------------------------------\n        # Take from the unkowns list\n        # ---------------------------------------------------------------------\n        elif argname in unknown_default_types.keys():\n            parser.add_argument(\"--%s\" % sub_arg_name, type=unknown_default_types[argname])\n\n        else:\n            raise ValueError(\n                \"Do not know how to handle argument '%s' for class '%s'\" % (argname, name))\n\n\ndef _add_special_arguments(parser):\n    # -------------------------------------------------------------------------\n    # Known arguments so far\n    # -------------------------------------------------------------------------\n    known_args = vars(parser.parse_known_args(sys.argv[1:])[0])\n\n    # -------------------------------------------------------------------------\n    # Add special arguments for training\n    # -------------------------------------------------------------------------\n    training_loss = known_args[\"training_loss\"]\n    if training_loss is not None:\n        parser.add_argument(\"--training_key\", type=str, default=\"total_loss\")\n\n    # -------------------------------------------------------------------------\n    # Add special arguments for validation\n    # -------------------------------------------------------------------------\n    validation_loss = known_args[\"validation_loss\"]\n    if validation_loss is not None:\n        parser.add_argument(\"--validation_key\", type=str, default=\"total_loss\")\n        parser.add_argument(\"--validation_key_minimize\", type=tools.str2bool, default=True)\n\n    # -------------------------------------------------------------------------\n    # Add special arguments for checkpoints\n    # -------------------------------------------------------------------------\n    checkpoint = known_args[\"checkpoint\"]\n    if checkpoint is not None:\n        parser.add_argument(\n            \"--checkpoint_mode\", type=str, default=\"resume_from_latest\",\n            choices=[\"resume_from_latest\", \"resume_from_best\"])\n\n        parser.add_argument(\n            \"--checkpoint_include_params\", type=tools.str2list, default=\"[*]\")\n        parser.add_argument(\n            \"--checkpoint_exclude_params\", type=tools.str2list, default=\"[]\")\n\n    # -------------------------------------------------------------------------\n    # Add special arguments for optimizer groups\n    # -------------------------------------------------------------------------\n    parser.add_argument(\"--optimizer_group\", action=\"append\", type=tools.str2dict, default=None)\n\n\ndef _parse_arguments():\n\n    # -------------------------------------------------------------------------\n    # Argument parser and shortcut function to add arguments\n    # -------------------------------------------------------------------------\n    parser = argparse.ArgumentParser()\n    add = parser.add_argument\n\n    # -------------------------------------------------------------------------\n    # Standard arguments\n    # -------------------------------------------------------------------------\n    add(\"--batch_size\", type=int, default=1)\n    add(\"--batch_size_val\", type=int, default=1)\n    add(\"--checkpoint\", type=tools.str2str_or_none, default=None)\n    add(\"--cuda\", type=tools.str2bool, default=True)\n    add(\"--evaluation\", type=tools.str2bool, default=False)\n    add(\"--name\", default=\"run\", type=str)\n    add(\"--num_workers\", type=int, default=4)\n    add(\"--save\", \"-s\", default=\"/tmp/work\", type=str)\n    add(\"--seed\", type=int, default=1)\n    add(\"--start_epoch\", type=int, default=1)\n    add(\"--total_epochs\", type=int, default=10)\n    add(\"--save_result_path_name\", default=\"\", type=str)\n    add(\"--save_result_img\", type=tools.str2bool, default=False)\n    add(\"--save_result_occ\", type=tools.str2bool, default=False)\n    add(\"--save_result_flo\", type=tools.str2bool, default=False)\n    add(\"--save_result_png\", type=tools.str2bool, default=False)\n    add(\"--save_result_bidirection\", type=tools.str2bool, default=False)\n    add(\"--num_iters\", type=int, default=1)\n\n    # -------------------------------------------------------------------------\n    # Arguments inferred from losses\n    # -------------------------------------------------------------------------\n    _add_arguments_for_module(\n        parser,\n        losses,\n        name=\"training_loss\",\n        default_class=None,\n        exclude_classes=[\"_*\", \"Variable\"],\n        exclude_params=[\"self\",\"args\"])\n\n    _add_arguments_for_module(\n        parser,\n        losses,\n        name=\"validation_loss\",\n        default_class=None,\n        exclude_classes=[\"_*\", \"Variable\"],\n        exclude_params=[\"self\",\"args\"])\n\n    # -------------------------------------------------------------------------\n    # Arguments inferred from models\n    # -------------------------------------------------------------------------\n    _add_arguments_for_module(\n        parser,\n        models,\n        name=\"model\",\n        default_class=\"FlowNet1S\",\n        exclude_classes=[\"_*\", \"Variable\"],\n        exclude_params=[\"self\",\"args\"])\n\n    # -------------------------------------------------------------------------\n    # Arguments inferred from augmentations for training\n    # -------------------------------------------------------------------------\n    _add_arguments_for_module(\n        parser,\n        augmentations,\n        name=\"training_augmentation\",\n        default_class=None,\n        exclude_classes=[\"_*\"],\n        exclude_params=[\"self\",\"args\"],\n        forced_default_types={\"crop\": tools.str2intlist})\n\n    # -------------------------------------------------------------------------\n    # Arguments inferred from augmentations for validation\n    # -------------------------------------------------------------------------\n    _add_arguments_for_module(\n        parser,\n        augmentations,\n        name=\"validation_augmentation\",\n        default_class=None,\n        exclude_classes=[\"_*\"],\n        exclude_params=[\"self\",\"args\"])\n\n    # -------------------------------------------------------------------------\n    # Arguments inferred from datasets for training\n    # -------------------------------------------------------------------------\n    _add_arguments_for_module(\n        parser,\n        datasets,\n        name=\"training_dataset\",\n        default_class=None,\n        exclude_params=[\"self\", \"args\", \"is_cropped\"],\n        exclude_classes=[\"_*\"],\n        unknown_default_types={\"root\": str})\n\n    # -------------------------------------------------------------------------\n    # Arguments inferred from datasets for validation\n    # -------------------------------------------------------------------------\n    _add_arguments_for_module(\n        parser,\n        datasets,\n        name=\"validation_dataset\",\n        default_class=None,\n        exclude_params=[\"self\", \"args\", \"is_cropped\"],\n        exclude_classes=[\"_*\"],\n        unknown_default_types={\"root\": str})\n\n    # -------------------------------------------------------------------------\n    # Arguments inferred from PyTorch optimizers\n    # -------------------------------------------------------------------------\n    _add_arguments_for_module(\n        parser,\n        optim,\n        name=\"optimizer\",\n        default_class=\"Adam\",\n        exclude_classes=[\"_*\",\"Optimizer\", \"constructor\"],\n        exclude_params=[\"self\", \"args\", \"params\"],\n        forced_default_types={\"lr\": float,\n                              \"momentum\": float,\n                              \"dampening\": float,\n                              \"weight_decay\": float,\n                              \"nesterov\": tools.str2bool})\n\n    # -------------------------------------------------------------------------\n    # Arguments inferred from PyTorch lr schedulers\n    # -------------------------------------------------------------------------\n    _add_arguments_for_module(\n        parser,\n        torch.optim.lr_scheduler,\n        name=\"lr_scheduler\",\n        default_class=None,\n        exclude_classes=[\"_*\",\"Optimizer\"],\n        exclude_params=[\"self\", \"args\", \"optimizer\"],\n        unknown_default_types={\"T_max\": int,\n                               \"lr_lambda\": str,\n                               \"step_size\": int,\n                               \"milestones\": tools.str2intlist,\n                               \"gamma\": float})\n\n    # -------------------------------------------------------------------------\n    # Special arguments\n    # -------------------------------------------------------------------------\n    _add_special_arguments(parser)\n\n    # -------------------------------------------------------------------------\n    # Parse arguments\n    # -------------------------------------------------------------------------\n    args = parser.parse_args()\n\n    # -------------------------------------------------------------------------\n    # Parse default arguments from a dummy commandline not specifying any args\n    # -------------------------------------------------------------------------\n    defaults = vars(parser.parse_known_args(['--dummy'])[0])\n\n    # -------------------------------------------------------------------------\n    # Consistency checks\n    # -------------------------------------------------------------------------\n    args.cuda = args.cuda and torch.cuda.is_available()\n\n    return args, defaults\n\n\ndef postprocess_args(args):\n\n    # ----------------------------------------------------------------------------\n    # Get appropriate class constructors from modules\n    # ----------------------------------------------------------------------------\n    args.model_class = tools.module_classes_to_dict(models)[args.model]\n\n    if args.optimizer is not None:\n        optimizer_classes = tools.module_classes_to_dict(optim)\n        args.optimizer_class = optimizer_classes[args.optimizer]\n\n    if args.training_loss is not None:\n        loss_classes = tools.module_classes_to_dict(losses)\n        args.training_loss_class = loss_classes[args.training_loss]\n\n    if args.validation_loss is not None:\n        loss_classes = tools.module_classes_to_dict(losses)\n        args.validation_loss_class = loss_classes[args.validation_loss]\n\n    if args.lr_scheduler is not None:\n        scheduler_classes = tools.module_classes_to_dict(torch.optim.lr_scheduler)\n        args.lr_scheduler_class = scheduler_classes[args.lr_scheduler]\n\n    if args.training_dataset is not None:\n        dataset_classes = tools.module_classes_to_dict(datasets)\n        args.training_dataset_class = dataset_classes[args.training_dataset]\n\n    if args.validation_dataset is not None:\n        dataset_classes = tools.module_classes_to_dict(datasets)\n        args.validation_dataset_class = dataset_classes[args.validation_dataset]\n\n    if args.training_augmentation is not None:\n        augmentation_classes = tools.module_classes_to_dict(augmentations)\n        args.training_augmentation_class = augmentation_classes[args.training_augmentation]\n\n    if args.validation_augmentation is not None:\n        augmentation_classes = tools.module_classes_to_dict(augmentations)\n        args.validation_augmentation_class = augmentation_classes[args.validation_augmentation]\n\n    return args\n\n\ndef setup_logging_and_parse_arguments(blocktitle):\n    # ----------------------------------------------------------------------------\n    # Get parse commandline and default arguments\n    # ----------------------------------------------------------------------------\n    args, defaults = _parse_arguments()\n\n    # ----------------------------------------------------------------------------\n    # Setup logbook before everything else\n    # ----------------------------------------------------------------------------\n    logger.configure_logging(os.path.join(args.save, 'logbook.txt'))\n\n    # ----------------------------------------------------------------------------\n    # Write arguments to file, as txt\n    # ----------------------------------------------------------------------------\n    tools.write_dictionary_to_file(\n        sorted(vars(args).items()),\n        filename=os.path.join(args.save, 'args.txt'))\n\n    # ----------------------------------------------------------------------------\n    # Log arguments\n    # ----------------------------------------------------------------------------\n    with logger.LoggingBlock(blocktitle, emph=True):\n        for argument, value in sorted(vars(args).items()):\n            reset = colorama.Style.RESET_ALL\n            color = reset if value == defaults[argument] else colorama.Fore.CYAN\n            logging.info('{}{}: {}{}'.format(color, argument, value, reset))\n\n    # ----------------------------------------------------------------------------\n    # Postprocess\n    # ----------------------------------------------------------------------------\n    args = postprocess_args(args)\n\n    return args\n"
  },
  {
    "path": "configuration.py",
    "content": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\nimport os\nimport torch\nfrom torch import nn\nimport numpy as np\nfrom torch.utils.data import DataLoader\nimport logger\nimport tools\nimport logging\nimport shutil\nimport random\nimport fnmatch\n\n# ---------------------------------------------------\n# Class that contains both the network model and loss\n# ---------------------------------------------------\nclass ModelAndLoss(nn.Module):\n    def __init__(self, args, model, training_loss, evaluation_loss=None):\n        super(ModelAndLoss, self).__init__()\n        self._model = model\n        self._training_loss = training_loss\n        self._evaluation_loss = evaluation_loss\n\n    @property\n    def training_loss(self):\n        return self._training_loss\n\n    @property\n    def evaluation_loss(self):\n        return self._evaluation_loss\n\n    @property\n    def model(self):\n        return self._model\n\n    def num_parameters(self):\n        return sum([p.data.nelement() if p.requires_grad else 0 for p in self.parameters()])\n\n    # -------------------------------------------------------------\n    # Note: We merge inputs and targets into a single dictionary !\n    # -------------------------------------------------------------\n    def forward(self, example_dict):\n        # -------------------------------------\n        # Run forward pass\n        # -------------------------------------\n        output_dict = self._model(example_dict)\n\n        # -------------------------------------\n        # Compute losses\n        # -------------------------------------\n        if self.training:\n            loss_dict = self._training_loss(output_dict, example_dict)\n        else:\n            loss_dict = self._evaluation_loss(output_dict, example_dict)\n\n        # -------------------------------------\n        # Return losses and outputs\n        # -------------------------------------\n        return loss_dict, output_dict\n\n\ndef configure_runtime_augmentations(args):\n    with logger.LoggingBlock(\"Runtime Augmentations\", emph=True):\n\n        training_augmentation = None\n        validation_augmentation = None\n\n        # ----------------------------------------------------\n        # Training Augmentation\n        # ----------------------------------------------------\n        if args.training_augmentation is not None:\n            kwargs = tools.kwargs_from_args(args, \"training_augmentation\")\n            logging.info(\"training_augmentation: %s\" % args.training_augmentation)\n            for param, default in sorted(kwargs.items()):\n                logging.info(\"  %s: %s\" % (param, default))\n            kwargs[\"args\"] = args\n            training_augmentation = tools.instance_from_kwargs(\n                args.training_augmentation_class, kwargs)\n            if args.cuda:\n                training_augmentation = training_augmentation.cuda()\n\n        else:\n            logging.info(\"training_augmentation: None\")\n\n        # ----------------------------------------------------\n        # Validation Augmentation\n        # ----------------------------------------------------\n        if args.validation_augmentation is not None:\n            kwargs = tools.kwargs_from_args(args, \"validation_augmentation\")\n            logging.info(\"validation_augmentation: %s\" % args.validation_augmentation)\n            for param, default in sorted(kwargs.items()):\n                logging.info(\"  %s: %s\" % (param, default))\n            kwargs[\"args\"] = args\n            validation_augmentation = tools.instance_from_kwargs(\n                args.validation_augmentation_class, kwargs)\n            if args.cuda:\n                validation_augmentation = validation_augmentation.cuda()\n\n        else:\n            logging.info(\"validation_augmentation: None\")\n\n    return training_augmentation, validation_augmentation\n\n\ndef configure_model_and_loss(args):\n\n    # ----------------------------------------------------\n    # Dynamically load model and loss class with parameters\n    # passed in via \"--model_[param]=[value]\" or \"--loss_[param]=[value]\" arguments\n    # ----------------------------------------------------\n    with logger.LoggingBlock(\"Model and Loss\", emph=True):\n\n        # ----------------------------------------------------\n        # Model\n        # ----------------------------------------------------\n        kwargs = tools.kwargs_from_args(args, \"model\")\n        kwargs[\"args\"] = args\n        model = tools.instance_from_kwargs(args.model_class, kwargs)\n\n        # ----------------------------------------------------\n        # Training loss\n        # ----------------------------------------------------\n        training_loss = None\n        if args.training_loss is not None:\n            kwargs = tools.kwargs_from_args(args, \"training_loss\")\n            kwargs[\"args\"] = args\n            training_loss = tools.instance_from_kwargs(args.training_loss_class, kwargs)\n\n        # ----------------------------------------------------\n        # Validation loss\n        # ----------------------------------------------------\n        validation_loss = None\n        if args.validation_loss is not None:\n            kwargs = tools.kwargs_from_args(args, \"validation_loss\")\n            kwargs[\"args\"] = args\n            validation_loss = tools.instance_from_kwargs(args.validation_loss_class, kwargs)\n\n        # ----------------------------------------------------\n        # Model and loss\n        # ----------------------------------------------------\n        model_and_loss = ModelAndLoss(args, model, training_loss, validation_loss)\n\n        # -----------------------------------------------------------\n        # If Cuda, transfer model to Cuda and wrap with DataParallel.\n        # -----------------------------------------------------------\n        if args.cuda:\n            model_and_loss = model_and_loss.cuda()\n\n        # ---------------------------------------------------------------\n        # Report some network statistics\n        # ---------------------------------------------------------------\n        logging.info(\"Batch Size: %i\" % args.batch_size)\n        logging.info(\"GPGPU: Cuda\") if args.cuda else logging.info(\"GPGPU: off\")\n        logging.info(\"Network: %s\" % args.model)\n        logging.info(\"Number of parameters: %i\" % tools.x2module(model_and_loss).num_parameters())\n        if training_loss is not None:\n            logging.info(\"Training Key: %s\" % args.training_key)\n            logging.info(\"Training Loss: %s\" % args.training_loss)\n        if validation_loss is not None:\n            logging.info(\"Validation Key: %s\" % args.validation_key)\n            logging.info(\"Validation Loss: %s\" % args.validation_loss)\n\n    return model_and_loss\n\n\ndef configure_random_seed(args):\n    with logger.LoggingBlock(\"Random Seeds\", emph=True):\n        # python\n        seed = args.seed\n        random.seed(seed)\n        logging.info(\"Python seed: %i\" % seed)\n        # numpy\n        seed += 1\n        np.random.seed(seed)\n        logging.info(\"Numpy seed: %i\" % seed)\n        # torch\n        seed += 1\n        torch.manual_seed(seed)\n        logging.info(\"Torch CPU seed: %i\" % seed)\n        # torch cuda\n        seed += 1\n        torch.cuda.manual_seed(seed)\n        logging.info(\"Torch CUDA seed: %i\" % seed)\n\n\n# --------------------------------------------------------------------------\n# Checkpoint loader/saver.\n# --------------------------------------------------------------------------\nclass CheckpointSaver:\n    def __init__(self,\n                 prefix=\"checkpoint\",\n                 latest_postfix=\"_latest\",\n                 best_postfix=\"_best\",\n                 model_key=\"state_dict\",\n                 extension=\".ckpt\"):\n\n        self._prefix = prefix\n        self._model_key = model_key\n        self._latest_postfix = latest_postfix\n        self._best_postfix = best_postfix\n        self._extension = extension\n\n    # the purpose of rewriting the loading function is we sometimes want to\n    # initialize parameters in modules without knowing the dimensions at runtime\n    #\n    # This function here will resize these parameters to whatever size required.\n    \n    def _load_state_dict_into_module(self, state_dict, module, strict=True):\n        own_state = module.state_dict()\n\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    # backwards compatibility for serialized parameters\n                    param = param.data\n                try:\n                    own_state[name].resize_as_(param)\n                    own_state[name].copy_(param)\n                except Exception:\n                    raise RuntimeError('While copying the parameter named {}, '\n                                       'whose dimensions in the model are {} and '\n                                       'whose dimensions in the checkpoint are {}.'\n                                       .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                raise KeyError('unexpected key \"{}\" in state_dict'\n                               .format(name))\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))\n\n    def restore(self, filename, model_and_loss, include_params=\"*\", exclude_params=()):\n        # -----------------------------------------------------------------------------------------\n        # Make sure file exists\n        # -----------------------------------------------------------------------------------------\n        if not os.path.isfile(filename):\n            logging.info(\"Could not find checkpoint file '%s'!\" % filename)\n            quit()\n\n        # -----------------------------------------------------------------------------------------\n        # Load checkpoint from file including the state_dict\n        # -----------------------------------------------------------------------------------------\n        checkpoint_with_state = torch.load(filename)\n\n        # -----------------------------------------------------------------------------------------\n        # Load filtered state dictionary\n        # -----------------------------------------------------------------------------------------\n        state_dict = checkpoint_with_state[self._model_key]\n        restore_keys = tools.filter_list_of_strings(\n            state_dict.keys(),\n            include=include_params,\n            exclude=exclude_params)\n        state_dict = {key: value for key, value in state_dict.items() if key in restore_keys}\n        self._load_state_dict_into_module(state_dict, model_and_loss)\n        # logging.info(\"  Restore keys:\")\n        # for key in restore_keys:\n        #     logging.info(\"    %s\" % key)\n\n        # -----------------------------------------------------------------------------------------\n        # Get checkpoint statistics without the state dict\n        # -----------------------------------------------------------------------------------------\n        checkpoint_stats = {\n            key: value for key, value in checkpoint_with_state.items() if key != self._model_key\n        }\n\n        return checkpoint_stats, filename\n\n    def restore_latest(self, directory, model_and_loss, include_params=\"*\", exclude_params=()):\n        latest_checkpoint_filename = os.path.join(\n            directory, self._prefix + self._latest_postfix + self._extension)\n        return self.restore(latest_checkpoint_filename, model_and_loss, include_params, exclude_params)\n\n    def restore_best(self, directory, model_and_loss, include_params=\"*\", exclude_params=()):\n        best_checkpoint_filename = os.path.join(\n            directory, self._prefix + self._best_postfix + self._extension)\n        return self.restore(best_checkpoint_filename, model_and_loss, include_params, exclude_params)\n\n    def save_latest(self, directory, model_and_loss, stats_dict, store_as_best=False):\n        # -----------------------------------------------------------------------------------------\n        # Make sure directory exists\n        # -----------------------------------------------------------------------------------------\n        tools.ensure_dir(directory)\n\n        # -----------------------------------------------------------------------------------------\n        # Save\n        # -----------------------------------------------------------------------------------------\n        save_dict = dict(stats_dict)\n        save_dict[self._model_key] = model_and_loss.state_dict()\n\n        latest_checkpoint_filename = os.path.join(\n            directory, self._prefix + self._latest_postfix + self._extension)\n\n        latest_statistics_filename = os.path.join(\n            directory, self._prefix + self._latest_postfix + \".json\")\n\n        torch.save(save_dict, latest_checkpoint_filename)\n        tools.write_json(data_dict=stats_dict, filename=latest_statistics_filename)\n\n        # -----------------------------------------------------------------------------------------\n        # Possibly store as best\n        # -----------------------------------------------------------------------------------------\n        if store_as_best:\n            best_checkpoint_filename = os.path.join(\n                directory, self._prefix + self._best_postfix + self._extension)\n\n            best_statistics_filename = os.path.join(\n                directory, self._prefix + self._best_postfix + \".json\")\n\n            logging.info(\"Saved checkpoint as best model..\")\n            shutil.copyfile(latest_checkpoint_filename, best_checkpoint_filename)\n            shutil.copyfile(latest_statistics_filename, best_statistics_filename)\n\n\ndef configure_checkpoint_saver(args, model_and_loss):\n    with logger.LoggingBlock(\"Checkpoint\", emph=True):\n        checkpoint_saver = CheckpointSaver()\n        checkpoint_stats = None\n\n        if args.checkpoint is None:\n            logging.info(\"No checkpoint given.\")\n            logging.info(\"Starting from scratch with random initialization.\")\n\n        elif os.path.isfile(args.checkpoint):\n            checkpoint_stats, filename = checkpoint_saver.restore(\n                filename=args.checkpoint,\n                model_and_loss=model_and_loss,\n                include_params=args.checkpoint_include_params,\n                exclude_params=args.checkpoint_exclude_params)\n\n        elif os.path.isdir(args.checkpoint):\n            if args.checkpoint_mode in [\"resume_from_best\"]:\n                logging.info(\"Loading best checkpoint in %s\" % args.checkpoint)\n                checkpoint_stats, filename = checkpoint_saver.restore_best(\n                    directory=args.checkpoint,\n                    model_and_loss=model_and_loss,\n                    include_params=args.checkpoint_include_params,\n                    exclude_params=args.checkpoint_exclude_params)\n\n            elif args.checkpoint_mode in [\"resume_from_latest\"]:\n                logging.info(\"Loading latest checkpoint in %s\" % args.checkpoint)\n                checkpoint_stats, filename = checkpoint_saver.restore_latest(\n                    directory=args.checkpoint,\n                    model_and_loss=model_and_loss,\n                    include_params=args.checkpoint_include_params,\n                    exclude_params=args.checkpoint_exclude_params)\n            else:\n                logging.info(\"Unknown checkpoint_restore '%s' given!\" % args.checkpoint_restore)\n                quit()\n        else:\n            logging.info(\"Could not find checkpoint file or directory '%s'\" % args.checkpoint)\n            quit()\n\n    return checkpoint_saver, checkpoint_stats\n\n\n# -------------------------------------------------------------------------------------------------\n# Configure data loading\n# -------------------------------------------------------------------------------------------------\ndef configure_data_loaders(args):\n    with logger.LoggingBlock(\"Datasets\", emph=True):\n\n        def _sizes_to_str(value):\n            if np.isscalar(value):\n                return '[1L]'\n            else:\n                return ' '.join([str([d for d in value.size()])])\n\n        def _log_statistics(dataset, prefix, name):\n            with logger.LoggingBlock(\"%s Dataset: %s\" % (prefix, name)):\n                example_dict = dataset[0]  # get sizes from first dataset example\n                for key, value in sorted(example_dict.items()):\n                    if key in [\"index\", \"basename\"]:  # no need to display these\n                        continue\n                    if isinstance(value, str):\n                        logging.info(\"{}: {}\".format(key, value))\n                    else:\n                        logging.info(\"%s: %s\" % (key, _sizes_to_str(value)))\n                logging.info(\"num_examples: %i\" % len(dataset))\n\n        # -----------------------------------------------------------------------------------------\n        # GPU parameters -- turning off pin_memory? for resolving the deadlock?\n        # -----------------------------------------------------------------------------------------\n        gpuargs = {\"num_workers\": args.num_workers, \"pin_memory\": False} if args.cuda else {}\n\n        train_loader = None\n        validation_loader = None\n        inference_loader = None\n\n        # -----------------------------------------------------------------------------------------\n        # Training dataset\n        # -----------------------------------------------------------------------------------------\n        if args.training_dataset is not None:\n\n            # ----------------------------------------------\n            # Figure out training_dataset arguments\n            # ----------------------------------------------\n            kwargs = tools.kwargs_from_args(args, \"training_dataset\")\n            kwargs[\"is_cropped\"] = True\n            kwargs[\"args\"] = args\n\n            # ----------------------------------------------\n            # Create training dataset\n            # ----------------------------------------------\n            train_dataset = tools.instance_from_kwargs(args.training_dataset_class, kwargs)\n\n            # ----------------------------------------------\n            # Create training loader\n            # ----------------------------------------------\n            train_loader = DataLoader(\n                train_dataset,\n                batch_size=args.batch_size,\n                shuffle=True,\n                drop_last=False,\n                **gpuargs)\n\n            _log_statistics(train_dataset, prefix=\"Training\", name=args.training_dataset)\n\n        # -----------------------------------------------------------------------------------------\n        # Validation dataset\n        # -----------------------------------------------------------------------------------------\n        if args.validation_dataset is not None:\n\n            # ----------------------------------------------\n            # Figure out validation_dataset arguments\n            # ----------------------------------------------\n            kwargs = tools.kwargs_from_args(args, \"validation_dataset\")\n            kwargs[\"is_cropped\"] = True\n            kwargs[\"args\"] = args\n\n            # ----------------------------------------------\n            # Create validation dataset\n            # ----------------------------------------------\n            validation_dataset = tools.instance_from_kwargs(args.validation_dataset_class, kwargs)\n\n            # ----------------------------------------------\n            # Create validation loader\n            # ----------------------------------------------\n            validation_loader = DataLoader(\n                validation_dataset,\n                batch_size=args.batch_size_val,\n                shuffle=False,\n                drop_last=False,\n                **gpuargs)\n\n            _log_statistics(validation_dataset, prefix=\"Validation\", name=args.validation_dataset)\n\n    return train_loader, validation_loader, inference_loader\n\n\n# ------------------------------------------------------------\n# Generator for trainable parameters by pattern matching\n# ------------------------------------------------------------\ndef _print_trainable_params(model_and_loss, match=\"*\"):\n    sum = 0\n    for name, p in model_and_loss.named_parameters():\n        if fnmatch.fnmatch(name, match):\n            if p.requires_grad:\n                logging.info(name)\n                logging.info(str(p.numel()))\n                print(name)\n                print(p.numel())\n                sum += p.numel()\n    logging.info(str(sum))\n\ndef _generate_trainable_params(model_and_loss, match=\"*\"):\n    for name, p in model_and_loss.named_parameters():\n        if fnmatch.fnmatch(name, match):\n            if p.requires_grad:\n                yield p\n\n\ndef _param_names_and_trainable_generator(model_and_loss, match=\"*\"):\n    names = []\n    for name, p in model_and_loss.named_parameters():\n        if fnmatch.fnmatch(name, match):\n            if p.requires_grad:\n                names.append(name)\n\n    return names, _generate_trainable_params(model_and_loss, match=match)\n\n\n# -------------------------------------------------------------------------------------------------\n# Build optimizer:\n# -------------------------------------------------------------------------------------------------\ndef configure_optimizer(args, model_and_loss):\n    optimizer = None\n    with logger.LoggingBlock(\"Optimizer\", emph=True):\n        if args.optimizer is not None:\n            if model_and_loss.num_parameters() == 0:\n                logging.info(\"No trainable parameters detected.\")\n                logging.info(\"Setting optimizer to None.\")\n            else:\n                logging.info(args.optimizer)\n\n                # -------------------------------------------\n                # Figure out all optimizer arguments\n                # -------------------------------------------\n                all_kwargs = tools.kwargs_from_args(args, \"optimizer\")\n\n                # -------------------------------------------\n                # Get the split of param groups\n                # -------------------------------------------\n                kwargs_without_groups = {\n                    key: value for key,value in all_kwargs.items() if key != \"group\"\n                }\n                param_groups = all_kwargs[\"group\"]\n\n                # ----------------------------------------------------------------------\n                # Print arguments (without groups)\n                # ----------------------------------------------------------------------\n                for param, default in sorted(kwargs_without_groups.items()):\n                    logging.info(\"%s: %s\" % (param, default))\n\n                # ----------------------------------------------------------------------\n                # Construct actual optimizer params\n                # ----------------------------------------------------------------------\n                kwargs = dict(kwargs_without_groups)\n                if param_groups is None:\n                    # ---------------------------------------------------------\n                    # Add all trainable parameters if there is no param groups\n                    # ---------------------------------------------------------\n                    all_trainable_parameters = _generate_trainable_params(model_and_loss)\n                    kwargs[\"params\"] = all_trainable_parameters\n                else:\n                    # -------------------------------------------\n                    # Add list of parameter groups instead\n                    # -------------------------------------------\n                    trainable_parameter_groups = []\n                    dnames, dparams = _param_names_and_trainable_generator(model_and_loss)\n                    dnames = set(dnames)\n                    dparams = set(list(dparams))\n                    with logger.LoggingBlock(\"parameter_groups:\"):\n                        for group in param_groups:\n                            #  log group settings\n                            group_match = group[\"params\"]\n                            group_args = {\n                                key: value for key, value in group.items() if key != \"params\"\n                            }\n\n                            with logger.LoggingBlock(\"%s: %s\" % (group_match, group_args)):\n                                # retrieve parameters by matching name\n                                gnames, gparams = _param_names_and_trainable_generator(\n                                    model_and_loss, match=group_match)\n                                # log all names affected\n                                for n in sorted(gnames):\n                                    logging.info(n)\n                                # set generator for group\n                                group_args[\"params\"] = gparams\n                                # append parameter group\n                                trainable_parameter_groups.append(group_args)\n                                # update remaining trainable parameters\n                                dnames -= set(gnames)\n                                dparams -= set(list(gparams))\n\n                        # append default parameter group\n                        trainable_parameter_groups.append({\"params\": list(dparams)})\n                        # and log its parameter names\n                        with logger.LoggingBlock(\"default:\"):\n                            for dname in sorted(dnames):\n                                logging.info(dname)\n\n                    # set params in optimizer kwargs\n                    kwargs[\"params\"] = trainable_parameter_groups\n\n                # -------------------------------------------\n                # Create optimizer instance\n                # -------------------------------------------\n                optimizer = tools.instance_from_kwargs(args.optimizer_class, kwargs)\n\n    return optimizer\n\n\n# -------------------------------------------------------------------------------------------------\n# Configure learning rate scheduler\n# -------------------------------------------------------------------------------------------------\ndef configure_lr_scheduler(args, optimizer):\n    lr_scheduler = None\n\n    with logger.LoggingBlock(\"Learning Rate Scheduler\", emph=True):\n        logging.info(\"class: %s\" % args.lr_scheduler)\n\n        if args.lr_scheduler is not None:\n\n            # ----------------------------------------------\n            # Figure out lr_scheduler arguments\n            # ----------------------------------------------\n            kwargs = tools.kwargs_from_args(args, \"lr_scheduler\")\n\n            # -------------------------------------------\n            # Print arguments\n            # -------------------------------------------\n            for param, default in sorted(kwargs.items()):\n                logging.info(\"%s: %s\" % (param, default))\n\n            # -------------------------------------------\n            # Add optimizer\n            # -------------------------------------------\n            kwargs[\"optimizer\"] = optimizer\n\n            # -------------------------------------------\n            # Create lr_scheduler instance\n            # -------------------------------------------\n            lr_scheduler = tools.instance_from_kwargs(args.lr_scheduler_class, kwargs)\n\n    return lr_scheduler\n"
  },
  {
    "path": "datasets/__init__.py",
    "content": "from . import flyingchairs\nfrom . import flyingchairsOcc\nfrom . import sintel\nfrom . import flyingThings3D\nfrom . import kitti_combined\nfrom . import sintel\n\n## FlyingChairs\nFlyingChairsTrain = flyingchairs.FlyingChairsTrain\nFlyingChairsValid = flyingchairs.FlyingChairsValid\nFlyingChairsFull = flyingchairs.FlyingChairsFull\n\n## Our custom FlyingChairs + Occ\nFlyingChairsOccTrain = flyingchairsOcc.FlyingChairsOccTrain\nFlyingChairsOccValid = flyingchairsOcc.FlyingChairsOccValid\nFlyingChairsOccFull = flyingchairsOcc.FlyingChairsOccFull\n\n\n## FlyingThings3D_subset\nFlyingThings3dFinalTrain = flyingThings3D.FlyingThings3dFinalTrain\nFlyingThings3dFinalTest = flyingThings3D.FlyingThings3dFinalTest\nFlyingThings3dCleanTrain = flyingThings3D.FlyingThings3dCleanTrain\nFlyingThings3dCleanTest = flyingThings3D.FlyingThings3dCleanTest\n\n\n## Sintel\nSintelTestClean = sintel.SintelTestClean\nSintelTestFinal = sintel.SintelTestFinal\n\nSintelTrainingCombFull = sintel.SintelTrainingCombFull\nSintelTrainingCombTrain = sintel.SintelTrainingCombTrain\nSintelTrainingCombValid = sintel.SintelTrainingCombValid\n\nSintelTrainingCleanFull = sintel.SintelTrainingCleanFull\nSintelTrainingCleanTrain = sintel.SintelTrainingCleanTrain\nSintelTrainingCleanValid = sintel.SintelTrainingCleanValid\n\nSintelTrainingFinalFull = sintel.SintelTrainingFinalFull\nSintelTrainingFinalTrain = sintel.SintelTrainingFinalTrain\nSintelTrainingFinalValid = sintel.SintelTrainingFinalValid\n\n\n## KITTI Optical Flow 2012 + 2015\nKittiCombTrain = kitti_combined.KittiCombTrain\nKittiCombVal = kitti_combined.KittiCombVal\nKittiCombFull = kitti_combined.KittiCombFull\n\nKittiComb2012Train = kitti_combined.KittiComb2012Train\nKittiComb2012Val = kitti_combined.KittiComb2012Val\nKittiComb2012Full = kitti_combined.KittiComb2012Full\nKittiComb2012Test = kitti_combined.KittiComb2012Test\n\nKittiComb2015Train = kitti_combined.KittiComb2015Train\nKittiComb2015Val = kitti_combined.KittiComb2015Val\nKittiComb2015Full = kitti_combined.KittiComb2015Full\nKittiComb2015Test = kitti_combined.KittiComb2015Test"
  },
  {
    "path": "datasets/common.py",
    "content": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\nimport torch\nimport numpy as np\nimport skimage.io as io\n\n\ndef numpy2torch(array):\n    assert(isinstance(array, np.ndarray))\n    if array.ndim == 3:\n        array = np.transpose(array, (2, 0, 1))\n    else:\n        array = np.expand_dims(array, axis=0)\n    return torch.from_numpy(array.copy()).float()\n\n\ndef read_flo_as_float32(filename):\n    with open(filename, 'rb') as file:\n        magic = np.fromfile(file, np.float32, count=1)\n        assert(202021.25 == magic), \"Magic number incorrect. Invalid .flo file\"\n        w = np.fromfile(file, np.int32, count=1)[0]\n        h = np.fromfile(file, np.int32, count=1)[0]\n        data = np.fromfile(file, np.float32, count=2*h*w)\n    data2D = np.resize(data, (h, w, 2))\n    return data2D\n\n\ndef read_occ_image_as_float32(filename):\n    occ = io.imread(filename).astype(np.float32) / np.float32(255.0)\n    if occ.ndim == 3:\n        occ = occ[:, :, 0]\n    return occ\n\n\ndef read_image_as_float32(filename):\n    return io.imread(filename).astype(np.float32) / np.float32(255.0)\n\n\ndef read_image_as_byte(filename):\n    return io.imread(filename)\n"
  },
  {
    "path": "datasets/flyingThings3D.py",
    "content": "from __future__ import absolute_import, division, print_function\n\nimport os\nimport torch.utils.data as data\nfrom glob import glob\n\nfrom torchvision import transforms as vision_transforms\n\nfrom . import transforms\nfrom . import common\n\nimport numpy as np\n\n\ndef fillingInNaN(flow):\n    h, w, c = flow.shape\n    indices = np.argwhere(np.isnan(flow))\n    neighbors = [[-1, 0], [1, 0], [0, -1], [0, 1]]\n    for ii, idx in enumerate(indices):\n        sum_sample = 0\n        count = 0\n        for jj in range(0, len(neighbors) - 1):\n            hh = idx[0] + neighbors[jj][0]\n            ww = idx[1] + neighbors[jj][1]\n            if hh < 0 or hh >= h:\n                continue\n            if ww < 0 or ww >= w:\n                continue\n            sample_flow = flow[hh, ww, idx[2]]\n            if np.isnan(sample_flow):\n                continue\n            sum_sample += sample_flow\n            count += 1\n        if count is 0:\n            print('FATAL ERROR: no sample')\n        flow[idx[0], idx[1], idx[2]] = sum_sample / count\n\n    return flow\n\n\nclass FlyingThings3d(data.Dataset):\n    def __init__(self,\n                 args,\n                 images_root,\n                 flow_root,\n                 occ_root,\n                 photometric_augmentations=False):\n\n        self._args = args\n        if not os.path.isdir(images_root):\n            raise ValueError(\"Image directory '%s' not found!\")\n        if flow_root is not None and not os.path.isdir(flow_root):\n            raise ValueError(\"Flow directory '%s' not found!\")\n        if occ_root is not None and not os.path.isdir(occ_root):\n            raise ValueError(\"Occ directory '%s' not found!\")\n\n        if flow_root is not None:\n            flow_f_filenames = sorted(glob(os.path.join(flow_root, \"into_future/*.flo\")))\n            flow_b_filenames = sorted(glob(os.path.join(flow_root, \"into_past/*.flo\")))\n\n        if occ_root is not None:\n            occ1_filenames = sorted(glob(os.path.join(occ_root, \"into_future/*.png\")))\n            occ2_filenames = sorted(glob(os.path.join(occ_root, \"into_past/*.png\")))\n\n        all_img_filenames = sorted(glob(os.path.join(images_root, \"*.png\")))\n\n        self._image_list = []\n        self._flow_list = [] if flow_root is not None else None\n        self._occ_list = [] if occ_root is not None else None\n\n        assert len(all_img_filenames) != 0\n        assert len(flow_f_filenames) != 0\n        assert len(flow_b_filenames) != 0\n        assert len(occ1_filenames) != 0\n        assert len(occ2_filenames) != 0\n\n        ## path definition\n        path_flow_f = os.path.join(flow_root, \"into_future\")\n        path_flow_b = os.path.join(flow_root, \"into_past\")\n        path_occ_f = os.path.join(occ_root, \"into_future\")\n        path_occ_b = os.path.join(occ_root, \"into_past\")\n\n        # ----------------------------------------------------------\n        # Save list of actual filenames for inputs and flows\n        # ----------------------------------------------------------\n\n        for ii in range(0, len(flow_f_filenames)):\n\n            flo_f = flow_f_filenames[ii]\n\n            idx_f = os.path.splitext(os.path.basename(flo_f))[0]\n            idx_b = str(int(idx_f) + 1).zfill(7)\n\n            flo_b = os.path.join(path_flow_b, idx_b + \".flo\")\n\n            im1 = os.path.join(images_root, idx_f + \".png\")\n            im2 = os.path.join(images_root, idx_b + \".png\")\n            occ1 = os.path.join(path_occ_f, idx_f + \".png\")\n            occ2 = os.path.join(path_occ_b, idx_b + \".png\")\n\n            if not os.path.isfile(flo_f) or not os.path.isfile(flo_b) or not os.path.isfile(im1) or not os.path.isfile(\n                    im2) or not os.path.isfile(occ1) or not os.path.isfile(occ2):\n                continue\n\n            self._image_list += [[im1, im2]]\n            self._flow_list += [[flo_f, flo_b]]\n            self._occ_list += [[occ1, occ2]]\n\n        self._size = len(self._image_list)\n\n        assert len(self._image_list) == len(self._flow_list)\n        assert len(self._occ_list) == len(self._flow_list)\n        assert len(self._image_list) != 0\n\n        # ----------------------------------------------------------\n        # photometric_augmentations\n        # ----------------------------------------------------------\n        if photometric_augmentations:\n            self._photometric_transform = transforms.ConcatTransformSplitChainer([\n                # uint8 -> PIL\n                vision_transforms.ToPILImage(),\n                # PIL -> PIL : random hsv and contrast\n                vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),\n                # PIL -> FloatTensor\n                vision_transforms.transforms.ToTensor(),\n                transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),\n            ], from_numpy=True, to_numpy=False)\n\n        else:\n            self._photometric_transform = transforms.ConcatTransformSplitChainer([\n                # uint8 -> FloatTensor\n                vision_transforms.transforms.ToTensor(),\n            ], from_numpy=True, to_numpy=False)\n\n    def __getitem__(self, index):\n        index = index % self._size\n\n        im1_filename = self._image_list[index][0]\n        im2_filename = self._image_list[index][1]\n        flo_f_filename = self._flow_list[index][0]\n        flo_b_filename = self._flow_list[index][1]\n        occ1_filename = self._occ_list[index][0]\n        occ2_filename = self._occ_list[index][1]\n\n        # read float32 images and flow\n        im1_np0 = common.read_image_as_byte(im1_filename)\n        im2_np0 = common.read_image_as_byte(im2_filename)\n        flo_f_np0 = common.read_flo_as_float32(flo_f_filename)\n        flo_b_np0 = common.read_flo_as_float32(flo_b_filename)\n        occ1_np0 = common.read_occ_image_as_float32(occ1_filename)\n        occ2_np0 = common.read_occ_image_as_float32(occ2_filename)\n\n        # temp - check isnan\n        if np.any(np.isnan(flo_f_np0)):\n            flo_f_np0 = fillingInNaN(flo_f_np0)\n\n        if np.any(np.isnan(flo_b_np0)):\n            flo_b_np0 = fillingInNaN(flo_b_np0)\n\n        # possibly apply photometric transformations\n        im1, im2 = self._photometric_transform(im1_np0, im2_np0)\n\n        # convert flow to FloatTensor\n        flo_f = common.numpy2torch(flo_f_np0)\n        flo_b = common.numpy2torch(flo_b_np0)\n\n        # convert occ to FloatTensor\n        occ1 = common.numpy2torch(occ1_np0)\n        occ2 = common.numpy2torch(occ2_np0)\n\n        # example filename\n        basename = os.path.basename(im1_filename)[:5]\n\n        example_dict = {\n            \"input1\": im1,\n            \"input2\": im2,\n            \"target1\": flo_f,\n            \"target2\": flo_b,\n            \"target_occ1\": occ1,\n            \"target_occ2\": occ2,\n            \"index\": index,\n            \"basename\": basename\n        }\n\n        return example_dict\n\n    def __len__(self):\n        return self._size\n\n\nclass FlyingThings3dFinalTrain(FlyingThings3d):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=True):\n        images_root = os.path.join(root, \"frames_finalpass\")\n        flow_root = os.path.join(root, \"optical_flow\")\n        occ_root = os.path.join(root, \"occlusion\")\n        super(FlyingThings3dFinalTrain, self).__init__(\n            args,\n            images_root=images_root,\n            flow_root=flow_root,\n            occ_root=occ_root,\n            photometric_augmentations=photometric_augmentations)\n\n\nclass FlyingThings3dFinalTest(FlyingThings3d):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=False):\n        images_root = os.path.join(root, \"frames_finalpass\")\n        flow_root = os.path.join(root, \"optical_flow\")\n        occ_root = os.path.join(root, \"occlusion\")\n        super(FlyingThings3dFinalTest, self).__init__(\n            args,\n            images_root=images_root,\n            flow_root=flow_root,\n            occ_root=occ_root,\n            photometric_augmentations=photometric_augmentations)\n\n\nclass FlyingThings3dCleanTrain(FlyingThings3d):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=True):\n        images_root = os.path.join(root, \"train\", \"image_clean\", \"left\")\n        flow_root = os.path.join(root, \"train\", \"flow\", \"left\")\n        occ_root = os.path.join(root, \"train\", \"flow_occlusions\", \"left\")\n        super(FlyingThings3dCleanTrain, self).__init__(\n            args,\n            images_root=images_root,\n            flow_root=flow_root,\n            occ_root=occ_root,\n            photometric_augmentations=photometric_augmentations)\n\n\nclass FlyingThings3dCleanTest(FlyingThings3d):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=False):\n        images_root = os.path.join(root, \"frames_cleanpass\")\n        flow_root = os.path.join(root, \"optical_flow\")\n        occ_root = os.path.join(root, \"occlusion\")\n        super(FlyingThings3dCleanTest, self).__init__(\n            args,\n            images_root=images_root,\n            flow_root=flow_root,\n            occ_root=occ_root,\n            photometric_augmentations=photometric_augmentations)\n"
  },
  {
    "path": "datasets/flyingchairs.py",
    "content": "from __future__ import absolute_import, division, print_function\n\nimport os\nimport torch.utils.data as data\nfrom glob import glob\n\nfrom torchvision import transforms as vision_transforms\n\nfrom . import transforms\nfrom . import common\n\n\nVALIDATE_INDICES = [\n    5, 17, 42, 45, 58, 62, 96, 111, 117, 120, 121, 131, 132,\n    152, 160, 248, 263, 264, 291, 293, 295, 299, 316, 320, 336,\n    337, 343, 358, 399, 401, 429, 438, 468, 476, 494, 509, 528,\n    531, 572, 581, 583, 588, 593, 681, 688, 696, 714, 767, 786,\n    810, 825, 836, 841, 883, 917, 937, 942, 970, 974, 980, 1016,\n    1043, 1064, 1118, 1121, 1133, 1153, 1155, 1158, 1159, 1173,\n    1187, 1219, 1237, 1238, 1259, 1266, 1278, 1296, 1354, 1378,\n    1387, 1494, 1508, 1518, 1574, 1601, 1614, 1668, 1673, 1699,\n    1712, 1714, 1737, 1841, 1872, 1879, 1901, 1921, 1934, 1961,\n    1967, 1978, 2018, 2030, 2039, 2043, 2061, 2113, 2204, 2216,\n    2236, 2250, 2274, 2292, 2310, 2342, 2359, 2374, 2382, 2399,\n    2415, 2419, 2483, 2502, 2504, 2576, 2589, 2590, 2622, 2624,\n    2636, 2651, 2655, 2658, 2659, 2664, 2672, 2706, 2707, 2709,\n    2725, 2732, 2761, 2827, 2864, 2866, 2905, 2922, 2929, 2966,\n    2972, 2993, 3010, 3025, 3031, 3040, 3041, 3070, 3113, 3124,\n    3129, 3137, 3141, 3157, 3183, 3206, 3219, 3247, 3253, 3272,\n    3276, 3321, 3328, 3333, 3338, 3341, 3346, 3351, 3396, 3419,\n    3430, 3433, 3448, 3455, 3463, 3503, 3526, 3529, 3537, 3555,\n    3577, 3584, 3591, 3594, 3597, 3603, 3613, 3615, 3670, 3676,\n    3678, 3697, 3723, 3728, 3734, 3745, 3750, 3752, 3779, 3782,\n    3813, 3817, 3819, 3854, 3885, 3944, 3947, 3970, 3985, 4011,\n    4022, 4071, 4075, 4132, 4158, 4167, 4190, 4194, 4207, 4246,\n    4249, 4298, 4307, 4317, 4318, 4319, 4320, 4382, 4399, 4401,\n    4407, 4416, 4423, 4484, 4491, 4493, 4517, 4525, 4538, 4578,\n    4606, 4609, 4620, 4623, 4637, 4646, 4662, 4668, 4716, 4739,\n    4747, 4770, 4774, 4776, 4785, 4800, 4845, 4863, 4891, 4904,\n    4922, 4925, 4956, 4963, 4964, 4994, 5011, 5019, 5036, 5038,\n    5041, 5055, 5118, 5122, 5130, 5162, 5164, 5178, 5196, 5227,\n    5266, 5270, 5273, 5279, 5299, 5310, 5314, 5363, 5375, 5384,\n    5393, 5414, 5417, 5433, 5448, 5494, 5505, 5509, 5525, 5566,\n    5581, 5602, 5609, 5620, 5653, 5670, 5678, 5690, 5700, 5703,\n    5724, 5752, 5765, 5803, 5811, 5860, 5881, 5895, 5912, 5915,\n    5940, 5952, 5966, 5977, 5988, 6007, 6037, 6061, 6069, 6080,\n    6111, 6127, 6146, 6161, 6166, 6168, 6178, 6182, 6190, 6220,\n    6235, 6253, 6270, 6343, 6372, 6379, 6410, 6411, 6442, 6453,\n    6481, 6498, 6500, 6509, 6532, 6541, 6543, 6560, 6576, 6580,\n    6594, 6595, 6609, 6625, 6629, 6644, 6658, 6673, 6680, 6698,\n    6699, 6702, 6705, 6741, 6759, 6785, 6792, 6794, 6809, 6810,\n    6830, 6838, 6869, 6871, 6889, 6925, 6995, 7003, 7026, 7029,\n    7080, 7082, 7097, 7102, 7116, 7165, 7200, 7232, 7271, 7282,\n    7324, 7333, 7335, 7372, 7387, 7407, 7472, 7474, 7482, 7489,\n    7499, 7516, 7533, 7536, 7566, 7620, 7654, 7691, 7704, 7722,\n    7746, 7750, 7773, 7806, 7821, 7827, 7851, 7873, 7880, 7884,\n    7904, 7912, 7948, 7964, 7965, 7984, 7989, 7992, 8035, 8050,\n    8074, 8091, 8094, 8113, 8116, 8151, 8159, 8171, 8179, 8194,\n    8195, 8239, 8263, 8290, 8295, 8312, 8367, 8374, 8387, 8407,\n    8437, 8439, 8518, 8556, 8588, 8597, 8601, 8651, 8657, 8723,\n    8759, 8763, 8785, 8802, 8813, 8826, 8854, 8856, 8866, 8918,\n    8922, 8923, 8932, 8958, 8967, 9003, 9018, 9078, 9095, 9104,\n    9112, 9129, 9147, 9170, 9171, 9197, 9200, 9249, 9253, 9270,\n    9282, 9288, 9295, 9321, 9323, 9324, 9347, 9399, 9403, 9417,\n    9426, 9427, 9439, 9468, 9486, 9496, 9511, 9516, 9518, 9529,\n    9557, 9563, 9564, 9584, 9586, 9591, 9599, 9600, 9601, 9632,\n    9654, 9667, 9678, 9696, 9716, 9723, 9740, 9820, 9824, 9825,\n    9828, 9863, 9866, 9868, 9889, 9929, 9938, 9953, 9967, 10019,\n    10020, 10025, 10059, 10111, 10118, 10125, 10174, 10194,\n    10201, 10202, 10220, 10221, 10226, 10242, 10250, 10276,\n    10295, 10302, 10305, 10327, 10351, 10360, 10369, 10393,\n    10407, 10438, 10455, 10463, 10465, 10470, 10478, 10503,\n    10508, 10509, 10809, 11080, 11331, 11607, 11610, 11864,\n    12390, 12393, 12396, 12399, 12671, 12921, 12930, 13178,\n    13453, 13717, 14499, 14517, 14775, 15297, 15556, 15834,\n    15839, 16126, 16127, 16386, 16633, 16644, 16651, 17166,\n    17169, 17958, 17959, 17962, 18224, 21176, 21180, 21190,\n    21802, 21803, 21806, 22584, 22857, 22858, 22866]\n\n\nclass FlyingChairs(data.Dataset):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=False,\n                 dstype=\"train\"):\n\n        self._args = args\n\n        # -------------------------------------------------------------\n        # filenames for all input images and target flows\n        # -------------------------------------------------------------\n        image_filenames = sorted( glob( os.path.join(root, \"*.ppm\")) )\n        flow_filenames = sorted( glob( os.path.join(root, \"*.flo\")) )\n        assert (len(image_filenames)/2 == len(flow_filenames))\n        num_flows = len(flow_filenames)\n\n        # -------------------------------------------------------------\n        # Remove invalid validation indices\n        # -------------------------------------------------------------\n        validate_indices = [x for x in VALIDATE_INDICES if x in range(num_flows)]\n\n        # ----------------------------------------------------------\n        # Construct list of indices for training/validation\n        # ----------------------------------------------------------\n        list_of_indices = None\n        if dstype == \"train\":\n            list_of_indices = [x for x in range(num_flows) if x not in validate_indices]\n        elif dstype == \"valid\":\n            list_of_indices = validate_indices\n        elif dstype == \"full\":\n            list_of_indices = range(num_flows)\n        else:\n            raise ValueError(\"FlyingChairs: dstype '%s' unknown!\", dstype)\n\n\n        # ----------------------------------------------------------\n        # Save list of actual filenames for inputs and flows\n        # ----------------------------------------------------------\n        self._image_list = []\n        self._flow_list = []\n        for i in list_of_indices:\n            flo = flow_filenames[i]\n            im1 = image_filenames[2*i]\n            im2 = image_filenames[2*i + 1]\n            self._image_list += [ [ im1, im2 ] ]\n            self._flow_list += [ flo ]\n        self._size = len(self._image_list)\n        assert len(self._image_list) == len(self._flow_list)\n\n        # ----------------------------------------------------------\n        # photometric_augmentations\n        # ----------------------------------------------------------\n        if photometric_augmentations:\n            self._photometric_transform = transforms.ConcatTransformSplitChainer([\n                # uint8 -> PIL\n                vision_transforms.ToPILImage(),\n                # PIL -> PIL : random hsv and contrast\n                vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),\n                # PIL -> FloatTensor\n                vision_transforms.transforms.ToTensor(),\n                transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),\n            ], from_numpy=True, to_numpy=False)\n        else:\n            self._photometric_transform = transforms.ConcatTransformSplitChainer([\n                # uint8 -> FloatTensor\n                vision_transforms.transforms.ToTensor(),\n            ], from_numpy=True, to_numpy=False)\n\n    def __getitem__(self, index):\n        index = index % self._size\n\n        im1_filename = self._image_list[index][0]\n        im2_filename = self._image_list[index][1]\n        flo_filename = self._flow_list[index]\n\n        # read float32 images and flow\n        im1_np0 = common.read_image_as_byte(im1_filename)\n        im2_np0 = common.read_image_as_byte(im2_filename)\n        flo_np0 = common.read_flo_as_float32(flo_filename)\n\n        # possibly apply photometric transformations\n        im1, im2 = self._photometric_transform(im1_np0, im2_np0)\n\n        # convert flow to FloatTensor\n        flo = common.numpy2torch(flo_np0)\n\n        # target_occ: initialized by zero (not used)\n        target_occ = common.numpy2torch(common.read_occ_image_as_float32(im1_filename)) * 0\n\n        # example filename\n        basename = os.path.basename(im1_filename)[:5]\n\n        example_dict = {\n            \"input1\": im1,\n            \"input2\": im2,\n            \"target1\": flo,\n            \"target_occ1\": target_occ,\n            \"index\": index,\n            \"basename\": basename\n        }\n\n        return example_dict\n\n    def __len__(self):\n        return self._size\n\n\nclass FlyingChairsTrain(FlyingChairs):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=True):\n        super(FlyingChairsTrain, self).__init__(\n            args,\n            root=root,\n            photometric_augmentations=photometric_augmentations,\n            dstype=\"train\")\n\n\nclass FlyingChairsValid(FlyingChairs):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=False):\n        super(FlyingChairsValid, self).__init__(\n            args,\n            root=root,\n            photometric_augmentations=photometric_augmentations,\n            dstype=\"valid\")\n\n\nclass FlyingChairsFull(FlyingChairs):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=False):\n        super(FlyingChairsFull, self).__init__(\n            args,\n            root=root,\n            photometric_augmentations=photometric_augmentations,\n            dstype=\"full\")\n"
  },
  {
    "path": "datasets/flyingchairsOcc.py",
    "content": "from __future__ import absolute_import, division, print_function\n\nimport os\nimport torch.utils.data as data\nfrom glob import glob\n\nfrom torchvision import transforms as vision_transforms\n\nfrom . import transforms\nfrom . import common\n\n\nVALIDATE_INDICES = [\n    5, 17, 42, 45, 58, 62, 96, 111, 117, 120, 121, 131, 132,\n    152, 160, 248, 263, 264, 291, 293, 295, 299, 316, 320, 336,\n    337, 343, 358, 399, 401, 429, 438, 468, 476, 494, 509, 528,\n    531, 572, 581, 583, 588, 593, 681, 688, 696, 714, 767, 786,\n    810, 825, 836, 841, 883, 917, 937, 942, 970, 974, 980, 1016,\n    1043, 1064, 1118, 1121, 1133, 1153, 1155, 1158, 1159, 1173,\n    1187, 1219, 1237, 1238, 1259, 1266, 1278, 1296, 1354, 1378,\n    1387, 1494, 1508, 1518, 1574, 1601, 1614, 1668, 1673, 1699,\n    1712, 1714, 1737, 1841, 1872, 1879, 1901, 1921, 1934, 1961,\n    1967, 1978, 2018, 2030, 2039, 2043, 2061, 2113, 2204, 2216,\n    2236, 2250, 2274, 2292, 2310, 2342, 2359, 2374, 2382, 2399,\n    2415, 2419, 2483, 2502, 2504, 2576, 2589, 2590, 2622, 2624,\n    2636, 2651, 2655, 2658, 2659, 2664, 2672, 2706, 2707, 2709,\n    2725, 2732, 2761, 2827, 2864, 2866, 2905, 2922, 2929, 2966,\n    2972, 2993, 3010, 3025, 3031, 3040, 3041, 3070, 3113, 3124,\n    3129, 3137, 3141, 3157, 3183, 3206, 3219, 3247, 3253, 3272,\n    3276, 3321, 3328, 3333, 3338, 3341, 3346, 3351, 3396, 3419,\n    3430, 3433, 3448, 3455, 3463, 3503, 3526, 3529, 3537, 3555,\n    3577, 3584, 3591, 3594, 3597, 3603, 3613, 3615, 3670, 3676,\n    3678, 3697, 3723, 3728, 3734, 3745, 3750, 3752, 3779, 3782,\n    3813, 3817, 3819, 3854, 3885, 3944, 3947, 3970, 3985, 4011,\n    4022, 4071, 4075, 4132, 4158, 4167, 4190, 4194, 4207, 4246,\n    4249, 4298, 4307, 4317, 4318, 4319, 4320, 4382, 4399, 4401,\n    4407, 4416, 4423, 4484, 4491, 4493, 4517, 4525, 4538, 4578,\n    4606, 4609, 4620, 4623, 4637, 4646, 4662, 4668, 4716, 4739,\n    4747, 4770, 4774, 4776, 4785, 4800, 4845, 4863, 4891, 4904,\n    4922, 4925, 4956, 4963, 4964, 4994, 5011, 5019, 5036, 5038,\n    5041, 5055, 5118, 5122, 5130, 5162, 5164, 5178, 5196, 5227,\n    5266, 5270, 5273, 5279, 5299, 5310, 5314, 5363, 5375, 5384,\n    5393, 5414, 5417, 5433, 5448, 5494, 5505, 5509, 5525, 5566,\n    5581, 5602, 5609, 5620, 5653, 5670, 5678, 5690, 5700, 5703,\n    5724, 5752, 5765, 5803, 5811, 5860, 5881, 5895, 5912, 5915,\n    5940, 5952, 5966, 5977, 5988, 6007, 6037, 6061, 6069, 6080,\n    6111, 6127, 6146, 6161, 6166, 6168, 6178, 6182, 6190, 6220,\n    6235, 6253, 6270, 6343, 6372, 6379, 6410, 6411, 6442, 6453,\n    6481, 6498, 6500, 6509, 6532, 6541, 6543, 6560, 6576, 6580,\n    6594, 6595, 6609, 6625, 6629, 6644, 6658, 6673, 6680, 6698,\n    6699, 6702, 6705, 6741, 6759, 6785, 6792, 6794, 6809, 6810,\n    6830, 6838, 6869, 6871, 6889, 6925, 6995, 7003, 7026, 7029,\n    7080, 7082, 7097, 7102, 7116, 7165, 7200, 7232, 7271, 7282,\n    7324, 7333, 7335, 7372, 7387, 7407, 7472, 7474, 7482, 7489,\n    7499, 7516, 7533, 7536, 7566, 7620, 7654, 7691, 7704, 7722,\n    7746, 7750, 7773, 7806, 7821, 7827, 7851, 7873, 7880, 7884,\n    7904, 7912, 7948, 7964, 7965, 7984, 7989, 7992, 8035, 8050,\n    8074, 8091, 8094, 8113, 8116, 8151, 8159, 8171, 8179, 8194,\n    8195, 8239, 8263, 8290, 8295, 8312, 8367, 8374, 8387, 8407,\n    8437, 8439, 8518, 8556, 8588, 8597, 8601, 8651, 8657, 8723,\n    8759, 8763, 8785, 8802, 8813, 8826, 8854, 8856, 8866, 8918,\n    8922, 8923, 8932, 8958, 8967, 9003, 9018, 9078, 9095, 9104,\n    9112, 9129, 9147, 9170, 9171, 9197, 9200, 9249, 9253, 9270,\n    9282, 9288, 9295, 9321, 9323, 9324, 9347, 9399, 9403, 9417,\n    9426, 9427, 9439, 9468, 9486, 9496, 9511, 9516, 9518, 9529,\n    9557, 9563, 9564, 9584, 9586, 9591, 9599, 9600, 9601, 9632,\n    9654, 9667, 9678, 9696, 9716, 9723, 9740, 9820, 9824, 9825,\n    9828, 9863, 9866, 9868, 9889, 9929, 9938, 9953, 9967, 10019,\n    10020, 10025, 10059, 10111, 10118, 10125, 10174, 10194,\n    10201, 10202, 10220, 10221, 10226, 10242, 10250, 10276,\n    10295, 10302, 10305, 10327, 10351, 10360, 10369, 10393,\n    10407, 10438, 10455, 10463, 10465, 10470, 10478, 10503,\n    10508, 10509, 10809, 11080, 11331, 11607, 11610, 11864,\n    12390, 12393, 12396, 12399, 12671, 12921, 12930, 13178,\n    13453, 13717, 14499, 14517, 14775, 15297, 15556, 15834,\n    15839, 16126, 16127, 16386, 16633, 16644, 16651, 17166,\n    17169, 17958, 17959, 17962, 18224, 21176, 21180, 21190,\n    21802, 21803, 21806, 22584, 22857, 22858, 22866]\n\n\nclass FlyingChairsOcc(data.Dataset):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=False,\n                 dstype=\"train\"):\n\n        self._args = args\n\n        # -------------------------------------------------------------\n        # filenames for all input images and target flows\n        # -------------------------------------------------------------\n        image1_filenames = sorted(glob(os.path.join(root, \"*_img1.png\")))\n        image2_filenames = sorted(glob(os.path.join(root, \"*_img2.png\")))\n        occ1_filenames = sorted(glob(os.path.join(root, \"*_occ1.png\")))\n        occ2_filenames = sorted(glob(os.path.join(root, \"*_occ2.png\")))\n        flow_f_filenames = sorted(glob(os.path.join(root, \"*_flow.flo\")))\n        flow_b_filenames = sorted(glob(os.path.join(root, \"*_flow_b.flo\")))\n        assert (len(image1_filenames) == len(image2_filenames))\n        assert (len(image2_filenames) == len(occ1_filenames))\n        assert (len(occ1_filenames) == len(occ2_filenames))\n        assert (len(occ2_filenames) == len(flow_f_filenames))\n        assert (len(flow_f_filenames) == len(flow_b_filenames))\n\n        num_flows = len(flow_f_filenames)\n\n        # -------------------------------------------------------------\n        # Remove invalid validation indices\n        # -------------------------------------------------------------\n        validate_indices = [x for x in VALIDATE_INDICES if x in range(num_flows)]\n\n        # ----------------------------------------------------------\n        # Construct list of indices for training/validation\n        # ----------------------------------------------------------\n        list_of_indices = None\n        if dstype == \"train\":\n            list_of_indices = [x for x in range(num_flows) if x not in validate_indices]\n        elif dstype == \"valid\":\n            list_of_indices = validate_indices\n        elif dstype == \"full\":\n            list_of_indices = range(num_flows)\n        else:\n            raise ValueError(\"FlyingChairs: dstype '%s' unknown!\", dstype)\n\n        # ----------------------------------------------------------\n        # Save list of actual filenames for inputs and flows\n        # ----------------------------------------------------------\n        self._image_list = []\n        self._flow_list = []\n        self._occ_list = []\n        for i in list_of_indices:\n            flo_f = flow_f_filenames[i]\n            flo_b = flow_b_filenames[i]\n            im1 = image1_filenames[i]\n            im2 = image2_filenames[i]\n            occ1 = occ1_filenames[i]\n            occ2 = occ2_filenames[i]\n            self._image_list += [[im1, im2]]\n            self._flow_list += [[flo_f, flo_b]]\n            self._occ_list += [[occ1, occ2]]\n        self._size = len(self._image_list)\n        assert len(self._image_list) == len(self._flow_list)\n        assert len(self._occ_list) == len(self._flow_list)\n\n        # ----------------------------------------------------------\n        # photometric_augmentations\n        # ----------------------------------------------------------\n        if photometric_augmentations:\n            self._photometric_transform = transforms.ConcatTransformSplitChainer([\n                # uint8 -> PIL\n                vision_transforms.ToPILImage(),\n                # PIL -> PIL : random hsv and contrast\n                vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),\n                # PIL -> FloatTensor\n                vision_transforms.transforms.ToTensor(),\n                transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),\n            ], from_numpy=True, to_numpy=False)\n\n        else:\n            self._photometric_transform = transforms.ConcatTransformSplitChainer([\n                # uint8 -> FloatTensor\n                vision_transforms.transforms.ToTensor(),\n            ], from_numpy=True, to_numpy=False)\n\n    def __getitem__(self, index):\n        index = index % self._size\n\n        im1_filename = self._image_list[index][0]\n        im2_filename = self._image_list[index][1]\n        flo_f_filename = self._flow_list[index][0]\n        flo_b_filename = self._flow_list[index][1]\n        occ1_filename = self._occ_list[index][0]\n        occ2_filename = self._occ_list[index][1]\n\n        # read float32 images and flow\n        im1_np0 = common.read_image_as_byte(im1_filename)\n        im2_np0 = common.read_image_as_byte(im2_filename)\n        flo_f_np0 = common.read_flo_as_float32(flo_f_filename)\n        flo_b_np0 = common.read_flo_as_float32(flo_b_filename)\n        occ1_np0 = common.read_occ_image_as_float32(occ1_filename)\n        occ2_np0 = common.read_occ_image_as_float32(occ2_filename)\n\n        # possibly apply photometric transformations\n        im1, im2 = self._photometric_transform(im1_np0, im2_np0)\n\n        # convert flow to FloatTensor\n        flo_f = common.numpy2torch(flo_f_np0)\n        flo_b = common.numpy2torch(flo_b_np0)\n\n        # convert occ to FloatTensor\n        occ1 = common.numpy2torch(occ1_np0)\n        occ2 = common.numpy2torch(occ2_np0)\n\n        # example filename\n        basename = os.path.basename(im1_filename)[:5]\n\n        example_dict = {\n            \"input1\": im1,\n            \"input2\": im2,\n            \"target1\": flo_f,\n            \"target2\": flo_b,\n            \"target_occ1\": occ1,\n            \"target_occ2\": occ2,\n            \"index\": index,\n            \"basename\": basename\n        }\n\n        return example_dict\n\n    def __len__(self):\n        return self._size\n\n\nclass FlyingChairsOccTrain(FlyingChairsOcc):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=True):\n        super(FlyingChairsOccTrain, self).__init__(\n            args,\n            root=root,\n            photometric_augmentations=photometric_augmentations,\n            dstype=\"train\")\n\n\nclass FlyingChairsOccValid(FlyingChairsOcc):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=False):\n        super(FlyingChairsOccValid, self).__init__(\n            args,\n            root=root,\n            photometric_augmentations=photometric_augmentations,\n            dstype=\"valid\")\n\n\nclass FlyingChairsOccFull(FlyingChairsOcc):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=False):\n        super(FlyingChairsOccFull, self).__init__(\n            args,\n            root=root,\n            photometric_augmentations=photometric_augmentations,\n            dstype=\"full\")\n"
  },
  {
    "path": "datasets/kitti_combined.py",
    "content": "from __future__ import absolute_import, division, print_function\n\nimport os\nimport torch.utils.data as data\nfrom glob import glob\n\nfrom torchvision import transforms as vision_transforms\n\nfrom . import transforms\nfrom . import common\n\nimport numpy as np\nimport png\n\nVALIDATE_INDICES_2015 = [10, 11, 12, 25, 26, 30, 31, 40, 41, 42, 46, 52, 53, 72, 73, 74, 75, 76, 80, 81, 85, 86, 95, 96, 97, 98, 104, 116, 117, 120, 121, 126, 127, 153, 172, 175, 183, 184, 190, 199]\nVALIDATE_INDICES_2012 = [0, 12, 15, 16, 17, 18, 24, 30, 38, 39, 42, 50, 54, 59, 60, 61, 77, 78, 81, 89, 97, 101, 107, 121, 124, 142, 145, 146, 152, 154, 155, 158, 159, 160, 164, 182, 183, 184, 190]\n\n\ndef read_png_flow(flow_file):\n    flow_object = png.Reader(filename=flow_file)\n    flow_direct = flow_object.asDirect()\n    flow_data = list(flow_direct[2])\n    (w, h) = flow_direct[3]['size']\n    flow = np.zeros((h, w, 3), dtype=np.float64)\n    for i in range(len(flow_data)):\n        flow[i, :, 0] = flow_data[i][0::3]\n        flow[i, :, 1] = flow_data[i][1::3]\n        flow[i, :, 2] = flow_data[i][2::3]\n\n    invalid_idx = (flow[:, :, 2] == 0)\n    flow[:, :, 0:2] = (flow[:, :, 0:2] - 2 ** 15) / 64.0\n    flow[invalid_idx, 0] = 0\n    flow[invalid_idx, 1] = 0\n    return flow[:, :, 0:2], (1 - invalid_idx * 1)[:, :, None]\n\n\ndef kitti_random_crop(im1, im2, flo_f, valid_mask, crop_height=370, crop_width=1224):\n    height, width, _ = im1.shape\n    # get starting positions\n    x = np.random.uniform(0, width - crop_width + 1)\n    y = np.random.uniform(0, height - crop_height + 1)\n    str_x = int(x)\n    str_y = int(y)\n    end_x = int(x + crop_width)\n    end_y = int(y + crop_height)\n\n    im1 = im1[str_y:end_y, str_x:end_x, :]\n    im2 = im2[str_y:end_y, str_x:end_x, :]\n    flo_f = flo_f[str_y:end_y, str_x:end_x, :]\n    valid_mask = valid_mask[str_y:end_y, str_x:end_x, :]\n\n    return im1, im2, flo_f, valid_mask\n\n\nclass Kitti_comb_test(data.Dataset):\n    def __init__(self,\n                 args,\n                 images_root_2015=None,\n                 images_root_2012=None,\n                 photometric_augmentations=False,\n                 preprocessing_crop=True):\n\n        self._args = args\n        self.preprocessing_crop = preprocessing_crop\n\n        list_of_indices_2012 = []\n        list_of_indices_2015 = []\n\n        # ----------------------------------------------------------\n        # KITTI 2015\n        # ----------------------------------------------------------        \n        if images_root_2015 is not None:\n\n            if not os.path.isdir(images_root_2015):\n                raise ValueError(\"Image directory not found! {}\".format(images_root_2015))\n\n            all_img1_2015_filenames = sorted(glob(os.path.join(images_root_2015, \"*_10.png\")))\n            all_img2_2015_filenames = sorted(glob(os.path.join(images_root_2015, \"*_11.png\")))\n            assert len(all_img1_2015_filenames) != 0\n            assert len(all_img2_2015_filenames) == len(all_img1_2015_filenames)\n            list_of_indices_2015 = range(len(all_img1_2015_filenames))           \n\n        # ----------------------------------------------------------\n        # KITTI 2012\n        # ----------------------------------------------------------        \n        if images_root_2012 is not None:\n\n            if not os.path.isdir(images_root_2012):\n                raise ValueError(\"Image directory not found! {}\".format(images_root_2012))\n\n            all_img1_2012_filenames = sorted(glob(os.path.join(images_root_2012, \"*_10.png\")))\n            all_img2_2012_filenames = sorted(glob(os.path.join(images_root_2012, \"*_11.png\")))\n            assert len(all_img1_2012_filenames) != 0\n            assert len(all_img2_2012_filenames) == len(all_img1_2012_filenames)\n            list_of_indices_2012 = range(len(all_img1_2012_filenames))\n\n        # ----------------------------------------------------------\n        # Save list of actual filenames for inputs and flows\n        # ----------------------------------------------------------\n        self._image_list = []\n        self._flow_list = []\n\n        for ii in list_of_indices_2015:\n\n            im1 = all_img1_2015_filenames[ii]\n            im2 = all_img2_2015_filenames[ii]\n            idx1 = os.path.splitext(os.path.basename(im1))[0][:-3]\n            idx2 = os.path.splitext(os.path.basename(im2))[0][:-3]\n            assert idx1 == idx2\n\n            if not os.path.isfile(im1) or not os.path.isfile(im2):\n                continue\n\n            self._image_list += [[im1, im2]]\n\n\n        for ii in list_of_indices_2012:\n\n            im1 = all_img1_2012_filenames[ii]\n            im2 = all_img2_2012_filenames[ii]\n            idx1 = os.path.splitext(os.path.basename(im1))[0][:-3]\n            idx2 = os.path.splitext(os.path.basename(im2))[0][:-3]\n            assert idx1 == idx2\n\n            if not os.path.isfile(im1) or not os.path.isfile(im2):\n                continue\n\n            self._image_list += [[im1, im2]]\n\n        self._size = len(self._image_list)\n\n        assert len(self._image_list) != 0\n\n\n        # ----------------------------------------------------------\n        # photometric_augmentations\n        # ----------------------------------------------------------\n        if photometric_augmentations:\n            self._photometric_transform = transforms.ConcatTransformSplitChainer([\n                # uint8 -> PIL\n                vision_transforms.ToPILImage(),\n                # PIL -> PIL : random hsv and contrast\n                vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),\n                # PIL -> FloatTensor\n                vision_transforms.transforms.ToTensor(),\n                transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),\n            ], from_numpy=True, to_numpy=False)\n\n        else:\n            self._photometric_transform = transforms.ConcatTransformSplitChainer([\n                # uint8 -> FloatTensor\n                vision_transforms.transforms.ToTensor(),\n            ], from_numpy=True, to_numpy=False)\n\n\n    def __getitem__(self, index):\n        index = index % self._size\n\n        im1_filename = self._image_list[index][0]\n        im2_filename = self._image_list[index][1]\n\n        # read float32 images and flow\n        im1_np0 = common.read_image_as_byte(im1_filename)\n        im2_np0 = common.read_image_as_byte(im2_filename)\n\n        # possibly apply photometric transformations\n        im1, im2 = self._photometric_transform(im1_np0, im2_np0)\n\n        # example filename\n        basename = os.path.basename(im1_filename)[:6]\n\n        example_dict = {\n            \"input1\": im1,\n            \"input2\": im2,\n            \"index\": index,\n            \"basename\": basename\n        }\n\n        return example_dict\n\n    def __len__(self):\n        return self._size\n\n\nclass Kitti_comb(data.Dataset):\n    def __init__(self,\n                 args,\n                 images_root_2015=None,\n                 flow_root_2015=None,\n                 images_root_2012=None,\n                 flow_root_2012=None,\n                 photometric_augmentations=False,\n                 preprocessing_crop=True,\n                 dstype=\"full\"):\n\n        self._args = args\n        self.preprocessing_crop = preprocessing_crop\n\n        list_of_indices_2012 = []\n        list_of_indices_2015 = []\n\n        # ----------------------------------------------------------\n        # KITTI 2015\n        # ----------------------------------------------------------        \n        if images_root_2015 is not None and flow_root_2015 is not None:\n\n            if not os.path.isdir(images_root_2015):\n                raise ValueError(\"Image directory not found!  {}\".format(images_root_2015))\n            if not os.path.isdir(flow_root_2015):\n                raise ValueError(\"Flow directory not found!  {}\".format(flow_root_2015))\n\n            all_img1_2015_filenames = sorted(glob(os.path.join(images_root_2015, \"*_10.png\")))\n            all_img2_2015_filenames = sorted(glob(os.path.join(images_root_2015, \"*_11.png\")))            \n            flow_f_2015_filenames = sorted(glob(os.path.join(flow_root_2015, \"*_10.png\")))\n            assert len(all_img1_2015_filenames) != 0\n            assert len(all_img2_2015_filenames) == len(all_img1_2015_filenames)\n            assert len(flow_f_2015_filenames) == len(all_img1_2015_filenames)\n            num_flows_2015 = len(flow_f_2015_filenames)           \n            validate_indices_2015 = [x for x in VALIDATE_INDICES_2015 if x in range(num_flows_2015)]\n\n            if dstype == \"train\":\n                list_of_indices_2015 = [x for x in range(num_flows_2015) if x not in validate_indices_2015]\n            elif dstype == \"valid\":\n                list_of_indices_2015 = validate_indices_2015\n            elif dstype == \"full\":\n                list_of_indices_2015 = range(len(all_img1_2015_filenames))\n            else:\n                raise ValueError(\"KITTI 2015: dstype '%s' unknown!\", dstype)\n\n\n        # ----------------------------------------------------------\n        # KITTI 2012\n        # ----------------------------------------------------------        \n        if images_root_2012 is not None:\n\n            if not os.path.isdir(images_root_2012):\n                raise ValueError(\"Image directory '%s' not found!\")\n            if not os.path.isdir(flow_root_2012):\n                raise ValueError(\"Flow directory '%s' not found!\")\n\n            all_img1_2012_filenames = sorted(glob(os.path.join(images_root_2012, \"*_10.png\")))\n            all_img2_2012_filenames = sorted(glob(os.path.join(images_root_2012, \"*_11.png\")))            \n            flow_f_2012_filenames = sorted(glob(os.path.join(flow_root_2012, \"*_10.png\")))\n            assert len(all_img1_2012_filenames) != 0\n            assert len(all_img2_2012_filenames) == len(all_img1_2012_filenames)\n            assert len(flow_f_2012_filenames) == len(all_img1_2012_filenames)\n            num_flows_2012 = len(flow_f_2012_filenames)           \n            validate_indices_2012 = [x for x in VALIDATE_INDICES_2012 if x in range(num_flows_2012)]\n\n            if dstype == \"train\":\n                list_of_indices_2012 = [x for x in range(num_flows_2012) if x not in validate_indices_2012]\n            elif dstype == \"valid\":\n                list_of_indices_2012 = validate_indices_2012\n            elif dstype == \"full\":\n                list_of_indices_2012 = range(len(all_img1_2012_filenames))\n            else:\n                raise ValueError(\"KITTI 2012: dstype '%s' unknown!\", dstype)\n\n\n        # ----------------------------------------------------------\n        # Save list of actual filenames for inputs and flows\n        # ----------------------------------------------------------\n        self._image_list = []\n        self._flow_list = []\n\n        for ii in list_of_indices_2015:\n\n            im1 = all_img1_2015_filenames[ii]\n            im2 = all_img2_2015_filenames[ii]\n            idx1 = os.path.splitext(os.path.basename(im1))[0][:-3]\n            idx2 = os.path.splitext(os.path.basename(im2))[0][:-3]\n            assert idx1 == idx2\n\n            if not os.path.isfile(im1) or not os.path.isfile(im2):\n                continue\n\n            self._image_list += [[im1, im2]]\n\n            if dstype is not \"test\":\n                flo_f = flow_f_2015_filenames[ii]\n                idx_f = os.path.splitext(os.path.basename(flo_f))[0][:-3]\n                assert idx1 == idx_f                \n                if not os.path.isfile(flo_f):\n                    continue\n                self._flow_list += [[flo_f]]\n\n\n        for ii in list_of_indices_2012:\n\n            im1 = all_img1_2012_filenames[ii]\n            im2 = all_img2_2012_filenames[ii]\n            idx1 = os.path.splitext(os.path.basename(im1))[0][:-3]\n            idx2 = os.path.splitext(os.path.basename(im2))[0][:-3]\n            assert idx1 == idx2\n\n            if not os.path.isfile(im1) or not os.path.isfile(im2):\n                continue\n\n            self._image_list += [[im1, im2]]\n\n            if dstype is not \"test\":\n                flo_f = flow_f_2012_filenames[ii]\n                idx_f = os.path.splitext(os.path.basename(flo_f))[0][:-3]\n                assert idx1 == idx_f                \n                if not os.path.isfile(flo_f):\n                    continue\n                self._flow_list += [[flo_f]]\n\n\n        self._size = len(self._image_list)\n\n        assert len(self._image_list) != 0\n        if dstype is not \"test\":\n            assert len(self._image_list) == len(self._flow_list)\n\n        # ----------------------------------------------------------\n        # photometric_augmentations\n        # ----------------------------------------------------------\n        if photometric_augmentations:\n            self._photometric_transform = transforms.ConcatTransformSplitChainer([\n                # uint8 -> PIL\n                vision_transforms.ToPILImage(),\n                # PIL -> PIL : random hsv and contrast\n                vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),\n                # PIL -> FloatTensor\n                vision_transforms.transforms.ToTensor(),\n                transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),\n            ], from_numpy=True, to_numpy=False)\n\n        else:\n            self._photometric_transform = transforms.ConcatTransformSplitChainer([\n                # uint8 -> FloatTensor\n                vision_transforms.transforms.ToTensor(),\n            ], from_numpy=True, to_numpy=False)\n\n\n    def __getitem__(self, index):\n        index = index % self._size\n\n        im1_filename = self._image_list[index][0]\n        im2_filename = self._image_list[index][1]\n        flo_f_filename = self._flow_list[index][0]\n\n        # read float32 images and flow\n        im1_np0 = common.read_image_as_byte(im1_filename)\n        im2_np0 = common.read_image_as_byte(im2_filename)\n        flo_f_np0, valid_mask = read_png_flow(flo_f_filename)\n\n        if self.preprocessing_crop:\n            im1_np0, im2_np0, flo_f_np0, valid_mask = kitti_random_crop(im1_np0, im2_np0, flo_f_np0, valid_mask)\n\n        # possibly apply photometric transformations\n        im1, im2 = self._photometric_transform(im1_np0, im2_np0)\n\n        # convert flow to FloatTensor\n        flo_f = common.numpy2torch(flo_f_np0)\n        valid_mask_f = common.numpy2torch(valid_mask)\n\n        # example filename\n        basename = os.path.basename(im1_filename)[:6]\n\n        example_dict = {\n            \"input1\": im1,\n            \"input2\": im2,\n            \"target1\": flo_f,\n            \"target2\": flo_f,\n            \"index\": index,\n            \"basename\": basename,\n            \"input_valid\": valid_mask_f\n        }\n\n        return example_dict\n\n    def __len__(self):\n        return self._size\n\n\nclass KittiCombTrain(Kitti_comb):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=True,\n                 preprocessing_crop=True):\n        images_root_2015 = os.path.join(root, \"data_scene_flow\", \"training\", \"image_2\")\n        flow_root_2015 = os.path.join(root, \"data_scene_flow\", \"training\", \"flow_occ\")\n        images_root_2012 = os.path.join(root, \"data_stereo_flow\", \"training\", \"colored_0\")\n        flow_root_2012 = os.path.join(root, \"data_stereo_flow\",  \"training\", \"flow_occ\")\n        super(KittiCombTrain, self).__init__(\n            args,\n            images_root_2015=images_root_2015,\n            flow_root_2015=flow_root_2015,\n            images_root_2012=images_root_2012,\n            flow_root_2012=flow_root_2012,\n            photometric_augmentations=photometric_augmentations,\n            preprocessing_crop=preprocessing_crop,\n            dstype=\"train\")\n\n\nclass KittiCombVal(Kitti_comb):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=False,\n                 preprocessing_crop=False):\n        images_root_2015 = os.path.join(root, \"data_scene_flow\", \"training\", \"image_2\")\n        flow_root_2015 = os.path.join(root, \"data_scene_flow\", \"training\", \"flow_occ\")\n        images_root_2012 = os.path.join(root, \"data_stereo_flow\", \"training\", \"colored_0\")\n        flow_root_2012 = os.path.join(root, \"data_stereo_flow\",  \"training\", \"flow_occ\")\n        super(KittiCombVal, self).__init__(\n            args,\n            images_root_2015=images_root_2015,\n            flow_root_2015=flow_root_2015,\n            images_root_2012=images_root_2012,\n            flow_root_2012=flow_root_2012,\n            photometric_augmentations=photometric_augmentations,\n            preprocessing_crop=preprocessing_crop,\n            dstype=\"valid\")\n\n\nclass KittiCombFull(Kitti_comb):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=True,\n                 preprocessing_crop=True):\n        images_root_2015 = os.path.join(root, \"data_scene_flow\", \"training\", \"image_2\")\n        flow_root_2015 = os.path.join(root, \"data_scene_flow\", \"training\", \"flow_occ\")\n        images_root_2012 = os.path.join(root, \"data_stereo_flow\", \"training\", \"colored_0\")\n        flow_root_2012 = os.path.join(root, \"data_stereo_flow\",  \"training\", \"flow_occ\")\n        super(KittiCombFull, self).__init__(\n            args,\n            images_root_2015=images_root_2015,\n            flow_root_2015=flow_root_2015,\n            images_root_2012=images_root_2012,\n            flow_root_2012=flow_root_2012,\n            photometric_augmentations=photometric_augmentations,\n            preprocessing_crop=preprocessing_crop,\n            dstype=\"full\")\n\n\nclass KittiComb2015Train(Kitti_comb):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=True,\n                 preprocessing_crop=True):\n        images_root_2015 = os.path.join(root, \"data_scene_flow\", \"training\", \"image_2\")\n        flow_root_2015 = os.path.join(root, \"data_scene_flow\", \"training\", \"flow_occ\")\n        super(KittiComb2015Train, self).__init__(\n            args,\n            images_root_2015=images_root_2015,\n            flow_root_2015=flow_root_2015,\n            photometric_augmentations=photometric_augmentations,\n            preprocessing_crop=preprocessing_crop,\n            dstype=\"train\")\n\n\nclass KittiComb2015Val(Kitti_comb):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=False,\n                 preprocessing_crop=False):\n        images_root_2015 = os.path.join(root, \"data_scene_flow\", \"training\", \"image_2\")\n        flow_root_2015 = os.path.join(root, \"data_scene_flow\", \"training\", \"flow_occ\")\n        super(KittiComb2015Val, self).__init__(\n            args,\n            images_root_2015=images_root_2015,\n            flow_root_2015=flow_root_2015,\n            photometric_augmentations=photometric_augmentations,\n            preprocessing_crop=preprocessing_crop,\n            dstype=\"valid\")\n\n\nclass KittiComb2015Full(Kitti_comb):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=True,\n                 preprocessing_crop=True):\n        images_root_2015 = os.path.join(root, \"data_scene_flow\", \"training\", \"image_2\")\n        flow_root_2015 = os.path.join(root, \"data_scene_flow\", \"training\", \"flow_occ\")\n        super(KittiComb2015Full, self).__init__(\n            args,\n            images_root_2015=images_root_2015,\n            flow_root_2015=flow_root_2015,\n            photometric_augmentations=photometric_augmentations,\n            preprocessing_crop=preprocessing_crop,\n            dstype=\"full\")\n\n\nclass KittiComb2015Test(Kitti_comb_test):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=False,\n                 preprocessing_crop=False):\n        images_root_2015 = os.path.join(root, \"data_scene_flow\", \"testing\", \"image_2\")\n        super(KittiComb2015Test, self).__init__(\n            args,\n            images_root_2015=images_root_2015,\n            photometric_augmentations=photometric_augmentations,\n            preprocessing_crop=preprocessing_crop)\n\n\nclass KittiComb2012Train(Kitti_comb):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=True,\n                 preprocessing_crop=True):\n        images_root_2012 = os.path.join(root, \"data_stereo_flow\", \"training\", \"colored_0\")\n        flow_root_2012 = os.path.join(root, \"data_stereo_flow\",  \"training\", \"flow_occ\")\n        super(KittiComb2012Train, self).__init__(\n            args,\n            images_root_2012=images_root_2012,\n            flow_root_2012=flow_root_2012,\n            photometric_augmentations=photometric_augmentations,\n            preprocessing_crop=preprocessing_crop,\n            dstype=\"train\")\n\n\nclass KittiComb2012Val(Kitti_comb):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=False,\n                 preprocessing_crop=False):\n        images_root_2012 = os.path.join(root, \"data_stereo_flow\", \"training\", \"colored_0\")\n        flow_root_2012 = os.path.join(root, \"data_stereo_flow\",  \"training\", \"flow_occ\")\n        super(KittiComb2012Val, self).__init__(\n            args,\n            images_root_2012=images_root_2012,\n            flow_root_2012=flow_root_2012,\n            photometric_augmentations=photometric_augmentations,\n            preprocessing_crop=preprocessing_crop,\n            dstype=\"valid\")\n\n\nclass KittiComb2012Full(Kitti_comb):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=True,\n                 preprocessing_crop=True):\n        images_root_2012 = os.path.join(root, \"data_stereo_flow\", \"training\", \"colored_0\")\n        flow_root_2012 = os.path.join(root, \"data_stereo_flow\",  \"training\", \"flow_occ\")\n        super(KittiComb2012Full, self).__init__(\n            args,\n            images_root_2012=images_root_2012,\n            flow_root_2012=flow_root_2012,\n            photometric_augmentations=photometric_augmentations,\n            preprocessing_crop=preprocessing_crop,\n            dstype=\"full\")\n\n\nclass KittiComb2012Test(Kitti_comb_test):\n    def __init__(self,\n                 args,\n                 root,\n                 photometric_augmentations=False,\n                 preprocessing_crop=False):\n        images_root_2012 = os.path.join(root, \"data_stereo_flow\", \"testing\", \"colored_0\")\n        super(KittiComb2012Test, self).__init__(\n            args,\n            images_root_2012=images_root_2012,\n            photometric_augmentations=photometric_augmentations,\n            preprocessing_crop=preprocessing_crop)"
  },
  {
    "path": "datasets/sintel.py",
    "content": "from __future__ import absolute_import, division, print_function\n\nimport os\nimport torch.utils.data as data\nfrom glob import glob\n\nfrom torchvision import transforms as vision_transforms\n\nfrom . import transforms\nfrom . import common\n\nimport tools\n\n\nVALIDATE_INDICES = [\n    199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,\n    211, 212, 213, 214, 215, 216, 217, 340, 341, 342, 343, 344,\n    345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356,\n    357, 358, 359, 360, 361, 362, 363, 364, 536, 537, 538, 539,\n    540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551,\n    552, 553, 554, 555, 556, 557, 558, 559, 560, 659, 660, 661,\n    662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673,\n    674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685,\n    686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697,\n    967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978,\n    979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990,\n    991]\n\n\nclass _Sintel(data.Dataset):\n    def __init__(self,\n                 args,\n                 dir_root=None,\n                 photometric_augmentations=False,\n                 imgtype=None,\n                 dstype=None):\n\n        self._args = args\n\n        images_root = os.path.join(dir_root, imgtype)\n        if imgtype is \"comb\":\n            images_root = os.path.join(dir_root, \"clean\")\n        flow_root = os.path.join(dir_root, \"flow\")\n        occ_root = os.path.join(dir_root, \"occlusions_rev\")\n\n        if not os.path.isdir(images_root):\n            raise ValueError(\"Image directory '%s' not found!\")\n        if flow_root is not None and not os.path.isdir(flow_root):\n            raise ValueError(\"Flow directory '%s' not found!\")\n        if occ_root is not None and not os.path.isdir(occ_root):\n            raise ValueError(\"Occ directory '%s' not found!\")\n        \n        all_flo_filenames = sorted(glob(os.path.join(flow_root, \"*/*.flo\")))\n        all_occ_filenames = sorted(glob(os.path.join(occ_root, \"*/*.png\")))\n        all_img_filenames = sorted(glob(os.path.join(images_root, \"*/*.png\")))\n\n        # Remember base for substraction at runtime\n        # e.g. subtract_base = \"/home/user/.../MPI-Sintel-Complete/training/clean\"\n        self._substract_base = tools.cd_dotdot(images_root)\n\n        # ------------------------------------------------------------------------\n        # Get unique basenames\n        # ------------------------------------------------------------------------\n        # e.g. base_folders = [alley_1\", \"alley_2\", \"ambush_2\", ...]\n        substract_full_base = tools.cd_dotdot(all_img_filenames[0])\n        base_folders = sorted(list(set([\n            os.path.dirname(fn.replace(substract_full_base, \"\"))[1:] for fn in all_img_filenames\n        ])))\n\n        self._image_list = []\n        self._flow_list = []\n        self._occ_list = []\n\n        for base_folder in base_folders:            \n            img_filenames = [x for x in all_img_filenames if base_folder in x]\n            flo_filenames = [x for x in all_flo_filenames if base_folder in x]\n            occ_filenames = [x for x in all_occ_filenames if base_folder in x]\n\n            for i in range(len(img_filenames) - 1):\n\n                im1 = img_filenames[i]\n                im2 = img_filenames[i + 1]\n                flo = flo_filenames[i]\n                occ = occ_filenames[i]\n\n                self._image_list += [[im1, im2]]\n                self._flow_list += [flo]\n                self._occ_list += [occ]\n\n                # Sanity check\n                im1_base_filename = os.path.splitext(os.path.basename(im1))[0]\n                im2_base_filename = os.path.splitext(os.path.basename(im2))[0]\n                flo_base_filename = os.path.splitext(os.path.basename(flo))[0]\n                occ_base_filename = os.path.splitext(os.path.basename(occ))[0]\n                im1_frame, im1_no = im1_base_filename.split(\"_\")\n                im2_frame, im2_no = im2_base_filename.split(\"_\")\n                assert(im1_frame == im2_frame)\n                assert(int(im1_no) == int(im2_no) - 1)\n                \n                flo_frame, flo_no = flo_base_filename.split(\"_\")\n                assert(im1_frame == flo_frame)\n                assert(int(im1_no) == int(flo_no))\n                \n                occ_frame, occ_no = occ_base_filename.split(\"_\")\n                assert(im1_frame == occ_frame)\n                assert(int(im1_no) == int(occ_no))\n        \n        assert len(self._image_list) == len(self._flow_list)        \n        assert len(self._image_list) == len(self._occ_list)\n\n        # -------------------------------------------------------------\n        # Remove invalid validation indices\n        # -------------------------------------------------------------\n        full_num_examples = len(self._image_list)\n        validate_indices = [x for x in VALIDATE_INDICES if x in range(full_num_examples)]\n\n        # ----------------------------------------------------------\n        # Construct list of indices for training/validation\n        # ----------------------------------------------------------\n        list_of_indices = None\n        if dstype == \"train\":\n            list_of_indices = [x for x in range(full_num_examples) if x not in validate_indices]\n        elif dstype == \"valid\":\n            list_of_indices = validate_indices\n        elif dstype == \"full\":\n            list_of_indices = range(full_num_examples)\n        else:\n            raise ValueError(\"dstype '%s' unknown!\", dstype)\n\n        # ----------------------------------------------------------\n        # Save list of actual filenames for inputs and flows\n        # ----------------------------------------------------------\n        self._image_list = [self._image_list[i] for i in list_of_indices]\n        self._flow_list = [self._flow_list[i] for i in list_of_indices]\n        self._occ_list = [self._occ_list[i] for i in list_of_indices]\n\n        if imgtype is \"comb\":\n            image_list_final = [[val[0].replace(\"clean\", \"final\"), val[1].replace(\"clean\", \"final\")] for idx, val in enumerate(self._image_list)]\n            self._image_list += image_list_final\n            self._flow_list += self._flow_list\n            self._occ_list += self._occ_list\n\n        assert len(self._image_list) == len(self._flow_list)\n        assert len(self._image_list) == len(self._occ_list)\n\n        # ----------------------------------------------------------\n        # photometric_augmentations\n        # ----------------------------------------------------------\n        if photometric_augmentations:\n            self._photometric_transform = transforms.ConcatTransformSplitChainer([\n                # uint8 -> PIL\n                vision_transforms.ToPILImage(),\n                # PIL -> PIL : random hsv and contrast\n                vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),\n                # PIL -> FloatTensor\n                vision_transforms.transforms.ToTensor(),\n                transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),\n            ], from_numpy=True, to_numpy=False)\n\n        else:\n            self._photometric_transform = transforms.ConcatTransformSplitChainer([\n                # uint8 -> FloatTensor\n                vision_transforms.transforms.ToTensor(),\n            ], from_numpy=True, to_numpy=False)\n\n        self._size = len(self._image_list)\n\n    def __getitem__(self, index):\n        index = index % self._size\n\n        im1_filename = self._image_list[index][0]\n        im2_filename = self._image_list[index][1]\n        flo_filename = self._flow_list[index]\n        occ_filename = self._occ_list[index]\n\n        # read float32 images and flow\n        im1_np0 = common.read_image_as_byte(im1_filename)\n        im2_np0 = common.read_image_as_byte(im2_filename)\n        flo_np0 = common.read_flo_as_float32(flo_filename)\n        occ_np0 = common.read_occ_image_as_float32(occ_filename)\n\n        # possibly apply photometric transformations\n        im1, im2 = self._photometric_transform(im1_np0, im2_np0)\n        flo = common.numpy2torch(flo_np0)\n        occ = common.numpy2torch(occ_np0)\n\n        # e.g. \"clean/alley_1/\"\n        basedir = os.path.splitext(os.path.dirname(im1_filename).replace(self._substract_base, \"\")[1:])[0]\n\n        # example filename\n        basename = os.path.splitext(os.path.basename(im1_filename))[0]\n\n        example_dict = {\n            \"input1\": im1,\n            \"input2\": im2,\n            \"index\": index,\n            \"basedir\": basedir,\n            \"basename\": basename,\n            \"target1\": flo,\n            \"target_occ1\": occ\n        }\n\n        return example_dict\n\n    def __len__(self):\n        return self._size\n\n\nclass _Sintel_test(data.Dataset):\n    def __init__(self,\n                 args,\n                 dir_root=None,\n                 photometric_augmentations=False,\n                 imgtype=None):\n\n        self._args = args\n        images_root = os.path.join(dir_root, imgtype)\n        if not os.path.isdir(images_root):\n            raise ValueError(\"Image directory '%s' not found!\")\n\n        all_img_filenames = sorted(glob(os.path.join(images_root, \"*/*.png\")))\n\n        # Remember base for substraction at runtime\n        # e.g. subtract_base = \"/home/user/.../MPI-Sintel-Complete/training/clean\"\n        self._substract_base = tools.cd_dotdot(images_root)\n\n        # ------------------------------------------------------------------------\n        # Get unique basenames\n        # ------------------------------------------------------------------------\n        # e.g. base_folders = [alley_1\", \"alley_2\", \"ambush_2\", ...]\n        substract_full_base = tools.cd_dotdot(all_img_filenames[0])\n        base_folders = sorted(list(set([\n            os.path.dirname(fn.replace(substract_full_base, \"\"))[1:] for fn in all_img_filenames\n        ])))\n\n        self._image_list = []\n\n        for base_folder in base_folders:            \n            img_filenames = [x for x in all_img_filenames if base_folder in x]\n\n            for i in range(len(img_filenames) - 1):\n\n                im1 = img_filenames[i]\n                im2 = img_filenames[i + 1]\n                self._image_list += [[im1, im2]]\n\n                # Sanity check\n                im1_base_filename = os.path.splitext(os.path.basename(im1))[0]\n                im2_base_filename = os.path.splitext(os.path.basename(im2))[0]\n                im1_frame, im1_no = im1_base_filename.split(\"_\")\n                im2_frame, im2_no = im2_base_filename.split(\"_\")\n                assert(im1_frame == im2_frame)\n                assert(int(im1_no) == int(im2_no) - 1)                \n\n        full_num_examples = len(self._image_list)\n        list_of_indices = range(full_num_examples)\n\n        # ----------------------------------------------------------\n        # Save list of actual filenames for inputs and flows\n        # ----------------------------------------------------------\n        self._image_list = [self._image_list[i] for i in list_of_indices]\n        \n        # ----------------------------------------------------------\n        # photometric_augmentations\n        # ----------------------------------------------------------\n        if photometric_augmentations:\n            self._photometric_transform = transforms.ConcatTransformSplitChainer([\n                # uint8 -> PIL\n                vision_transforms.ToPILImage(),\n                # PIL -> PIL : random hsv and contrast\n                vision_transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),\n                # PIL -> FloatTensor\n                vision_transforms.transforms.ToTensor(),\n                transforms.RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),\n            ], from_numpy=True, to_numpy=False)\n\n        else:\n            self._photometric_transform = transforms.ConcatTransformSplitChainer([\n                # uint8 -> FloatTensor\n                vision_transforms.transforms.ToTensor(),\n            ], from_numpy=True, to_numpy=False)\n\n        self._size = len(self._image_list)\n\n    def __getitem__(self, index):\n        index = index % self._size\n\n        im1_filename = self._image_list[index][0]\n        im2_filename = self._image_list[index][1]\n\n        # read float32 images and flow\n        im1_np0 = common.read_image_as_byte(im1_filename)\n        im2_np0 = common.read_image_as_byte(im2_filename)\n\n        # possibly apply photometric transformations\n        im1, im2 = self._photometric_transform(im1_np0, im2_np0)\n\n        # e.g. \"clean/alley_1/\"\n        basedir = os.path.splitext(os.path.dirname(im1_filename).replace(self._substract_base, \"\")[1:])[0]\n\n        # example filename\n        basename = os.path.splitext(os.path.basename(im1_filename))[0]\n\n        example_dict = {\n            \"input1\": im1,\n            \"input2\": im2,\n            \"index\": index,\n            \"basedir\": basedir,\n            \"basename\": basename\n        }\n\n        return example_dict\n\n    def __len__(self):\n        return self._size\n\n\nclass SintelTrainingCleanTrain(_Sintel):\n    def __init__(self, args, root, photometric_augmentations=True):\n        dir_root = os.path.join(root, \"training\")\n        super(SintelTrainingCleanTrain, self).__init__(\n            args,\n            dir_root=dir_root,\n            photometric_augmentations=photometric_augmentations,\n            imgtype=\"clean\",\n            dstype=\"train\")\n\n\nclass SintelTrainingCleanValid(_Sintel):\n    def __init__(self, args, root, photometric_augmentations=False):\n        dir_root = os.path.join(root, \"training\")\n        super(SintelTrainingCleanValid, self).__init__(\n            args,\n            dir_root=dir_root,\n            photometric_augmentations=photometric_augmentations,\n            imgtype=\"clean\",\n            dstype=\"valid\")\n\n\nclass SintelTrainingCleanFull(_Sintel):\n    def __init__(self, args, root, photometric_augmentations=True):\n        dir_root = os.path.join(root, \"training\")\n        super(SintelTrainingCleanFull, self).__init__(\n            args,\n            dir_root=dir_root,\n            photometric_augmentations=photometric_augmentations,\n            imgtype=\"clean\",\n            dstype=\"full\")\n\n\nclass SintelTrainingFinalTrain(_Sintel):\n    def __init__(self, args, root, photometric_augmentations=True):\n        dir_root = os.path.join(root, \"training\")\n        super(SintelTrainingFinalTrain, self).__init__(\n            args,\n            dir_root=dir_root,\n            photometric_augmentations=photometric_augmentations,\n            imgtype=\"final\",\n            dstype=\"train\")\n\n\nclass SintelTrainingFinalValid(_Sintel):\n    def __init__(self, args, root, photometric_augmentations=False):\n        dir_root = os.path.join(root, \"training\")\n        super(SintelTrainingFinalValid, self).__init__(\n            args,\n            dir_root=dir_root,\n            photometric_augmentations=photometric_augmentations,\n            imgtype=\"final\",\n            dstype=\"valid\")\n\n\nclass SintelTrainingFinalFull(_Sintel):\n    def __init__(self, args, root, photometric_augmentations=True):\n        dir_root = os.path.join(root, \"training\")\n        super(SintelTrainingFinalFull, self).__init__(\n            args,\n            dir_root=dir_root,\n            photometric_augmentations=photometric_augmentations,\n            imgtype=\"final\",\n            dstype=\"full\")\n\n\nclass SintelTrainingCombTrain(_Sintel):\n    def __init__(self, args, root, photometric_augmentations=True):\n        dir_root = os.path.join(root, \"training\")\n        super(SintelTrainingCombTrain, self).__init__(\n            args,\n            dir_root=dir_root,\n            photometric_augmentations=photometric_augmentations,\n            imgtype=\"comb\",\n            dstype=\"train\")\n\n\nclass SintelTrainingCombValid(_Sintel):\n    def __init__(self, args, root, photometric_augmentations=False):\n        dir_root = os.path.join(root, \"training\")\n        super(SintelTrainingCombValid, self).__init__(\n            args,\n            dir_root=dir_root,\n            photometric_augmentations=photometric_augmentations,\n            imgtype=\"comb\",\n            dstype=\"valid\")\n\n\nclass SintelTrainingCombFull(_Sintel):\n    def __init__(self, args, root, photometric_augmentations=True):\n        dir_root = os.path.join(root, \"training\")\n        super(SintelTrainingCombFull, self).__init__(\n            args,\n            dir_root=dir_root,\n            photometric_augmentations=photometric_augmentations,\n            imgtype=\"comb\",\n            dstype=\"full\")\n\n\nclass SintelTestClean(_Sintel_test):\n    def __init__(self, args, root, photometric_augmentations=False):\n        dir_root = os.path.join(root, \"test\")\n        super(SintelTestClean, self).__init__(\n            args,\n            dir_root=dir_root,\n            photometric_augmentations=photometric_augmentations,\n            imgtype=\"clean\")\n\n\nclass SintelTestFinal(_Sintel_test):\n    def __init__(self, args, root, photometric_augmentations=False):\n        dir_root = os.path.join(root, \"test\")\n        super(SintelTestFinal, self).__init__(\n            args,\n            dir_root=dir_root,\n            photometric_augmentations=photometric_augmentations,\n            imgtype=\"final\")\n"
  },
  {
    "path": "datasets/transforms.py",
    "content": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\nimport numpy as np\nimport torch\n\n\ndef image_random_gamma(image, min_gamma=0.7, max_gamma=1.5, clip_image=False):\n    gamma = np.random.uniform(min_gamma, max_gamma)\n    adjusted = torch.pow(image, gamma)\n    if clip_image:\n        adjusted.clamp_(0.0, 1.0)\n    return adjusted\n\n\nclass RandomGamma:\n    def __init__(self, min_gamma=0.7, max_gamma=1.5, clip_image=False):\n        self._min_gamma = min_gamma\n        self._max_gamma = max_gamma\n        self._clip_image = clip_image\n\n    def __call__(self, image):\n        return image_random_gamma(\n            image,\n            min_gamma=self._min_gamma,\n            max_gamma=self._max_gamma,\n            clip_image=self._clip_image)\n\n\n# ------------------------------------------------------------------\n# Allow transformation chains of the type:\n#   im1, im2, .... = transform(im1, im2, ...)\n# ------------------------------------------------------------------\nclass TransformChainer:\n    def __init__(self, list_of_transforms):\n        self._list_of_transforms = list_of_transforms\n\n    def __call__(self, *args):\n        list_of_args = list(args)\n        for transform in self._list_of_transforms:\n            list_of_args = [transform(arg) for arg in list_of_args]\n        if len(args) == 1:\n            return list_of_args[0]\n        else:\n            return list_of_args\n\n\n# ------------------------------------------------------------------\n# Allow transformation chains of the type:\n#   im1, im2, .... = split( transform( concatenate(im1, im2, ...) ))\n# ------------------------------------------------------------------\nclass ConcatTransformSplitChainer:\n    def __init__(self, list_of_transforms, from_numpy=True, to_numpy=False):\n        self._chainer = TransformChainer(list_of_transforms)\n        self._from_numpy = from_numpy\n        self._to_numpy = to_numpy\n\n    def __call__(self, *args):\n        num_splits = len(args)\n\n        if self._from_numpy:\n            concatenated = np.concatenate(args, axis=0)\n        else:\n            concatenated = torch.cat(args, dim=1)\n\n        transformed = self._chainer(concatenated)\n\n        if self._to_numpy:\n            split = np.split(transformed, indices_or_sections=num_splits, axis=0)\n        else:\n            split = torch.chunk(transformed, num_splits, dim=1)\n\n        return split\n"
  },
  {
    "path": "flyingchairsocc/README.md",
    "content": "# FlyingChairsOcc dataset\n\n<img src=demo_img.png>\n\nThe FlyingChairsOcc dataset is an extended version of the <a href=\"https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html\" target=\"_blank\">Flying Chairs Dataset</a>, including bi-directional optical flow ground truth and two occlusion maps for each image.\nYou may also find that another concurrent dataset, the <a href=\"https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html\" target=\"_blank\">Flying Chairs 2 Dataset</a>, is useful.\n\n\n## License agreement\n\nThis dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree:\n\n1. That the dataset comes “AS IS”, without express or implied warranty. Although every effort has been made to ensure accuracy, we (TU Darmstadt) do not accept any responsibility for errors or omissions.\n2. That you include a reference to the FlyingChairsOcc Dataset in any work that makes use of the dataset. For research papers, cite our publication: J. Hur and S. Roth, “Iterative residual refinement for joint optical flow and occlusion estimation,” in Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Long Beach, California, June 2019\n3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character.\n4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain.\n5. That all rights not expressly granted to you are reserved by us (TU Darmstadt).\n\n\n## Download link\n\n<a href=\"https://download.visinf.tu-darmstadt.de/data/flyingchairs_occ/FlyingChairsOcc.tar.gz\" target=\"_blank\"><b>Download</b></a>\n(82GB)\n\n\n## Reference\n\nPlease cite the paper below if you find the dataset and source codes are useful.  \n\n    @inproceedings{Hur:2019:IRR,  \n      Author = {Junhwa Hur and Stefan Roth},  \n      Booktitle = {CVPR},  \n      Title = {Iterative Residual Refinement for Joint Optical Flow and Occlusion Estimation},  \n      Year = {2019}  \n    }\n\nContact: junhwa.hur[at]visinf.tu-darmstadt.de\n"
  },
  {
    "path": "install.sh",
    "content": "#!/bin/bash\ncd ./models/correlation_package\npython setup.py install\ncd ..\n"
  },
  {
    "path": "logger.py",
    "content": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\nimport colorama\nimport logging\nimport os\nimport re\nimport tools\nimport sys\n\n\ndef get_default_logging_format(colorize=False, brackets=False):\n    style = colorama.Style.DIM if colorize else ''\n    # color = colorama.Fore.CYAN if colorize else ''\n    color = colorama.Fore.WHITE if colorize else ''\n    reset = colorama.Style.RESET_ALL if colorize else ''\n    if brackets:\n        result = \"{}{}[%(asctime)s]{} %(message)s\".format(style, color, reset)\n    else:\n        result = \"{}{}%(asctime)s{} %(message)s\".format(style, color, reset)\n    return result\n\n\ndef get_default_logging_datefmt():\n    return \"%Y-%m-%d %H:%M:%S\"\n\n\ndef log_module_info(module):\n    lines = module.__str__().split(\"\\n\")\n    for line in lines:\n        logging.info(line)\n\n\nclass LogbookFormatter(logging.Formatter):\n    def __init__(self, fmt=None, datefmt=None):\n        super(LogbookFormatter, self).__init__(fmt=fmt, datefmt=datefmt)\n        self._re = re.compile(r\"\\033\\[[0-9]+m\")\n\n    def remove_colors_from_msg(self, msg):\n        msg = re.sub(self._re, \"\", msg)\n        return msg\n\n    def format(self, record=None):\n        record.msg = self.remove_colors_from_msg(record.msg)\n        return super(LogbookFormatter, self).format(record)\n\n\nclass ConsoleFormatter(logging.Formatter):\n    def __init__(self, fmt=None, datefmt=None):\n        super(ConsoleFormatter, self).__init__(fmt=fmt, datefmt=datefmt)\n\n    def format(self, record=None):\n        indent = sys.modules[__name__].global_indent\n        record.msg = \" \" * indent + record.msg\n        return super(ConsoleFormatter, self).format(record)\n\n\nclass SkipLogbookFilter(logging.Filter):\n    def filter(self, record):\n        return record.levelno != logging.LOGBOOK\n\n\ndef configure_logging(filename=None):\n    # set global indent level\n    sys.modules[__name__].global_indent = 0\n\n    # add custom tqdm logger\n    tools.addLoggingLevel(\"LOGBOOK\", 1000)\n\n    # create logger\n    root_logger = logging.getLogger(\"\")\n    root_logger.setLevel(logging.INFO)\n\n    # create console handler and set level to debug\n    console = logging.StreamHandler()\n    console.setLevel(logging.INFO)\n    fmt = get_default_logging_format(colorize=True, brackets=False)\n    datefmt = get_default_logging_datefmt()\n    formatter = ConsoleFormatter(fmt=fmt, datefmt=datefmt)\n    console.setFormatter(formatter)\n\n    # Skip logging.tqdm requests for console outputs\n    skip_logbook_filter = SkipLogbookFilter()\n    console.addFilter(skip_logbook_filter)\n\n    # add console to root_logger\n    root_logger.addHandler(console)\n\n    # add logbook\n    if filename is not None:\n        # ensure dir\n        d = os.path.dirname(filename)\n        if not os.path.exists(d):\n            os.makedirs(d)\n\n        # --------------------------------------------------------------------------------------\n        # Configure handler that removes color codes from logbook\n        # --------------------------------------------------------------------------------------\n        logbook = logging.FileHandler(filename=filename, mode=\"a\", encoding=\"utf-8\")\n        logbook.setLevel(logging.INFO)\n        fmt = get_default_logging_format(colorize=False, brackets=True)\n        logbook_formatter = LogbookFormatter(fmt=fmt, datefmt=datefmt)\n        logbook.setFormatter(logbook_formatter)\n        root_logger.addHandler(logbook)\n\n\nclass LoggingBlock:\n    def __init__(self, title, emph=False):\n        self._emph = emph\n        bright = colorama.Style.BRIGHT\n        cyan = colorama.Fore.CYAN\n        reset = colorama.Style.RESET_ALL\n        if emph:\n            logging.info(\"%s==>%s %s%s%s\" % (cyan, reset, bright, title, reset))\n        else:\n            logging.info(title)\n\n    def __enter__(self):\n        sys.modules[__name__].global_indent += 2\n        return self\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        sys.modules[__name__].global_indent -= 2\n"
  },
  {
    "path": "losses.py",
    "content": "from __future__ import absolute_import, division, print_function\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as tf\n\n\ndef _elementwise_epe(input_flow, target_flow):\n    residual = target_flow - input_flow\n    return torch.norm(residual, p=2, dim=1, keepdim=True)\n\ndef _elementwise_robust_epe_char(input_flow, target_flow):\n    residual = target_flow - input_flow\n    return torch.pow(torch.norm(residual, p=2, dim=1, keepdim=True) + 0.01, 0.4)\n\ndef _downsample2d_as(inputs, target_as):\n    _, _, h, w = target_as.size()\n    return tf.adaptive_avg_pool2d(inputs, [h, w])\n\ndef _upsample2d_as(inputs, target_as, mode=\"bilinear\"):\n    _, _, h, w = target_as.size()\n    return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True)\n\ndef f1_score(y_true, y_pred):\n    return fbeta_score(y_true, y_pred, 1)\n\ndef fbeta_score(y_true, y_pred, beta, eps=1e-8):\n    beta2 = beta ** 2\n\n    y_pred = y_pred.float()\n    y_true = y_true.float()\n\n    true_positive = (y_pred * y_true).sum(dim=2).sum(dim=2)\n    precision = true_positive / (y_pred.sum(dim=2).sum(dim=2) + eps)\n    recall = true_positive / (y_true.sum(dim=2).sum(dim=2) + eps)\n\n    return torch.mean(precision * recall / (precision * beta2 + recall + eps) * (1 + beta2))\n\ndef f1_score_bal_loss(y_pred, y_true):\n    eps = 1e-8\n\n    tp = -(y_true * torch.log(y_pred + eps)).sum(dim=2).sum(dim=2).sum(dim=1)\n    fn = -((1 - y_true) * torch.log((1 - y_pred) + eps)).sum(dim=2).sum(dim=2).sum(dim=1)\n\n    denom_tp = y_true.sum(dim=2).sum(dim=2).sum(dim=1) + y_pred.sum(dim=2).sum(dim=2).sum(dim=1) + eps\n    denom_fn = (1 - y_true).sum(dim=2).sum(dim=2).sum(dim=1) + (1 - y_pred).sum(dim=2).sum(dim=2).sum(dim=1) + eps\n\n    return ((tp / denom_tp).sum() + (fn / denom_fn).sum()) * y_pred.size(2) * y_pred.size(3) * 0.5\n\n\nclass MultiScaleEPE_FlowNet(nn.Module):\n    def __init__(self,\n                 args):\n\n        super(MultiScaleEPE_FlowNet, self).__init__()\n        self._args = args        \n        self._batch_size = args.batch_size\n        self._weights = [0.005, 0.01, 0.02, 0.08, 0.32]\n\n    def forward(self, output_dict, target_dict):\n        loss_dict = {}\n\n        if self.training:\n            outputs = [output_dict[key] for key in [\"flow2\", \"flow3\", \"flow4\", \"flow5\", \"flow6\"]]\n\n            # div_flow trick\n            target = self._args.model_div_flow * target_dict[\"target1\"]\n\n            total_loss = 0\n            for i, output_i in enumerate(outputs):\n                target_i = _downsample2d_as(target, output_i)\n                epe_i = _elementwise_epe(output_i, target_i)\n                total_loss = total_loss + self._weights[i] * epe_i.sum() / self._batch_size\n                loss_dict[\"epe%i\" % (i + 2)] = epe_i.mean()\n            loss_dict[\"total_loss\"] = total_loss\n        else:\n            output = output_dict[\"flow1\"]\n            target = target_dict[\"target1\"]\n            epe = _elementwise_epe(output, target)\n            loss_dict[\"epe\"] = epe.mean()\n\n        return loss_dict\n\nclass MultiScaleEPE_FlowNet_IRR(nn.Module):\n    def __init__(self,\n                 args):\n\n        super(MultiScaleEPE_FlowNet_IRR, self).__init__()\n        self._args = args        \n        self._batch_size = args.batch_size\n        self._weights = [0.005, 0.01, 0.02, 0.08, 0.32]\n        self._num_iters = args.num_iters\n\n    def forward(self, output_dict, target_dict):\n        loss_dict = {}\n\n        if self.training:\n            outputs_flo = [output_dict[key] for key in [\"flow2\", \"flow3\", \"flow4\", \"flow5\", \"flow6\"]]\n\n            # div_flow trick\n            target_f = self._args.model_div_flow * target_dict[\"target1\"]\n\n            total_loss = 0\n            for ii, output_ii in enumerate(outputs_flo):\n                target_f_ii = _downsample2d_as(target_f, output_ii[0])\n                for jj, output_ii_jj in enumerate(output_ii):\n                    epe_f_ii = _elementwise_epe(output_ii_jj, target_f_ii)\n                    total_loss = total_loss + self._weights[ii] * epe_f_ii.sum()\n                    loss_dict[\"epe%i\" % (ii + 2)] = epe_f_ii.mean()\n            loss_dict[\"total_loss\"] = total_loss / self._batch_size / self._num_iters\n\n        else:\n            output = output_dict[\"flow1\"]\n            target_f = target_dict[\"target1\"]\n            epe_f = _elementwise_epe(target_f, output)\n            loss_dict[\"epe\"] = epe_f.mean()\n\n        return loss_dict\n\nclass MultiScaleEPE_FlowNet_IRR_Bi(nn.Module):\n    def __init__(self,\n                 args):\n\n        super(MultiScaleEPE_FlowNet_IRR_Bi, self).__init__()\n        self._args = args        \n        self._batch_size = args.batch_size\n        self._weights = [0.005, 0.01, 0.02, 0.08, 0.32]\n        self._num_iters = args.num_iters\n\n    def forward(self, output_dict, target_dict):\n\n        loss_dict = {}\n\n        if self.training:\n            outputs_flo = [output_dict[key] for key in [\"flow2\", \"flow3\", \"flow4\", \"flow5\", \"flow6\"]]\n\n            # div_flow trick\n            target_f = self._args.model_div_flow * target_dict[\"target1\"]\n            target_b = self._args.model_div_flow * target_dict[\"target2\"]\n\n            total_loss = 0\n            for ii, output_ii in enumerate(outputs_flo):\n                target_f_ii = _downsample2d_as(target_f, output_ii[0][0])\n                target_b_ii = _downsample2d_as(target_b, output_ii[0][1])\n                for jj, output_ii_jj in enumerate(output_ii):\n                    epe_f_ii = _elementwise_epe(output_ii_jj[0], target_f_ii)\n                    epe_b_ii = _elementwise_epe(output_ii_jj[1], target_b_ii)\n                    total_loss = total_loss + self._weights[ii] * (epe_f_ii.sum() + epe_b_ii.sum())\n                    loss_dict[\"epe%i\" % (ii + 2)] = (epe_f_ii.mean() + epe_b_ii.mean()) / 2\n            loss_dict[\"total_loss\"] = total_loss / self._batch_size / self._num_iters / 2\n        else:\n            epe_f = _elementwise_epe(output_dict[\"flow1\"], target_dict[\"target1\"])\n            loss_dict[\"epe\"] = epe_f.mean()\n\n        return loss_dict\n\nclass MultiScaleEPE_FlowNet_IRR_Occ(nn.Module):\n    def __init__(self,\n                 args):\n\n        super(MultiScaleEPE_FlowNet_IRR_Occ, self).__init__()\n        self._args = args        \n        self._batch_size = args.batch_size\n        self._weights = [0.005, 0.01, 0.02, 0.08, 0.32]\n        self._num_iters = args.num_iters\n\n        self.f1_score_bal_loss = f1_score_bal_loss\n        self.occ_activ = nn.Sigmoid()\n        \n    def forward(self, output_dict, target_dict):\n        loss_dict = {}\n\n        if self.training:\n            outputs_flo = [output_dict[key] for key in [\"flow2\", \"flow3\", \"flow4\", \"flow5\", \"flow6\"]]\n            outputs_occ = [output_dict[key] for key in [\"occ2\", \"occ3\", \"occ4\", \"occ5\", \"occ6\"]]\n\n            # div_flow trick\n            target = self._args.model_div_flow * target_dict[\"target1\"]\n            target_occ = target_dict[\"target_occ1\"]\n\n            flow_loss = 0\n            occ_loss = 0\n\n            for ii, output_ii in enumerate(outputs_flo):\n                target_ii = _downsample2d_as(target, output_ii[0])\n                for jj, output_ii_jj in enumerate(output_ii):\n                    flow_loss = flow_loss + self._weights[ii] * _elementwise_epe(output_ii_jj, target_ii).sum()\n\n            for ii, output_ii in enumerate(outputs_occ):\n                target_occ_f = _downsample2d_as(target_occ, output_ii[0])\n                for jj, output_ii_jj in enumerate(output_ii):\n                    occ_loss = occ_loss + self._weights[ii] * self.f1_score_bal_loss(self.occ_activ(output_ii_jj), target_occ_f)\n\n            f_loss = flow_loss.detach()\n            o_loss = occ_loss.detach()\n            if f_loss > o_loss:\n                f_l_w = 1\n                o_l_w = f_loss / o_loss\n            else:\n                f_l_w = o_loss / f_loss\n                o_l_w = 1\n\n            loss_dict[\"flow_loss\"] = flow_loss / self._batch_size / self._num_iters\n            loss_dict[\"occ_loss\"] = occ_loss / self._batch_size / self._num_iters\n            loss_dict[\"total_loss\"] = (flow_loss * f_l_w + occ_loss * o_l_w) / self._batch_size / self._num_iters\n\n        else:\n            loss_dict[\"epe\"] = _elementwise_epe(output_dict[\"flow1\"], target_dict[\"target1\"]).mean()\n            loss_dict[\"F1\"] = f1_score(target_dict[\"target_occ1\"], torch.round(self.occ_activ(output_dict[\"occ1\"])))\n\n        return loss_dict\n\nclass MultiScaleEPE_FlowNet_IRR_Bi_Occ(nn.Module):\n    def __init__(self,\n                 args):\n\n        super(MultiScaleEPE_FlowNet_IRR_Bi_Occ, self).__init__()\n        self._args = args        \n        self._batch_size = args.batch_size\n        self._weights = [0.005, 0.01, 0.02, 0.08, 0.32]\n        self._num_iters = args.num_iters\n\n        self.f1_score_bal_loss = f1_score_bal_loss\n        self.occ_activ = nn.Sigmoid()\n\n    def forward(self, output_dict, target_dict):\n        loss_dict = {}\n\n        if self.training:\n            outputs_flo = [output_dict[key] for key in [\"flow2\", \"flow3\", \"flow4\", \"flow5\", \"flow6\"]]\n            outputs_occ = [output_dict[key] for key in [\"occ2\", \"occ3\", \"occ4\", \"occ5\", \"occ6\"]]\n\n            # div_flow trick\n            target_f = self._args.model_div_flow * target_dict[\"target1\"]\n            target_b = self._args.model_div_flow * target_dict[\"target2\"]\n            target_occ_f = target_dict[\"target_occ1\"]\n            target_occ_b = target_dict[\"target_occ2\"]\n\n            flow_loss = 0\n            occ_loss = 0\n\n            for ii, output_ii in enumerate(outputs_flo):\n                target_f_ii = _downsample2d_as(target_f, output_ii[0][0])\n                target_b_ii = _downsample2d_as(target_b, output_ii[0][1])\n                for jj, output_ii_jj in enumerate(output_ii):\n                    epe_f_ii = _elementwise_epe(output_ii_jj[0], target_f_ii)\n                    epe_b_ii = _elementwise_epe(output_ii_jj[1], target_b_ii)\n                    flow_loss = flow_loss + self._weights[ii] * (epe_f_ii.sum() + epe_b_ii.sum()) * 0.5\n\n            for ii, output_ii in enumerate(outputs_occ):\n                target_occ_f = _downsample2d_as(target_occ_f, output_ii[0][0])\n                target_occ_b = _downsample2d_as(target_occ_b, output_ii[0][1])\n                for jj, output_ii_jj in enumerate(output_ii):\n                    output_occ_f = self.occ_activ(output_ii_jj[0])\n                    output_occ_b = self.occ_activ(output_ii_jj[1])\n                    bce_f_ii = self.f1_score_bal_loss(output_occ_f, target_occ_f)\n                    bce_b_ii = self.f1_score_bal_loss(output_occ_b, target_occ_b)\n                    occ_loss = occ_loss + self._weights[ii] * (bce_f_ii + bce_b_ii) * 0.5\n\n            f_loss = flow_loss.detach()\n            o_loss = occ_loss.detach()\n            if f_loss > o_loss:\n                f_l_w = 1\n                o_l_w = f_loss / o_loss\n            else:\n                f_l_w = o_loss / f_loss\n                o_l_w = 1\n\n            loss_dict[\"flow_loss\"] = flow_loss / self._batch_size / self._num_iters\n            loss_dict[\"occ_loss\"] = occ_loss / self._batch_size / self._num_iters\n            loss_dict[\"total_loss\"] = (flow_loss * f_l_w + occ_loss * o_l_w) / self._batch_size / self._num_iters\n        else:\n            loss_dict[\"epe\"] = _elementwise_epe(output_dict[\"flow1\"], target_dict[\"target1\"]).mean()\n            loss_dict[\"F1\"] = f1_score(target_dict[\"target_occ1\"], torch.round(self.occ_activ(output_dict[\"occ1\"])))\n\n        return loss_dict\n\nclass MultiScaleEPE_FlowNet_IRR_Bi_Occ_upsample(nn.Module):\n    def __init__(self,\n                 args):\n        super(MultiScaleEPE_FlowNet_IRR_Bi_Occ_upsample, self).__init__()\n        self._args = args\n        self._batch_size = args.batch_size        \n        self._weights = [0.0003125, 0.00125, 0.005, 0.01, 0.02, 0.08, 0.32]\n        \n        self.occ_activ = nn.Sigmoid()\n        self.f1_score_bal_loss = f1_score_bal_loss\n\n    def forward(self, output_dict, target_dict):\n        loss_dict = {}\n\n        if self.training:\n            outputs_flo = [output_dict[key] for key in [\"flow\", \"flow1\", \"flow2\", \"flow3\", \"flow4\", \"flow5\", \"flow6\"]]\n            outputs_occ = [output_dict[key] for key in [\"occ\", \"occ1\", \"occ2\", \"occ3\", \"occ4\", \"occ5\", \"occ6\"]]\n\n            # div_flow trick\n            target_f = self._args.model_div_flow * target_dict[\"target1\"]\n            target_b = self._args.model_div_flow * target_dict[\"target2\"]\n            target_occ_f = target_dict[\"target_occ1\"]\n            target_occ_b = target_dict[\"target_occ2\"]\n\n            num_iters = len(outputs_flo[0])\n            flow_loss = 0\n            occ_loss = 0\n\n            for ii, output_ii in enumerate(outputs_flo):\n                target_f_ii = _downsample2d_as(target_f, output_ii[0][0])\n                target_b_ii = _downsample2d_as(target_b, output_ii[0][1])\n                for jj, output_ii_jj in enumerate(output_ii):\n                    epe_f_ii = _elementwise_epe(output_ii_jj[0], target_f_ii)\n                    epe_b_ii = _elementwise_epe(output_ii_jj[1], target_b_ii)\n                    flow_loss = flow_loss + self._weights[ii] * (epe_f_ii.sum() + epe_b_ii.sum()) * 0.5\n\n            for ii, output_ii in enumerate(outputs_occ):\n                target_occ_f = _downsample2d_as(target_occ_f, output_ii[0][0])\n                target_occ_b = _downsample2d_as(target_occ_b, output_ii[0][1])\n                for jj, output_ii_jj in enumerate(output_ii):\n                    output_occ_f = self.occ_activ(output_ii_jj[0])\n                    output_occ_b = self.occ_activ(output_ii_jj[1])\n                    bce_f_ii = self.f1_score_bal_loss(output_occ_f, target_occ_f)\n                    bce_b_ii = self.f1_score_bal_loss(output_occ_b, target_occ_b)\n                    occ_loss = occ_loss + self._weights[ii] * (bce_f_ii + bce_b_ii) * 0.5\n\n            f_loss = flow_loss.detach()\n            o_loss = occ_loss.detach()\n            if f_loss > o_loss:\n                f_l_w = 1\n                o_l_w = f_loss / o_loss\n            else:\n                f_l_w = o_loss / f_loss\n                o_l_w = 1\n\n            loss_dict[\"flow_loss\"] = flow_loss / self._batch_size / num_iters\n            loss_dict[\"occ_loss\"] = occ_loss / self._batch_size / num_iters\n            loss_dict[\"total_loss\"] = (flow_loss * f_l_w + occ_loss * o_l_w) / self._batch_size / num_iters\n        else:\n            loss_dict[\"epe\"] = _elementwise_epe(output_dict[\"flow\"], target_dict[\"target1\"]).mean()\n            loss_dict[\"F1\"] = f1_score(target_dict[\"target_occ1\"], torch.round(self.occ_activ(output_dict[\"occ\"])))\n\n        return loss_dict\n\n\n\nclass MultiScaleEPE_PWC(nn.Module):\n    def __init__(self,\n                 args):\n\n        super(MultiScaleEPE_PWC, self).__init__()\n        self._args = args\n        self._batch_size = args.batch_size\n        self._weights = [0.32, 0.08, 0.02, 0.01, 0.005]\n\n    def forward(self, output_dict, target_dict):\n        loss_dict = {}\n\n        if self.training:\n            outputs = output_dict['flow']\n\n            # div_flow trick\n            target = self._args.model_div_flow * target_dict[\"target1\"]\n\n            total_loss = 0\n            for ii, output_ii in enumerate(outputs):\n                loss_ii = _elementwise_epe(output_ii, _downsample2d_as(target, output_ii)).sum()\n                total_loss = total_loss + self._weights[ii] * loss_ii\n            loss_dict[\"total_loss\"] = total_loss / self._batch_size\n\n        else:\n            epe = _elementwise_epe(output_dict[\"flow\"], target_dict[\"target1\"])\n            loss_dict[\"epe\"] = epe.mean()\n\n        return loss_dict\n\nclass MultiScaleEPE_PWC_Bi(nn.Module):\n    def __init__(self,\n                 args):\n\n        super(MultiScaleEPE_PWC_Bi, self).__init__()\n        self._args = args\n        self._batch_size = args.batch_size\n        self._weights = [0.32, 0.08, 0.02, 0.01, 0.005]\n\n    def forward(self, output_dict, target_dict):\n        loss_dict = {}\n\n        if self.training:\n            outputs = output_dict['flow']\n\n            # div_flow trick\n            target_f = self._args.model_div_flow * target_dict[\"target1\"]\n            target_b = self._args.model_div_flow * target_dict[\"target2\"]\n\n            total_loss = 0\n            for i, output_i in enumerate(outputs):\n                epe_i_f = _elementwise_epe(output_i[0], _downsample2d_as(target_f, output_i[0]))\n                epe_i_b = _elementwise_epe(output_i[1], _downsample2d_as(target_b, output_i[1]))\n                total_loss = total_loss + self._weights[i] * (epe_i_f.sum() + epe_i_b.sum())\n            loss_dict[\"total_loss\"] = total_loss / (2 * self._batch_size)\n        else:\n            epe = _elementwise_epe(output_dict[\"flow\"], target_dict[\"target1\"])\n            loss_dict[\"epe\"] = epe.mean()\n\n        return loss_dict\n\nclass MultiScaleEPE_PWC_Occ(nn.Module):\n    def __init__(self,\n                 args):\n\n        super(MultiScaleEPE_PWC_Occ, self).__init__()\n        self._args = args\n        self._batch_size = args.batch_size\n        self._weights = [0.32, 0.08, 0.02, 0.01, 0.005]\n\n        self.occ_activ = nn.Sigmoid()\n        self.f1_score_bal_loss = f1_score_bal_loss\n\n    def forward(self, output_dict, target_dict):\n        loss_dict = {}\n\n        if self.training:\n            output_flo = output_dict['flow']\n            output_occ = output_dict['occ']\n\n            # div_flow trick\n            target_flo = self._args.model_div_flow * target_dict[\"target1\"]\n            target_occ = target_dict[\"target_occ1\"]\n\n            flow_loss = 0\n            occ_loss = 0\n\n            for i, output_i in enumerate(output_flo):\n                flow_loss = flow_loss + self._weights[i] * _elementwise_epe(output_i, _downsample2d_as(target_flo, output_i)).sum()\n\n            for i, output_i in enumerate(output_occ):\n                output_occ = self.occ_activ(output_i)\n                occ_loss = occ_loss + self._weights[i] * self.f1_score_bal_loss(output_occ, _downsample2d_as(target_occ, output_occ))\n\n            f_loss = flow_loss.detach()\n            o_loss = occ_loss.detach()\n            if f_loss > o_loss:\n                f_l_w = 1\n                o_l_w = f_loss / o_loss\n            else:\n                f_l_w = o_loss / f_loss\n                o_l_w = 1\n\n            loss_dict[\"flow_loss\"] = flow_loss / self._batch_size\n            loss_dict[\"occ_loss\"] = occ_loss / self._batch_size\n            loss_dict[\"total_loss\"] = (flow_loss * f_l_w + occ_loss * o_l_w) / self._batch_size\n\n        else:\n            loss_dict[\"epe\"] = _elementwise_epe(output_dict[\"flow\"], target_dict[\"target1\"]).mean()\n            loss_dict[\"F1\"] = f1_score(target_dict[\"target_occ1\"], torch.round(self.occ_activ(output_dict[\"occ\"])))\n\n        return loss_dict\n\nclass MultiScaleEPE_PWC_Bi_Occ(nn.Module):\n    def __init__(self,\n                 args):\n\n        super(MultiScaleEPE_PWC_Bi_Occ, self).__init__()\n        self._args = args\n        self._batch_size = args.batch_size        \n        self._weights = [0.32, 0.08, 0.02, 0.01, 0.005]\n\n        self.occ_activ = nn.Sigmoid()\n        self.f1_score_bal_loss = f1_score_bal_loss\n\n    def forward(self, output_dict, target_dict):\n        loss_dict = {}\n\n        if self.training:\n            output_flo = output_dict['flow']\n            output_occ = output_dict['occ']\n\n            # div_flow trick\n            target_flo_f = self._args.model_div_flow * target_dict[\"target1\"]\n            target_flo_b = self._args.model_div_flow * target_dict[\"target2\"]\n            target_occ_f = target_dict[\"target_occ1\"]\n            target_occ_b = target_dict[\"target_occ2\"]\n\n            # bchw\n            flow_loss = 0\n            occ_loss = 0\n\n            for i, output_i in enumerate(output_flo):\n                flow_loss = flow_loss + self._weights[i] * _elementwise_epe(output_i[0], _downsample2d_as(target_flo_f, output_i[0])).sum()\n                flow_loss = flow_loss + self._weights[i] * _elementwise_epe(output_i[1], _downsample2d_as(target_flo_b, output_i[1])).sum()\n\n            for i, output_i in enumerate(output_occ):\n                output_occ_f = self.occ_activ(output_i[0])\n                output_occ_b = self.occ_activ(output_i[1])\n                occ_loss = occ_loss + self._weights[i] * self.f1_score_bal_loss(output_occ_f, _downsample2d_as(target_occ_f, output_occ_f))\n                occ_loss = occ_loss + self._weights[i] * self.f1_score_bal_loss(output_occ_b, _downsample2d_as(target_occ_b, output_occ_b))\n\n            f_loss = flow_loss.detach()\n            o_loss = occ_loss.detach()\n            if f_loss > o_loss:\n                f_l_w = 1\n                o_l_w = f_loss / o_loss\n            else:\n                f_l_w = o_loss / f_loss\n                o_l_w = 1\n\n            loss_dict[\"flow_loss\"] = flow_loss / (2 * self._batch_size)\n            loss_dict[\"occ_loss\"] = occ_loss / (2 * self._batch_size) \n            loss_dict[\"total_loss\"] = (flow_loss * f_l_w + occ_loss * o_l_w) / (2 * self._batch_size)\n\n        else:\n            loss_dict[\"epe\"] = _elementwise_epe(output_dict[\"flow\"], target_dict[\"target1\"]).mean()\n            loss_dict[\"F1\"] = f1_score(target_dict[\"target_occ1\"], torch.round(self.occ_activ(output_dict[\"occ\"])))\n\n        return loss_dict\n\nclass MultiScaleEPE_PWC_Bi_Occ_upsample(nn.Module):\n    def __init__(self,\n                 args):\n\n        super(MultiScaleEPE_PWC_Bi_Occ_upsample, self).__init__()\n        self._args = args\n        self._batch_size = args.batch_size\n        self._weights = [0.32, 0.08, 0.02, 0.01, 0.005, 0.00125, 0.0003125]\n\n        self.occ_activ = nn.Sigmoid()\n        self.f1_score_bal_loss = f1_score_bal_loss\n\n    def forward(self, output_dict, target_dict):\n        loss_dict = {}\n\n        if self.training:\n            output_flo = output_dict['flow']\n            output_occ = output_dict['occ']\n\n            # div_flow trick\n            target_flo_f = self._args.model_div_flow * target_dict[\"target1\"]\n            target_flo_b = self._args.model_div_flow * target_dict[\"target2\"]\n            target_occ_f = target_dict[\"target_occ1\"]\n            target_occ_b = target_dict[\"target_occ2\"]\n\n            # bchw\n            flow_loss = 0\n            occ_loss = 0\n\n            for ii, output_ii in enumerate(output_flo):\n                loss_ii = 0\n                for jj in range(0, len(output_ii) // 2):\n                    loss_ii = loss_ii + _elementwise_epe(output_ii[2 * jj], _downsample2d_as(target_flo_f, output_ii[2 * jj])).sum()\n                    loss_ii = loss_ii + _elementwise_epe(output_ii[2 * jj + 1], _downsample2d_as(target_flo_b, output_ii[2 * jj + 1])).sum()\n                flow_loss = flow_loss + self._weights[ii] * loss_ii / len(output_ii)\n\n            for ii, output_ii in enumerate(output_occ):\n                loss_ii = 0\n                for jj in range(0, len(output_ii) // 2):\n                    output_occ_f = self.occ_activ(output_ii[2 * jj])\n                    output_occ_b = self.occ_activ(output_ii[2 * jj + 1])\n                    loss_ii = loss_ii + self.f1_score_bal_loss(output_occ_f, _downsample2d_as(target_occ_f, output_occ_f))\n                    loss_ii = loss_ii + self.f1_score_bal_loss(output_occ_b, _downsample2d_as(target_occ_b, output_occ_b))\n                occ_loss = occ_loss + self._weights[ii] * loss_ii / len(output_ii)\n\n            f_loss = flow_loss.detach()\n            o_loss = occ_loss.detach()\n            if f_loss > o_loss:\n                f_l_w = 1\n                o_l_w = f_loss / o_loss\n            else:\n                f_l_w = o_loss / f_loss\n                o_l_w = 1\n\n            loss_dict[\"flow_loss\"] = flow_loss / self._batch_size\n            loss_dict[\"occ_loss\"] = occ_loss / self._batch_size\n            loss_dict[\"total_loss\"] = (flow_loss * f_l_w + occ_loss * o_l_w) / self._batch_size\n\n        else:\n            loss_dict[\"epe\"] = _elementwise_epe(output_dict[\"flow\"], target_dict[\"target1\"]).mean()\n            loss_dict[\"F1\"] = f1_score(target_dict[\"target_occ1\"], torch.round(self.occ_activ(output_dict[\"occ\"])))\n\n        return loss_dict\n\nclass MultiScaleEPE_PWC_Bi_Occ_upsample_Sintel(nn.Module):\n    def __init__(self,\n                 args):\n\n        super(MultiScaleEPE_PWC_Bi_Occ_upsample_Sintel, self).__init__()\n        self._args = args\n        self._batch_size = args.batch_size        \n        self._weights = [0.32, 0.08, 0.02, 0.01, 0.005, 0.00125, 0.0003125]\n\n        self.occ_activ = nn.Sigmoid()\n        self.occ_loss_bce = nn.BCELoss(reduction='sum')\n\n    def forward(self, output_dict, target_dict):\n        loss_dict = {}\n\n        if self.training:\n            output_flo = output_dict['flow']\n            output_occ = output_dict['occ']\n\n            # div_flow trick\n            target_flo_f = self._args.model_div_flow * target_dict[\"target1\"]\n            target_occ_f = target_dict[\"target_occ1\"]\n\n            # bchw\n            flow_loss = 0\n            occ_loss = 0\n\n            for ii, output_ii in enumerate(output_flo):\n                loss_ii = 0\n                for jj in range(0, len(output_ii) // 2):\n                    loss_ii = loss_ii + _elementwise_robust_epe_char(output_ii[2 * jj], _downsample2d_as(target_flo_f, output_ii[2 * jj])).sum()\n                    output_ii[2 * jj + 1] = output_ii[2 * jj + 1].detach()\n                flow_loss = flow_loss + self._weights[ii] * loss_ii / len(output_ii) * 2\n\n            for ii, output_ii in enumerate(output_occ):\n                loss_ii = 0\n                for jj in range(0, len(output_ii) // 2):\n                    output_occ_f = self.occ_activ(output_ii[2 * jj])\n                    output_ii[2 * jj + 1] = output_ii[2 * jj + 1].detach()\n                    loss_ii = loss_ii + self.occ_loss_bce(output_occ_f, _downsample2d_as(target_occ_f, output_occ_f))\n                occ_loss = occ_loss + self._weights[ii] * loss_ii / len(output_ii) * 2\n\n            f_loss = flow_loss.detach()\n            o_loss = occ_loss.detach()\n            if f_loss > o_loss:\n                f_l_w = 1\n                o_l_w = f_loss / o_loss\n            else:\n                f_l_w = o_loss / f_loss\n                o_l_w = 1\n\n            loss_dict[\"flow_loss\"] = flow_loss / self._batch_size\n            loss_dict[\"occ_loss\"] = occ_loss / self._batch_size\n            loss_dict[\"total_loss\"] = (flow_loss * f_l_w + occ_loss * o_l_w) / self._batch_size\n\n        else:\n            loss_dict[\"epe\"] = _elementwise_epe(output_dict[\"flow\"], target_dict[\"target1\"]).mean()\n            loss_dict[\"F1\"] = f1_score(target_dict[\"target_occ1\"], torch.round(self.occ_activ(output_dict[\"occ\"])))\n\n        return loss_dict\n\nclass MultiScaleEPE_PWC_Bi_Occ_upsample_KITTI(nn.Module):\n    def __init__(self,\n                 args):\n\n        super(MultiScaleEPE_PWC_Bi_Occ_upsample_KITTI, self).__init__()\n        self._args = args\n        self._batch_size = args.batch_size\n        self._weights = [0.001, 0.001, 0.001, 0.002, 0.004, 0.004, 0.004]\n\n        self.occ_activ = nn.Sigmoid()\n        \n    def forward(self, output_dict, target_dict):\n        loss_dict = {}\n\n        valid_mask = target_dict[\"input_valid\"]\n        b, _, h, w = target_dict[\"target1\"].size()\n\n        if self.training:\n            output_flo = output_dict['flow']\n            output_occ = output_dict['occ']\n\n            # div_flow trick\n            target_flo_f = self._args.model_div_flow * target_dict[\"target1\"]\n\n            # bchw\n            flow_loss = 0\n\n            for ii, output_ii in enumerate(output_flo):\n                loss_ii = 0\n                for jj in range(0, len(output_ii) // 2):\n                    valid_epe = _elementwise_robust_epe_char(_upsample2d_as(output_ii[2 * jj], target_flo_f), target_flo_f) * valid_mask\n\n                    for bb in range(0, b):\n                        valid_epe[bb, ...][valid_mask[bb, ...] == 0] = valid_epe[bb, ...][valid_mask[bb, ...] == 0].detach()\n                        norm_const = h * w / (valid_mask[bb, ...].sum())\n                        loss_ii = loss_ii + valid_epe[bb, ...][valid_mask[bb, ...] != 0].sum() * norm_const\n\n                    output_ii[2 * jj + 1] = output_ii[2 * jj + 1].detach()\n                flow_loss = flow_loss + self._weights[ii] * loss_ii / len(output_ii) * 2\n\n            for ii, output_ii in enumerate(output_occ):\n                for jj in range(0, len(output_ii) // 2):\n                    output_ii[2 * jj] = output_ii[2 * jj].detach()\n                    output_ii[2 * jj + 1] = output_ii[2 * jj + 1].detach()\n\n            loss_dict[\"flow_loss\"] = flow_loss / self._batch_size\n            loss_dict[\"total_loss\"] = flow_loss / self._batch_size\n\n        else:\n            flow_gt_mag = torch.norm(target_dict[\"target1\"], p=2, dim=1, keepdim=True) + 1e-8\n            flow_epe = _elementwise_epe(output_dict[\"flow\"], target_dict[\"target1\"]) * valid_mask\n\n            epe_per_image = (flow_epe.view(b, -1).sum(1)) / (valid_mask.view(b, -1).sum(1))\n            loss_dict[\"epe\"] = epe_per_image.mean()\n\n            outlier_epe = (flow_epe > 3).float() * ((flow_epe / flow_gt_mag) > 0.05).float() * valid_mask\n            outlier_per_image = (outlier_epe.view(b, -1).sum(1)) / (valid_mask.view(b, -1).sum(1))\n            loss_dict[\"outlier\"] = outlier_per_image.mean()\n\n        return loss_dict\n\n\n\n\n\n\n\n"
  },
  {
    "path": "main.py",
    "content": "from __future__ import absolute_import, division, print_function\n\nimport os\nimport subprocess\nimport commandline\nimport configuration as config\nimport runtime\nimport logger\nimport logging\nimport tools\nimport torch\n\n\ndef main():\n\n    # Change working directory    \n    os.chdir(os.path.dirname(os.path.realpath(__file__)))\n\n    # Parse commandline arguments    \n    args = commandline.setup_logging_and_parse_arguments(blocktitle=\"Commandline Arguments\")\n\n    # Set random seed, possibly on Cuda    \n    config.configure_random_seed(args)    \n\n    # DataLoader\n    train_loader, validation_loader, inference_loader = config.configure_data_loaders(args)\n    success = any(loader is not None for loader in [train_loader, validation_loader, inference_loader])\n    if not success:\n        logging.info(\"No dataset could be loaded successfully. Please check dataset paths!\")\n        quit()\n\n    # Configure data augmentation\n    training_augmentation, validation_augmentation = config.configure_runtime_augmentations(args)\n\n    # Configure model and loss    \n    model_and_loss = config.configure_model_and_loss(args)\n\n    # Resume from checkpoint if available    \n    checkpoint_saver, checkpoint_stats = config.configure_checkpoint_saver(args, model_and_loss)\n\n    # Checkpoint and save directory    \n    with logger.LoggingBlock(\"Save Directory\", emph=True):\n        logging.info(\"Save directory: %s\" % args.save)\n        if not os.path.exists(args.save):\n            os.makedirs(args.save)\n\n    # # Multi-GPU automation    \n    # with logger.LoggingBlock(\"Multi GPU\", emph=True):\n    #     if torch.cuda.device_count() > 1:\n    #         logging.info(\"Let's use %d GPUs!\" % torch.cuda.device_count())\n    #         model_and_loss._model = torch.nn.DataParallel(model_and_loss._model)\n    #     else:\n    #         logging.info(\"Let's use %d GPU!\" % torch.cuda.device_count())\n\n    \n    # Configure optimizer    \n    optimizer = config.configure_optimizer(args, model_and_loss)\n    \n    # Configure learning rate    \n    lr_scheduler = config.configure_lr_scheduler(args, optimizer)\n\n    # If this is just an evaluation: overwrite savers and epochs\n    if args.evaluation:\n        args.start_epoch = 1\n        args.total_epochs = 1\n        train_loader = None\n        checkpoint_saver = None\n        optimizer = None\n        lr_scheduler = None\n\n    # Cuda optimization    \n    if args.cuda:\n        torch.backends.cudnn.benchmark = True\n\n    # Kickoff training, validation and/or testing    \n    return runtime.exec_runtime(\n        args,\n        checkpoint_saver=checkpoint_saver,\n        model_and_loss=model_and_loss,\n        optimizer=optimizer,\n        lr_scheduler=lr_scheduler,\n        train_loader=train_loader,\n        validation_loader=validation_loader,\n        inference_loader=inference_loader,\n        training_augmentation=training_augmentation,\n        validation_augmentation=validation_augmentation)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "models/IRR_FlowNet.py",
    "content": "from __future__ import absolute_import, division, print_function\n\nimport torch\nimport torch.nn as nn\nfrom .flownet_modules import conv, deconv\nfrom .flownet_modules import concatenate_as, upsample2d_as\nfrom .flownet_modules import initialize_msra\nfrom .flownet_modules import WarpingLayer\nfrom .irr_modules import OccUpsampleNetwork, RefineFlow, RefineOcc\n\nclass FlowNetS(nn.Module):\n    def __init__(self, args):\n        super(FlowNetS, self).__init__()\n\n        def make_conv(in_planes, out_planes, kernel_size, stride):\n            pad = kernel_size // 2\n            return conv(in_planes, out_planes, kernel_size=kernel_size,\n                        stride=stride, pad=pad, nonlinear=True, bias=True)\n\n        self._conv3_1 = make_conv( 256,  256, kernel_size=3, stride=1)\n        self._conv4   = make_conv( 256,  512, kernel_size=3, stride=2)\n        self._conv4_1 = make_conv( 512,  512, kernel_size=3, stride=1)\n        self._conv5   = make_conv( 512,  512, kernel_size=3, stride=2)\n        self._conv5_1 = make_conv( 512,  512, kernel_size=3, stride=1)\n        self._conv6   = make_conv( 512, 1024, kernel_size=3, stride=2)\n        self._conv6_1 = make_conv(1024, 1024, kernel_size=3, stride=1)\n\n        def make_deconv(in_planes, out_planes):\n            return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,\n                          nonlinear=True, bias=False)\n\n        self._deconv5 = make_deconv(1024    , 512)\n        self._deconv4 = make_deconv(1024 + 2, 256)\n        self._deconv3 = make_deconv( 768 + 2, 128)\n        self._deconv2 = make_deconv( 384 + 2,  64)\n\n        self._deconv_occ5 = make_deconv(1024    , 512)\n        self._deconv_occ4 = make_deconv(1024 + 1, 256)\n        self._deconv_occ3 = make_deconv( 768 + 1, 128)\n        self._deconv_occ2 = make_deconv( 384 + 1,  64)\n\n        def make_predict(in_planes, out_planes):\n            return conv(in_planes, out_planes, kernel_size=3, stride=1, pad=1,\n                        nonlinear=False, bias=True)\n\n        self._predict_flow6 = make_predict(1024    , 2)\n        self._predict_flow5 = make_predict(1024 + 2, 2)\n        self._predict_flow4 = make_predict( 768 + 2, 2)\n        self._predict_flow3 = make_predict( 384 + 2, 2)\n        self._predict_flow2 = make_predict( 128 + 2, 2)\n\n        self._predict_occ6 = make_predict(1024    , 1)\n        self._predict_occ5 = make_predict(1024 + 1, 1)\n        self._predict_occ4 = make_predict( 768 + 1, 1)\n        self._predict_occ3 = make_predict( 384 + 1, 1)\n        self._predict_occ2 = make_predict( 128 + 1, 1)\n\n        def make_upsample(in_planes, out_planes):\n            return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,\n                          nonlinear=False, bias=False)\n\n        self._upsample_flow6_to_5 = make_upsample(2, 2)\n        self._upsample_flow5_to_4 = make_upsample(2, 2)\n        self._upsample_flow4_to_3 = make_upsample(2, 2)\n        self._upsample_flow3_to_2 = make_upsample(2, 2)\n\n        self._upsample_occ6_to_5 = make_upsample(1, 1)\n        self._upsample_occ5_to_4 = make_upsample(1, 1)\n        self._upsample_occ4_to_3 = make_upsample(1, 1)\n        self._upsample_occ3_to_2 = make_upsample(1, 1)\n\n    def forward(self, conv2_im1, conv3_im1, conv3_im2):\n\n        conv_concat3 = torch.cat((conv3_im1, conv3_im2), dim=1)\n\n        conv3_1 = self._conv3_1(conv_concat3)\n        conv4_1 = self._conv4_1(self._conv4(conv3_1))\n        conv5_1 = self._conv5_1(self._conv5(conv4_1))\n        conv6_1 = self._conv6_1(self._conv6(conv5_1))\n\n        # Flow Decoder\n        predict_flow6        = self._predict_flow6(conv6_1)\n\n        upsampled_flow6_to_5 = self._upsample_flow6_to_5(predict_flow6)\n        deconv5              = self._deconv5(conv6_1)\n        concat5              = concatenate_as((conv5_1, deconv5, upsampled_flow6_to_5), conv5_1, dim=1)\n        predict_flow5        = self._predict_flow5(concat5)\n\n        upsampled_flow5_to_4 = self._upsample_flow5_to_4(predict_flow5)\n        deconv4              = self._deconv4(concat5)\n        concat4              = concatenate_as((conv4_1, deconv4, upsampled_flow5_to_4), conv4_1, dim=1)\n        predict_flow4        = self._predict_flow4(concat4)\n\n        upsampled_flow4_to_3 = self._upsample_flow4_to_3(predict_flow4)\n        deconv3              = self._deconv3(concat4)\n        concat3              = concatenate_as((conv3_1, deconv3, upsampled_flow4_to_3), conv3_1, dim=1)\n        predict_flow3        = self._predict_flow3(concat3)\n\n        upsampled_flow3_to_2 = self._upsample_flow3_to_2(predict_flow3)\n        deconv2              = self._deconv2(concat3)\n        concat2              = concatenate_as((conv2_im1, deconv2, upsampled_flow3_to_2), conv2_im1, dim=1)\n        predict_flow2        = self._predict_flow2(concat2)\n\n        # Occ Decoder\n        predict_occ6 = self._predict_occ6(conv6_1)\n\n        upsampled_occ6_to_5 = self._upsample_occ6_to_5(predict_occ6)\n        deconv_occ5         = self._deconv_occ5(conv6_1)\n        concat_occ5         = concatenate_as((conv5_1, deconv_occ5, upsampled_occ6_to_5), conv5_1, dim=1)\n        predict_occ5        = self._predict_occ5(concat_occ5)\n\n        upsampled_occ5_to_4 = self._upsample_occ5_to_4(predict_occ5)\n        deconv_occ4         = self._deconv_occ4(concat_occ5)\n        concat_occ4         = concatenate_as((conv4_1, deconv_occ4, upsampled_occ5_to_4), conv4_1, dim=1)\n        predict_occ4        = self._predict_occ4(concat_occ4)\n\n        upsampled_occ4_to_3 = self._upsample_occ4_to_3(predict_occ4)\n        deconv_occ3         = self._deconv_occ3(concat_occ4)\n        concat_occ3         = concatenate_as((conv3_1, deconv_occ3, upsampled_occ4_to_3), conv3_1, dim=1)\n        predict_occ3        = self._predict_occ3(concat_occ3)\n\n        upsampled_occ3_to_2 = self._upsample_occ3_to_2(predict_occ3)\n        deconv_occ2         = self._deconv_occ2(concat_occ3)\n        concat_occ2         = concatenate_as((conv2_im1, deconv_occ2, upsampled_occ3_to_2), conv2_im1, dim=1)\n        predict_occ2        = self._predict_occ2(concat_occ2)\n\n        return predict_flow2, predict_flow3, predict_flow4, predict_flow5, predict_flow6, predict_occ2, predict_occ3, predict_occ4, predict_occ5, predict_occ6\n\n\nclass FlowNet1S(nn.Module):\n    def __init__(self, args, div_flow=0.05):\n        super(FlowNet1S, self).__init__()\n        self._flownets = FlowNetS(args)\n        self._warping_layer = WarpingLayer()\n        self._div_flow = div_flow\n        self._num_iters = args.num_iters\n\n        def make_conv(in_planes, out_planes, kernel_size, stride):\n            pad = kernel_size // 2\n            return conv(in_planes, out_planes, kernel_size=kernel_size,\n                        stride=stride, pad=pad, nonlinear=True, bias=True)\n\n        self._conv1   = make_conv(   3,   32, kernel_size=7, stride=2)\n        self._conv2   = make_conv(  32,   64, kernel_size=5, stride=2)\n        self._conv3   = make_conv(  64,  128, kernel_size=5, stride=2)\n\n        self.occ_shuffle_upsample = OccUpsampleNetwork(11, 1)\n        self.refine_flow = RefineFlow(2 + 1 + 64)\n        self.refine_occ = RefineOcc(1 + 64 + 64)\n\n        initialize_msra(self.modules())\n\n    def forward(self, input_dict):\n        im1 = input_dict['input1']\n        im2 = input_dict['input2']\n\n        conv1_im1 = self._conv1(im1)\n        conv2_im1 = self._conv2(conv1_im1)\n        conv3_im1 = self._conv3(conv2_im1)\n        conv3_im1_wp = conv3_im1\n\n        conv1_im2 = self._conv1(im2)\n        conv2_im2 = self._conv2(conv1_im2)\n        conv3_im2 = self._conv3(conv2_im2)\n        conv3_im2_wp = conv3_im2\n\n        out_dict = {}\n        out_dict['flow'] = []\n        out_dict['flow1'] = []\n        out_dict['flow2'] = []\n        out_dict['flow3'] = []\n        out_dict['flow4'] = []\n        out_dict['flow5'] = []\n        out_dict['flow6'] = []\n        out_dict['occ'] = []\n        out_dict['occ1'] = []\n        out_dict['occ2'] = []\n        out_dict['occ3'] = []\n        out_dict['occ4'] = []\n        out_dict['occ5'] = []\n        out_dict['occ6'] = []\n\n        # warping:\n        _, _, height_im, width_im = im1.size()\n\n        # for iterative\n        for ii in range(0, self._num_iters):\n            flo2_f, flo3_f, flo4_f, flo5_f, flo6_f, occ2_f, occ3_f, occ4_f, occ5_f, occ6_f = self._flownets(conv2_im1,\n                                                                                                            conv3_im1,\n                                                                                                            conv3_im2_wp)\n            flo2_b, flo3_b, flo4_b, flo5_b, flo6_b, occ2_b, occ3_b, occ4_b, occ5_b, occ6_b = self._flownets(conv2_im2,\n                                                                                                            conv3_im2,\n                                                                                                            conv3_im1_wp)\n\n            if ii == 0:\n                out_dict['flow2'].append([flo2_f, flo2_b])\n                out_dict['flow3'].append([flo3_f, flo3_b])\n                out_dict['flow4'].append([flo4_f, flo4_b])\n                out_dict['flow5'].append([flo5_f, flo5_b])\n                out_dict['flow6'].append([flo6_f, flo6_b])\n                out_dict['occ2'].append([occ2_f, occ2_b])\n                out_dict['occ3'].append([occ3_f, occ3_b])\n                out_dict['occ4'].append([occ4_f, occ4_b])\n                out_dict['occ5'].append([occ5_f, occ5_b])\n                out_dict['occ6'].append([occ6_f, occ6_b])\n                flo2_f_out = flo2_f\n                flo2_b_out = flo2_b\n                occ2_f_out = occ2_f\n                occ2_b_out = occ2_b\n            else:\n                out_dict['flow2'].append([flo2_f + out_dict['flow2'][ii - 1][0], flo2_b + out_dict['flow2'][ii - 1][1]])\n                out_dict['flow3'].append([flo3_f + out_dict['flow3'][ii - 1][0], flo3_b + out_dict['flow3'][ii - 1][1]])\n                out_dict['flow4'].append([flo4_f + out_dict['flow4'][ii - 1][0], flo4_b + out_dict['flow4'][ii - 1][1]])\n                out_dict['flow5'].append([flo5_f + out_dict['flow5'][ii - 1][0], flo5_b + out_dict['flow5'][ii - 1][1]])\n                out_dict['flow6'].append([flo6_f + out_dict['flow6'][ii - 1][0], flo6_b + out_dict['flow6'][ii - 1][1]])\n                out_dict['occ2'].append([occ2_f + out_dict['occ2'][ii - 1][0], occ2_b + out_dict['occ2'][ii - 1][1]])\n                out_dict['occ3'].append([occ3_f + out_dict['occ3'][ii - 1][0], occ3_b + out_dict['occ3'][ii - 1][1]])\n                out_dict['occ4'].append([occ4_f + out_dict['occ4'][ii - 1][0], occ4_b + out_dict['occ4'][ii - 1][1]])\n                out_dict['occ5'].append([occ5_f + out_dict['occ5'][ii - 1][0], occ5_b + out_dict['occ5'][ii - 1][1]])\n                out_dict['occ6'].append([occ6_f + out_dict['occ6'][ii - 1][0], occ6_b + out_dict['occ6'][ii - 1][1]])\n                flo2_f_out = flo2_f + upsample2d_as(out_dict['flow1'][ii - 1][0], flo2_f, mode=\"bilinear\")\n                flo2_b_out = flo2_b + upsample2d_as(out_dict['flow1'][ii - 1][1], flo2_b, mode=\"bilinear\")\n                occ2_f_out = occ2_f + upsample2d_as(out_dict['occ1'][ii - 1][0], occ2_f, mode=\"bilinear\")\n                occ2_b_out = occ2_b + upsample2d_as(out_dict['occ1'][ii - 1][1], occ2_b, mode=\"bilinear\")\n\n            ## refine layer\n            flo2_f_out = upsample2d_as(flo2_f_out, conv2_im1, mode=\"bilinear\")  \n            flo2_b_out = upsample2d_as(flo2_b_out, conv2_im2, mode=\"bilinear\") \n            occ2_f_out = upsample2d_as(occ2_f_out, conv2_im1, mode=\"bilinear\")\n            occ2_b_out = upsample2d_as(occ2_b_out, conv2_im2, mode=\"bilinear\")\n\n            img1_resize = upsample2d_as(im1, flo2_f_out, mode=\"bilinear\")\n            img2_resize = upsample2d_as(im2, flo2_b_out, mode=\"bilinear\")\n            img2_warp = self._warping_layer(img2_resize, flo2_f_out, height_im, width_im, self._div_flow)\n            img1_warp = self._warping_layer(img1_resize, flo2_b_out, height_im, width_im, self._div_flow)\n\n            # flow refine\n            flow_f = self.refine_flow(flo2_f_out.detach(), img1_resize - img2_warp, conv2_im1)\n            flow_b = self.refine_flow(flo2_b_out.detach(), img2_resize - img1_warp, conv2_im2)\n\n            # occ refine\n            conv2_im2_warp = self._warping_layer(conv2_im2, flow_f, height_im, width_im, self._div_flow)\n            conv2_im1_warp = self._warping_layer(conv2_im1, flow_b, height_im, width_im, self._div_flow)\n\n            occ_f = self.refine_occ(occ2_f_out.detach(), conv2_im1, conv2_im1 - conv2_im2_warp)\n            occ_b = self.refine_occ(occ2_b_out.detach(), conv2_im2, conv2_im2 - conv2_im1_warp)\n            out_dict['flow1'].append([flow_f, flow_b])\n            out_dict['occ1'].append([occ_f, occ_b])\n\n            ## upsample layer\n            flow_f = upsample2d_as(flow_f, im1, mode=\"bilinear\")\n            flow_b = upsample2d_as(flow_b, im2, mode=\"bilinear\")\n            out_dict['flow'].append([flow_f, flow_b])\n\n            im2_warp = self._warping_layer(im2, flow_f, height_im, width_im, self._div_flow)\n            im1_warp = self._warping_layer(im1, flow_b, height_im, width_im, self._div_flow)\n            flow_b_warp = self._warping_layer(flow_b, flow_f, height_im, width_im, self._div_flow)\n            flow_f_warp = self._warping_layer(flow_f, flow_b, height_im, width_im, self._div_flow)\n\n            occ_f = self.occ_shuffle_upsample(occ_f, torch.cat([im1, im2_warp, flow_f, flow_b_warp], dim=1))\n            occ_b = self.occ_shuffle_upsample(occ_b, torch.cat([im2, im1_warp, flow_b, flow_f_warp], dim=1))\n\n            out_dict['occ'].append([occ_f, occ_b])\n\n            if ii < (self._num_iters - 1):\n                flow_f_resized = upsample2d_as(flow_f, conv3_im2, mode=\"bilinear\")\n                flow_b_resized = upsample2d_as(flow_b, conv3_im1, mode=\"bilinear\")\n                conv3_im2_wp = self._warping_layer(conv3_im2, flow_f_resized, height_im, width_im, self._div_flow)\n                conv3_im1_wp = self._warping_layer(conv3_im1, flow_b_resized, height_im, width_im, self._div_flow)\n                \n        if self.training:\n            return out_dict\n        else:\n            out_dict_eval = {}\n            out_dict_eval['flow'] = upsample2d_as(out_dict['flow'][self._num_iters - 1][0], im1, mode=\"bilinear\") / self._div_flow\n            out_dict_eval['occ'] = upsample2d_as(out_dict['occ'][self._num_iters - 1][0], im1, mode=\"bilinear\")\n            return out_dict_eval\n"
  },
  {
    "path": "models/IRR_PWC.py",
    "content": "from __future__ import absolute_import, division, print_function\n\nimport torch\nimport torch.nn as nn\n\nfrom .pwc_modules import conv, upsample2d_as, rescale_flow, initialize_msra, compute_cost_volume\nfrom .pwc_modules import WarpingLayer, FeatureExtractor, ContextNetwork, FlowEstimatorDense, OccContextNetwork, OccEstimatorDense\nfrom .irr_modules import OccUpsampleNetwork, RefineFlow, RefineOcc\nimport copy\n\n\n\n\nclass PWCNet(nn.Module):\n    def __init__(self, args, div_flow=0.05):\n        super(PWCNet, self).__init__()\n        self.args = args\n        self._div_flow = div_flow\n        self.search_range = 4\n        self.num_chs = [3, 16, 32, 64, 96, 128, 196]\n        self.output_level = 4\n        self.num_levels = 7\n        self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)\n\n        self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)\n        self.warping_layer = WarpingLayer()\n\n        self.dim_corr = (self.search_range * 2 + 1) ** 2\n        self.num_ch_in_flo = self.dim_corr + 32 + 2\n        self.num_ch_in_occ = self.dim_corr + 32 + 1\n\n        self.flow_estimators = FlowEstimatorDense(self.num_ch_in_flo)\n        self.context_networks = ContextNetwork(self.num_ch_in_flo + 448 + 2)\n        self.occ_estimators = OccEstimatorDense(self.num_ch_in_occ)\n        self.occ_context_networks = OccContextNetwork(self.num_ch_in_occ + 448 + 1)\n        self.occ_shuffle_upsample = OccUpsampleNetwork(11, 1)\n\n        self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),\n                                       conv(128, 32, kernel_size=1, stride=1, dilation=1),\n                                       conv(96, 32, kernel_size=1, stride=1, dilation=1),\n                                       conv(64, 32, kernel_size=1, stride=1, dilation=1)])\n\n        self.conv_1x1_1 = conv(16, 3, kernel_size=1, stride=1, dilation=1)\n\n        self.refine_flow = RefineFlow(2 + 1 + 32)\n        self.refine_occ = RefineOcc(1 + 32 + 32)\n        self.corr_params = {\"pad_size\": self.search_range, \"kernel_size\": 1, \"max_disp\": self.search_range, \"stride1\": 1, \"stride2\": 1, \"corr_multiply\": 1}\n\n        initialize_msra(self.modules())\n\n    def forward(self, input_dict):\n\n        x1_raw = input_dict['input1']\n        x2_raw = input_dict['input2']\n        batch_size, _, height_im, width_im = x1_raw.size()\n\n        # on the bottom level are original images\n        x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]\n        x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]\n\n        # outputs\n        output_dict = {}\n        output_dict_eval = {}\n        flows = []\n        occs = []\n\n        _, _, h_x1, w_x1, = x1_pyramid[0].size()\n        flow_f = torch.zeros(batch_size, 2, h_x1, w_x1).float().cuda()\n        flow_b = torch.zeros(batch_size, 2, h_x1, w_x1).float().cuda()\n        occ_f = torch.zeros(batch_size, 1, h_x1, w_x1).float().cuda()\n        occ_b = torch.zeros(batch_size, 1, h_x1, w_x1).float().cuda()\n\n        for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):\n\n            if l <= self.output_level:\n\n                # warping\n                if l == 0:\n                    x2_warp = x2\n                    x1_warp = x1\n                else:\n                    flow_f = upsample2d_as(flow_f, x1, mode=\"bilinear\")\n                    flow_b = upsample2d_as(flow_b, x2, mode=\"bilinear\")\n                    occ_f = upsample2d_as(occ_f, x1, mode=\"bilinear\")\n                    occ_b = upsample2d_as(occ_b, x2, mode=\"bilinear\")\n                    x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)\n                    x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)\n\n                # correlation\n                out_corr_f = compute_cost_volume(x1, x2_warp, self.corr_params)\n                out_corr_b = compute_cost_volume(x2, x1_warp, self.corr_params)\n\n\n                out_corr_relu_f = self.leakyRELU(out_corr_f)\n                out_corr_relu_b = self.leakyRELU(out_corr_b)\n\n                if l != self.output_level:\n                    x1_1by1 = self.conv_1x1[l](x1)\n                    x2_1by1 = self.conv_1x1[l](x2)\n                else:\n                    x1_1by1 = x1\n                    x2_1by1 = x2\n\n                # concat and estimate flow\n                flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=True)\n                flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=True)\n\n                x_intm_f, flow_res_f = self.flow_estimators(torch.cat([out_corr_relu_f, x1_1by1, flow_f], dim=1))\n                x_intm_b, flow_res_b = self.flow_estimators(torch.cat([out_corr_relu_b, x2_1by1, flow_b], dim=1))\n                flow_est_f = flow_f + flow_res_f\n                flow_est_b = flow_b + flow_res_b\n\n                flow_cont_f = flow_est_f + self.context_networks(torch.cat([x_intm_f, flow_est_f], dim=1))\n                flow_cont_b = flow_est_b + self.context_networks(torch.cat([x_intm_b, flow_est_b], dim=1))\n\n                # occ estimation\n                x_intm_occ_f, occ_res_f = self.occ_estimators(torch.cat([out_corr_relu_f, x1_1by1, occ_f], dim=1))\n                x_intm_occ_b, occ_res_b = self.occ_estimators(torch.cat([out_corr_relu_b, x2_1by1, occ_b], dim=1))\n                occ_est_f = occ_f + occ_res_f\n                occ_est_b = occ_b + occ_res_b\n\n                occ_cont_f = occ_est_f + self.occ_context_networks(torch.cat([x_intm_occ_f, occ_est_f], dim=1))\n                occ_cont_b = occ_est_b + self.occ_context_networks(torch.cat([x_intm_occ_b, occ_est_b], dim=1))\n\n                # refinement\n                img1_resize = upsample2d_as(x1_raw, flow_f, mode=\"bilinear\")\n                img2_resize = upsample2d_as(x2_raw, flow_b, mode=\"bilinear\")\n                img2_warp = self.warping_layer(img2_resize, rescale_flow(flow_cont_f, self._div_flow, width_im, height_im, to_local=False), height_im, width_im, self._div_flow)\n                img1_warp = self.warping_layer(img1_resize, rescale_flow(flow_cont_b, self._div_flow, width_im, height_im, to_local=False), height_im, width_im, self._div_flow)\n\n                # flow refine\n                flow_f = self.refine_flow(flow_cont_f.detach(), img1_resize - img2_warp, x1_1by1)\n                flow_b = self.refine_flow(flow_cont_b.detach(), img2_resize - img1_warp, x2_1by1)\n\n                flow_cont_f = rescale_flow(flow_cont_f, self._div_flow, width_im, height_im, to_local=False)\n                flow_cont_b = rescale_flow(flow_cont_b, self._div_flow, width_im, height_im, to_local=False)\n                flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=False)\n                flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=False)\n\n                # occ refine\n                x2_1by1_warp = self.warping_layer(x2_1by1, flow_f, height_im, width_im, self._div_flow)\n                x1_1by1_warp = self.warping_layer(x1_1by1, flow_b, height_im, width_im, self._div_flow)\n\n                occ_f = self.refine_occ(occ_cont_f.detach(), x1_1by1, x1_1by1 - x2_1by1_warp)\n                occ_b = self.refine_occ(occ_cont_b.detach(), x2_1by1, x2_1by1 - x1_1by1_warp)\n\n                flows.append([flow_cont_f, flow_cont_b, flow_f, flow_b])\n                occs.append([occ_cont_f, occ_cont_b, occ_f, occ_b])\n\n            else:\n                flow_f = upsample2d_as(flow_f, x1, mode=\"bilinear\")\n                flow_b = upsample2d_as(flow_b, x2, mode=\"bilinear\")\n                flows.append([flow_f, flow_b])\n\n                x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)\n                x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)\n                flow_b_warp = self.warping_layer(flow_b, flow_f, height_im, width_im, self._div_flow)\n                flow_f_warp = self.warping_layer(flow_f, flow_b, height_im, width_im, self._div_flow)\n\n                if l != self.num_levels-1:\n                    x1_in = self.conv_1x1_1(x1)\n                    x2_in = self.conv_1x1_1(x2)\n                    x1_w_in = self.conv_1x1_1(x1_warp)\n                    x2_w_in = self.conv_1x1_1(x2_warp)\n                else:\n                    x1_in = x1\n                    x2_in = x2\n                    x1_w_in = x1_warp\n                    x2_w_in = x2_warp\n\n                occ_f = self.occ_shuffle_upsample(occ_f, torch.cat([x1_in, x2_w_in, flow_f, flow_b_warp], dim=1))\n                occ_b = self.occ_shuffle_upsample(occ_b, torch.cat([x2_in, x1_w_in, flow_b, flow_f_warp], dim=1))\n\n                occs.append([occ_f, occ_b])\n\n        output_dict_eval['flow'] = upsample2d_as(flow_f, x1_raw, mode=\"bilinear\") * (1.0 / self._div_flow)\n        output_dict_eval['occ'] = upsample2d_as(occ_f, x1_raw, mode=\"bilinear\")\n        output_dict['flow'] = flows\n        output_dict['occ'] = occs\n\n        if self.training:\n            return output_dict\n        else:\n            return output_dict_eval\n"
  },
  {
    "path": "models/__init__.py",
    "content": "from . import flownet1s\nfrom . import flownet1s_irr\nfrom . import flownet1s_irr_bi\nfrom . import flownet1s_irr_occ\nfrom . import flownet1s_irr_occ_bi\nfrom . import IRR_FlowNet\n\nfrom . import pwcnet\nfrom . import pwcnet_bi\nfrom . import pwcnet_occ\nfrom . import pwcnet_occ_bi\nfrom . import pwcnet_irr\nfrom . import pwcnet_irr_bi\nfrom . import pwcnet_irr_occ\nfrom . import pwcnet_irr_occ_bi\nfrom . import IRR_PWC\n\n\nFlowNet1S            = flownet1s.FlowNet1S\nFlowNet1S_irr        = flownet1s_irr.FlowNet1S\nFlowNet1S_irr_bi     = flownet1s_irr_bi.FlowNet1S\nFlowNet1S_irr_occ    = flownet1s_irr_occ.FlowNet1S\nFlowNet1S_irr_occ_bi = flownet1s_irr_occ_bi.FlowNet1S\n\nPWCNet               = pwcnet.PWCNet\nPWCNet_bi            = pwcnet_bi.PWCNet\nPWCNet_occ           = pwcnet_occ.PWCNet\nPWCNet_occ_bi        = pwcnet_occ_bi.PWCNet\nPWCNet_irr           = pwcnet_irr.PWCNet\nPWCNet_irr_bi        = pwcnet_irr_bi.PWCNet\nPWCNet_irr_occ       = pwcnet_irr_occ.PWCNet\nPWCNet_irr_occ_bi    = pwcnet_irr_occ_bi.PWCNet\n\nIRR_FlowNet          = IRR_FlowNet.FlowNet1S\nIRR_PWC              = IRR_PWC.PWCNet\n\n"
  },
  {
    "path": "models/correlation_package/__init__.py",
    "content": ""
  },
  {
    "path": "models/correlation_package/correlation.py",
    "content": "import torch\nfrom torch.nn.modules.module import Module\nfrom torch.autograd import Function\nimport correlation_cuda\n\nclass CorrelationFunction(Function):\n\n    def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):\n        super(CorrelationFunction, self).__init__()\n        self.pad_size = pad_size\n        self.kernel_size = kernel_size\n        self.max_displacement = max_displacement\n        self.stride1 = stride1\n        self.stride2 = stride2\n        self.corr_multiply = corr_multiply\n        # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1)\n\n    def forward(self, input1, input2):\n        self.save_for_backward(input1, input2)\n\n        with torch.cuda.device_of(input1):\n            rbot1 = input1.new()\n            rbot2 = input2.new()\n            output = input1.new()\n\n            correlation_cuda.forward(input1, input2, rbot1, rbot2, output, \n                self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)\n\n        return output\n\n    def backward(self, grad_output):\n        input1, input2 = self.saved_tensors\n\n        with torch.cuda.device_of(input1):\n            rbot1 = input1.new()\n            rbot2 = input2.new()\n\n            grad_input1 = input1.new()\n            grad_input2 = input2.new()\n\n            correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,\n                self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)\n\n        return grad_input1, grad_input2\n\n\nclass Correlation(Module):\n    def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):\n        super(Correlation, self).__init__()\n        self.pad_size = pad_size\n        self.kernel_size = kernel_size\n        self.max_displacement = max_displacement\n        self.stride1 = stride1\n        self.stride2 = stride2\n        self.corr_multiply = corr_multiply\n\n    def forward(self, input1, input2):\n\n        result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement, self.stride1, self.stride2, self.corr_multiply)(input1, input2)\n\n        return result\n\n"
  },
  {
    "path": "models/correlation_package/correlation_cuda.cc",
    "content": "#include <torch/torch.h>\n#include <ATen/ATen.h>\n#include <stdio.h>\n#include <iostream>\n\n#include \"correlation_cuda_kernel.cuh\"\n\nint correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output,\n                       int pad_size,\n                       int kernel_size,\n                       int max_displacement,\n                       int stride1,\n                       int stride2,\n                       int corr_type_multiply)\n{\n\n  int batchSize = input1.size(0);\n\n  int nInputChannels = input1.size(1);\n  int inputHeight = input1.size(2);\n  int inputWidth = input1.size(3);\n\n  int kernel_radius = (kernel_size - 1) / 2;\n  int border_radius = kernel_radius + max_displacement;\n\n  int paddedInputHeight = inputHeight + 2 * pad_size;\n  int paddedInputWidth = inputWidth + 2 * pad_size;\n\n  int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);\n\n  int outputHeight = ceil(static_cast<float>(paddedInputHeight - 2 * border_radius) / static_cast<float>(stride1));\n  int outputwidth = ceil(static_cast<float>(paddedInputWidth - 2 * border_radius) / static_cast<float>(stride1));\n\n  rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});\n  rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});\n  output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth});\n\n  rInput1.fill_(0);\n  rInput2.fill_(0);\n  output.fill_(0);\n\n  int success = correlation_forward_cuda_kernel(\n    output,\n    output.size(0), \n    output.size(1),\n    output.size(2),\n    output.size(3),\n    output.stride(0),\n    output.stride(1),\n    output.stride(2),\n    output.stride(3),\n    input1,\n    input1.size(1),\n    input1.size(2),\n    input1.size(3),\n    input1.stride(0),\n    input1.stride(1),\n    input1.stride(2),\n    input1.stride(3),\n    input2,\n    input2.size(1),\n    input2.stride(0),\n    input2.stride(1),\n    input2.stride(2),\n    input2.stride(3),\n    rInput1,\n    rInput2,\n    pad_size,     \n    kernel_size,\n    max_displacement,\n    stride1,\n    stride2,\n    corr_type_multiply,\n    at::globalContext().getCurrentCUDAStream()\n  );\n\n  //check for errors\n  if (!success) {\n    AT_ERROR(\"CUDA call failed\");\n  }\n\n  return 1;\n\n}\n\nint correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, \n                       at::Tensor& gradInput1, at::Tensor& gradInput2,\n                       int pad_size,\n                       int kernel_size,\n                       int max_displacement,\n                       int stride1,\n                       int stride2,\n                       int corr_type_multiply)\n{\n\n  int batchSize = input1.size(0);\n  int nInputChannels = input1.size(1);\n  int paddedInputHeight = input1.size(2)+ 2 * pad_size;\n  int paddedInputWidth = input1.size(3)+ 2 * pad_size;\n\n  int height = input1.size(2);\n  int width = input1.size(3);\n\n  rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});\n  rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});\n  gradInput1.resize_({batchSize, nInputChannels, height, width});\n  gradInput2.resize_({batchSize, nInputChannels, height, width});\n\n  rInput1.fill_(0);\n  rInput2.fill_(0);\n  gradInput1.fill_(0);\n  gradInput2.fill_(0);\n\n  int success = correlation_backward_cuda_kernel(gradOutput,\n                                                gradOutput.size(0),\n                                                gradOutput.size(1),\n                                                gradOutput.size(2),\n                                                gradOutput.size(3),\n                                                gradOutput.stride(0),\n                                                gradOutput.stride(1),\n                                                gradOutput.stride(2),\n                                                gradOutput.stride(3),\n                                                input1,\n                                                input1.size(1),\n                                                input1.size(2),\n                                                input1.size(3),\n                                                input1.stride(0),\n                                                input1.stride(1),\n                                                input1.stride(2),\n                                                input1.stride(3),\n                                                input2,  \n                                                input2.stride(0),\n                                                input2.stride(1),\n                                                input2.stride(2),\n                                                input2.stride(3),\n                                                gradInput1,\n                                                gradInput1.stride(0),\n                                                gradInput1.stride(1),\n                                                gradInput1.stride(2),\n                                                gradInput1.stride(3),\n                                                gradInput2,\n                                                gradInput2.size(1),\n                                                gradInput2.stride(0),\n                                                gradInput2.stride(1),\n                                                gradInput2.stride(2),\n                                                gradInput2.stride(3),\n                                                rInput1,\n                                                rInput2,\n                                                pad_size,\n                                                kernel_size,\n                                                max_displacement,\n                                                stride1, \n                                                stride2,\n                                                corr_type_multiply,\n                                                at::globalContext().getCurrentCUDAStream()\n                                               );\n\n  if (!success) {\n    AT_ERROR(\"CUDA call failed\");\n  }\n\n  return 1;\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &correlation_forward_cuda, \"Correlation forward (CUDA)\");\n  m.def(\"backward\", &correlation_backward_cuda, \"Correlation backward (CUDA)\");\n}\n\n"
  },
  {
    "path": "models/correlation_package/correlation_cuda_kernel.cu",
    "content": "#include <stdio.h>\n\n#include \"correlation_cuda_kernel.cuh\"\n\n#define CUDA_NUM_THREADS 1024\n#define THREADS_PER_BLOCK 32\n\n#include <ATen/ATen.h>\n#include <ATen/NativeFunctions.h>\n#include <ATen/Dispatch.h>\n#include <ATen/cuda/CUDAApplyUtils.cuh>\n\nusing at::Half;\n\ntemplate <typename scalar_t>\n__global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size)\n{\n\n\t// n (batch size), c (num of channels), y (height), x (width)\n\tint n = blockIdx.x;\n\tint y = blockIdx.y;\n\tint x = blockIdx.z;\n\n\tint ch_off = threadIdx.x;\n\tscalar_t value;\n\n\tint dimcyx = channels * height * width;\n\tint dimyx = height * width;\n\n\tint p_dimx = (width + 2 * pad_size);\n\tint p_dimy = (height + 2 * pad_size);\n\tint p_dimyxc = channels * p_dimy * p_dimx;\n\tint p_dimxc = p_dimx * channels;\n\n\tfor (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) {\n\t\tvalue = input[n * dimcyx + c * dimyx + y * width + x];\n\t\trinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value;\n\t}\n}\n\ntemplate <typename scalar_t>\n__global__ void correlation_forward(scalar_t*  output, int nOutputChannels, int outputHeight, int outputWidth,\n\tconst scalar_t* __restrict__ rInput1, int nInputChannels, int inputHeight, int inputWidth,\n\tconst scalar_t* __restrict__ rInput2,\n\tint pad_size,\n\tint kernel_size,\n\tint max_displacement,\n\tint stride1,\n\tint stride2)\n{\n\t// n (batch size), c (num of channels), y (height), x (width)\n\n\tint pInputWidth = inputWidth + 2 * pad_size;\n\tint pInputHeight = inputHeight + 2 * pad_size;\n\n\tint kernel_rad = (kernel_size - 1) / 2;\n\tint displacement_rad = max_displacement / stride2;\n\tint displacement_size = 2 * displacement_rad + 1;\n\n\tint n = blockIdx.x;\n\tint y1 = blockIdx.y * stride1 + max_displacement;\n\tint x1 = blockIdx.z * stride1 + max_displacement;\n\tint c = threadIdx.x;\n\n\tint pdimyxc = pInputHeight * pInputWidth * nInputChannels;\n\tint pdimxc = pInputWidth * nInputChannels;\n\tint pdimc = nInputChannels;\n\n\tint tdimcyx = nOutputChannels * outputHeight * outputWidth;\n\tint tdimyx = outputHeight * outputWidth;\n\tint tdimx = outputWidth;\n\n\tscalar_t nelems = kernel_size * kernel_size * pdimc;\n\n\t__shared__ scalar_t prod_sum[THREADS_PER_BLOCK];\n\n\t// no significant speed-up in using chip memory for input1 sub-data, \n\t// not enough chip memory size to accomodate memory per block for input2 sub-data\n\t// instead i've used device memory for both \n\n\t// element-wise product along channel axis\n\tfor (int tj = -displacement_rad; tj <= displacement_rad; ++tj) {\n\t\tfor (int ti = -displacement_rad; ti <= displacement_rad; ++ti) {\n\t\t\tprod_sum[c] = 0;\n\t\t\tint x2 = x1 + ti*stride2;\n\t\t\tint y2 = y1 + tj*stride2;\n\n\t\t\tfor (int j = -kernel_rad; j <= kernel_rad; ++j) {\n\t\t\t\tfor (int i = -kernel_rad; i <= kernel_rad; ++i) {\n\t\t\t\t\tfor (int ch = c; ch < pdimc; ch += THREADS_PER_BLOCK) {\n\t\t\t\t\t\tint indx1 = n * pdimyxc + (y1 + j) * pdimxc + (x1 + i) * pdimc + ch;\n\t\t\t\t\t\tint indx2 = n * pdimyxc + (y2 + j) * pdimxc + (x2 + i) * pdimc + ch;\n\n\t\t\t\t\t\tprod_sum[c] += rInput1[indx1] * rInput2[indx2];\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t}\n\n\t\t\t// accumulate \n\t\t\t__syncthreads();\n\t\t\tif (c == 0) {\n\t\t\t\tscalar_t reduce_sum = 0;\n\t\t\t\tfor (int index = 0; index < THREADS_PER_BLOCK; ++index) {\n\t\t\t\t\treduce_sum += prod_sum[index];\n\t\t\t\t}\n\t\t\t\tint tc = (tj + displacement_rad) * displacement_size + (ti + displacement_rad);\n\t\t\t\tconst int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx + blockIdx.z;\n\t\t\t\toutput[tindx] = reduce_sum / nelems;\n\t\t\t}\n\n\t\t}\n\t}\n\n}\n\ntemplate <typename scalar_t>\n__global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth,\n\tconst scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,\n\tconst scalar_t* __restrict__ rInput2,\n\tint pad_size,\n\tint kernel_size,\n\tint max_displacement,\n\tint stride1,\n\tint stride2)\n{\n\t// n (batch size), c (num of channels), y (height), x (width)\n\n\tint n = item;\n\tint y = blockIdx.x * stride1 + pad_size;\n\tint x = blockIdx.y * stride1 + pad_size;\n\tint c = blockIdx.z;\n\tint tch_off = threadIdx.x;\n\n\tint kernel_rad = (kernel_size - 1) / 2;\n\tint displacement_rad = max_displacement / stride2;\n\tint displacement_size = 2 * displacement_rad + 1;\n\n\tint xmin = (x - kernel_rad - max_displacement) / stride1;\n\tint ymin = (y - kernel_rad - max_displacement) / stride1;\n\n\tint xmax = (x + kernel_rad - max_displacement) / stride1;\n\tint ymax = (y + kernel_rad - max_displacement) / stride1;\n\n\tif (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {\n\t\t// assumes gradInput1 is pre-allocated and zero filled\n\t\treturn;\n\t}\n\n\tif (xmin > xmax || ymin > ymax) {\n\t\t// assumes gradInput1 is pre-allocated and zero filled\n\t\treturn;\n\t}\n\n\txmin = max(0, xmin);\n\txmax = min(outputWidth - 1, xmax);\n\n\tymin = max(0, ymin);\n\tymax = min(outputHeight - 1, ymax);\n\n\tint pInputWidth = inputWidth + 2 * pad_size;\n\tint pInputHeight = inputHeight + 2 * pad_size;\n\n\tint pdimyxc = pInputHeight * pInputWidth * nInputChannels;\n\tint pdimxc = pInputWidth * nInputChannels;\n\tint pdimc = nInputChannels;\n\n\tint tdimcyx = nOutputChannels * outputHeight * outputWidth;\n\tint tdimyx = outputHeight * outputWidth;\n\tint tdimx = outputWidth;\n\n\tint odimcyx = nInputChannels * inputHeight* inputWidth;\n\tint odimyx = inputHeight * inputWidth;\n\tint odimx = inputWidth;\n\n\tscalar_t nelems = kernel_size * kernel_size * nInputChannels;\n\n\t__shared__ scalar_t prod_sum[THREADS_PER_BLOCK];\n\tprod_sum[tch_off] = 0;\n\n\tfor (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {\n\n\t\tint i2 = (tc % displacement_size - displacement_rad) * stride2;\n\t\tint j2 = (tc / displacement_size - displacement_rad) * stride2;\n\n\t\tint indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c;\n\n\t\tscalar_t val2 = rInput2[indx2];\n\n\t\tfor (int j = ymin; j <= ymax; ++j) {\n\t\t\tfor (int i = xmin; i <= xmax; ++i) {\n\t\t\t\tint tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;\n\t\t\t\tprod_sum[tch_off] += gradOutput[tindx] * val2;\n\t\t\t}\n\t\t}\n\t}\n\t__syncthreads();\n\n\tif (tch_off == 0) {\n\t\tscalar_t reduce_sum = 0;\n\t\tfor (int idx = 0; idx < THREADS_PER_BLOCK; idx++) {\n\t\t\treduce_sum += prod_sum[idx];\n\t\t}\n\t\tconst int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);\n\t\tgradInput1[indx1] = reduce_sum / nelems;\n\t}\n\n}\n\ntemplate <typename scalar_t>\n__global__ void correlation_backward_input2(int item, scalar_t*  gradInput2, int nInputChannels, int inputHeight, int inputWidth,\n\tconst scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,\n\tconst scalar_t* __restrict__ rInput1,\n\tint pad_size,\n\tint kernel_size,\n\tint max_displacement,\n\tint stride1,\n\tint stride2)\n{\n\t// n (batch size), c (num of channels), y (height), x (width)\n\n\tint n = item;\n\tint y = blockIdx.x * stride1 + pad_size;\n\tint x = blockIdx.y * stride1 + pad_size;\n\tint c = blockIdx.z;\n\n\tint tch_off = threadIdx.x;\n\n\tint kernel_rad = (kernel_size - 1) / 2;\n\tint displacement_rad = max_displacement / stride2;\n\tint displacement_size = 2 * displacement_rad + 1;\n\n\tint pInputWidth = inputWidth + 2 * pad_size;\n\tint pInputHeight = inputHeight + 2 * pad_size;\n\n\tint pdimyxc = pInputHeight * pInputWidth * nInputChannels;\n\tint pdimxc = pInputWidth * nInputChannels;\n\tint pdimc = nInputChannels;\n\n\tint tdimcyx = nOutputChannels * outputHeight * outputWidth;\n\tint tdimyx = outputHeight * outputWidth;\n\tint tdimx = outputWidth;\n\n\tint odimcyx = nInputChannels * inputHeight* inputWidth;\n\tint odimyx = inputHeight * inputWidth;\n\tint odimx = inputWidth;\n\n\tscalar_t nelems = kernel_size * kernel_size * nInputChannels;\n\n\t__shared__ scalar_t prod_sum[THREADS_PER_BLOCK];\n\tprod_sum[tch_off] = 0;\n\n\tfor (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {\n\t\tint i2 = (tc % displacement_size - displacement_rad) * stride2;\n\t\tint j2 = (tc / displacement_size - displacement_rad) * stride2;\n\n\t\tint xmin = (x - kernel_rad - max_displacement - i2) / stride1;\n\t\tint ymin = (y - kernel_rad - max_displacement - j2) / stride1;\n\n\t\tint xmax = (x + kernel_rad - max_displacement - i2) / stride1;\n\t\tint ymax = (y + kernel_rad - max_displacement - j2) / stride1;\n\n\t\tif (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {\n\t\t\t// assumes gradInput2 is pre-allocated and zero filled\n\t\t\tcontinue;\n\t\t}\n\n\t\tif (xmin > xmax || ymin > ymax) {\n\t\t\t// assumes gradInput2 is pre-allocated and zero filled\n\t\t\tcontinue;\n\t\t}\n\n\t\txmin = max(0, xmin);\n\t\txmax = min(outputWidth - 1, xmax);\n\n\t\tymin = max(0, ymin);\n\t\tymax = min(outputHeight - 1, ymax);\n\n\t\tint indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c;\n\t\tscalar_t val1 = rInput1[indx1];\n\n\t\tfor (int j = ymin; j <= ymax; ++j) {\n\t\t\tfor (int i = xmin; i <= xmax; ++i) {\n\t\t\t\tint tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;\n\t\t\t\tprod_sum[tch_off] += gradOutput[tindx] * val1;\n\t\t\t}\n\t\t}\n\t}\n\n\t__syncthreads();\n\n\tif (tch_off == 0) {\n\t\tscalar_t reduce_sum = 0;\n\t\tfor (int idx = 0; idx < THREADS_PER_BLOCK; idx++) {\n\t\t\treduce_sum += prod_sum[idx];\n\t\t}\n\t\tconst int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);\n\t\tgradInput2[indx2] = reduce_sum / nelems;\n\t}\n\n}\n\nint correlation_forward_cuda_kernel(at::Tensor& output,\n\tint ob,\n\tint oc,\n\tint oh,\n\tint ow,\n\tint osb,\n\tint osc,\n\tint osh,\n\tint osw,\n\n\tat::Tensor& input1,\n\tint ic,\n\tint ih,\n\tint iw,\n\tint isb,\n\tint isc,\n\tint ish,\n\tint isw,\n\n\tat::Tensor& input2,\n\tint gc,\n\tint gsb,\n\tint gsc,\n\tint gsh,\n\tint gsw,\n\n\tat::Tensor& rInput1,\n\tat::Tensor& rInput2,\n\tint pad_size,\n\tint kernel_size,\n\tint max_displacement,\n\tint stride1,\n\tint stride2,\n\tint corr_type_multiply,\n\tcudaStream_t stream)\n{\n\n\tint batchSize = ob;\n\n\tint nInputChannels = ic;\n\tint inputWidth = iw;\n\tint inputHeight = ih;\n\n\tint nOutputChannels = oc;\n\tint outputWidth = ow;\n\tint outputHeight = oh;\n\n\tdim3 blocks_grid(batchSize, inputHeight, inputWidth);\n\tdim3 threads_block(THREADS_PER_BLOCK);\n\n\tAT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), \"channels_first_fwd_1\", ([&] {\n\n\t\tchannels_first<scalar_t> << <blocks_grid, threads_block, 0, stream >> >(\n\t\t\tinput1.data<scalar_t>(), rInput1.data<scalar_t>(), nInputChannels, inputHeight, inputWidth, pad_size);\n\n\t}));\n\n\tAT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), \"channels_first_fwd_2\", ([&] {\n\n\t\tchannels_first<scalar_t> << <blocks_grid, threads_block, 0, stream >> > (\n\t\t\tinput2.data<scalar_t>(), rInput2.data<scalar_t>(), nInputChannels, inputHeight, inputWidth, pad_size);\n\n\t}));\n\n\tdim3 threadsPerBlock(THREADS_PER_BLOCK);\n\tdim3 totalBlocksCorr(batchSize, outputHeight, outputWidth);\n\n\tAT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), \"correlation_forward\", ([&] {\n\n\t\tcorrelation_forward<scalar_t> << <totalBlocksCorr, threadsPerBlock, 0, stream >> >\n\t\t\t(output.data<scalar_t>(), nOutputChannels, outputHeight, outputWidth,\n\t\t\trInput1.data<scalar_t>(), nInputChannels, inputHeight, inputWidth,\n\t\t\trInput2.data<scalar_t>(),\n\t\t\tpad_size,\n\t\t\tkernel_size,\n\t\t\tmax_displacement,\n\t\t\tstride1,\n\t\t\tstride2);\n\n\t}));\n\n\tcudaError_t err = cudaGetLastError();\n\n\n\t// check for errors\n\tif (err != cudaSuccess) {\n\t\tprintf(\"error in correlation_forward_cuda_kernel: %s\\n\", cudaGetErrorString(err));\n\t\treturn 0;\n\t}\n\n\treturn 1;\n}\n\n\nint correlation_backward_cuda_kernel(\n\tat::Tensor& gradOutput,\n\tint gob,\n\tint goc,\n\tint goh,\n\tint gow,\n\tint gosb,\n\tint gosc,\n\tint gosh,\n\tint gosw,\n\n\tat::Tensor& input1,\n\tint ic,\n\tint ih,\n\tint iw,\n\tint isb,\n\tint isc,\n\tint ish,\n\tint isw,\n\n\tat::Tensor& input2,\n\tint gsb,\n\tint gsc,\n\tint gsh,\n\tint gsw,\n\n\tat::Tensor& gradInput1,\n\tint gisb,\n\tint gisc,\n\tint gish,\n\tint gisw,\n\n\tat::Tensor& gradInput2,\n\tint ggc,\n\tint ggsb,\n\tint ggsc,\n\tint ggsh,\n\tint ggsw,\n\n\tat::Tensor& rInput1,\n\tat::Tensor& rInput2,\n\tint pad_size,\n\tint kernel_size,\n\tint max_displacement,\n\tint stride1,\n\tint stride2,\n\tint corr_type_multiply,\n\tcudaStream_t stream)\n{\n\n\tint batchSize = gob;\n\tint num = batchSize;\n\n\tint nInputChannels = ic;\n\tint inputWidth = iw;\n\tint inputHeight = ih;\n\n\tint nOutputChannels = goc;\n\tint outputWidth = gow;\n\tint outputHeight = goh;\n\n\tdim3 blocks_grid(batchSize, inputHeight, inputWidth);\n\tdim3 threads_block(THREADS_PER_BLOCK);\n\n\n\tAT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), \"lltm_forward_cuda\", ([&] {\n\n\t\tchannels_first<scalar_t> << <blocks_grid, threads_block, 0, stream >> >(\n\t\t\tinput1.data<scalar_t>(),\n\t\t\trInput1.data<scalar_t>(),\n\t\t\tnInputChannels,\n\t\t\tinputHeight,\n\t\t\tinputWidth,\n\t\t\tpad_size\n\t\t\t);\n\t}));\n\n\tAT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), \"lltm_forward_cuda\", ([&] {\n\n\t\tchannels_first<scalar_t> << <blocks_grid, threads_block, 0, stream >> >(\n\t\t\tinput2.data<scalar_t>(),\n\t\t\trInput2.data<scalar_t>(),\n\t\t\tnInputChannels,\n\t\t\tinputHeight,\n\t\t\tinputWidth,\n\t\t\tpad_size\n\t\t\t);\n\t}));\n\n\tdim3 threadsPerBlock(THREADS_PER_BLOCK);\n\tdim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels);\n\n\tfor (int n = 0; n < num; ++n) {\n\n\t\tAT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), \"lltm_forward_cuda\", ([&] {\n\n\n\t\t\tcorrelation_backward_input1<scalar_t> << <totalBlocksCorr, threadsPerBlock, 0, stream >> > (\n\t\t\t\tn, gradInput1.data<scalar_t>(), nInputChannels, inputHeight, inputWidth,\n\t\t\t\tgradOutput.data<scalar_t>(), nOutputChannels, outputHeight, outputWidth,\n\t\t\t\trInput2.data<scalar_t>(),\n\t\t\t\tpad_size,\n\t\t\t\tkernel_size,\n\t\t\t\tmax_displacement,\n\t\t\t\tstride1,\n\t\t\t\tstride2);\n\t\t}));\n\t}\n\n\tfor (int n = 0; n < batchSize; n++) {\n\n\t\tAT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), \"lltm_forward_cuda\", ([&] {\n\n\t\t\tcorrelation_backward_input2<scalar_t> << <totalBlocksCorr, threadsPerBlock, 0, stream >> >(\n\t\t\t\tn, gradInput2.data<scalar_t>(), nInputChannels, inputHeight, inputWidth,\n\t\t\t\tgradOutput.data<scalar_t>(), nOutputChannels, outputHeight, outputWidth,\n\t\t\t\trInput1.data<scalar_t>(),\n\t\t\t\tpad_size,\n\t\t\t\tkernel_size,\n\t\t\t\tmax_displacement,\n\t\t\t\tstride1,\n\t\t\t\tstride2);\n\n\t\t}));\n\t}\n\n\t// check for errors\n\tcudaError_t err = cudaGetLastError();\n\tif (err != cudaSuccess) {\n\t\tprintf(\"error in correlation_backward_cuda_kernel: %s\\n\", cudaGetErrorString(err));\n\t\treturn 0;\n\t}\n\n\treturn 1;\n}\n"
  },
  {
    "path": "models/correlation_package/correlation_cuda_kernel.cuh",
    "content": "#pragma once\n\n#include <ATen/ATen.h>\n#include <ATen/Context.h>\n#include <cuda_runtime.h>\n\nint correlation_forward_cuda_kernel(at::Tensor& output,\n    int ob,\n    int oc,\n    int oh,\n    int ow,\n    int osb,\n    int osc,\n    int osh,\n    int osw,\n\n    at::Tensor& input1,\n    int ic,\n    int ih,\n    int iw,\n    int isb,\n    int isc,\n    int ish,\n    int isw,\n\n    at::Tensor& input2,\n    int gc,\n    int gsb,\n    int gsc,\n    int gsh,\n    int gsw,\n\n    at::Tensor& rInput1,\n    at::Tensor& rInput2,\n    int pad_size,\n    int kernel_size,\n    int max_displacement,\n    int stride1,\n    int stride2,\n    int corr_type_multiply,\n    cudaStream_t stream);\n\n\nint correlation_backward_cuda_kernel(   \n    at::Tensor& gradOutput,\n    int gob,\n    int goc,\n    int goh,\n    int gow,\n    int gosb,\n    int gosc,\n    int gosh,\n    int gosw,\n\n    at::Tensor& input1,\n    int ic,\n    int ih,\n    int iw,\n    int isb,\n    int isc,\n    int ish,\n    int isw,\n\n    at::Tensor& input2,\n    int gsb,\n    int gsc,\n    int gsh,\n    int gsw,\n\n    at::Tensor& gradInput1, \n    int gisb,\n    int gisc,\n    int gish,\n    int gisw,\n\n    at::Tensor& gradInput2,\n    int ggc,\n    int ggsb,\n    int ggsc,\n    int ggsh,\n    int ggsw,\n\n    at::Tensor& rInput1,\n    at::Tensor& rInput2,\n    int pad_size,\n    int kernel_size,\n    int max_displacement,\n    int stride1,\n    int stride2,\n    int corr_type_multiply,\n    cudaStream_t stream);\n"
  },
  {
    "path": "models/correlation_package/setup.py",
    "content": "#!/usr/bin/env python3\nimport os\nimport torch\n\nfrom setuptools import setup, find_packages\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\ncxx_args = ['-std=c++11']\n\nnvcc_args = [\n    '-gencode', 'arch=compute_50,code=sm_50',\n    '-gencode', 'arch=compute_52,code=sm_52',\n    '-gencode', 'arch=compute_60,code=sm_60',\n    '-gencode', 'arch=compute_61,code=sm_61',\n    '-gencode', 'arch=compute_61,code=compute_61'\n]\n\nsetup(\n    name='correlation_cuda',\n    ext_modules=[\n        CUDAExtension('correlation_cuda', [\n            'correlation_cuda.cc',\n            'correlation_cuda_kernel.cu'\n        ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})\n    ],\n    cmdclass={\n        'build_ext': BuildExtension\n    })\n"
  },
  {
    "path": "models/correlation_package_cu9/__init__.py",
    "content": ""
  },
  {
    "path": "models/correlation_package_cu9/correlation.py",
    "content": "import torch\nfrom torch.nn.modules.module import Module\nfrom torch.autograd import Function\nimport correlation_cuda\n\nclass CorrelationFunction(Function):\n\n    def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):\n        super(CorrelationFunction, self).__init__()\n        self.pad_size = pad_size\n        self.kernel_size = kernel_size\n        self.max_displacement = max_displacement\n        self.stride1 = stride1\n        self.stride2 = stride2\n        self.corr_multiply = corr_multiply\n        # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1)\n\n    def forward(self, input1, input2):\n        self.save_for_backward(input1, input2)\n\n        with torch.cuda.device_of(input1):\n            rbot1 = input1.new()\n            rbot2 = input2.new()\n            output = input1.new()\n\n            correlation_cuda.forward(input1, input2, rbot1, rbot2, output, \n                self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)\n\n        return output\n\n    def backward(self, grad_output):\n        input1, input2 = self.saved_tensors\n\n        with torch.cuda.device_of(input1):\n            rbot1 = input1.new()\n            rbot2 = input2.new()\n\n            grad_input1 = input1.new()\n            grad_input2 = input2.new()\n\n            correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,\n                self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)\n\n        return grad_input1, grad_input2\n\n\nclass Correlation(Module):\n    def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):\n        super(Correlation, self).__init__()\n        self.pad_size = pad_size\n        self.kernel_size = kernel_size\n        self.max_displacement = max_displacement\n        self.stride1 = stride1\n        self.stride2 = stride2\n        self.corr_multiply = corr_multiply\n\n    def forward(self, input1, input2):\n\n        result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement, self.stride1, self.stride2, self.corr_multiply)(input1, input2)\n\n        return result\n\n"
  },
  {
    "path": "models/correlation_package_cu9/correlation_cuda.cc",
    "content": "#include <torch/extension.h>\n#include <ATen/ATen.h>\n#include <ATen/Context.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <stdio.h>\n#include <iostream>\n\n#include \"correlation_cuda_kernel.cuh\"\n\nint correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output,\n                       int pad_size,\n                       int kernel_size,\n                       int max_displacement,\n                       int stride1,\n                       int stride2,\n                       int corr_type_multiply)\n{\n\n  int batchSize = input1.size(0);\n\n  int nInputChannels = input1.size(1);\n  int inputHeight = input1.size(2);\n  int inputWidth = input1.size(3);\n\n  int kernel_radius = (kernel_size - 1) / 2;\n  int border_radius = kernel_radius + max_displacement;\n\n  int paddedInputHeight = inputHeight + 2 * pad_size;\n  int paddedInputWidth = inputWidth + 2 * pad_size;\n\n  int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);\n\n  int outputHeight = ceil(static_cast<float>(paddedInputHeight - 2 * border_radius) / static_cast<float>(stride1));\n  int outputwidth = ceil(static_cast<float>(paddedInputWidth - 2 * border_radius) / static_cast<float>(stride1));\n\n  rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});\n  rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});\n  output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth});\n\n  rInput1.fill_(0);\n  rInput2.fill_(0);\n  output.fill_(0);\n\n  int success = correlation_forward_cuda_kernel(\n    output,\n    output.size(0), \n    output.size(1),\n    output.size(2),\n    output.size(3),\n    output.stride(0),\n    output.stride(1),\n    output.stride(2),\n    output.stride(3),\n    input1,\n    input1.size(1),\n    input1.size(2),\n    input1.size(3),\n    input1.stride(0),\n    input1.stride(1),\n    input1.stride(2),\n    input1.stride(3),\n    input2,\n    input2.size(1),\n    input2.stride(0),\n    input2.stride(1),\n    input2.stride(2),\n    input2.stride(3),\n    rInput1,\n    rInput2,\n    pad_size,     \n    kernel_size,\n    max_displacement,\n    stride1,\n    stride2,\n    corr_type_multiply,\n  at::cuda::getCurrentCUDAStream()\n  //at::globalContext().getCurrentCUDAStream()\n  );\n\n  //check for errors\n  if (!success) {\n    AT_ERROR(\"CUDA call failed\");\n  }\n\n  return 1;\n\n}\n\nint correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, \n                       at::Tensor& gradInput1, at::Tensor& gradInput2,\n                       int pad_size,\n                       int kernel_size,\n                       int max_displacement,\n                       int stride1,\n                       int stride2,\n                       int corr_type_multiply)\n{\n\n  int batchSize = input1.size(0);\n  int nInputChannels = input1.size(1);\n  int paddedInputHeight = input1.size(2)+ 2 * pad_size;\n  int paddedInputWidth = input1.size(3)+ 2 * pad_size;\n\n  int height = input1.size(2);\n  int width = input1.size(3);\n\n  rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});\n  rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});\n  gradInput1.resize_({batchSize, nInputChannels, height, width});\n  gradInput2.resize_({batchSize, nInputChannels, height, width});\n\n  rInput1.fill_(0);\n  rInput2.fill_(0);\n  gradInput1.fill_(0);\n  gradInput2.fill_(0);\n\n  int success = correlation_backward_cuda_kernel(gradOutput,\n                                                gradOutput.size(0),\n                                                gradOutput.size(1),\n                                                gradOutput.size(2),\n                                                gradOutput.size(3),\n                                                gradOutput.stride(0),\n                                                gradOutput.stride(1),\n                                                gradOutput.stride(2),\n                                                gradOutput.stride(3),\n                                                input1,\n                                                input1.size(1),\n                                                input1.size(2),\n                                                input1.size(3),\n                                                input1.stride(0),\n                                                input1.stride(1),\n                                                input1.stride(2),\n                                                input1.stride(3),\n                                                input2,  \n                                                input2.stride(0),\n                                                input2.stride(1),\n                                                input2.stride(2),\n                                                input2.stride(3),\n                                                gradInput1,\n                                                gradInput1.stride(0),\n                                                gradInput1.stride(1),\n                                                gradInput1.stride(2),\n                                                gradInput1.stride(3),\n                                                gradInput2,\n                                                gradInput2.size(1),\n                                                gradInput2.stride(0),\n                                                gradInput2.stride(1),\n                                                gradInput2.stride(2),\n                                                gradInput2.stride(3),\n                                                rInput1,\n                                                rInput2,\n                                                pad_size,\n                                                kernel_size,\n                                                max_displacement,\n                                                stride1, \n                                                stride2,\n                                                corr_type_multiply,\n                        at::cuda::getCurrentCUDAStream()\n                                                //at::globalContext().getCurrentCUDAStream()\n                                               );\n\n  if (!success) {\n    AT_ERROR(\"CUDA call failed\");\n  }\n\n  return 1;\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &correlation_forward_cuda, \"Correlation forward (CUDA)\");\n  m.def(\"backward\", &correlation_backward_cuda, \"Correlation backward (CUDA)\");\n}\n\n"
  },
  {
    "path": "models/correlation_package_cu9/correlation_cuda_kernel.cu",
    "content": "#include <stdio.h>\n\n#include \"correlation_cuda_kernel.cuh\"\n\n#define CUDA_NUM_THREADS 1024\n#define THREADS_PER_BLOCK 32\n\n#include <ATen/ATen.h>\n#include <ATen/NativeFunctions.h>\n#include <ATen/Dispatch.h>\n#include <ATen/cuda/CUDAApplyUtils.cuh>\n\nusing at::Half;\n\ntemplate <typename scalar_t>\n__global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size)\n{\n\n\t// n (batch size), c (num of channels), y (height), x (width)\n\tint n = blockIdx.x;\n\tint y = blockIdx.y;\n\tint x = blockIdx.z;\n\n\tint ch_off = threadIdx.x;\n\tscalar_t value;\n\n\tint dimcyx = channels * height * width;\n\tint dimyx = height * width;\n\n\tint p_dimx = (width + 2 * pad_size);\n\tint p_dimy = (height + 2 * pad_size);\n\tint p_dimyxc = channels * p_dimy * p_dimx;\n\tint p_dimxc = p_dimx * channels;\n\n\tfor (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) {\n\t\tvalue = input[n * dimcyx + c * dimyx + y * width + x];\n\t\trinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value;\n\t}\n}\n\ntemplate <typename scalar_t>\n__global__ void correlation_forward(scalar_t*  output, int nOutputChannels, int outputHeight, int outputWidth,\n\tconst scalar_t* __restrict__ rInput1, int nInputChannels, int inputHeight, int inputWidth,\n\tconst scalar_t* __restrict__ rInput2,\n\tint pad_size,\n\tint kernel_size,\n\tint max_displacement,\n\tint stride1,\n\tint stride2)\n{\n\t// n (batch size), c (num of channels), y (height), x (width)\n\n\tint pInputWidth = inputWidth + 2 * pad_size;\n\tint pInputHeight = inputHeight + 2 * pad_size;\n\n\tint kernel_rad = (kernel_size - 1) / 2;\n\tint displacement_rad = max_displacement / stride2;\n\tint displacement_size = 2 * displacement_rad + 1;\n\n\tint n = blockIdx.x;\n\tint y1 = blockIdx.y * stride1 + max_displacement;\n\tint x1 = blockIdx.z * stride1 + max_displacement;\n\tint c = threadIdx.x;\n\n\tint pdimyxc = pInputHeight * pInputWidth * nInputChannels;\n\tint pdimxc = pInputWidth * nInputChannels;\n\tint pdimc = nInputChannels;\n\n\tint tdimcyx = nOutputChannels * outputHeight * outputWidth;\n\tint tdimyx = outputHeight * outputWidth;\n\tint tdimx = outputWidth;\n\n\tscalar_t nelems = kernel_size * kernel_size * pdimc;\n\n\t__shared__ scalar_t prod_sum[THREADS_PER_BLOCK];\n\n\t// no significant speed-up in using chip memory for input1 sub-data, \n\t// not enough chip memory size to accomodate memory per block for input2 sub-data\n\t// instead i've used device memory for both \n\n\t// element-wise product along channel axis\n\tfor (int tj = -displacement_rad; tj <= displacement_rad; ++tj) {\n\t\tfor (int ti = -displacement_rad; ti <= displacement_rad; ++ti) {\n\t\t\tprod_sum[c] = 0;\n\t\t\tint x2 = x1 + ti*stride2;\n\t\t\tint y2 = y1 + tj*stride2;\n\n\t\t\tfor (int j = -kernel_rad; j <= kernel_rad; ++j) {\n\t\t\t\tfor (int i = -kernel_rad; i <= kernel_rad; ++i) {\n\t\t\t\t\tfor (int ch = c; ch < pdimc; ch += THREADS_PER_BLOCK) {\n\t\t\t\t\t\tint indx1 = n * pdimyxc + (y1 + j) * pdimxc + (x1 + i) * pdimc + ch;\n\t\t\t\t\t\tint indx2 = n * pdimyxc + (y2 + j) * pdimxc + (x2 + i) * pdimc + ch;\n\n\t\t\t\t\t\tprod_sum[c] += rInput1[indx1] * rInput2[indx2];\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t}\n\n\t\t\t// accumulate \n\t\t\t__syncthreads();\n\t\t\tif (c == 0) {\n\t\t\t\tscalar_t reduce_sum = 0;\n\t\t\t\tfor (int index = 0; index < THREADS_PER_BLOCK; ++index) {\n\t\t\t\t\treduce_sum += prod_sum[index];\n\t\t\t\t}\n\t\t\t\tint tc = (tj + displacement_rad) * displacement_size + (ti + displacement_rad);\n\t\t\t\tconst int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx + blockIdx.z;\n\t\t\t\toutput[tindx] = reduce_sum / nelems;\n\t\t\t}\n\n\t\t}\n\t}\n\n}\n\ntemplate <typename scalar_t>\n__global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth,\n\tconst scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,\n\tconst scalar_t* __restrict__ rInput2,\n\tint pad_size,\n\tint kernel_size,\n\tint max_displacement,\n\tint stride1,\n\tint stride2)\n{\n\t// n (batch size), c (num of channels), y (height), x (width)\n\n\tint n = item;\n\tint y = blockIdx.x * stride1 + pad_size;\n\tint x = blockIdx.y * stride1 + pad_size;\n\tint c = blockIdx.z;\n\tint tch_off = threadIdx.x;\n\n\tint kernel_rad = (kernel_size - 1) / 2;\n\tint displacement_rad = max_displacement / stride2;\n\tint displacement_size = 2 * displacement_rad + 1;\n\n\tint xmin = (x - kernel_rad - max_displacement) / stride1;\n\tint ymin = (y - kernel_rad - max_displacement) / stride1;\n\n\tint xmax = (x + kernel_rad - max_displacement) / stride1;\n\tint ymax = (y + kernel_rad - max_displacement) / stride1;\n\n\tif (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {\n\t\t// assumes gradInput1 is pre-allocated and zero filled\n\t\treturn;\n\t}\n\n\tif (xmin > xmax || ymin > ymax) {\n\t\t// assumes gradInput1 is pre-allocated and zero filled\n\t\treturn;\n\t}\n\n\txmin = max(0, xmin);\n\txmax = min(outputWidth - 1, xmax);\n\n\tymin = max(0, ymin);\n\tymax = min(outputHeight - 1, ymax);\n\n\tint pInputWidth = inputWidth + 2 * pad_size;\n\tint pInputHeight = inputHeight + 2 * pad_size;\n\n\tint pdimyxc = pInputHeight * pInputWidth * nInputChannels;\n\tint pdimxc = pInputWidth * nInputChannels;\n\tint pdimc = nInputChannels;\n\n\tint tdimcyx = nOutputChannels * outputHeight * outputWidth;\n\tint tdimyx = outputHeight * outputWidth;\n\tint tdimx = outputWidth;\n\n\tint odimcyx = nInputChannels * inputHeight* inputWidth;\n\tint odimyx = inputHeight * inputWidth;\n\tint odimx = inputWidth;\n\n\tscalar_t nelems = kernel_size * kernel_size * nInputChannels;\n\n\t__shared__ scalar_t prod_sum[THREADS_PER_BLOCK];\n\tprod_sum[tch_off] = 0;\n\n\tfor (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {\n\n\t\tint i2 = (tc % displacement_size - displacement_rad) * stride2;\n\t\tint j2 = (tc / displacement_size - displacement_rad) * stride2;\n\n\t\tint indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c;\n\n\t\tscalar_t val2 = rInput2[indx2];\n\n\t\tfor (int j = ymin; j <= ymax; ++j) {\n\t\t\tfor (int i = xmin; i <= xmax; ++i) {\n\t\t\t\tint tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;\n\t\t\t\tprod_sum[tch_off] += gradOutput[tindx] * val2;\n\t\t\t}\n\t\t}\n\t}\n\t__syncthreads();\n\n\tif (tch_off == 0) {\n\t\tscalar_t reduce_sum = 0;\n\t\tfor (int idx = 0; idx < THREADS_PER_BLOCK; idx++) {\n\t\t\treduce_sum += prod_sum[idx];\n\t\t}\n\t\tconst int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);\n\t\tgradInput1[indx1] = reduce_sum / nelems;\n\t}\n\n}\n\ntemplate <typename scalar_t>\n__global__ void correlation_backward_input2(int item, scalar_t*  gradInput2, int nInputChannels, int inputHeight, int inputWidth,\n\tconst scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,\n\tconst scalar_t* __restrict__ rInput1,\n\tint pad_size,\n\tint kernel_size,\n\tint max_displacement,\n\tint stride1,\n\tint stride2)\n{\n\t// n (batch size), c (num of channels), y (height), x (width)\n\n\tint n = item;\n\tint y = blockIdx.x * stride1 + pad_size;\n\tint x = blockIdx.y * stride1 + pad_size;\n\tint c = blockIdx.z;\n\n\tint tch_off = threadIdx.x;\n\n\tint kernel_rad = (kernel_size - 1) / 2;\n\tint displacement_rad = max_displacement / stride2;\n\tint displacement_size = 2 * displacement_rad + 1;\n\n\tint pInputWidth = inputWidth + 2 * pad_size;\n\tint pInputHeight = inputHeight + 2 * pad_size;\n\n\tint pdimyxc = pInputHeight * pInputWidth * nInputChannels;\n\tint pdimxc = pInputWidth * nInputChannels;\n\tint pdimc = nInputChannels;\n\n\tint tdimcyx = nOutputChannels * outputHeight * outputWidth;\n\tint tdimyx = outputHeight * outputWidth;\n\tint tdimx = outputWidth;\n\n\tint odimcyx = nInputChannels * inputHeight* inputWidth;\n\tint odimyx = inputHeight * inputWidth;\n\tint odimx = inputWidth;\n\n\tscalar_t nelems = kernel_size * kernel_size * nInputChannels;\n\n\t__shared__ scalar_t prod_sum[THREADS_PER_BLOCK];\n\tprod_sum[tch_off] = 0;\n\n\tfor (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {\n\t\tint i2 = (tc % displacement_size - displacement_rad) * stride2;\n\t\tint j2 = (tc / displacement_size - displacement_rad) * stride2;\n\n\t\tint xmin = (x - kernel_rad - max_displacement - i2) / stride1;\n\t\tint ymin = (y - kernel_rad - max_displacement - j2) / stride1;\n\n\t\tint xmax = (x + kernel_rad - max_displacement - i2) / stride1;\n\t\tint ymax = (y + kernel_rad - max_displacement - j2) / stride1;\n\n\t\tif (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {\n\t\t\t// assumes gradInput2 is pre-allocated and zero filled\n\t\t\tcontinue;\n\t\t}\n\n\t\tif (xmin > xmax || ymin > ymax) {\n\t\t\t// assumes gradInput2 is pre-allocated and zero filled\n\t\t\tcontinue;\n\t\t}\n\n\t\txmin = max(0, xmin);\n\t\txmax = min(outputWidth - 1, xmax);\n\n\t\tymin = max(0, ymin);\n\t\tymax = min(outputHeight - 1, ymax);\n\n\t\tint indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c;\n\t\tscalar_t val1 = rInput1[indx1];\n\n\t\tfor (int j = ymin; j <= ymax; ++j) {\n\t\t\tfor (int i = xmin; i <= xmax; ++i) {\n\t\t\t\tint tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;\n\t\t\t\tprod_sum[tch_off] += gradOutput[tindx] * val1;\n\t\t\t}\n\t\t}\n\t}\n\n\t__syncthreads();\n\n\tif (tch_off == 0) {\n\t\tscalar_t reduce_sum = 0;\n\t\tfor (int idx = 0; idx < THREADS_PER_BLOCK; idx++) {\n\t\t\treduce_sum += prod_sum[idx];\n\t\t}\n\t\tconst int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);\n\t\tgradInput2[indx2] = reduce_sum / nelems;\n\t}\n\n}\n\nint correlation_forward_cuda_kernel(at::Tensor& output,\n\tint ob,\n\tint oc,\n\tint oh,\n\tint ow,\n\tint osb,\n\tint osc,\n\tint osh,\n\tint osw,\n\n\tat::Tensor& input1,\n\tint ic,\n\tint ih,\n\tint iw,\n\tint isb,\n\tint isc,\n\tint ish,\n\tint isw,\n\n\tat::Tensor& input2,\n\tint gc,\n\tint gsb,\n\tint gsc,\n\tint gsh,\n\tint gsw,\n\n\tat::Tensor& rInput1,\n\tat::Tensor& rInput2,\n\tint pad_size,\n\tint kernel_size,\n\tint max_displacement,\n\tint stride1,\n\tint stride2,\n\tint corr_type_multiply,\n\tcudaStream_t stream)\n{\n\n\tint batchSize = ob;\n\n\tint nInputChannels = ic;\n\tint inputWidth = iw;\n\tint inputHeight = ih;\n\n\tint nOutputChannels = oc;\n\tint outputWidth = ow;\n\tint outputHeight = oh;\n\n\tdim3 blocks_grid(batchSize, inputHeight, inputWidth);\n\tdim3 threads_block(THREADS_PER_BLOCK);\n\n\tAT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), \"channels_first_fwd_1\", ([&] {\n\n\t\tchannels_first<scalar_t> << <blocks_grid, threads_block, 0, stream >> >(\n\t\t\tinput1.data<scalar_t>(), rInput1.data<scalar_t>(), nInputChannels, inputHeight, inputWidth, pad_size);\n\n\t}));\n\n\tAT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), \"channels_first_fwd_2\", ([&] {\n\n\t\tchannels_first<scalar_t> << <blocks_grid, threads_block, 0, stream >> > (\n\t\t\tinput2.data<scalar_t>(), rInput2.data<scalar_t>(), nInputChannels, inputHeight, inputWidth, pad_size);\n\n\t}));\n\n\tdim3 threadsPerBlock(THREADS_PER_BLOCK);\n\tdim3 totalBlocksCorr(batchSize, outputHeight, outputWidth);\n\n\tAT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), \"correlation_forward\", ([&] {\n\n\t\tcorrelation_forward<scalar_t> << <totalBlocksCorr, threadsPerBlock, 0, stream >> >\n\t\t\t(output.data<scalar_t>(), nOutputChannels, outputHeight, outputWidth,\n\t\t\trInput1.data<scalar_t>(), nInputChannels, inputHeight, inputWidth,\n\t\t\trInput2.data<scalar_t>(),\n\t\t\tpad_size,\n\t\t\tkernel_size,\n\t\t\tmax_displacement,\n\t\t\tstride1,\n\t\t\tstride2);\n\n\t}));\n\n\tcudaError_t err = cudaGetLastError();\n\n\n\t// check for errors\n\tif (err != cudaSuccess) {\n\t\tprintf(\"error in correlation_forward_cuda_kernel: %s\\n\", cudaGetErrorString(err));\n\t\treturn 0;\n\t}\n\n\treturn 1;\n}\n\n\nint correlation_backward_cuda_kernel(\n\tat::Tensor& gradOutput,\n\tint gob,\n\tint goc,\n\tint goh,\n\tint gow,\n\tint gosb,\n\tint gosc,\n\tint gosh,\n\tint gosw,\n\n\tat::Tensor& input1,\n\tint ic,\n\tint ih,\n\tint iw,\n\tint isb,\n\tint isc,\n\tint ish,\n\tint isw,\n\n\tat::Tensor& input2,\n\tint gsb,\n\tint gsc,\n\tint gsh,\n\tint gsw,\n\n\tat::Tensor& gradInput1,\n\tint gisb,\n\tint gisc,\n\tint gish,\n\tint gisw,\n\n\tat::Tensor& gradInput2,\n\tint ggc,\n\tint ggsb,\n\tint ggsc,\n\tint ggsh,\n\tint ggsw,\n\n\tat::Tensor& rInput1,\n\tat::Tensor& rInput2,\n\tint pad_size,\n\tint kernel_size,\n\tint max_displacement,\n\tint stride1,\n\tint stride2,\n\tint corr_type_multiply,\n\tcudaStream_t stream)\n{\n\n\tint batchSize = gob;\n\tint num = batchSize;\n\n\tint nInputChannels = ic;\n\tint inputWidth = iw;\n\tint inputHeight = ih;\n\n\tint nOutputChannels = goc;\n\tint outputWidth = gow;\n\tint outputHeight = goh;\n\n\tdim3 blocks_grid(batchSize, inputHeight, inputWidth);\n\tdim3 threads_block(THREADS_PER_BLOCK);\n\n\n\tAT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), \"lltm_forward_cuda\", ([&] {\n\n\t\tchannels_first<scalar_t> << <blocks_grid, threads_block, 0, stream >> >(\n\t\t\tinput1.data<scalar_t>(),\n\t\t\trInput1.data<scalar_t>(),\n\t\t\tnInputChannels,\n\t\t\tinputHeight,\n\t\t\tinputWidth,\n\t\t\tpad_size\n\t\t\t);\n\t}));\n\n\tAT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), \"lltm_forward_cuda\", ([&] {\n\n\t\tchannels_first<scalar_t> << <blocks_grid, threads_block, 0, stream >> >(\n\t\t\tinput2.data<scalar_t>(),\n\t\t\trInput2.data<scalar_t>(),\n\t\t\tnInputChannels,\n\t\t\tinputHeight,\n\t\t\tinputWidth,\n\t\t\tpad_size\n\t\t\t);\n\t}));\n\n\tdim3 threadsPerBlock(THREADS_PER_BLOCK);\n\tdim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels);\n\n\tfor (int n = 0; n < num; ++n) {\n\n\t\tAT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), \"lltm_forward_cuda\", ([&] {\n\n\n\t\t\tcorrelation_backward_input1<scalar_t> << <totalBlocksCorr, threadsPerBlock, 0, stream >> > (\n\t\t\t\tn, gradInput1.data<scalar_t>(), nInputChannels, inputHeight, inputWidth,\n\t\t\t\tgradOutput.data<scalar_t>(), nOutputChannels, outputHeight, outputWidth,\n\t\t\t\trInput2.data<scalar_t>(),\n\t\t\t\tpad_size,\n\t\t\t\tkernel_size,\n\t\t\t\tmax_displacement,\n\t\t\t\tstride1,\n\t\t\t\tstride2);\n\t\t}));\n\t}\n\n\tfor (int n = 0; n < batchSize; n++) {\n\n\t\tAT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), \"lltm_forward_cuda\", ([&] {\n\n\t\t\tcorrelation_backward_input2<scalar_t> << <totalBlocksCorr, threadsPerBlock, 0, stream >> >(\n\t\t\t\tn, gradInput2.data<scalar_t>(), nInputChannels, inputHeight, inputWidth,\n\t\t\t\tgradOutput.data<scalar_t>(), nOutputChannels, outputHeight, outputWidth,\n\t\t\t\trInput1.data<scalar_t>(),\n\t\t\t\tpad_size,\n\t\t\t\tkernel_size,\n\t\t\t\tmax_displacement,\n\t\t\t\tstride1,\n\t\t\t\tstride2);\n\n\t\t}));\n\t}\n\n\t// check for errors\n\tcudaError_t err = cudaGetLastError();\n\tif (err != cudaSuccess) {\n\t\tprintf(\"error in correlation_backward_cuda_kernel: %s\\n\", cudaGetErrorString(err));\n\t\treturn 0;\n\t}\n\n\treturn 1;\n}\n"
  },
  {
    "path": "models/correlation_package_cu9/correlation_cuda_kernel.cuh",
    "content": "#pragma once\n\n#include <ATen/ATen.h>\n#include <ATen/Context.h>\n#include <cuda_runtime.h>\n\nint correlation_forward_cuda_kernel(at::Tensor& output,\n    int ob,\n    int oc,\n    int oh,\n    int ow,\n    int osb,\n    int osc,\n    int osh,\n    int osw,\n\n    at::Tensor& input1,\n    int ic,\n    int ih,\n    int iw,\n    int isb,\n    int isc,\n    int ish,\n    int isw,\n\n    at::Tensor& input2,\n    int gc,\n    int gsb,\n    int gsc,\n    int gsh,\n    int gsw,\n\n    at::Tensor& rInput1,\n    at::Tensor& rInput2,\n    int pad_size,\n    int kernel_size,\n    int max_displacement,\n    int stride1,\n    int stride2,\n    int corr_type_multiply,\n    cudaStream_t stream);\n\n\nint correlation_backward_cuda_kernel(   \n    at::Tensor& gradOutput,\n    int gob,\n    int goc,\n    int goh,\n    int gow,\n    int gosb,\n    int gosc,\n    int gosh,\n    int gosw,\n\n    at::Tensor& input1,\n    int ic,\n    int ih,\n    int iw,\n    int isb,\n    int isc,\n    int ish,\n    int isw,\n\n    at::Tensor& input2,\n    int gsb,\n    int gsc,\n    int gsh,\n    int gsw,\n\n    at::Tensor& gradInput1, \n    int gisb,\n    int gisc,\n    int gish,\n    int gisw,\n\n    at::Tensor& gradInput2,\n    int ggc,\n    int ggsb,\n    int ggsc,\n    int ggsh,\n    int ggsw,\n\n    at::Tensor& rInput1,\n    at::Tensor& rInput2,\n    int pad_size,\n    int kernel_size,\n    int max_displacement,\n    int stride1,\n    int stride2,\n    int corr_type_multiply,\n    cudaStream_t stream);\n"
  },
  {
    "path": "models/correlation_package_cu9/setup.py",
    "content": "#!/usr/bin/env python3\nimport os\nimport torch\n\nfrom setuptools import setup, find_packages\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\ncxx_args = ['-std=c++11']\n\nnvcc_args = [\n    '-gencode', 'arch=compute_50,code=sm_50',\n    '-gencode', 'arch=compute_52,code=sm_52',\n    '-gencode', 'arch=compute_60,code=sm_60',\n    '-gencode', 'arch=compute_61,code=sm_61',\n    '-gencode', 'arch=compute_61,code=compute_61',\n    '-ccbin', '/usr/bin/gcc-4.9'\n]\n\nsetup(\n    name='correlation_cuda',\n    ext_modules=[\n        CUDAExtension('correlation_cuda', [\n            'correlation_cuda.cc',\n            'correlation_cuda_kernel.cu'\n        ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args, 'cuda-path': ['/usr/local/cuda-9.0']})\n    ],\n    cmdclass={\n        'build_ext': BuildExtension\n    })\n"
  },
  {
    "path": "models/flownet1s.py",
    "content": "from __future__ import absolute_import, division, print_function\n\nimport torch\nimport torch.nn as nn\nfrom .flownet_modules import conv, deconv\nfrom .flownet_modules import concatenate_as, upsample2d_as\nfrom .flownet_modules import initialize_msra\n\n\nclass FlowNetS(nn.Module):\n    def __init__(self, args):\n        super(FlowNetS, self).__init__()\n\n        def make_conv(in_planes, out_planes, kernel_size, stride):\n            pad = kernel_size // 2\n            return conv(in_planes, out_planes, kernel_size=kernel_size,\n                        stride=stride, pad=pad, nonlinear=True, bias=True)\n\n        self._conv1   = make_conv(   6,   64, kernel_size=7, stride=2)\n        self._conv2   = make_conv(  64,  128, kernel_size=5, stride=2)\n        self._conv3   = make_conv( 128,  256, kernel_size=5, stride=2)\n        self._conv3_1 = make_conv( 256,  256, kernel_size=3, stride=1)\n        self._conv4   = make_conv( 256,  512, kernel_size=3, stride=2)\n        self._conv4_1 = make_conv( 512,  512, kernel_size=3, stride=1)\n        self._conv5   = make_conv( 512,  512, kernel_size=3, stride=2)\n        self._conv5_1 = make_conv( 512,  512, kernel_size=3, stride=1)\n        self._conv6   = make_conv( 512, 1024, kernel_size=3, stride=2)\n        self._conv6_1 = make_conv(1024, 1024, kernel_size=3, stride=1)\n\n        def make_deconv(in_planes, out_planes):\n            return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,\n                          nonlinear=True, bias=False)\n\n        self._deconv5 = make_deconv(1024    , 512)\n        self._deconv4 = make_deconv(1024 + 2, 256)\n        self._deconv3 = make_deconv( 768 + 2, 128)\n        self._deconv2 = make_deconv( 384 + 2,  64)\n\n        def make_predict(in_planes, out_planes):\n            return conv(in_planes, out_planes, kernel_size=3, stride=1, pad=1,\n                        nonlinear=False, bias=True)\n\n        self._predict_flow6 = make_predict(1024    , 2)\n        self._predict_flow5 = make_predict(1024 + 2, 2)\n        self._predict_flow4 = make_predict( 768 + 2, 2)\n        self._predict_flow3 = make_predict( 384 + 2, 2)\n        self._predict_flow2 = make_predict( 192 + 2, 2)\n\n        def make_upsample(in_planes, out_planes):\n            return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,\n                          nonlinear=False, bias=False)\n\n        self._upsample_flow6_to_5 = make_upsample(2, 2)\n        self._upsample_flow5_to_4 = make_upsample(2, 2)\n        self._upsample_flow4_to_3 = make_upsample(2, 2)\n        self._upsample_flow3_to_2 = make_upsample(2, 2)\n\n        initialize_msra(self.modules())\n\n    def forward(self, inputs):\n        conv1 = self._conv1(inputs)\n        conv2 = self._conv2(conv1)\n        conv3_1 = self._conv3_1(self._conv3(conv2))\n        conv4_1 = self._conv4_1(self._conv4(conv3_1))\n        conv5_1 = self._conv5_1(self._conv5(conv4_1))\n        conv6_1 = self._conv6_1(self._conv6(conv5_1))\n\n        predict_flow6        = self._predict_flow6(conv6_1)\n\n        upsampled_flow6_to_5 = self._upsample_flow6_to_5(predict_flow6)\n        deconv5              = self._deconv5(conv6_1)\n        concat5              = concatenate_as((conv5_1, deconv5, upsampled_flow6_to_5), conv5_1, dim=1)\n        predict_flow5        = self._predict_flow5(concat5)\n\n        upsampled_flow5_to_4 = self._upsample_flow5_to_4(predict_flow5)\n        deconv4              = self._deconv4(concat5)\n        concat4              = concatenate_as((conv4_1, deconv4, upsampled_flow5_to_4), conv4_1, dim=1)\n        predict_flow4        = self._predict_flow4(concat4)\n\n        upsampled_flow4_to_3 = self._upsample_flow4_to_3(predict_flow4)\n        deconv3              = self._deconv3(concat4)\n        concat3              = concatenate_as((conv3_1, deconv3, upsampled_flow4_to_3), conv3_1, dim=1)\n        predict_flow3        = self._predict_flow3(concat3)\n\n        upsampled_flow3_to_2 = self._upsample_flow3_to_2(predict_flow3)\n        deconv2              = self._deconv2(concat3)\n        concat2              = concatenate_as((conv2, deconv2, upsampled_flow3_to_2), conv2, dim=1)\n        predict_flow2        = self._predict_flow2(concat2)\n\n        if self.training:\n            return predict_flow2, predict_flow3, predict_flow4, predict_flow5, predict_flow6\n        else:\n            return predict_flow2\n\n\nclass FlowNet1S(nn.Module):\n    def __init__(self, args, div_flow=0.05):\n        super(FlowNet1S, self).__init__()\n        self._flownets = FlowNetS(args)\n        self._div_flow = div_flow\n\n    def forward(self, input_dict):\n        im1 = input_dict['input1']\n        im2 = input_dict['input2']\n        inputs = torch.cat((im1, im2), dim=1)\n\n        output_dict = {}\n        if self.training:\n            flow2, flow3, flow4, flow5, flow6 = self._flownets(inputs)\n            output_dict['flow2'] = flow2\n            output_dict['flow3'] = flow3\n            output_dict['flow4'] = flow4\n            output_dict['flow5'] = flow5\n            output_dict['flow6'] = flow6\n        else:\n            flow2 = self._flownets(inputs)\n            output_dict['flow1'] = (1.0 / self._div_flow) * upsample2d_as(flow2, im1, mode=\"bilinear\")\n\n        return output_dict\n"
  },
  {
    "path": "models/flownet1s_irr.py",
    "content": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nfrom .flownet_modules import conv, deconv\r\nfrom .flownet_modules import concatenate_as, upsample2d_as\r\nfrom .flownet_modules import initialize_msra\r\nfrom .flownet_modules import WarpingLayer\r\n\r\nclass FlowNetS(nn.Module):\r\n    def __init__(self, args):\r\n        super(FlowNetS, self).__init__()\r\n\r\n        def make_conv(in_planes, out_planes, kernel_size, stride):\r\n            pad = kernel_size // 2\r\n            return conv(in_planes, out_planes, kernel_size=kernel_size,\r\n                        stride=stride, pad=pad, nonlinear=True, bias=True)\r\n\r\n        self._conv3_1 = make_conv( 256,  256, kernel_size=3, stride=1)\r\n        self._conv4   = make_conv( 256,  512, kernel_size=3, stride=2)\r\n        self._conv4_1 = make_conv( 512,  512, kernel_size=3, stride=1)\r\n        self._conv5   = make_conv( 512,  512, kernel_size=3, stride=2)\r\n        self._conv5_1 = make_conv( 512,  512, kernel_size=3, stride=1)\r\n        self._conv6   = make_conv( 512, 1024, kernel_size=3, stride=2)\r\n        self._conv6_1 = make_conv(1024, 1024, kernel_size=3, stride=1)\r\n\r\n        def make_deconv(in_planes, out_planes):\r\n            return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,\r\n                          nonlinear=True, bias=False)\r\n\r\n        self._deconv5 = make_deconv(1024    , 512)\r\n        self._deconv4 = make_deconv(1024 + 2, 256)\r\n        self._deconv3 = make_deconv( 768 + 2, 128)\r\n        self._deconv2 = make_deconv( 384 + 2,  64)\r\n\r\n        def make_predict(in_planes, out_planes):\r\n            return conv(in_planes, out_planes, kernel_size=3, stride=1, pad=1,\r\n                        nonlinear=False, bias=True)\r\n\r\n        self._predict_flow6 = make_predict(1024    , 2)\r\n        self._predict_flow5 = make_predict(1024 + 2, 2)\r\n        self._predict_flow4 = make_predict( 768 + 2, 2)\r\n        self._predict_flow3 = make_predict( 384 + 2, 2)\r\n        self._predict_flow2 = make_predict( 128 + 2, 2)\r\n\r\n        def make_upsample(in_planes, out_planes):\r\n            return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,\r\n                          nonlinear=False, bias=False)\r\n\r\n        self._upsample_flow6_to_5 = make_upsample(2, 2)\r\n        self._upsample_flow5_to_4 = make_upsample(2, 2)\r\n        self._upsample_flow4_to_3 = make_upsample(2, 2)\r\n        self._upsample_flow3_to_2 = make_upsample(2, 2)\r\n\r\n    def forward(self, conv2_im1, conv3_im1, conv3_im2):\r\n\r\n        conv_concat3 = torch.cat((conv3_im1, conv3_im2), dim=1)\r\n\r\n        conv3_1 = self._conv3_1(conv_concat3)\r\n        conv4_1 = self._conv4_1(self._conv4(conv3_1))\r\n        conv5_1 = self._conv5_1(self._conv5(conv4_1))\r\n        conv6_1 = self._conv6_1(self._conv6(conv5_1))\r\n\r\n        predict_flow6        = self._predict_flow6(conv6_1)\r\n\r\n        upsampled_flow6_to_5 = self._upsample_flow6_to_5(predict_flow6)\r\n        deconv5              = self._deconv5(conv6_1)\r\n        concat5              = concatenate_as((conv5_1, deconv5, upsampled_flow6_to_5), conv5_1, dim=1)\r\n        predict_flow5        = self._predict_flow5(concat5)\r\n\r\n        upsampled_flow5_to_4 = self._upsample_flow5_to_4(predict_flow5)\r\n        deconv4              = self._deconv4(concat5)\r\n        concat4              = concatenate_as((conv4_1, deconv4, upsampled_flow5_to_4), conv4_1, dim=1)\r\n        predict_flow4        = self._predict_flow4(concat4)\r\n\r\n        upsampled_flow4_to_3 = self._upsample_flow4_to_3(predict_flow4)\r\n        deconv3              = self._deconv3(concat4)\r\n        concat3              = concatenate_as((conv3_1, deconv3, upsampled_flow4_to_3), conv3_1, dim=1)\r\n        predict_flow3        = self._predict_flow3(concat3)\r\n\r\n        upsampled_flow3_to_2 = self._upsample_flow3_to_2(predict_flow3)\r\n        deconv2              = self._deconv2(concat3)\r\n        concat2              = concatenate_as((conv2_im1, deconv2, upsampled_flow3_to_2), conv2_im1, dim=1)\r\n        predict_flow2        = self._predict_flow2(concat2)\r\n\r\n        return predict_flow2, predict_flow3, predict_flow4, predict_flow5, predict_flow6\r\n\r\n\r\nclass FlowNet1S(nn.Module):\r\n    def __init__(self, args, div_flow=0.05):\r\n        super(FlowNet1S, self).__init__()\r\n        self._flownets = FlowNetS(args)\r\n        self._warping_layer = WarpingLayer()\r\n        self._div_flow = div_flow\r\n        self._num_iters = args.num_iters\r\n\r\n        def make_conv(in_planes, out_planes, kernel_size, stride):\r\n            pad = kernel_size // 2\r\n            return conv(in_planes, out_planes, kernel_size=kernel_size,\r\n                        stride=stride, pad=pad, nonlinear=True, bias=True)\r\n\r\n        self._conv1   = make_conv(   3,   32, kernel_size=7, stride=2)\r\n        self._conv2   = make_conv(  32,   64, kernel_size=5, stride=2)\r\n        self._conv3   = make_conv(  64,  128, kernel_size=5, stride=2)\r\n\r\n        initialize_msra(self.modules())\r\n\r\n    def forward(self, input_dict):\r\n\r\n        im1 = input_dict['input1']\r\n        im2 = input_dict['input2']\r\n\r\n        conv1_im1 = self._conv1(im1)\r\n        conv2_im1 = self._conv2(conv1_im1)\r\n        conv3_im1 = self._conv3(conv2_im1)\r\n\r\n        conv1_im2 = self._conv1(im2)\r\n        conv2_im2 = self._conv2(conv1_im2)\r\n        conv3_im2_orig = self._conv3(conv2_im2)\r\n        conv3_im2 = conv3_im2_orig\r\n\r\n        output_dict = {}        \r\n        output_dict['flow2'] = []\r\n        output_dict['flow3'] = []\r\n        output_dict['flow4'] = []\r\n        output_dict['flow5'] = []\r\n        output_dict['flow6'] = []\r\n        \r\n        _, _, height_im, width_im = im1.size()\r\n\r\n        # for iterative\r\n        for ii in range(0, self._num_iters):\r\n            flow2, flow3, flow4, flow5, flow6 = self._flownets(conv2_im1, conv3_im1, conv3_im2)\r\n\r\n            if ii == 0:\r\n                output_dict['flow2'].append(flow2)\r\n                output_dict['flow3'].append(flow3)\r\n                output_dict['flow4'].append(flow4)\r\n                output_dict['flow5'].append(flow5)\r\n                output_dict['flow6'].append(flow6)\r\n            else:\r\n                output_dict['flow2'].append(flow2 + output_dict['flow2'][ii - 1])\r\n                output_dict['flow3'].append(flow3 + output_dict['flow3'][ii - 1])\r\n                output_dict['flow4'].append(flow4 + output_dict['flow4'][ii - 1])\r\n                output_dict['flow5'].append(flow5 + output_dict['flow5'][ii - 1])\r\n                output_dict['flow6'].append(flow6 + output_dict['flow6'][ii - 1])\r\n\r\n            if ii < (self._num_iters - 1):\r\n                up_flow = upsample2d_as(output_dict['flow2'][ii], conv3_im2_orig, mode=\"bilinear\")\r\n                conv3_im2 = self._warping_layer(conv3_im2_orig, up_flow, height_im, width_im, self._div_flow)\r\n\r\n        if self.training:\r\n            return output_dict\r\n        else:\r\n            output_dict_eval = {}\r\n            up_flow_final = upsample2d_as(output_dict['flow2'][self._num_iters - 1], im1, mode=\"bilinear\")\r\n            output_dict_eval['flow1'] = (1.0 / self._div_flow) * up_flow_final\r\n            return output_dict_eval"
  },
  {
    "path": "models/flownet1s_irr_bi.py",
    "content": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nfrom .flownet_modules import conv, deconv\r\nfrom .flownet_modules import concatenate_as, upsample2d_as\r\nfrom .flownet_modules import initialize_msra\r\nfrom .flownet_modules import WarpingLayer\r\n\r\nclass FlowNetS(nn.Module):\r\n    def __init__(self, args):\r\n        super(FlowNetS, self).__init__()\r\n\r\n        def make_conv(in_planes, out_planes, kernel_size, stride):\r\n            pad = kernel_size // 2\r\n            return conv(in_planes, out_planes, kernel_size=kernel_size,\r\n                        stride=stride, pad=pad, nonlinear=True, bias=True)\r\n\r\n        self._conv3_1 = make_conv( 256,  256, kernel_size=3, stride=1)\r\n        self._conv4   = make_conv( 256,  512, kernel_size=3, stride=2)\r\n        self._conv4_1 = make_conv( 512,  512, kernel_size=3, stride=1)\r\n        self._conv5   = make_conv( 512,  512, kernel_size=3, stride=2)\r\n        self._conv5_1 = make_conv( 512,  512, kernel_size=3, stride=1)\r\n        self._conv6   = make_conv( 512, 1024, kernel_size=3, stride=2)\r\n        self._conv6_1 = make_conv(1024, 1024, kernel_size=3, stride=1)\r\n\r\n        def make_deconv(in_planes, out_planes):\r\n            return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,\r\n                          nonlinear=True, bias=False)\r\n\r\n        self._deconv5 = make_deconv(1024    , 512)\r\n        self._deconv4 = make_deconv(1024 + 2, 256)\r\n        self._deconv3 = make_deconv( 768 + 2, 128)\r\n        self._deconv2 = make_deconv( 384 + 2,  64)\r\n\r\n        def make_predict(in_planes, out_planes):\r\n            return conv(in_planes, out_planes, kernel_size=3, stride=1, pad=1,\r\n                        nonlinear=False, bias=True)\r\n\r\n        self._predict_flow6 = make_predict(1024    , 2)\r\n        self._predict_flow5 = make_predict(1024 + 2, 2)\r\n        self._predict_flow4 = make_predict( 768 + 2, 2)\r\n        self._predict_flow3 = make_predict( 384 + 2, 2)\r\n        self._predict_flow2 = make_predict( 128 + 2, 2)\r\n\r\n        def make_upsample(in_planes, out_planes):\r\n            return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,\r\n                          nonlinear=False, bias=False)\r\n\r\n        self._upsample_flow6_to_5 = make_upsample(2, 2)\r\n        self._upsample_flow5_to_4 = make_upsample(2, 2)\r\n        self._upsample_flow4_to_3 = make_upsample(2, 2)\r\n        self._upsample_flow3_to_2 = make_upsample(2, 2)\r\n\r\n    def forward(self, conv2_im1, conv3_im1, conv3_im2):\r\n\r\n        conv_concat3 = torch.cat((conv3_im1, conv3_im2), dim=1)\r\n\r\n        conv3_1 = self._conv3_1(conv_concat3)\r\n        conv4_1 = self._conv4_1(self._conv4(conv3_1))\r\n        conv5_1 = self._conv5_1(self._conv5(conv4_1))\r\n        conv6_1 = self._conv6_1(self._conv6(conv5_1))\r\n\r\n        # Flow Decoder\r\n        predict_flow6        = self._predict_flow6(conv6_1)\r\n\r\n        upsampled_flow6_to_5 = self._upsample_flow6_to_5(predict_flow6)\r\n        deconv5              = self._deconv5(conv6_1)\r\n        concat5              = concatenate_as((conv5_1, deconv5, upsampled_flow6_to_5), conv5_1, dim=1)\r\n        predict_flow5        = self._predict_flow5(concat5)\r\n\r\n        upsampled_flow5_to_4 = self._upsample_flow5_to_4(predict_flow5)\r\n        deconv4              = self._deconv4(concat5)\r\n        concat4              = concatenate_as((conv4_1, deconv4, upsampled_flow5_to_4), conv4_1, dim=1)\r\n        predict_flow4        = self._predict_flow4(concat4)\r\n\r\n        upsampled_flow4_to_3 = self._upsample_flow4_to_3(predict_flow4)\r\n        deconv3              = self._deconv3(concat4)\r\n        concat3              = concatenate_as((conv3_1, deconv3, upsampled_flow4_to_3), conv3_1, dim=1)\r\n        predict_flow3        = self._predict_flow3(concat3)\r\n\r\n        upsampled_flow3_to_2 = self._upsample_flow3_to_2(predict_flow3)\r\n        deconv2              = self._deconv2(concat3)\r\n        concat2              = concatenate_as((conv2_im1, deconv2, upsampled_flow3_to_2), conv2_im1, dim=1)\r\n        predict_flow2        = self._predict_flow2(concat2)\r\n\r\n        return predict_flow2, predict_flow3, predict_flow4, predict_flow5, predict_flow6\r\n\r\n\r\nclass FlowNet1S(nn.Module):\r\n    def __init__(self, args, div_flow=0.05):\r\n        super(FlowNet1S, self).__init__()\r\n        self._flownets = FlowNetS(args)\r\n        self._warping_layer = WarpingLayer()\r\n        self._div_flow = div_flow\r\n        self._num_iters = args.num_iters     \r\n\r\n        def make_conv(in_planes, out_planes, kernel_size, stride):\r\n            pad = kernel_size // 2\r\n            return conv(in_planes, out_planes, kernel_size=kernel_size,\r\n                        stride=stride, pad=pad, nonlinear=True, bias=True)\r\n\r\n        self._conv1   = make_conv(   3,   32, kernel_size=7, stride=2)\r\n        self._conv2   = make_conv(  32,   64, kernel_size=5, stride=2)\r\n        self._conv3   = make_conv(  64,  128, kernel_size=5, stride=2)\r\n\r\n        initialize_msra(self.modules())\r\n\r\n    def forward(self, input_dict):\r\n        im1 = input_dict['input1']\r\n        im2 = input_dict['input2']\r\n\r\n        conv1_im1 = self._conv1(im1)\r\n        conv2_im1 = self._conv2(conv1_im1)\r\n        conv3_im1 = self._conv3(conv2_im1)\r\n        conv2_im1_wp = conv2_im1\r\n        conv3_im1_wp = conv3_im1\r\n\r\n        conv1_im2 = self._conv1(im2)\r\n        conv2_im2 = self._conv2(conv1_im2)\r\n        conv3_im2 = self._conv3(conv2_im2)\r\n        conv2_im2_wp = conv2_im2\r\n        conv3_im2_wp = conv3_im2\r\n\r\n        out_dict = {}\r\n        out_dict['flow2'] = []\r\n        out_dict['flow3'] = []\r\n        out_dict['flow4'] = []\r\n        out_dict['flow5'] = []\r\n        out_dict['flow6'] = []\r\n\r\n        _, _, height_im, width_im = im1.size()\r\n\r\n        # for iterative\r\n        for ii in range(0, self._num_iters):\r\n            flo2_f, flo3_f, flo4_f, flo5_f, flo6_f = self._flownets(conv2_im1, conv3_im1, conv3_im2_wp)\r\n            flo2_b, flo3_b, flo4_b, flo5_b, flo6_b = self._flownets(conv2_im2, conv3_im2, conv3_im1_wp)\r\n\r\n            if ii == 0:\r\n                out_dict['flow2'].append([flo2_f, flo2_b])\r\n                out_dict['flow3'].append([flo3_f, flo3_b])\r\n                out_dict['flow4'].append([flo4_f, flo4_b])\r\n                out_dict['flow5'].append([flo5_f, flo5_b])\r\n                out_dict['flow6'].append([flo6_f, flo6_b])\r\n            else:\r\n                out_dict['flow2'].append([flo2_f + out_dict['flow2'][ii - 1][0], flo2_b + out_dict['flow2'][ii - 1][1]])\r\n                out_dict['flow3'].append([flo3_f + out_dict['flow3'][ii - 1][0], flo3_b + out_dict['flow3'][ii - 1][1]])\r\n                out_dict['flow4'].append([flo4_f + out_dict['flow4'][ii - 1][0], flo4_b + out_dict['flow4'][ii - 1][1]])\r\n                out_dict['flow5'].append([flo5_f + out_dict['flow5'][ii - 1][0], flo5_b + out_dict['flow5'][ii - 1][1]])\r\n                out_dict['flow6'].append([flo6_f + out_dict['flow6'][ii - 1][0], flo6_b + out_dict['flow6'][ii - 1][1]])\r\n\r\n            if ii < (self._num_iters - 1):\r\n                up_flow_f_c3 = upsample2d_as(out_dict['flow2'][ii][0], conv3_im2, mode=\"bilinear\")\r\n                up_flow_b_c3 = upsample2d_as(out_dict['flow2'][ii][1], conv3_im1, mode=\"bilinear\")\r\n                conv3_im2_wp = self._warping_layer(conv3_im2, up_flow_f_c3, height_im, width_im, self._div_flow)\r\n                conv3_im1_wp = self._warping_layer(conv3_im1, up_flow_b_c3, height_im, width_im, self._div_flow)\r\n\r\n        if self.training:\r\n            return out_dict\r\n        else:\r\n            out_dict_eval = {}\r\n            up_flow_final = upsample2d_as(out_dict['flow2'][self._num_iters - 1][0], im1, mode=\"bilinear\")\r\n            out_dict_eval['flow1'] = (1.0 / self._div_flow) * up_flow_final\r\n            return out_dict_eval\r\n"
  },
  {
    "path": "models/flownet1s_irr_occ.py",
    "content": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nfrom .flownet_modules import conv, deconv\r\nfrom .flownet_modules import concatenate_as, upsample2d_as\r\nfrom .flownet_modules import initialize_msra\r\nfrom .flownet_modules import WarpingLayer\r\n\r\nclass FlowNetS(nn.Module):\r\n    def __init__(self, args):\r\n        super(FlowNetS, self).__init__()\r\n\r\n        def make_conv(in_planes, out_planes, kernel_size, stride):\r\n            pad = kernel_size // 2\r\n            return conv(in_planes, out_planes, kernel_size=kernel_size,\r\n                        stride=stride, pad=pad, nonlinear=True, bias=True)\r\n\r\n        self._conv3_1 = make_conv( 256,  256, kernel_size=3, stride=1)\r\n        self._conv4   = make_conv( 256,  512, kernel_size=3, stride=2)\r\n        self._conv4_1 = make_conv( 512,  512, kernel_size=3, stride=1)\r\n        self._conv5   = make_conv( 512,  512, kernel_size=3, stride=2)\r\n        self._conv5_1 = make_conv( 512,  512, kernel_size=3, stride=1)\r\n        self._conv6   = make_conv( 512, 1024, kernel_size=3, stride=2)\r\n        self._conv6_1 = make_conv(1024, 1024, kernel_size=3, stride=1)\r\n\r\n        def make_deconv(in_planes, out_planes):\r\n            return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,\r\n                          nonlinear=True, bias=False)\r\n\r\n        self._deconv5 = make_deconv(1024    , 512)\r\n        self._deconv4 = make_deconv(1024 + 2, 256)\r\n        self._deconv3 = make_deconv( 768 + 2, 128)\r\n        self._deconv2 = make_deconv( 384 + 2,  64)\r\n\r\n        self._deconv_occ5 = make_deconv(1024    , 512)\r\n        self._deconv_occ4 = make_deconv(1024 + 1, 256)\r\n        self._deconv_occ3 = make_deconv( 768 + 1, 128)\r\n        self._deconv_occ2 = make_deconv( 384 + 1,  64)\r\n\r\n        def make_predict(in_planes, out_planes):\r\n            return conv(in_planes, out_planes, kernel_size=3, stride=1, pad=1,\r\n                        nonlinear=False, bias=True)\r\n\r\n        self._predict_flow6 = make_predict(1024    , 2)\r\n        self._predict_flow5 = make_predict(1024 + 2, 2)\r\n        self._predict_flow4 = make_predict( 768 + 2, 2)\r\n        self._predict_flow3 = make_predict( 384 + 2, 2)\r\n        self._predict_flow2 = make_predict( 128 + 2, 2)\r\n\r\n        self._predict_occ6 = make_predict(1024    , 1)\r\n        self._predict_occ5 = make_predict(1024 + 1, 1)\r\n        self._predict_occ4 = make_predict( 768 + 1, 1)\r\n        self._predict_occ3 = make_predict( 384 + 1, 1)\r\n        self._predict_occ2 = make_predict( 128 + 1, 1)\r\n\r\n        def make_upsample(in_planes, out_planes):\r\n            return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,\r\n                          nonlinear=False, bias=False)\r\n\r\n        self._upsample_flow6_to_5 = make_upsample(2, 2)\r\n        self._upsample_flow5_to_4 = make_upsample(2, 2)\r\n        self._upsample_flow4_to_3 = make_upsample(2, 2)\r\n        self._upsample_flow3_to_2 = make_upsample(2, 2)\r\n\r\n        self._upsample_occ6_to_5 = make_upsample(1, 1)\r\n        self._upsample_occ5_to_4 = make_upsample(1, 1)\r\n        self._upsample_occ4_to_3 = make_upsample(1, 1)\r\n        self._upsample_occ3_to_2 = make_upsample(1, 1)\r\n\r\n    def forward(self, conv2_im1, conv3_im1, conv3_im2):\r\n\r\n        conv_concat3 = torch.cat((conv3_im1, conv3_im2), dim=1)\r\n\r\n        conv3_1 = self._conv3_1(conv_concat3)\r\n        conv4_1 = self._conv4_1(self._conv4(conv3_1))\r\n        conv5_1 = self._conv5_1(self._conv5(conv4_1))\r\n        conv6_1 = self._conv6_1(self._conv6(conv5_1))\r\n\r\n        # Flow Decoder\r\n        predict_flow6        = self._predict_flow6(conv6_1)\r\n\r\n        upsampled_flow6_to_5 = self._upsample_flow6_to_5(predict_flow6)\r\n        deconv5              = self._deconv5(conv6_1)\r\n        concat5              = concatenate_as((conv5_1, deconv5, upsampled_flow6_to_5), conv5_1, dim=1)\r\n        predict_flow5        = self._predict_flow5(concat5)\r\n\r\n        upsampled_flow5_to_4 = self._upsample_flow5_to_4(predict_flow5)\r\n        deconv4              = self._deconv4(concat5)\r\n        concat4              = concatenate_as((conv4_1, deconv4, upsampled_flow5_to_4), conv4_1, dim=1)\r\n        predict_flow4        = self._predict_flow4(concat4)\r\n\r\n        upsampled_flow4_to_3 = self._upsample_flow4_to_3(predict_flow4)\r\n        deconv3              = self._deconv3(concat4)\r\n        concat3              = concatenate_as((conv3_1, deconv3, upsampled_flow4_to_3), conv3_1, dim=1)\r\n        predict_flow3        = self._predict_flow3(concat3)\r\n\r\n        upsampled_flow3_to_2 = self._upsample_flow3_to_2(predict_flow3)\r\n        deconv2              = self._deconv2(concat3)\r\n        concat2              = concatenate_as((conv2_im1, deconv2, upsampled_flow3_to_2), conv2_im1, dim=1)\r\n        predict_flow2        = self._predict_flow2(concat2)\r\n\r\n        # Occ Decoder\r\n        predict_occ6 = self._predict_occ6(conv6_1)\r\n\r\n        upsampled_occ6_to_5 = self._upsample_occ6_to_5(predict_occ6)\r\n        deconv_occ5         = self._deconv_occ5(conv6_1)\r\n        concat_occ5         = concatenate_as((conv5_1, deconv_occ5, upsampled_occ6_to_5), conv5_1, dim=1)\r\n        predict_occ5        = self._predict_occ5(concat_occ5)\r\n\r\n        upsampled_occ5_to_4 = self._upsample_occ5_to_4(predict_occ5)\r\n        deconv_occ4         = self._deconv_occ4(concat_occ5)\r\n        concat_occ4         = concatenate_as((conv4_1, deconv_occ4, upsampled_occ5_to_4), conv4_1, dim=1)\r\n        predict_occ4        = self._predict_occ4(concat_occ4)\r\n\r\n        upsampled_occ4_to_3 = self._upsample_occ4_to_3(predict_occ4)\r\n        deconv_occ3         = self._deconv_occ3(concat_occ4)\r\n        concat_occ3         = concatenate_as((conv3_1, deconv_occ3, upsampled_occ4_to_3), conv3_1, dim=1)\r\n        predict_occ3        = self._predict_occ3(concat_occ3)\r\n\r\n        upsampled_occ3_to_2 = self._upsample_occ3_to_2(predict_occ3)\r\n        deconv_occ2         = self._deconv_occ2(concat_occ3)\r\n        concat_occ2         = concatenate_as((conv2_im1, deconv_occ2, upsampled_occ3_to_2), conv2_im1, dim=1)\r\n        predict_occ2        = self._predict_occ2(concat_occ2)\r\n\r\n        return predict_flow2, predict_flow3, predict_flow4, predict_flow5, predict_flow6, predict_occ2, predict_occ3, predict_occ4, predict_occ5, predict_occ6\r\n\r\n\r\nclass FlowNet1S(nn.Module):\r\n    def __init__(self, args, div_flow=0.05):\r\n        super(FlowNet1S, self).__init__()\r\n        self._flownets = FlowNetS(args)\r\n        self._warping_layer = WarpingLayer()\r\n        self._div_flow = div_flow\r\n        self._num_iters = args.num_iters\r\n\r\n        def make_conv(in_planes, out_planes, kernel_size, stride):\r\n            pad = kernel_size // 2\r\n            return conv(in_planes, out_planes, kernel_size=kernel_size,\r\n                        stride=stride, pad=pad, nonlinear=True, bias=True)\r\n\r\n        self._conv1   = make_conv(   3,   32, kernel_size=7, stride=2)\r\n        self._conv2   = make_conv(  32,   64, kernel_size=5, stride=2)\r\n        self._conv3   = make_conv(  64,  128, kernel_size=5, stride=2)\r\n\r\n        initialize_msra(self.modules())\r\n\r\n    def forward(self, input_dict):\r\n        im1 = input_dict['input1']\r\n        im2 = input_dict['input2']\r\n\r\n        conv1_im1 = self._conv1(im1)\r\n        conv2_im1 = self._conv2(conv1_im1)\r\n        conv3_im1 = self._conv3(conv2_im1)\r\n\r\n        conv1_im2 = self._conv1(im2)\r\n        conv2_im2 = self._conv2(conv1_im2)\r\n        conv3_im2 = self._conv3(conv2_im2)\r\n        conv3_im2_wp = conv3_im2\r\n\r\n        output_dict = {}\r\n        output_dict['flow2'] = []\r\n        output_dict['flow3'] = []\r\n        output_dict['flow4'] = []\r\n        output_dict['flow5'] = []\r\n        output_dict['flow6'] = []\r\n        output_dict['occ2'] = []\r\n        output_dict['occ3'] = []\r\n        output_dict['occ4'] = []\r\n        output_dict['occ5'] = []\r\n        output_dict['occ6'] = []\r\n\r\n        _, _, height_im, width_im = im1.size()\r\n        \r\n        # for iterative\r\n        for ii in range(0, self._num_iters):\r\n            flow2, flow3, flow4, flow5, flow6, occ2, occ3, occ4, occ5, occ6 = self._flownets(conv2_im1, conv3_im1, conv3_im2_wp)\r\n\r\n            if ii == 0:\r\n                output_dict['flow2'].append(flow2)\r\n                output_dict['flow3'].append(flow3)\r\n                output_dict['flow4'].append(flow4)\r\n                output_dict['flow5'].append(flow5)\r\n                output_dict['flow6'].append(flow6)\r\n                output_dict['occ2'].append(occ2)\r\n                output_dict['occ3'].append(occ3)\r\n                output_dict['occ4'].append(occ4)\r\n                output_dict['occ5'].append(occ5)\r\n                output_dict['occ6'].append(occ6)\r\n            else:\r\n                output_dict['flow2'].append(flow2 + output_dict['flow2'][ii - 1])\r\n                output_dict['flow3'].append(flow3 + output_dict['flow3'][ii - 1])\r\n                output_dict['flow4'].append(flow4 + output_dict['flow4'][ii - 1])\r\n                output_dict['flow5'].append(flow5 + output_dict['flow5'][ii - 1])\r\n                output_dict['flow6'].append(flow6 + output_dict['flow6'][ii - 1])\r\n                output_dict['occ2'].append(occ2 + output_dict['occ2'][ii - 1])\r\n                output_dict['occ3'].append(occ3 + output_dict['occ3'][ii - 1])\r\n                output_dict['occ4'].append(occ4 + output_dict['occ4'][ii - 1])\r\n                output_dict['occ5'].append(occ5 + output_dict['occ5'][ii - 1])\r\n                output_dict['occ6'].append(occ6 + output_dict['occ6'][ii - 1])\r\n\r\n            if ii < (self._num_iters - 1):\r\n                up_flow = upsample2d_as(output_dict['flow2'][ii], conv3_im2, mode=\"bilinear\")\r\n                conv3_im2_wp = self._warping_layer(conv3_im2, up_flow, height_im, width_im, self._div_flow)     \r\n\r\n        if self.training:\r\n            return output_dict\r\n        else:\r\n            output_dict_eval = {}\r\n            up_flow_final = upsample2d_as(output_dict['flow2'][self._num_iters - 1], im1, mode=\"bilinear\")  \r\n            up_occ_final = upsample2d_as(output_dict['occ2'][self._num_iters - 1], im1, mode=\"bilinear\")\r\n            output_dict_eval['flow1'] = (1.0 / self._div_flow) * up_flow_final\r\n            output_dict_eval['occ1'] = up_occ_final\r\n            return output_dict_eval"
  },
  {
    "path": "models/flownet1s_irr_occ_bi.py",
    "content": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nfrom .flownet_modules import conv, deconv\r\nfrom .flownet_modules import concatenate_as, upsample2d_as\r\nfrom .flownet_modules import initialize_msra\r\nfrom .flownet_modules import WarpingLayer\r\n\r\nclass FlowNetS(nn.Module):\r\n    def __init__(self, args):\r\n        super(FlowNetS, self).__init__()\r\n\r\n        def make_conv(in_planes, out_planes, kernel_size, stride):\r\n            pad = kernel_size // 2\r\n            return conv(in_planes, out_planes, kernel_size=kernel_size,\r\n                        stride=stride, pad=pad, nonlinear=True, bias=True)\r\n\r\n        self._conv3_1 = make_conv( 256,  256, kernel_size=3, stride=1)\r\n        self._conv4   = make_conv( 256,  512, kernel_size=3, stride=2)\r\n        self._conv4_1 = make_conv( 512,  512, kernel_size=3, stride=1)\r\n        self._conv5   = make_conv( 512,  512, kernel_size=3, stride=2)\r\n        self._conv5_1 = make_conv( 512,  512, kernel_size=3, stride=1)\r\n        self._conv6   = make_conv( 512, 1024, kernel_size=3, stride=2)\r\n        self._conv6_1 = make_conv(1024, 1024, kernel_size=3, stride=1)\r\n\r\n        def make_deconv(in_planes, out_planes):\r\n            return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,\r\n                          nonlinear=True, bias=False)\r\n\r\n        self._deconv5 = make_deconv(1024    , 512)\r\n        self._deconv4 = make_deconv(1024 + 2, 256)\r\n        self._deconv3 = make_deconv( 768 + 2, 128)\r\n        self._deconv2 = make_deconv( 384 + 2,  64)\r\n\r\n        self._deconv_occ5 = make_deconv(1024    , 512)\r\n        self._deconv_occ4 = make_deconv(1024 + 1, 256)\r\n        self._deconv_occ3 = make_deconv( 768 + 1, 128)\r\n        self._deconv_occ2 = make_deconv( 384 + 1,  64)\r\n\r\n        def make_predict(in_planes, out_planes):\r\n            return conv(in_planes, out_planes, kernel_size=3, stride=1, pad=1,\r\n                        nonlinear=False, bias=True)\r\n\r\n        self._predict_flow6 = make_predict(1024    , 2)\r\n        self._predict_flow5 = make_predict(1024 + 2, 2)\r\n        self._predict_flow4 = make_predict( 768 + 2, 2)\r\n        self._predict_flow3 = make_predict( 384 + 2, 2)\r\n        self._predict_flow2 = make_predict( 128 + 2, 2)\r\n\r\n        self._predict_occ6 = make_predict(1024    , 1)\r\n        self._predict_occ5 = make_predict(1024 + 1, 1)\r\n        self._predict_occ4 = make_predict( 768 + 1, 1)\r\n        self._predict_occ3 = make_predict( 384 + 1, 1)\r\n        self._predict_occ2 = make_predict( 128 + 1, 1)\r\n\r\n        def make_upsample(in_planes, out_planes):\r\n            return deconv(in_planes, out_planes, kernel_size=4, stride=2, pad=1,\r\n                          nonlinear=False, bias=False)\r\n\r\n        self._upsample_flow6_to_5 = make_upsample(2, 2)\r\n        self._upsample_flow5_to_4 = make_upsample(2, 2)\r\n        self._upsample_flow4_to_3 = make_upsample(2, 2)\r\n        self._upsample_flow3_to_2 = make_upsample(2, 2)\r\n\r\n        self._upsample_occ6_to_5 = make_upsample(1, 1)\r\n        self._upsample_occ5_to_4 = make_upsample(1, 1)\r\n        self._upsample_occ4_to_3 = make_upsample(1, 1)\r\n        self._upsample_occ3_to_2 = make_upsample(1, 1)\r\n\r\n    def forward(self, conv2_im1, conv3_im1, conv3_im2):\r\n\r\n        conv_concat3 = torch.cat((conv3_im1, conv3_im2), dim=1)\r\n\r\n        conv3_1 = self._conv3_1(conv_concat3)\r\n        conv4_1 = self._conv4_1(self._conv4(conv3_1))\r\n        conv5_1 = self._conv5_1(self._conv5(conv4_1))\r\n        conv6_1 = self._conv6_1(self._conv6(conv5_1))\r\n\r\n        # Flow Decoder\r\n        predict_flow6        = self._predict_flow6(conv6_1)\r\n\r\n        upsampled_flow6_to_5 = self._upsample_flow6_to_5(predict_flow6)\r\n        deconv5              = self._deconv5(conv6_1)\r\n        concat5              = concatenate_as((conv5_1, deconv5, upsampled_flow6_to_5), conv5_1, dim=1)\r\n        predict_flow5        = self._predict_flow5(concat5)\r\n\r\n        upsampled_flow5_to_4 = self._upsample_flow5_to_4(predict_flow5)\r\n        deconv4              = self._deconv4(concat5)\r\n        concat4              = concatenate_as((conv4_1, deconv4, upsampled_flow5_to_4), conv4_1, dim=1)\r\n        predict_flow4        = self._predict_flow4(concat4)\r\n\r\n        upsampled_flow4_to_3 = self._upsample_flow4_to_3(predict_flow4)\r\n        deconv3              = self._deconv3(concat4)\r\n        concat3              = concatenate_as((conv3_1, deconv3, upsampled_flow4_to_3), conv3_1, dim=1)\r\n        predict_flow3        = self._predict_flow3(concat3)\r\n\r\n        upsampled_flow3_to_2 = self._upsample_flow3_to_2(predict_flow3)\r\n        deconv2              = self._deconv2(concat3)\r\n        concat2              = concatenate_as((conv2_im1, deconv2, upsampled_flow3_to_2), conv2_im1, dim=1)\r\n        predict_flow2        = self._predict_flow2(concat2)\r\n\r\n        # Occ Decoder\r\n        predict_occ6 = self._predict_occ6(conv6_1)\r\n\r\n        upsampled_occ6_to_5 = self._upsample_occ6_to_5(predict_occ6)\r\n        deconv_occ5         = self._deconv_occ5(conv6_1)\r\n        concat_occ5         = concatenate_as((conv5_1, deconv_occ5, upsampled_occ6_to_5), conv5_1, dim=1)\r\n        predict_occ5        = self._predict_occ5(concat_occ5)\r\n\r\n        upsampled_occ5_to_4 = self._upsample_occ5_to_4(predict_occ5)\r\n        deconv_occ4         = self._deconv_occ4(concat_occ5)\r\n        concat_occ4         = concatenate_as((conv4_1, deconv_occ4, upsampled_occ5_to_4), conv4_1, dim=1)\r\n        predict_occ4        = self._predict_occ4(concat_occ4)\r\n\r\n        upsampled_occ4_to_3 = self._upsample_occ4_to_3(predict_occ4)\r\n        deconv_occ3         = self._deconv_occ3(concat_occ4)\r\n        concat_occ3         = concatenate_as((conv3_1, deconv_occ3, upsampled_occ4_to_3), conv3_1, dim=1)\r\n        predict_occ3        = self._predict_occ3(concat_occ3)\r\n\r\n        upsampled_occ3_to_2 = self._upsample_occ3_to_2(predict_occ3)\r\n        deconv_occ2         = self._deconv_occ2(concat_occ3)\r\n        concat_occ2         = concatenate_as((conv2_im1, deconv_occ2, upsampled_occ3_to_2), conv2_im1, dim=1)\r\n        predict_occ2        = self._predict_occ2(concat_occ2)\r\n\r\n        return predict_flow2, predict_flow3, predict_flow4, predict_flow5, predict_flow6, predict_occ2, predict_occ3, predict_occ4, predict_occ5, predict_occ6\r\n\r\n\r\nclass FlowNet1S(nn.Module):\r\n    def __init__(self, args, div_flow=0.05):\r\n        super(FlowNet1S, self).__init__()\r\n        self._flownets = FlowNetS(args)\r\n        self._warping_layer = WarpingLayer()\r\n        self._div_flow = div_flow\r\n        self._num_iters = args.num_iters\r\n\r\n        def make_conv(in_planes, out_planes, kernel_size, stride):\r\n            pad = kernel_size // 2\r\n            return conv(in_planes, out_planes, kernel_size=kernel_size,\r\n                        stride=stride, pad=pad, nonlinear=True, bias=True)\r\n\r\n        self._conv1   = make_conv(   3,   32, kernel_size=7, stride=2)\r\n        self._conv2   = make_conv(  32,   64, kernel_size=5, stride=2)\r\n        self._conv3   = make_conv(  64,  128, kernel_size=5, stride=2)\r\n\r\n        initialize_msra(self.modules())\r\n\r\n    def forward(self, input_dict):\r\n        im1 = input_dict['input1']\r\n        im2 = input_dict['input2']\r\n\r\n        conv1_im1 = self._conv1(im1)\r\n        conv2_im1 = self._conv2(conv1_im1)\r\n        conv3_im1 = self._conv3(conv2_im1)\r\n        conv3_im1_wp = conv3_im1\r\n\r\n        conv1_im2 = self._conv1(im2)\r\n        conv2_im2 = self._conv2(conv1_im2)\r\n        conv3_im2 = self._conv3(conv2_im2)\r\n        conv3_im2_wp = conv3_im2\r\n\r\n        out_dict = {}        \r\n        out_dict['flow2'] = []\r\n        out_dict['flow3'] = []\r\n        out_dict['flow4'] = []\r\n        out_dict['flow5'] = []\r\n        out_dict['flow6'] = []\r\n        out_dict['occ2'] = []\r\n        out_dict['occ3'] = []\r\n        out_dict['occ4'] = []\r\n        out_dict['occ5'] = []\r\n        out_dict['occ6'] = []\r\n\r\n        _, _, height_im, width_im = im1.size()\r\n\r\n        # for iterative\r\n        for ii in range(0, self._num_iters):\r\n            flo2_f, flo3_f, flo4_f, flo5_f, flo6_f, occ2_f, occ3_f, occ4_f, occ5_f, occ6_f = self._flownets(conv2_im1,\r\n                                                                                                            conv3_im1,\r\n                                                                                                            conv3_im2_wp)\r\n            flo2_b, flo3_b, flo4_b, flo5_b, flo6_b, occ2_b, occ3_b, occ4_b, occ5_b, occ6_b = self._flownets(conv2_im2,\r\n                                                                                                            conv3_im2,\r\n                                                                                                            conv3_im1_wp)\r\n\r\n            if ii == 0:\r\n                out_dict['flow2'].append([flo2_f, flo2_b])\r\n                out_dict['flow3'].append([flo3_f, flo3_b])\r\n                out_dict['flow4'].append([flo4_f, flo4_b])\r\n                out_dict['flow5'].append([flo5_f, flo5_b])\r\n                out_dict['flow6'].append([flo6_f, flo6_b])\r\n                out_dict['occ2'].append([occ2_f, occ2_b])\r\n                out_dict['occ3'].append([occ3_f, occ3_b])\r\n                out_dict['occ4'].append([occ4_f, occ4_b])\r\n                out_dict['occ5'].append([occ5_f, occ5_b])\r\n                out_dict['occ6'].append([occ6_f, occ6_b])\r\n            else:\r\n                out_dict['flow2'].append([flo2_f + out_dict['flow2'][ii - 1][0], flo2_b + out_dict['flow2'][ii - 1][1]])\r\n                out_dict['flow3'].append([flo3_f + out_dict['flow3'][ii - 1][0], flo3_b + out_dict['flow3'][ii - 1][1]])\r\n                out_dict['flow4'].append([flo4_f + out_dict['flow4'][ii - 1][0], flo4_b + out_dict['flow4'][ii - 1][1]])\r\n                out_dict['flow5'].append([flo5_f + out_dict['flow5'][ii - 1][0], flo5_b + out_dict['flow5'][ii - 1][1]])\r\n                out_dict['flow6'].append([flo6_f + out_dict['flow6'][ii - 1][0], flo6_b + out_dict['flow6'][ii - 1][1]])\r\n                out_dict['occ2'].append([occ2_f + out_dict['occ2'][ii - 1][0], occ2_b + out_dict['occ2'][ii - 1][1]])\r\n                out_dict['occ3'].append([occ3_f + out_dict['occ3'][ii - 1][0], occ3_b + out_dict['occ3'][ii - 1][1]])\r\n                out_dict['occ4'].append([occ4_f + out_dict['occ4'][ii - 1][0], occ4_b + out_dict['occ4'][ii - 1][1]])\r\n                out_dict['occ5'].append([occ5_f + out_dict['occ5'][ii - 1][0], occ5_b + out_dict['occ5'][ii - 1][1]])\r\n                out_dict['occ6'].append([occ6_f + out_dict['occ6'][ii - 1][0], occ6_b + out_dict['occ6'][ii - 1][1]])\r\n\r\n            if ii < (self._num_iters - 1):\r\n                up_flow_f_c3 = upsample2d_as(out_dict['flow2'][ii][0], conv3_im2, mode=\"bilinear\")\r\n                up_flow_b_c3 = upsample2d_as(out_dict['flow2'][ii][1], conv3_im1, mode=\"bilinear\")\r\n                conv3_im2_wp = self._warping_layer(conv3_im2, up_flow_f_c3, height_im, width_im, self._div_flow)\r\n                conv3_im1_wp = self._warping_layer(conv3_im1, up_flow_b_c3, height_im, width_im, self._div_flow)\r\n\r\n        if self.training:\r\n            return out_dict\r\n        else:\r\n            out_dict_eval = {}\r\n            up_flow_final = upsample2d_as(out_dict['flow2'][self._num_iters - 1][0], im1, mode=\"bilinear\")\r\n            up_occ_final = upsample2d_as(out_dict['occ2'][self._num_iters - 1][0], im1, mode=\"bilinear\")\r\n            out_dict_eval['flow1'] = (1.0 / self._div_flow) * up_flow_final\r\n            out_dict_eval['occ1'] = up_occ_final\r\n            return out_dict_eval\r\n"
  },
  {
    "path": "models/flownet_modules.py",
    "content": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as tf\r\nimport logging\r\n\r\n\r\ndef conv(in_planes, out_planes, kernel_size, stride, pad, nonlinear, bias):\r\n    if nonlinear:\r\n        return nn.Sequential(\r\n            nn.Conv2d(\r\n                in_planes, out_planes, kernel_size=kernel_size,\r\n                stride=stride, padding=pad, bias=bias),\r\n            nn.LeakyReLU(0.1, inplace=True)\r\n        )\r\n    else:\r\n        return nn.Conv2d(\r\n            in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=pad, bias=bias)\r\n\r\n\r\ndef deconv(in_planes, out_planes, kernel_size, stride, pad, nonlinear, bias):\r\n    if nonlinear:\r\n        return nn.Sequential(\r\n            nn.ConvTranspose2d(\r\n                in_planes, out_planes, kernel_size=kernel_size,\r\n                stride=stride, padding=pad, bias=bias),\r\n            nn.LeakyReLU(0.1, inplace=True)\r\n        )\r\n    else:\r\n        return nn.ConvTranspose2d(\r\n            in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=pad, bias=bias)\r\n\r\n\r\ndef resize2D(inputs, size_targets, mode=\"bilinear\"):\r\n    size_inputs = [inputs.size(2), inputs.size(3)]\r\n\r\n    if all([size_inputs == size_targets]):\r\n        return inputs  # nothing to do\r\n    elif any([size_targets < size_inputs]):\r\n        resized = tf.adaptive_avg_pool2d(inputs, size_targets)  # downscaling\r\n    else:\r\n        resized = tf.interpolate(inputs, size=size_targets, mode=mode, align_corners=True)\r\n\r\n    return resized\r\n\r\ndef resize2D_as(inputs, output_as, mode=\"bilinear\"):\r\n    size_targets = [output_as.size(2), output_as.size(3)]\r\n    return resize2D(inputs, size_targets, mode=mode)\r\n\r\n\r\ndef concatenate_as(tensor_list, tensor_as, dim, mode=\"bilinear\"):\r\n    tensor_list = [resize2D_as(x, tensor_as, mode=mode) for x in tensor_list]\r\n    return torch.cat(tensor_list, dim=dim)\r\n\r\n\r\ndef upsample2d_as(inputs, target_as, mode=\"bilinear\"):\r\n    _, _, h, w = target_as.size()\r\n    return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True)\r\n\r\n\r\ndef initialize_msra(modules):\r\n    logging.info(\"Initializing MSRA\")\r\n    for layer in modules:\r\n        if isinstance(layer, nn.Conv2d):\r\n            nn.init.kaiming_normal_(layer.weight)\r\n            if layer.bias is not None:\r\n                nn.init.constant_(layer.bias, 0)\r\n\r\n        elif isinstance(layer, nn.ConvTranspose2d):\r\n            nn.init.kaiming_normal_(layer.weight)\r\n            if layer.bias is not None:\r\n                nn.init.constant_(layer.bias, 0)\r\n\r\n        elif isinstance(layer, nn.LeakyReLU):\r\n            pass\r\n\r\n        elif isinstance(layer, nn.Sequential):\r\n            pass\r\n\r\n        elif \"models\" in str(type(layer)) and \"FlowNet\" in str(type(layer)):            \r\n            pass\r\n\r\n\r\ndef get_grid(x):\r\n    grid_H = torch.linspace(-1.0, 1.0, x.size(3)).view(1, 1, 1, x.size(3)).expand(x.size(0), 1, x.size(2), x.size(3))\r\n    grid_V = torch.linspace(-1.0, 1.0, x.size(2)).view(1, 1, x.size(2), 1).expand(x.size(0), 1, x.size(2), x.size(3))\r\n    grid = torch.cat([grid_H, grid_V], 1)\r\n    grids_cuda = grid.float().requires_grad_(False).cuda()\r\n    return grids_cuda\r\n\r\n\r\nclass WarpingLayer(nn.Module):\r\n    def __init__(self):\r\n        super(WarpingLayer, self).__init__()\r\n\r\n    def forward(self, x, flow, height_im, width_im, div_flow):\r\n        flo_list = []\r\n        flo_w = flow[:, 0] * 2 / width_im / div_flow\r\n        flo_h = flow[:, 1] * 2 / height_im / div_flow\r\n        flo_list.append(flo_w)\r\n        flo_list.append(flo_h)\r\n        flow_for_grid = torch.stack(flo_list).transpose(0, 1)\r\n        grid = torch.add(get_grid(x), flow_for_grid).transpose(1, 2).transpose(2, 3)\r\n        x_warp = tf.grid_sample(x, grid, align_corners=True)\r\n\r\n        return x_warp"
  },
  {
    "path": "models/irr_modules.py",
    "content": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as tf\r\n\r\ndef conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True):\r\n    if isReLU:\r\n        return nn.Sequential(\r\n            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,\r\n                      padding=((kernel_size - 1) * dilation) // 2, bias=True),\r\n            nn.LeakyReLU(0.1, inplace=True)\r\n        )\r\n    else:\r\n        return nn.Sequential(\r\n            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,\r\n                      padding=((kernel_size - 1) * dilation) // 2, bias=True)\r\n        )\r\n\r\n\r\ndef upsample_factor2(inputs, target_as):\r\n    inputs = tf.interpolate(inputs, scale_factor=2, mode=\"nearest\")\r\n    _, _, h, w = target_as.size()\r\n    if inputs.size(2) != h or inputs.size(3) != w:\r\n        return tf.interpolate(inputs, [h, w], mode=\"bilinear\", align_corners=False)\r\n    else:\r\n        return inputs\r\n\r\n\r\nclass OccUpsampleNetwork(nn.Module):\r\n    def __init__(self, ch_in, ch_out):\r\n        super(OccUpsampleNetwork, self).__init__()\r\n\r\n        self.feat_dim = 32\r\n        self.init_conv = conv(ch_in, self.feat_dim)\r\n\r\n        self.res_convs = nn.Sequential(\r\n            conv(self.feat_dim, self.feat_dim),\r\n            conv(self.feat_dim, self.feat_dim, isReLU=False)\r\n        )\r\n        self.res_end_conv = conv(self.feat_dim, self.feat_dim)\r\n        self.mul_const = 0.1\r\n\r\n        self.out_convs = conv(self.feat_dim, ch_out)\r\n\r\n    def forward(self, occ, x):\r\n        occ = upsample_factor2(occ, x)\r\n        x_in = torch.cat([occ, x], dim=1)\r\n        x_init = self.init_conv(x_in)\r\n        x_res = x_init\r\n        x_res = x_res + self.res_convs(x_res) * self.mul_const\r\n        x_res = x_res + self.res_convs(x_res) * self.mul_const\r\n        x_res = x_res + self.res_convs(x_res) * self.mul_const\r\n        x_init = x_init + self.res_end_conv(x_res)\r\n\r\n        return self.out_convs(x_init) + occ\r\n\r\n\r\ndef subtract_mean(input):\r\n    return input - input.mean(2).mean(2).unsqueeze(2).unsqueeze(2).expand_as(input)\r\n\r\n    \r\nclass RefineFlow(nn.Module):\r\n    def __init__(self, ch_in):\r\n        super(RefineFlow, self).__init__()\r\n\r\n        self.kernel_size = 3\r\n        self.pad_size = 1\r\n        self.pad_ftn = nn.ReplicationPad2d(self.pad_size)\r\n\r\n        self.convs = nn.Sequential(\r\n            conv(ch_in, 128, 3, 1, 1),\r\n            conv(128, 128, 3, 1, 1),\r\n            conv(128, 64, 3, 1, 1),\r\n            conv(64, 64, 3, 1, 1),\r\n            conv(64, 32, 3, 1, 1),\r\n            conv(32, 32, 3, 1, 1),\r\n            conv(32, self.kernel_size * self.kernel_size, 3, 1, 1)\r\n        )\r\n\r\n        self.softmax_feat = nn.Softmax(dim=1)\r\n        self.unfold_flow = nn.Unfold(kernel_size=(self.kernel_size, self.kernel_size))\r\n        self.unfold_kernel = nn.Unfold(kernel_size=(1, 1))\r\n\r\n    def forward(self, flow, diff_img, feature):\r\n        b, _, h, w = flow.size()\r\n\r\n        flow_m = subtract_mean(flow)\r\n        norm2_img = torch.norm(diff_img, p=2, dim=1, keepdim=True)\r\n\r\n        feat = self.convs(torch.cat([flow_m, norm2_img, feature], dim=1))\r\n        feat_kernel = self.softmax_feat(-feat ** 2)\r\n\r\n        flow_x = flow[:, 0].unsqueeze(1)\r\n        flow_y = flow[:, 1].unsqueeze(1)\r\n\r\n        flow_x_unfold = self.unfold_flow(self.pad_ftn(flow_x))\r\n        flow_y_unfold = self.unfold_flow(self.pad_ftn(flow_y))\r\n        feat_kernel_unfold = self.unfold_kernel(feat_kernel)\r\n\r\n        flow_out_x = torch.sum(flow_x_unfold * feat_kernel_unfold, dim=1).unsqueeze(1).view(b, 1, h, w)\r\n        flow_out_y = torch.sum(flow_y_unfold * feat_kernel_unfold, dim=1).unsqueeze(1).view(b, 1, h, w)\r\n\r\n        return torch.cat([flow_out_x, flow_out_y], dim=1)\r\n\r\n\r\nclass RefineOcc(nn.Module):\r\n    def __init__(self, ch_in):\r\n        super(RefineOcc, self).__init__()\r\n\r\n        self.kernel_size = 3\r\n        self.pad_size = 1\r\n        self.pad_ftn = nn.ReplicationPad2d(self.pad_size)\r\n\r\n        self.convs = nn.Sequential(\r\n            conv(ch_in, 128, 3, 1, 1),\r\n            conv(128, 128, 3, 1, 1),\r\n            conv(128, 64, 3, 1, 1),\r\n            conv(64, 64, 3, 1, 1),\r\n            conv(64, 32, 3, 1, 1),\r\n            conv(32, 32, 3, 1, 1),\r\n            conv(32, self.kernel_size * self.kernel_size, 3, 1, 1)\r\n        )\r\n\r\n        self.softmax_feat = nn.Softmax(dim=1)\r\n        self.unfold_occ = nn.Unfold(kernel_size=(self.kernel_size, self.kernel_size))\r\n        self.unfold_kernel = nn.Unfold(kernel_size=(1, 1))\r\n\r\n    def forward(self, occ, feat1, feat2):\r\n        b, _, h, w = occ.size()\r\n\r\n        feat = self.convs(torch.cat([occ, feat1, feat2], dim=1))\r\n        feat_kernel = self.softmax_feat(-feat ** 2)\r\n\r\n        occ_unfold = self.unfold_occ(self.pad_ftn(occ))\r\n        feat_kernel_unfold = self.unfold_kernel(feat_kernel)\r\n\r\n        occ_out = torch.sum(occ_unfold * feat_kernel_unfold, dim=1).unsqueeze(1).view(b, 1, h, w)\r\n\r\n        return occ_out"
  },
  {
    "path": "models/pwc_modules.py",
    "content": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as tf\r\nimport logging\r\n\r\ndef conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True):\r\n    if isReLU:\r\n        return nn.Sequential(\r\n            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,\r\n                      padding=((kernel_size - 1) * dilation) // 2, bias=True),\r\n            nn.LeakyReLU(0.1, inplace=True)\r\n        )\r\n    else:\r\n        return nn.Sequential(\r\n            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,\r\n                      padding=((kernel_size - 1) * dilation) // 2, bias=True)\r\n        )\r\n\r\n\r\ndef initialize_msra(modules):\r\n    logging.info(\"Initializing MSRA\")\r\n    for layer in modules:\r\n        if isinstance(layer, nn.Conv2d):\r\n            nn.init.kaiming_normal_(layer.weight)\r\n            if layer.bias is not None:\r\n                nn.init.constant_(layer.bias, 0)\r\n\r\n        elif isinstance(layer, nn.ConvTranspose2d):\r\n            nn.init.kaiming_normal_(layer.weight)\r\n            if layer.bias is not None:\r\n                nn.init.constant_(layer.bias, 0)\r\n\r\n        elif isinstance(layer, nn.LeakyReLU):\r\n            pass\r\n\r\n        elif isinstance(layer, nn.Sequential):\r\n            pass\r\n\r\n\r\ndef compute_cost_volume(feat1, feat2, param_dict):\r\n    \"\"\"\r\n    only implemented for:\r\n        kernel_size = 1\r\n        stride1 = 1\r\n        stride2 = 1\r\n    \"\"\"\r\n\r\n    max_disp = param_dict[\"max_disp\"]\r\n\r\n    _, _, height, width = feat1.size()\r\n    num_shifts = 2 * max_disp + 1\r\n    feat2_padded = tf.pad(feat2, (max_disp, max_disp, max_disp, max_disp), \"constant\", 0)\r\n\r\n    cost_list = []\r\n    for i in range(num_shifts):\r\n        for j in range(num_shifts):\r\n            corr = torch.mean(feat1 * feat2_padded[:, :, i:(height + i), j:(width + j)], axis=1, keepdims=True)\r\n            cost_list.append(corr)\r\n    cost_volume = torch.cat(cost_list, axis=1)\r\n    return cost_volume\r\n\r\n\r\ndef upsample2d_as(inputs, target_as, mode=\"bilinear\"):\r\n    _, _, h, w = target_as.size()\r\n    return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True)\r\n\r\n\r\ndef rescale_flow(flow, div_flow, width_im, height_im, to_local=True):\r\n    if to_local:\r\n        u_scale = float(flow.size(3) / width_im / div_flow)\r\n        v_scale = float(flow.size(2) / height_im / div_flow)\r\n    else:\r\n        u_scale = float(width_im * div_flow / flow.size(3))\r\n        v_scale = float(height_im * div_flow / flow.size(2))\r\n\r\n    u, v = flow.chunk(2, dim=1)\r\n    u *= u_scale\r\n    v *= v_scale\r\n\r\n    return torch.cat([u, v], dim=1)\r\n\r\n\r\nclass FeatureExtractor(nn.Module):\r\n    def __init__(self, num_chs):\r\n        super(FeatureExtractor, self).__init__()\r\n        self.num_chs = num_chs\r\n        self.convs = nn.ModuleList()\r\n\r\n        for l, (ch_in, ch_out) in enumerate(zip(num_chs[:-1], num_chs[1:])):\r\n            layer = nn.Sequential(\r\n                conv(ch_in, ch_out, stride=2),\r\n                conv(ch_out, ch_out)\r\n            )\r\n            self.convs.append(layer)\r\n\r\n    def forward(self, x):\r\n        feature_pyramid = []\r\n        for conv in self.convs:\r\n            x = conv(x)\r\n            feature_pyramid.append(x)\r\n\r\n        return feature_pyramid[::-1]\r\n\r\n\r\ndef get_grid(x):\r\n    grid_H = torch.linspace(-1.0, 1.0, x.size(3)).view(1, 1, 1, x.size(3)).expand(x.size(0), 1, x.size(2), x.size(3))\r\n    grid_V = torch.linspace(-1.0, 1.0, x.size(2)).view(1, 1, x.size(2), 1).expand(x.size(0), 1, x.size(2), x.size(3))\r\n    grid = torch.cat([grid_H, grid_V], 1)\r\n    grids_cuda = grid.float().requires_grad_(False).cuda()\r\n    return grids_cuda\r\n\r\n\r\nclass WarpingLayer(nn.Module):\r\n    def __init__(self):\r\n        super(WarpingLayer, self).__init__()\r\n\r\n    def forward(self, x, flow, height_im, width_im, div_flow):\r\n        flo_list = []\r\n        flo_w = flow[:, 0] * 2 / max(width_im - 1, 1) / div_flow\r\n        flo_h = flow[:, 1] * 2 / max(height_im - 1, 1) / div_flow\r\n        flo_list.append(flo_w)\r\n        flo_list.append(flo_h)\r\n        flow_for_grid = torch.stack(flo_list).transpose(0, 1)\r\n        grid = torch.add(get_grid(x), flow_for_grid).transpose(1, 2).transpose(2, 3)        \r\n        x_warp = tf.grid_sample(x, grid, align_corners=True)\r\n\r\n        mask = torch.ones(x.size(), requires_grad=False).cuda()\r\n        mask = tf.grid_sample(mask, grid, align_corners=True)\r\n        mask = (mask >= 1.0).float()\r\n\r\n        return x_warp * mask\r\n\r\nclass OpticalFlowEstimator(nn.Module):\r\n    def __init__(self, ch_in):\r\n        super(OpticalFlowEstimator, self).__init__()\r\n\r\n        self.convs = nn.Sequential(\r\n            conv(ch_in, 128),\r\n            conv(128, 128),\r\n            conv(128, 96),\r\n            conv(96, 64),\r\n            conv(64, 32)\r\n        )\r\n        self.conv_last = conv(32, 2, isReLU=False)\r\n\r\n    def forward(self, x):\r\n        x_intm = self.convs(x)\r\n        return x_intm, self.conv_last(x_intm)\r\n\r\n\r\nclass FlowEstimatorDense(nn.Module):\r\n    def __init__(self, ch_in):\r\n        super(FlowEstimatorDense, self).__init__()\r\n        self.conv1 = conv(ch_in, 128)\r\n        self.conv2 = conv(ch_in + 128, 128)\r\n        self.conv3 = conv(ch_in + 256, 96)\r\n        self.conv4 = conv(ch_in + 352, 64)\r\n        self.conv5 = conv(ch_in + 416, 32)\r\n        self.conv_last = conv(ch_in + 448, 2, isReLU=False)\r\n\r\n    def forward(self, x):\r\n        x1 = torch.cat([self.conv1(x), x], dim=1)\r\n        x2 = torch.cat([self.conv2(x1), x1], dim=1)\r\n        x3 = torch.cat([self.conv3(x2), x2], dim=1)\r\n        x4 = torch.cat([self.conv4(x3), x3], dim=1)\r\n        x5 = torch.cat([self.conv5(x4), x4], dim=1)\r\n        x_out = self.conv_last(x5)\r\n        return x5, x_out\r\n\r\n\r\nclass OcclusionEstimator(nn.Module):\r\n    def __init__(self, ch_in):\r\n        super(OcclusionEstimator, self).__init__()\r\n        self.convs = nn.Sequential(\r\n            conv(ch_in, 128),\r\n            conv(128, 128),\r\n            conv(128, 96),\r\n            conv(96, 64),\r\n            conv(64, 32)\r\n        )\r\n        self.conv_last = conv(32, 1, isReLU=False)\r\n\r\n    def forward(self, x):\r\n        x_intm = self.convs(x)\r\n        return x_intm, self.conv_last(x_intm)\r\n\r\n\r\nclass OccEstimatorDense(nn.Module):\r\n    def __init__(self, ch_in):\r\n        super(OccEstimatorDense, self).__init__()\r\n        self.conv1 = conv(ch_in, 128)\r\n        self.conv2 = conv(ch_in + 128, 128)\r\n        self.conv3 = conv(ch_in + 256, 96)\r\n        self.conv4 = conv(ch_in + 352, 64)\r\n        self.conv5 = conv(ch_in + 416, 32)\r\n        self.conv_last = conv(ch_in + 448, 1, isReLU=False)\r\n\r\n    def forward(self, x):\r\n        x1 = torch.cat([self.conv1(x), x], dim=1)\r\n        x2 = torch.cat([self.conv2(x1), x1], dim=1)\r\n        x3 = torch.cat([self.conv3(x2), x2], dim=1)\r\n        x4 = torch.cat([self.conv4(x3), x3], dim=1)\r\n        x5 = torch.cat([self.conv5(x4), x4], dim=1)\r\n        x_out = self.conv_last(x5)\r\n        return x5, x_out\r\n\r\n\r\nclass ContextNetwork(nn.Module):\r\n    def __init__(self, ch_in):\r\n        super(ContextNetwork, self).__init__()\r\n\r\n        self.convs = nn.Sequential(\r\n            conv(ch_in, 128, 3, 1, 1),\r\n            conv(128, 128, 3, 1, 2),\r\n            conv(128, 128, 3, 1, 4),\r\n            conv(128, 96, 3, 1, 8),\r\n            conv(96, 64, 3, 1, 16),\r\n            conv(64, 32, 3, 1, 1),\r\n            conv(32, 2, isReLU=False)\r\n        )\r\n\r\n    def forward(self, x):\r\n        return self.convs(x)\r\n\r\n\r\nclass OccContextNetwork(nn.Module):\r\n    def __init__(self, ch_in):\r\n        super(OccContextNetwork, self).__init__()\r\n\r\n        self.convs = nn.Sequential(\r\n            conv(ch_in, 128, 3, 1, 1),\r\n            conv(128, 128, 3, 1, 2),\r\n            conv(128, 128, 3, 1, 4),\r\n            conv(128, 96, 3, 1, 8),\r\n            conv(96, 64, 3, 1, 16),\r\n            conv(64, 32, 3, 1, 1),\r\n            conv(32, 1, isReLU=False)\r\n        )\r\n\r\n    def forward(self, x):\r\n        return self.convs(x)\r\n"
  },
  {
    "path": "models/pwcnet.py",
    "content": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom .pwc_modules import upsample2d_as, initialize_msra, compute_cost_volume\r\nfrom .pwc_modules import WarpingLayer, FeatureExtractor, ContextNetwork, FlowEstimatorDense\r\n\r\nclass PWCNet(nn.Module):\r\n    def __init__(self, args, div_flow=0.05):\r\n        super(PWCNet, self).__init__()\r\n        self.args = args\r\n        self._div_flow = div_flow\r\n        self.search_range = 4\r\n        self.num_chs = [3, 16, 32, 64, 96, 128, 196]\r\n        self.output_level = 4\r\n        self.num_levels = 7\r\n        self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)\r\n\r\n        self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)\r\n        self.warping_layer = WarpingLayer()\r\n\r\n        self.flow_estimators = nn.ModuleList()\r\n        self.dim_corr = (self.search_range * 2 + 1) ** 2\r\n        for l, ch in enumerate(self.num_chs[::-1]):\r\n            if l > self.output_level:\r\n                break\r\n\r\n            if l == 0:\r\n                num_ch_in = self.dim_corr\r\n            else:\r\n                num_ch_in = self.dim_corr + ch + 2\r\n\r\n            layer = FlowEstimatorDense(num_ch_in)\r\n            self.flow_estimators.append(layer)\r\n\r\n        self.context_networks = ContextNetwork(self.dim_corr + 32 + 2 + 448 + 2)\r\n        \r\n        self.corr_params = {\"pad_size\": self.search_range, \"kernel_size\": 1, \"max_disp\": self.search_range, \"stride1\": 1, \"stride2\": 1, \"corr_multiply\": 1}\r\n        \r\n        initialize_msra(self.modules())\r\n\r\n    def forward(self, input_dict):\r\n\r\n        x1_raw = input_dict['input1']\r\n        x2_raw = input_dict['input2']\r\n        _, _, height_im, width_im = x1_raw.size()\r\n\r\n        # on the bottom level are original images\r\n        x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]\r\n        x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]\r\n\r\n        # outputs\r\n        output_dict = {}\r\n        flows = []\r\n\r\n        # init\r\n        b_size, _, h_x1, w_x1, = x1_pyramid[0].size()\r\n        init_dtype = x1_pyramid[0].dtype\r\n        init_device = x1_pyramid[0].device\r\n        flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n\r\n        for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):\r\n\r\n            # warping\r\n            if l == 0:\r\n                x2_warp = x2\r\n            else:\r\n                flow = upsample2d_as(flow, x1, mode=\"bilinear\")\r\n                x2_warp = self.warping_layer(x2, flow, height_im, width_im, self._div_flow)\r\n\r\n            # correlation\r\n            out_corr = compute_cost_volume(x1, x2_warp, self.corr_params)\r\n            out_corr_relu = self.leakyRELU(out_corr)\r\n\r\n            # flow estimator\r\n            if l == 0:\r\n                x_intm, flow = self.flow_estimators[l](out_corr_relu)\r\n            else:\r\n                x_intm, flow = self.flow_estimators[l](torch.cat([out_corr_relu, x1, flow], dim=1))\r\n\r\n            # upsampling or post-processing\r\n            if l != self.output_level:\r\n                flows.append(flow)\r\n            else:\r\n                flow_res = self.context_networks(torch.cat([x_intm, flow], dim=1))\r\n                flow = flow + flow_res\r\n                flows.append(flow)                \r\n                break\r\n\r\n        output_dict['flow'] = flows\r\n\r\n        if self.training:\r\n            return output_dict\r\n        else:\r\n            output_dict_eval = {}\r\n            out_flow = upsample2d_as(flow, x1_raw, mode=\"bilinear\") * (1.0 / self._div_flow)\r\n            output_dict_eval['flow'] = out_flow\r\n            return output_dict_eval\r\n"
  },
  {
    "path": "models/pwcnet_bi.py",
    "content": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom .pwc_modules import upsample2d_as, initialize_msra, compute_cost_volume\r\nfrom .pwc_modules import WarpingLayer, FeatureExtractor, ContextNetwork, FlowEstimatorDense\r\n\r\nclass PWCNet(nn.Module):\r\n    def __init__(self, args, div_flow=0.05):\r\n        super(PWCNet, self).__init__()\r\n        self.args = args\r\n        self._div_flow = div_flow\r\n        self.search_range = 4\r\n        self.num_chs = [3, 16, 32, 64, 96, 128, 196]\r\n        self.output_level = 4\r\n        self.num_levels = 7\r\n        self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)\r\n\r\n        self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)\r\n        self.warping_layer = WarpingLayer()\r\n\r\n        self.flow_estimators = nn.ModuleList()\r\n        self.dim_corr = (self.search_range * 2 + 1) ** 2\r\n        for l, ch in enumerate(self.num_chs[::-1]):\r\n            if l > self.output_level:\r\n                break\r\n\r\n            if l == 0:\r\n                num_ch_in = self.dim_corr\r\n            else:\r\n                num_ch_in = self.dim_corr + ch + 2\r\n\r\n            layer = FlowEstimatorDense(num_ch_in)\r\n            self.flow_estimators.append(layer)\r\n\r\n        self.context_networks = ContextNetwork(self.dim_corr + 32 + 2 + 448 + 2)\r\n        self.corr_params = {\"pad_size\": self.search_range, \"kernel_size\": 1, \"max_disp\": self.search_range, \"stride1\": 1, \"stride2\": 1, \"corr_multiply\": 1}\r\n        \r\n        initialize_msra(self.modules())\r\n\r\n    def forward(self, input_dict):\r\n\r\n        x1_raw = input_dict['input1']\r\n        x2_raw = input_dict['input2']\r\n        _, _, height_im, width_im = x1_raw.size()\r\n\r\n        # on the bottom level are original images\r\n        x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]\r\n        x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]\r\n\r\n        # outputs\r\n        output_dict = {}        \r\n        flows = []\r\n\r\n        # init\r\n        b_size, _, h_x1, w_x1, = x1_pyramid[0].size()\r\n        init_dtype = x1_pyramid[0].dtype\r\n        init_device = x1_pyramid[0].device\r\n        flow_f = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n        flow_b = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n\r\n        for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):\r\n\r\n            # warping\r\n            if l == 0:\r\n                x2_warp = x2\r\n                x1_warp = x1\r\n            else:\r\n                flow_f = upsample2d_as(flow_f, x1, mode=\"bilinear\")\r\n                flow_b = upsample2d_as(flow_b, x2, mode=\"bilinear\")\r\n                x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)\r\n                x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)\r\n\r\n            # correlation\r\n            out_corr_f = compute_cost_volume(x1, x2_warp, self.corr_params)\r\n            out_corr_b = compute_cost_volume(x2, x1_warp, self.corr_params)\r\n                \r\n            out_corr_relu_f = self.leakyRELU(out_corr_f)\r\n            out_corr_relu_b = self.leakyRELU(out_corr_b)\r\n\r\n            # flow estimator\r\n            if l == 0:\r\n                x_intm_f, flow_f = self.flow_estimators[l](out_corr_relu_f)\r\n                x_intm_b, flow_b = self.flow_estimators[l](out_corr_relu_b)\r\n            else:\r\n                x_intm_f, flow_f = self.flow_estimators[l](torch.cat([out_corr_relu_f, x1, flow_f], dim=1))\r\n                x_intm_b, flow_b = self.flow_estimators[l](torch.cat([out_corr_relu_b, x2, flow_b], dim=1))\r\n\r\n            # upsampling or post-processing\r\n            if l != self.output_level:\r\n                flows.append([flow_f, flow_b])\r\n            else:\r\n                flow_fine_f = self.context_networks(torch.cat([x_intm_f, flow_f], dim=1))\r\n                flow_fine_b = self.context_networks(torch.cat([x_intm_b, flow_b], dim=1))\r\n                flow_f = flow_f + flow_fine_f\r\n                flow_b = flow_b + flow_fine_b\r\n                flows.append([flow_f, flow_b])\r\n                break\r\n\r\n        output_dict['flow'] = flows\r\n\r\n        if self.training:\r\n            return output_dict\r\n        else:\r\n            output_dict_eval = {}            \r\n            out_flow = upsample2d_as(flow_f, x1_raw, mode=\"bilinear\") * (1.0 / self._div_flow)\r\n            output_dict_eval['flow'] = out_flow\r\n            return output_dict_eval\r\n"
  },
  {
    "path": "models/pwcnet_irr.py",
    "content": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom .pwc_modules import conv, rescale_flow, upsample2d_as, initialize_msra, compute_cost_volume\r\nfrom .pwc_modules import WarpingLayer, FeatureExtractor, ContextNetwork, FlowEstimatorDense\r\n\r\nclass PWCNet(nn.Module):\r\n    def __init__(self, args, div_flow=0.05):\r\n        super(PWCNet, self).__init__()\r\n        self.args = args\r\n        self._div_flow = div_flow\r\n        self.search_range = 4\r\n        self.num_chs = [3, 16, 32, 64, 96, 128, 196]\r\n        self.output_level = 4\r\n        self.num_levels = 7\r\n        self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)\r\n\r\n        self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)\r\n        self.warping_layer = WarpingLayer()\r\n\r\n        self.dim_corr = (self.search_range * 2 + 1) ** 2\r\n        self.num_ch_in = self.dim_corr + 32 + 2\r\n\r\n        self.flow_estimators = FlowEstimatorDense(self.num_ch_in)\r\n\r\n        self.context_networks = ContextNetwork(self.num_ch_in + 448 + 2)\r\n\r\n        self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),\r\n                                       conv(128, 32, kernel_size=1, stride=1, dilation=1),\r\n                                       conv(96, 32, kernel_size=1, stride=1, dilation=1),\r\n                                       conv(64, 32, kernel_size=1, stride=1, dilation=1),\r\n                                       conv(32, 32, kernel_size=1, stride=1, dilation=1)])\r\n        \r\n        self.corr_params = {\"pad_size\": self.search_range, \"kernel_size\": 1, \"max_disp\": self.search_range, \"stride1\": 1, \"stride2\": 1, \"corr_multiply\": 1}\r\n        \r\n        initialize_msra(self.modules())\r\n\r\n    def forward(self, input_dict):\r\n\r\n        x1_raw = input_dict['input1']\r\n        x2_raw = input_dict['input2']\r\n        _, _, height_im, width_im = x1_raw.size()\r\n\r\n        # on the bottom level are original images\r\n        x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]\r\n        x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]\r\n\r\n        # outputs\r\n        output_dict = {}\r\n        flows = []\r\n\r\n        # init\r\n        b_size, _, h_x1, w_x1, = x1_pyramid[0].size()\r\n        init_dtype = x1_pyramid[0].dtype\r\n        init_device = x1_pyramid[0].device\r\n        flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n\r\n        for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):\r\n\r\n            # warping\r\n            if l == 0:\r\n                x2_warp = x2\r\n            else:\r\n                flow = upsample2d_as(flow, x1, mode=\"bilinear\")\r\n                x2_warp = self.warping_layer(x2, flow, height_im, width_im, self._div_flow)\r\n\r\n            # correlation\r\n            out_corr = compute_cost_volume(x1, x2_warp, self.corr_params)\r\n            out_corr_relu = self.leakyRELU(out_corr)\r\n\r\n            # concat and estimate flow\r\n            flow = rescale_flow(flow, self._div_flow, width_im, height_im, to_local=True)\r\n\r\n            x1_1by1 = self.conv_1x1[l](x1)\r\n            x_intm, flow_res = self.flow_estimators(torch.cat([out_corr_relu, x1_1by1, flow], dim=1))\r\n            flow = flow + flow_res\r\n\r\n            flow_fine = self.context_networks(torch.cat([x_intm, flow], dim=1))\r\n            flow = flow + flow_fine\r\n\r\n            flow = rescale_flow(flow, self._div_flow, width_im, height_im, to_local=False)\r\n            flows.append(flow)\r\n\r\n            # upsampling or post-processing\r\n            if l == self.output_level:\r\n                break\r\n\r\n        output_dict['flow'] = flows\r\n\r\n        if self.training:\r\n            return output_dict\r\n        else:\r\n            output_dict_eval = {}\r\n            output_dict_eval['flow'] = upsample2d_as(flow, x1_raw, mode=\"bilinear\") * (1.0 / self._div_flow)                \r\n            return output_dict_eval\r\n"
  },
  {
    "path": "models/pwcnet_irr_bi.py",
    "content": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom .pwc_modules import conv, rescale_flow, upsample2d_as, initialize_msra, compute_cost_volume\r\nfrom .pwc_modules import WarpingLayer, FeatureExtractor, ContextNetwork, FlowEstimatorDense\r\n\r\nclass PWCNet(nn.Module):\r\n    def __init__(self, args, div_flow=0.05):\r\n        super(PWCNet, self).__init__()\r\n        self.args = args\r\n        self._div_flow = div_flow\r\n        self.search_range = 4\r\n        self.num_chs = [3, 16, 32, 64, 96, 128, 196]\r\n        self.output_level = 4\r\n        self.num_levels = 7\r\n        self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)\r\n\r\n        self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)\r\n        self.warping_layer = WarpingLayer()\r\n\r\n        self.dim_corr = (self.search_range * 2 + 1) ** 2\r\n        self.num_ch_in = self.dim_corr + 32 + 2\r\n\r\n        self.flow_estimators = FlowEstimatorDense(self.num_ch_in)\r\n\r\n        self.context_networks = ContextNetwork(self.num_ch_in + 448 + 2)\r\n\r\n        self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),\r\n                                       conv(128, 32, kernel_size=1, stride=1, dilation=1),\r\n                                       conv(96, 32, kernel_size=1, stride=1, dilation=1),\r\n                                       conv(64, 32, kernel_size=1, stride=1, dilation=1),\r\n                                       conv(32, 32, kernel_size=1, stride=1, dilation=1)])\r\n        \r\n        self.corr_params = {\"pad_size\": self.search_range, \"kernel_size\": 1, \"max_disp\": self.search_range, \"stride1\": 1, \"stride2\": 1, \"corr_multiply\": 1}\r\n        \r\n        initialize_msra(self.modules())\r\n\r\n    def forward(self, input_dict):\r\n\r\n        x1_raw = input_dict['input1']\r\n        x2_raw = input_dict['input2']\r\n        _, _, height_im, width_im = x1_raw.size()\r\n\r\n        # on the bottom level are original images\r\n        x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]\r\n        x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]\r\n\r\n        # outputs\r\n        output_dict = {}\r\n        flows = []\r\n\r\n        # init\r\n        b_size, _, h_x1, w_x1, = x1_pyramid[0].size()\r\n        init_dtype = x1_pyramid[0].dtype\r\n        init_device = x1_pyramid[0].device\r\n        flow_f = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n        flow_b = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n\r\n        for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):\r\n\r\n            # warping\r\n            if l == 0:\r\n                x2_warp = x2\r\n                x1_warp = x1\r\n            else:\r\n                flow_f = upsample2d_as(flow_f, x1, mode=\"bilinear\")\r\n                flow_b = upsample2d_as(flow_b, x2, mode=\"bilinear\")\r\n                x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)\r\n                x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)\r\n\r\n            # correlation\r\n            out_corr_f = compute_cost_volume(x1, x2_warp, self.corr_params)\r\n            out_corr_b = compute_cost_volume(x2, x1_warp, self.corr_params)\r\n            out_corr_relu_f = self.leakyRELU(out_corr_f)\r\n            out_corr_relu_b = self.leakyRELU(out_corr_b)\r\n\r\n            # concat and estimate flow\r\n            flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=True)\r\n            flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=True)\r\n\r\n            x1_1by1 = self.conv_1x1[l](x1)\r\n            x2_1by1 = self.conv_1x1[l](x2)\r\n            x_intm_f, flow_res_f = self.flow_estimators(torch.cat([out_corr_relu_f, x1_1by1, flow_f], dim=1))\r\n            x_intm_b, flow_res_b = self.flow_estimators(torch.cat([out_corr_relu_b, x2_1by1, flow_b], dim=1))\r\n            flow_f = flow_f + flow_res_f\r\n            flow_b = flow_b + flow_res_b\r\n\r\n            flow_fine_f = self.context_networks(torch.cat([x_intm_f, flow_f], dim=1))\r\n            flow_fine_b = self.context_networks(torch.cat([x_intm_b, flow_b], dim=1))\r\n            flow_f = flow_f + flow_fine_f\r\n            flow_b = flow_b + flow_fine_b\r\n\r\n            flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=False)\r\n            flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=False)\r\n\r\n            flows.append([flow_f, flow_b])\r\n\r\n            # upsampling or post-processing\r\n            if l == self.output_level:\r\n                break\r\n\r\n        output_dict['flow'] = flows\r\n\r\n        if self.training:\r\n            return output_dict\r\n        else:\r\n            output_dict_eval = {}\r\n            output_dict_eval['flow'] = upsample2d_as(flow_f, x1_raw, mode=\"bilinear\") * (1.0 / self._div_flow)                \r\n            return output_dict_eval\r\n"
  },
  {
    "path": "models/pwcnet_irr_occ.py",
    "content": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom .pwc_modules import conv, rescale_flow, upsample2d_as, initialize_msra, compute_cost_volume\r\nfrom .pwc_modules import WarpingLayer, FeatureExtractor, FlowEstimatorDense, ContextNetwork, OccEstimatorDense, OccContextNetwork\r\n\r\nclass PWCNet(nn.Module):\r\n    def __init__(self, args, div_flow=0.05):\r\n        super(PWCNet, self).__init__()\r\n        self.args = args\r\n        self._div_flow = div_flow\r\n        self.search_range = 4\r\n        self.num_chs = [3, 16, 32, 64, 96, 128, 196]\r\n        self.output_level = 4\r\n        self.num_levels = 7\r\n        self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)\r\n\r\n        self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)\r\n        self.warping_layer = WarpingLayer()\r\n\r\n        self.dim_corr = (self.search_range * 2 + 1) ** 2\r\n        self.num_ch_in_flo = self.dim_corr + 32 + 2\r\n        self.num_ch_in_occ = self.dim_corr + 32 + 1\r\n\r\n        self.flow_estimators = FlowEstimatorDense(self.num_ch_in_flo)\r\n        self.context_networks = ContextNetwork(self.num_ch_in_flo + 448 + 2)\r\n\r\n        self.occ_estimators = OccEstimatorDense(self.num_ch_in_occ)\r\n        self.occ_context_networks = OccContextNetwork(self.num_ch_in_occ + 448 + 1)\r\n\r\n        self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),\r\n                                       conv(128, 32, kernel_size=1, stride=1, dilation=1),\r\n                                       conv(96, 32, kernel_size=1, stride=1, dilation=1),\r\n                                       conv(64, 32, kernel_size=1, stride=1, dilation=1),\r\n                                       conv(32, 32, kernel_size=1, stride=1, dilation=1)])\r\n        \r\n        self.corr_params = {\"pad_size\": self.search_range, \"kernel_size\": 1, \"max_disp\": self.search_range, \"stride1\": 1, \"stride2\": 1, \"corr_multiply\": 1}\r\n        \r\n        initialize_msra(self.modules())\r\n\r\n    def forward(self, input_dict):\r\n\r\n        x1_raw = input_dict['input1']\r\n        x2_raw = input_dict['input2']\r\n        _, _, height_im, width_im = x1_raw.size()\r\n\r\n        # on the bottom level are original images\r\n        x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]\r\n        x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]\r\n\r\n        # outputs\r\n        output_dict = {}\r\n        flows = []\r\n        occs = []\r\n\r\n        # init\r\n        b_size, _, h_x1, w_x1, = x1_pyramid[0].size()\r\n        init_dtype = x1_pyramid[0].dtype\r\n        init_device = x1_pyramid[0].device\r\n        flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n        occ = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n\r\n        for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):\r\n\r\n            # warping\r\n            if l == 0:\r\n                x2_warp = x2\r\n            else:\r\n                flow = upsample2d_as(flow, x1, mode=\"bilinear\")\r\n                occ = upsample2d_as(occ, x1, mode=\"bilinear\")\r\n                x2_warp = self.warping_layer(x2, flow, height_im, width_im, self._div_flow)\r\n\r\n            # correlation\r\n            out_corr = compute_cost_volume(x1, x2_warp, self.corr_params)\r\n            out_corr_relu = self.leakyRELU(out_corr)\r\n\r\n            # concat and estimate flow\r\n            flow = rescale_flow(flow, self._div_flow, width_im, height_im, to_local=True)\r\n\r\n            x1_1by1 = self.conv_1x1[l](x1)\r\n            x_intm, flow_res = self.flow_estimators(torch.cat([out_corr_relu, x1_1by1, flow], dim=1))\r\n            flow = flow + flow_res\r\n\r\n            flow_fine = self.context_networks(torch.cat([x_intm, flow], dim=1))\r\n            flow = flow + flow_fine\r\n\r\n            flow = rescale_flow(flow, self._div_flow, width_im, height_im, to_local=False)\r\n            flows.append(flow)\r\n\r\n            x_intm_occ, occ_res = self.occ_estimators(torch.cat([out_corr_relu, x1_1by1, occ], dim=1))\r\n            occ = occ + occ_res\r\n\r\n            occ_fine = self.occ_context_networks(torch.cat([x_intm_occ, occ], dim=1))\r\n            occ = occ + occ_fine\r\n            occs.append(occ)\r\n\r\n            # upsampling or post-processing\r\n            if l == self.output_level:\r\n                break\r\n\r\n        output_dict['flow'] = flows\r\n        output_dict['occ'] = occs\r\n\r\n        if self.training:\r\n            return output_dict\r\n        else:\r\n            output_dict_eval = {}\r\n            output_dict_eval['flow'] = upsample2d_as(flow, x1_raw, mode=\"bilinear\") * (1.0 / self._div_flow)\r\n            output_dict_eval['occ'] = upsample2d_as(occ, x1_raw, mode=\"bilinear\")\r\n            return output_dict_eval\r\n"
  },
  {
    "path": "models/pwcnet_irr_occ_bi.py",
    "content": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom .pwc_modules import conv, rescale_flow, upsample2d_as, initialize_msra, compute_cost_volume\r\nfrom .pwc_modules import WarpingLayer, FeatureExtractor, FlowEstimatorDense, ContextNetwork, OccEstimatorDense, OccContextNetwork\r\n\r\nclass PWCNet(nn.Module):\r\n    def __init__(self, args, div_flow=0.05):\r\n        super(PWCNet, self).__init__()\r\n        self.args = args\r\n        self._div_flow = div_flow\r\n        self.search_range = 4\r\n        self.num_chs = [3, 16, 32, 64, 96, 128, 196]\r\n        self.output_level = 4\r\n        self.num_levels = 7\r\n        self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)\r\n\r\n        self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)\r\n        self.warping_layer = WarpingLayer()\r\n\r\n        self.dim_corr = (self.search_range * 2 + 1) ** 2\r\n        self.num_ch_in_flo = self.dim_corr + 32 + 2\r\n        self.num_ch_in_occ = self.dim_corr + 32 + 1\r\n\r\n        self.flow_estimators = FlowEstimatorDense(self.num_ch_in_flo)\r\n        self.context_networks = ContextNetwork(self.num_ch_in_flo + 448 + 2)\r\n        \r\n        self.occ_estimators = OccEstimatorDense(self.num_ch_in_occ)\r\n        self.occ_context_networks = OccContextNetwork(self.num_ch_in_occ + 448 + 1)\r\n        \r\n        self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),\r\n                                       conv(128, 32, kernel_size=1, stride=1, dilation=1),\r\n                                       conv(96, 32, kernel_size=1, stride=1, dilation=1),\r\n                                       conv(64, 32, kernel_size=1, stride=1, dilation=1),\r\n                                       conv(32, 32, kernel_size=1, stride=1, dilation=1)])\r\n        \r\n        self.corr_params = {\"pad_size\": self.search_range, \"kernel_size\": 1, \"max_disp\": self.search_range, \"stride1\": 1, \"stride2\": 1, \"corr_multiply\": 1}\r\n        \r\n        initialize_msra(self.modules())\r\n\r\n    def forward(self, input_dict):\r\n\r\n        x1_raw = input_dict['input1']\r\n        x2_raw = input_dict['input2']\r\n        _, _, height_im, width_im = x1_raw.size()\r\n\r\n        # on the bottom level are original images\r\n        x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]\r\n        x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]\r\n\r\n        # outputs\r\n        output_dict = {}\r\n        flows = []\r\n        occs = []\r\n\r\n        # init\r\n        b_size, _, h_x1, w_x1, = x1_pyramid[0].size()\r\n        init_dtype = x1_pyramid[0].dtype\r\n        init_device = x1_pyramid[0].device\r\n        flow_f = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n        flow_b = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n        occ_f = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n        occ_b = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n\r\n        for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):\r\n\r\n            # warping\r\n            if l == 0:\r\n                x2_warp = x2\r\n                x1_warp = x1\r\n            else:\r\n                flow_f = upsample2d_as(flow_f, x1, mode=\"bilinear\")\r\n                flow_b = upsample2d_as(flow_b, x2, mode=\"bilinear\")\r\n                occ_f = upsample2d_as(occ_f, x1, mode=\"bilinear\")\r\n                occ_b = upsample2d_as(occ_b, x2, mode=\"bilinear\")\r\n                x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)\r\n                x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)\r\n\r\n            # correlation\r\n            out_corr_f = compute_cost_volume(x1, x2_warp, self.corr_params)\r\n            out_corr_b = compute_cost_volume(x2, x1_warp, self.corr_params)\r\n            out_corr_relu_f = self.leakyRELU(out_corr_f)\r\n            out_corr_relu_b = self.leakyRELU(out_corr_b)\r\n\r\n            # concat and estimate flow\r\n            flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=True)\r\n            flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=True)\r\n\r\n            x1_1by1 = self.conv_1x1[l](x1)\r\n            x2_1by1 = self.conv_1x1[l](x2)\r\n            x_intm_f, flow_res_f = self.flow_estimators(torch.cat([out_corr_relu_f, x1_1by1, flow_f], dim=1))\r\n            x_intm_b, flow_res_b = self.flow_estimators(torch.cat([out_corr_relu_b, x2_1by1, flow_b], dim=1))\r\n            flow_f = flow_f + flow_res_f\r\n            flow_b = flow_b + flow_res_b\r\n\r\n            flow_fine_f = self.context_networks(torch.cat([x_intm_f, flow_f], dim=1))\r\n            flow_fine_b = self.context_networks(torch.cat([x_intm_b, flow_b], dim=1))\r\n            flow_f = flow_f + flow_fine_f\r\n            flow_b = flow_b + flow_fine_b\r\n\r\n            flow_f = rescale_flow(flow_f, self._div_flow, width_im, height_im, to_local=False)\r\n            flow_b = rescale_flow(flow_b, self._div_flow, width_im, height_im, to_local=False)\r\n\r\n            flows.append([flow_f, flow_b])\r\n\r\n            # occ estimation\r\n            x_intm_occ_f, occ_res_f = self.occ_estimators(torch.cat([out_corr_relu_f, x1_1by1, occ_f], dim=1))\r\n            x_intm_occ_b, occ_res_b = self.occ_estimators(torch.cat([out_corr_relu_b, x2_1by1, occ_b], dim=1))\r\n            occ_f = occ_f + occ_res_f\r\n            occ_b = occ_b + occ_res_b\r\n\r\n            occ_fine_f = self.occ_context_networks(torch.cat([x_intm_occ_f, occ_f], dim=1))\r\n            occ_fine_b = self.occ_context_networks(torch.cat([x_intm_occ_b, occ_b], dim=1))\r\n            occ_f = occ_f + occ_fine_f\r\n            occ_b = occ_b + occ_fine_b\r\n            occs.append([occ_f, occ_b])\r\n\r\n            # upsampling or post-processing\r\n            if l == self.output_level:\r\n                break\r\n\r\n        output_dict['flow'] = flows\r\n        output_dict['occ'] = occs\r\n\r\n        if self.training:\r\n            return output_dict\r\n        else:\r\n            output_dict_eval = {}\r\n            output_dict_eval['flow'] = upsample2d_as(flow_f, x1_raw, mode=\"bilinear\") * (1.0 / self._div_flow)\r\n            output_dict_eval['occ'] = upsample2d_as(occ_f, x1_raw, mode=\"bilinear\")\r\n            return output_dict_eval\r\n"
  },
  {
    "path": "models/pwcnet_occ.py",
    "content": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom .pwc_modules import upsample2d_as, initialize_msra, compute_cost_volume\r\nfrom .pwc_modules import WarpingLayer, FeatureExtractor, FlowEstimatorDense, ContextNetwork, OccEstimatorDense, OccContextNetwork\r\n\r\nclass PWCNet(nn.Module):\r\n    def __init__(self, args, div_flow=0.05):\r\n        super(PWCNet, self).__init__()\r\n        self.args = args\r\n        self._div_flow = div_flow\r\n        self.search_range = 4\r\n        self.num_chs = [3, 16, 32, 64, 96, 128, 196]\r\n        self.output_level = 4\r\n        self.num_levels = 7\r\n        self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)\r\n\r\n        self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)\r\n        self.warping_layer = WarpingLayer()\r\n\r\n        self.flow_estimators = nn.ModuleList()\r\n        self.occ_estimators = nn.ModuleList()\r\n        self.dim_corr = (self.search_range * 2 + 1) ** 2\r\n        for l, ch in enumerate(self.num_chs[::-1]):\r\n            if l > self.output_level:\r\n                break\r\n\r\n            if l == 0:\r\n                num_ch_in = self.dim_corr\r\n                num_ch_in_occ = self.dim_corr\r\n            else:\r\n                num_ch_in = self.dim_corr + ch + 2\r\n                num_ch_in_occ = self.dim_corr + ch + 1\r\n\r\n            layer = FlowEstimatorDense(num_ch_in)\r\n            layer_occ = OccEstimatorDense(num_ch_in_occ)\r\n            self.flow_estimators.append(layer)\r\n            self.occ_estimators.append(layer_occ)\r\n\r\n        self.context_networks = ContextNetwork(self.dim_corr + 32 + 2 + 448 + 2)\r\n        self.context_networks_occ = OccContextNetwork(self.dim_corr + 32 + 1 + 448 + 1)\r\n        \r\n        self.corr_params = {\"pad_size\": self.search_range, \"kernel_size\": 1, \"max_disp\": self.search_range, \"stride1\": 1, \"stride2\": 1, \"corr_multiply\": 1}\r\n        \r\n        initialize_msra(self.modules())\r\n\r\n    def forward(self, input_dict):\r\n\r\n        x1_raw = input_dict['input1']\r\n        x2_raw = input_dict['input2']\r\n        _, _, height_im, width_im = x1_raw.size()\r\n\r\n        # on the bottom level are original images\r\n        x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]\r\n        x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]\r\n\r\n        # outputs\r\n        output_dict = {}\r\n        flows = []\r\n        occs = []\r\n\r\n        # init\r\n        b_size, _, h_x1, w_x1, = x1_pyramid[0].size()\r\n        init_dtype = x1_pyramid[0].dtype\r\n        init_device = x1_pyramid[0].device\r\n        flow = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n        occ = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n\r\n        for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):\r\n\r\n            # warping\r\n            if l == 0:\r\n                x2_warp = x2\r\n            else:\r\n                flow = upsample2d_as(flow, x1, mode=\"bilinear\")\r\n                occ = upsample2d_as(occ, x1, mode=\"bilinear\")\r\n                x2_warp = self.warping_layer(x2, flow, height_im, width_im, self._div_flow)\r\n\r\n            # correlation  \r\n            out_corr = compute_cost_volume(x1, x2_warp, self.corr_params)\r\n            out_corr_relu = self.leakyRELU(out_corr)\r\n\r\n            # flow estimator\r\n            if l == 0:\r\n                x_intm, flow = self.flow_estimators[l](out_corr_relu)\r\n                x_intm_occ, occ= self.occ_estimators[l](out_corr_relu)\r\n\r\n            else:\r\n                x_intm, flow = self.flow_estimators[l](torch.cat([out_corr_relu, x1, flow], dim=1))\r\n                x_intm_occ, occ = self.occ_estimators[l](torch.cat([out_corr_relu, x1, occ], dim=1))\r\n\r\n            # upsampling or post-processing\r\n            if l != self.output_level:\r\n                flows.append(flow)\r\n                occs.append(occ)\r\n            else:\r\n                flow_fine = self.context_networks(torch.cat([x_intm, flow], dim=1))\r\n                flow = flow + flow_fine\r\n                flows.append(flow)\r\n\r\n                occ_fine = self.context_networks_occ(torch.cat([x_intm_occ, occ], dim=1))\r\n                occ = occ + occ_fine\r\n                occs.append(occ)\r\n                break\r\n\r\n        output_dict['flow'] = flows\r\n        output_dict['occ'] = occs\r\n\r\n        if self.training:\r\n            return output_dict\r\n        else:\r\n            output_dict_eval = {}\r\n            output_dict_eval['flow'] = upsample2d_as(flow, x1_raw, mode=\"bilinear\") * (1.0 / self._div_flow)\r\n            output_dict_eval['occ'] = upsample2d_as(occ, x1_raw, mode=\"bilinear\")\r\n            return output_dict_eval\r\n"
  },
  {
    "path": "models/pwcnet_occ_bi.py",
    "content": "from __future__ import absolute_import, division, print_function\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom .pwc_modules import upsample2d_as, initialize_msra, compute_cost_volume\r\nfrom .pwc_modules import WarpingLayer, FeatureExtractor, FlowEstimatorDense, ContextNetwork, OccEstimatorDense, OccContextNetwork\r\n\r\nclass PWCNet(nn.Module):\r\n    def __init__(self, args, div_flow=0.05):\r\n        super(PWCNet, self).__init__()\r\n        self.args = args\r\n        self._div_flow = div_flow\r\n        self.search_range = 4\r\n        self.num_chs = [3, 16, 32, 64, 96, 128, 196]\r\n        self.output_level = 4\r\n        self.num_levels = 7\r\n        self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)\r\n\r\n        self.feature_pyramid_extractor = FeatureExtractor(self.num_chs)\r\n        self.warping_layer = WarpingLayer()\r\n\r\n        self.flow_estimators = nn.ModuleList()\r\n        self.occ_estimators = nn.ModuleList()\r\n        self.dim_corr = (self.search_range * 2 + 1) ** 2\r\n        for l, ch in enumerate(self.num_chs[::-1]):\r\n            if l > self.output_level:\r\n                break\r\n\r\n            if l == 0:\r\n                num_ch_in = self.dim_corr\r\n                num_ch_in_occ = self.dim_corr\r\n            else:\r\n                num_ch_in = self.dim_corr + ch + 2\r\n                num_ch_in_occ = self.dim_corr + ch + 1\r\n\r\n            layer = FlowEstimatorDense(num_ch_in)\r\n            layer_occ = OccEstimatorDense(num_ch_in_occ)\r\n            self.flow_estimators.append(layer)\r\n            self.occ_estimators.append(layer_occ)\r\n\r\n        self.context_networks = ContextNetwork(self.dim_corr + 32 + 2 + 448 + 2)\r\n        self.context_networks_occ = OccContextNetwork(self.dim_corr + 32 + 1 + 448 + 1)\r\n        \r\n        self.corr_params = {\"pad_size\": self.search_range, \"kernel_size\": 1, \"max_disp\": self.search_range, \"stride1\": 1, \"stride2\": 1, \"corr_multiply\": 1}\r\n        \r\n        initialize_msra(self.modules())\r\n\r\n    def forward(self, input_dict):\r\n\r\n        x1_raw = input_dict['input1']\r\n        x2_raw = input_dict['input2']\r\n        _, _, height_im, width_im = x1_raw.size()\r\n\r\n        # on the bottom level are original images\r\n        x1_pyramid = self.feature_pyramid_extractor(x1_raw) + [x1_raw]\r\n        x2_pyramid = self.feature_pyramid_extractor(x2_raw) + [x2_raw]\r\n\r\n        # outputs\r\n        output_dict = {}\r\n        flows = []\r\n        occs = []\r\n\r\n        # init\r\n        b_size, _, h_x1, w_x1, = x1_pyramid[0].size()\r\n        init_dtype = x1_pyramid[0].dtype\r\n        init_device = x1_pyramid[0].device\r\n        flow_f = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n        flow_b = torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n        occ_f = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n        occ_b = torch.zeros(b_size, 1, h_x1, w_x1, dtype=init_dtype, device=init_device).float()\r\n\r\n        for l, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)):\r\n\r\n            # warping\r\n            if l == 0:\r\n                x2_warp = x2\r\n                x1_warp = x1\r\n            else:\r\n                flow_f = upsample2d_as(flow_f, x1, mode=\"bilinear\")\r\n                flow_b = upsample2d_as(flow_b, x2, mode=\"bilinear\")\r\n                occ_f = upsample2d_as(occ_f, x1, mode=\"bilinear\")\r\n                occ_b = upsample2d_as(occ_b, x2, mode=\"bilinear\")\r\n                x2_warp = self.warping_layer(x2, flow_f, height_im, width_im, self._div_flow)\r\n                x1_warp = self.warping_layer(x1, flow_b, height_im, width_im, self._div_flow)\r\n\r\n            # correlation\r\n            out_corr_f = compute_cost_volume(x1, x2_warp, self.corr_params)\r\n            out_corr_b = compute_cost_volume(x2, x1_warp, self.corr_params)\r\n            out_corr_relu_f = self.leakyRELU(out_corr_f)\r\n            out_corr_relu_b = self.leakyRELU(out_corr_b)\r\n\r\n            # flow estimator\r\n            if l == 0:\r\n                x_intm_f, flow_f = self.flow_estimators[l](out_corr_relu_f)\r\n                x_intm_b, flow_b = self.flow_estimators[l](out_corr_relu_b)\r\n                x_intm_occ_f, occ_f = self.occ_estimators[l](out_corr_relu_f)\r\n                x_intm_occ_b, occ_b = self.occ_estimators[l](out_corr_relu_b)\r\n            else:\r\n                x_intm_f, flow_f = self.flow_estimators[l](torch.cat([out_corr_relu_f, x1, flow_f], dim=1))\r\n                x_intm_b, flow_b = self.flow_estimators[l](torch.cat([out_corr_relu_b, x2, flow_b], dim=1))\r\n                x_intm_occ_f, occ_f = self.occ_estimators[l](torch.cat([out_corr_relu_f, x1, occ_f], dim=1))\r\n                x_intm_occ_b, occ_b = self.occ_estimators[l](torch.cat([out_corr_relu_b, x1, occ_b], dim=1))\r\n\r\n            # upsampling or post-processing\r\n            if l != self.output_level:\r\n                flows.append([flow_f, flow_b])\r\n                occs.append([occ_f, occ_b])\r\n            else:\r\n                flow_fine_f = self.context_networks(torch.cat([x_intm_f, flow_f], dim=1))\r\n                flow_fine_b = self.context_networks(torch.cat([x_intm_b, flow_b], dim=1))\r\n                flow_f = flow_f + flow_fine_f\r\n                flow_b = flow_b + flow_fine_b\r\n                flows.append([flow_f, flow_b])\r\n\r\n                occ_fine_f = self.context_networks_occ(torch.cat([x_intm_occ_f, occ_f], dim=1))\r\n                occ_fine_b = self.context_networks_occ(torch.cat([x_intm_occ_b, occ_b], dim=1))\r\n                occ_f = occ_f + occ_fine_f\r\n                occ_b = occ_b + occ_fine_b\r\n                occs.append([occ_f, occ_b])                \r\n                break\r\n\r\n        output_dict['flow'] = flows\r\n        output_dict['occ'] = occs\r\n\r\n        if self.training:\r\n            return output_dict\r\n        else:\r\n            output_dict_eval = {}\r\n            output_dict_eval['flow'] = upsample2d_as(flow_f, x1_raw, mode=\"bilinear\") * (1.0 / self._div_flow)\r\n            output_dict_eval['occ'] = upsample2d_as(occ_f, x1_raw, mode=\"bilinear\")\r\n            return output_dict_eval\r\n"
  },
  {
    "path": "optim/__init__.py",
    "content": "import torch\nimport sys\nfrom tools import module_classes_to_dict\n\n# ------------------------------------------------------------------------------------\n# Export PyTorch optimizer\n# ------------------------------------------------------------------------------------\n_this = sys.modules[__name__]\n_optimizer_classes = module_classes_to_dict(torch.optim, exclude_classes=\"Optimizer\")\nfor name, constructor in _optimizer_classes.items():\n    setattr(_this, name, constructor)\n__all__ = _optimizer_classes.keys()\n\n"
  },
  {
    "path": "runtime.py",
    "content": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\nimport numpy as np\nimport colorama\nimport logging\nimport logger\nimport tools\nfrom tools import MovingAverage\nimport collections\n\nimport scipy.misc\nimport torch\nimport torch.nn as nn\nimport os\n\n# for evaluation\nfrom utils.flow import flow_to_png, flow_to_png_middlebury\nfrom utils.flow import write_flow, write_flow_png\n\n# --------------------------------------------------------------------------------\n# Exponential moving average smoothing factor for speed estimates\n# Ranges from 0 (average speed) to 1 (current/instantaneous speed) [default: 0.3].\n# --------------------------------------------------------------------------------\nTQDM_SMOOTHING = 0\n\n\n# -------------------------------------------------------------------------------------------\n# Magic progressbar for inputs of type 'iterable'\n# -------------------------------------------------------------------------------------------\ndef create_progressbar(iterable,\n                       desc=\"\",\n                       train=False,\n                       unit=\"it\",\n                       initial=0,\n                       offset=0,\n                       invert_iterations=False,\n                       logging_on_update=False,\n                       logging_on_close=True,\n                       postfix=False):\n\n    # ---------------------------------------------------------------\n    # Pick colors\n    # ---------------------------------------------------------------\n    reset = colorama.Style.RESET_ALL\n    bright = colorama.Style.BRIGHT\n    cyan = colorama.Fore.CYAN\n    dim = colorama.Style.DIM\n    green = colorama.Fore.GREEN\n\n    # ---------------------------------------------------------------\n    # Specify progressbar layout:\n    #   l_bar, bar, r_bar, n, n_fmt, total, total_fmt, percentage,\n    #   rate, rate_fmt, rate_noinv, rate_noinv_fmt, rate_inv,\n    #   rate_inv_fmt, elapsed, remaining, desc, postfix.\n    # ---------------------------------------------------------------\n    bar_format = \"\"\n    bar_format += \"%s==>%s%s {desc}:%s \" % (cyan, reset, bright, reset)     # description\n    bar_format += \"{percentage:3.0f}%\"                                      # percentage\n    bar_format += \"%s|{bar}|%s \" % (dim, reset)                             # bar\n    bar_format += \" {n_fmt}/{total_fmt}  \"                                  # i/n counter\n    bar_format += \"{elapsed}<{remaining}\"                                   # eta\n    if invert_iterations:\n        bar_format += \" {rate_inv_fmt}  \"                                   # iteration timings\n    else:\n        bar_format += \" {rate_noinv_fmt}  \"\n    bar_format += \"%s{postfix}%s\" % (green, reset)                          # postfix\n\n    # ---------------------------------------------------------------\n    # Specify TQDM arguments\n    # ---------------------------------------------------------------\n    tqdm_args = {\n        \"iterable\": iterable,\n        \"desc\": desc,                          # Prefix for the progress bar\n        \"total\": len(iterable),                # The number of expected iterations\n        \"leave\": True,                         # Leave progress bar when done\n        \"miniters\": 1 if train else None,      # Minimum display update interval in iterations\n        \"unit\": unit,                          # String be used to define the unit of each iteration\n        \"initial\": initial,                    # The initial counter value.\n        \"dynamic_ncols\": True,                 # Allow window resizes\n        \"smoothing\": TQDM_SMOOTHING,           # Moving average smoothing factor for speed estimates\n        \"bar_format\": bar_format,              # Specify a custom bar string formatting\n        \"position\": offset,                    # Specify vertical line offset\n        \"ascii\": True,\n        \"logging_on_update\": logging_on_update,\n        \"logging_on_close\": logging_on_close\n    }\n\n    return tools.tqdm_with_logging(**tqdm_args)\n\n\ndef tensor2float_dict(tensor_dict):\n    return {key: tensor.item() for key, tensor in tensor_dict.items()}\n\n\ndef format_moving_averages_as_progress_dict(moving_averages_dict={},\n                                            moving_averages_postfix=\"avg\"):\n    progress_dict = collections.OrderedDict([\n        (key + moving_averages_postfix, \"%1.4f\" % moving_averages_dict[key].mean())\n        for key in sorted(moving_averages_dict.keys())\n    ])\n    return progress_dict\n\n\ndef format_learning_rate(lr):\n    if np.isscalar(lr):\n        return \"{}\".format(lr)\n    else:\n        return \"{}\".format(str(lr[0]) if len(lr) == 1 else lr)\n\n\nclass TrainingEpoch:\n    def __init__(self,\n                 args,\n                 model_and_loss,\n                 loader,\n                 optimizer,\n                 augmentation=None,\n                 add_progress_stats={},\n                 desc=\"Training Epoch\"):\n\n        self._args = args\n        self._desc = desc\n        self._loader = loader\n        self._model_and_loss = model_and_loss\n        self._optimizer = optimizer\n        self._augmentation = augmentation\n        self._add_progress_stats = add_progress_stats\n\n    def _step(self, example_dict):\n\n        # -------------------------------------------------------------\n        # Get input and target tensor keys\n        # -------------------------------------------------------------\n        input_keys = list(filter(lambda x: \"input\" in x, example_dict.keys()))\n        target_keys = list(filter(lambda x: \"target\" in x, example_dict.keys()))\n        tensor_keys = input_keys + target_keys\n\n        # -------------------------------------------------------------\n        # Possibly transfer to Cuda\n        # -------------------------------------------------------------\n        if self._args.cuda:\n            for key, value in example_dict.items():\n                if key in tensor_keys:\n                    example_dict[key] = value.cuda(non_blocking=False)\n\n        # -------------------------------------------------------------\n        # Optionally perform augmentations\n        # -------------------------------------------------------------\n        if self._augmentation is not None:\n            with torch.no_grad():\n                example_dict = self._augmentation(example_dict)\n\n        # -------------------------------------------------------------\n        # Convert inputs/targets to variables that require gradients\n        # -------------------------------------------------------------\n        for key, tensor in example_dict.items():\n            if key in input_keys:\n                example_dict[key] = tensor.requires_grad_(True)\n            elif key in target_keys:\n                example_dict[key] = tensor.requires_grad_(False)\n\n        # -------------------------------------------------------------\n        # Extract batch size from first input\n        # -------------------------------------------------------------\n        batch_size = example_dict[\"input1\"].size()[0]\n\n        # -------------------------------------------------------------\n        # Reset gradients\n        # -------------------------------------------------------------\n        self._optimizer.zero_grad()\n\n        # -------------------------------------------------------------\n        # Run forward pass to get losses and outputs.\n        # -------------------------------------------------------------\n        loss_dict, output_dict = self._model_and_loss(example_dict)\n\n        # -------------------------------------------------------------\n        # Check total_loss for NaNs\n        # -------------------------------------------------------------\n        training_loss = loss_dict[self._args.training_key]\n        assert (not np.isnan(training_loss.item())), \"training_loss is NaN\"\n\n        # -------------------------------------------------------------\n        # Back propagation\n        # -------------------------------------------------------------\n        training_loss.backward()\n        self._optimizer.step()\n\n        # -------------------------------------------------------------\n        # Return success flag, loss and output dictionary\n        # -------------------------------------------------------------\n        return loss_dict, output_dict, batch_size\n\n    def run(self, offset=0):\n        # ---------------------------------------\n        # Tell model that we want to train\n        # ---------------------------------------\n        self._model_and_loss.train()\n\n        # ---------------------------------------\n        # Keep track of moving averages\n        # ---------------------------------------\n        moving_averages_dict = None\n\n        # ---------------------------------------\n        # Progress bar arguments\n        # ---------------------------------------\n        progressbar_args = {\n            \"iterable\": self._loader,\n            \"desc\": self._desc,\n            \"train\": True,\n            \"offset\": offset,\n            \"logging_on_update\": False,\n            \"logging_on_close\": True,\n            \"postfix\": True\n        }\n\n        # ---------------------------------------\n        # Perform training steps\n        # ---------------------------------------\n        with create_progressbar(**progressbar_args) as progress:\n            for example_dict in progress:\n                # perform step\n                loss_dict_per_step, output_dict, batch_size = self._step(example_dict)\n                # convert\n                loss_dict_per_step = tensor2float_dict(loss_dict_per_step)\n\n                # --------------------------------------------------------\n                # Possibly initialize moving averages\n                # --------------------------------------------------------\n                if moving_averages_dict is None:\n                    moving_averages_dict = {\n                        key: MovingAverage() for key in loss_dict_per_step.keys()\n                    }\n\n                # --------------------------------------------------------\n                # Add moving averages\n                # --------------------------------------------------------\n                for key, loss in loss_dict_per_step.items():\n                    moving_averages_dict[key].add_average(loss, addcount=batch_size)\n\n                # view statistics in progress bar\n                progress_stats = format_moving_averages_as_progress_dict(\n                    moving_averages_dict=moving_averages_dict,\n                    moving_averages_postfix=\"_ema\")\n\n                progress.set_postfix(progress_stats)\n\n        # -------------------------------------------------------------\n        # Return loss and output dictionary\n        # -------------------------------------------------------------\n        ema_loss_dict = { key: ma.mean() for key, ma in moving_averages_dict.items() }\n        return ema_loss_dict\n\n\nclass EvaluationEpoch:\n    def __init__(self,\n                 args,\n                 model_and_loss,\n                 loader,\n                 augmentation=None,\n                 add_progress_stats={},\n                 desc=\"Evaluation Epoch\"):\n        self._args = args\n        self._desc = desc\n        self._loader = loader\n        self._model_and_loss = model_and_loss\n        self._add_progress_stats = add_progress_stats\n        self._augmentation = augmentation\n        self._save_output = False\n        if self._args.save_result_img or self._args.save_result_flo or self._args.save_result_png:\n            self._save_output = True\n\n    def save_outputs(self, example_dict, output_dict):\n\n        # save occ\n        save_root_img = self._args.save + '/img/'\n        save_root_flo = self._args.save + '/flo/'\n\n        if self._args.save_result_bidirection:\n            flow_f = output_dict[\"flow\"].data.cpu().numpy()\n            flow_b = output_dict[\"flow_b\"].data.cpu().numpy()\n            b_size = output_dict[\"flow\"].data.size(0)\n        else:\n            flow_f = output_dict[\"flow\"].data.cpu().numpy()\n            b_size = output_dict[\"flow\"].data.size(0)\n\n        if self._args.save_result_occ:\n            if self._args.save_result_bidirection:\n                output_occ = np.round(\n                    nn.Sigmoid()(output_dict[\"occ\"]).expand(-1, 3, -1, -1).data.cpu().numpy().transpose(\n                        [0, 2, 3, 1])) * 255\n                output_occ_b = np.round(\n                    nn.Sigmoid()(output_dict[\"occ_b\"]).expand(-1, 3, -1, -1).data.cpu().numpy().transpose(\n                        [0, 2, 3, 1])) * 255\n            else:\n                output_occ = np.round(\n                    nn.Sigmoid()(output_dict[\"occ\"]).expand(-1, 3, -1, -1).data.cpu().numpy().transpose(\n                        [0, 2, 3, 1])) * 255\n\n        # file names\n        file_names_img = []\n        file_names_flo = []\n        for ii in range(0, b_size):\n            if \"basedir\" in  example_dict.keys():\n                file_name_img = save_root_img + example_dict[\"basedir\"][ii] + '/' + str(example_dict[\"basename\"][ii])\n                file_name_flo = save_root_flo + example_dict[\"basedir\"][ii] + '/' + str(example_dict[\"basename\"][ii])\n                file_names_img.append(file_name_img)\n                file_names_flo.append(file_name_flo)\n            else:\n                file_name_img = save_root_img + '/' + str(example_dict[\"basename\"][ii])\n                file_name_flo = save_root_flo + '/' + str(example_dict[\"basename\"][ii])\n                file_names_img.append(file_name_img)\n                file_names_flo.append(file_name_flo)\n\n            directory_img = os.path.dirname(file_name_img)\n            if not os.path.exists(directory_img):\n                os.makedirs(directory_img)\n            directory_flo = os.path.dirname(file_name_flo)\n            if not os.path.exists(directory_flo):\n                os.makedirs(directory_flo)\n\n        if self._args.save_result_img:\n            for ii in range(0, b_size):\n                if self._args.save_result_occ:\n                    file_name_occ = file_names_img[ii] + '_occ.png'\n                    scipy.misc.imsave(file_name_occ, output_occ[ii])\n\n                    if self._args.save_result_bidirection:\n                        scipy.misc.imsave(file_names_img[ii] + '_occ_b.png', output_occ_b[ii])\n\n                # flow vis\n                flow_f_rgb = flow_to_png_middlebury(flow_f[ii, ...])\n                file_name_flo_vis = file_names_img[ii] + '_flow.png'\n                scipy.misc.imsave(file_name_flo_vis, flow_f_rgb)\n\n                if self._args.save_result_bidirection:\n                    flow_b_rgb = flow_to_png_middlebury(flow_b[ii, ...])\n                    file_name_flo_vis = file_names_img[ii] + '_flow_b.png'\n                    scipy.misc.imsave(file_name_flo_vis, flow_b_rgb)\n\n        if self._args.save_result_flo or self._args.save_result_png:\n            for ii in range(0, b_size):\n                if self._args.save_result_flo:\n                    file_name = file_names_flo[ii] + '.flo'\n                    write_flow(file_name, flow_f[ii, ...].swapaxes(0, 1).swapaxes(1, 2))\n                if self._args.save_result_png:\n                    file_name = file_names_flo[ii] + '.png'\n                    write_flow_png(file_name, flow_f[ii, ...].swapaxes(0, 1).swapaxes(1, 2))\n\n\n    def _step(self, example_dict):\n        # -------------------------------------------------------------\n        # Get input and target tensor keys\n        # -------------------------------------------------------------\n        input_keys = list(filter(lambda x: \"input\" in x, example_dict.keys()))\n        target_keys = list(filter(lambda x: \"target\" in x, example_dict.keys()))\n        tensor_keys = input_keys + target_keys\n\n        # -------------------------------------------------------------\n        # Possibly transfer to Cuda\n        # -------------------------------------------------------------\n        if self._args.cuda:\n            for key, value in example_dict.items():\n                if key in tensor_keys:\n                    example_dict[key] = value.cuda(non_blocking=False)\n\n        # -------------------------------------------------------------\n        # Optionally perform augmentations\n        # -------------------------------------------------------------\n        if self._augmentation is not None:\n            example_dict = self._augmentation(example_dict)\n\n        # -------------------------------------------------------------\n        # Extract batch size from first input\n        # -------------------------------------------------------------\n        batch_size = example_dict[\"input1\"].size()[0]\n\n        # -------------------------------------------------------------\n        # Run forward pass to get losses and outputs.\n        # -------------------------------------------------------------\n        loss_dict, output_dict = self._model_and_loss(example_dict)\n\n        # -------------------------------------------------------------\n        # Return loss and output dictionary\n        # -------------------------------------------------------------\n        return loss_dict, output_dict, batch_size\n\n    def run(self, offset=0):\n\n        with torch.no_grad():\n\n            # ---------------------------------------\n            # Tell model that we want to evaluate\n            # ---------------------------------------\n            self._model_and_loss.eval()\n\n            # ---------------------------------------\n            # Keep track of moving averages\n            # ---------------------------------------\n            moving_averages_dict = None\n\n            # ---------------------------------------\n            # Progress bar arguments\n            # ---------------------------------------\n            progressbar_args = {\n                \"iterable\": self._loader,\n                \"desc\": self._desc,\n                \"train\": False,\n                \"offset\": offset,\n                \"logging_on_update\": False,\n                \"logging_on_close\": True,\n                \"postfix\": True\n            }\n\n            # ---------------------------------------\n            # Perform evaluation steps\n            # ---------------------------------------\n            with create_progressbar(**progressbar_args) as progress:\n                for example_dict in progress:\n\n                    # ---------------------------------------\n                    # Perform forward evaluation step\n                    # ---------------------------------------\n                    loss_dict_per_step, output_dict, batch_size = self._step(example_dict)\n\n                    # --------------------------------------------------------\n                    # Save results\n                    # --------------------------------------------------------\n                    if self._save_output:\n                        self.save_outputs(example_dict, output_dict)\n\n                    # ---------------------------------------\n                    # Convert loss dictionary to float\n                    # ---------------------------------------\n                    loss_dict_per_step = tensor2float_dict(loss_dict_per_step)\n\n                    # --------------------------------------------------------\n                    # Possibly initialize moving averages\n                    # --------------------------------------------------------\n                    if moving_averages_dict is None:\n                        moving_averages_dict = {\n                            key: MovingAverage() for key in loss_dict_per_step.keys()\n                        }\n\n                    # --------------------------------------------------------\n                    # Add moving averages\n                    # --------------------------------------------------------\n                    for key, loss in loss_dict_per_step.items():\n                        moving_averages_dict[key].add_average(loss, addcount=batch_size)\n\n                    # view statistics in progress bar\n                    progress_stats = format_moving_averages_as_progress_dict(\n                        moving_averages_dict=moving_averages_dict,\n                        moving_averages_postfix=\"_avg\")\n\n                    progress.set_postfix(progress_stats)\n\n            # -------------------------------------------------------------\n            # Record average losses\n            # -------------------------------------------------------------\n            avg_loss_dict = { key: ma.mean() for key, ma in moving_averages_dict.items() }\n\n            # -------------------------------------------------------------\n            # Return average losses and output dictionary\n            # -------------------------------------------------------------\n            return avg_loss_dict\n\n\ndef exec_runtime(args,\n                 checkpoint_saver,\n                 model_and_loss,\n                 optimizer,\n                 lr_scheduler,\n                 train_loader,\n                 validation_loader,\n                 inference_loader,\n                 training_augmentation,\n                 validation_augmentation):\n\n    # ----------------------------------------------------------------------------------------------\n    # Validation schedulers are a bit special:\n    # They want to be called with a validation loss..\n    # ----------------------------------------------------------------------------------------------\n    validation_scheduler = (lr_scheduler is not None and args.lr_scheduler == \"ReduceLROnPlateau\")\n\n    # --------------------------------------------------------\n    # Log some runtime info\n    # --------------------------------------------------------\n    with logger.LoggingBlock(\"Runtime\", emph=True):\n        logging.info(\"start_epoch: %i\" % args.start_epoch)\n        logging.info(\"total_epochs: %i\" % args.total_epochs)\n\n    # ---------------------------------------\n    # Total progress bar arguments\n    # ---------------------------------------\n    progressbar_args = {\n        \"desc\": \"Progress\",\n        \"initial\": args.start_epoch - 1,\n        \"invert_iterations\": True,\n        \"iterable\": range(1, args.total_epochs + 1),\n        \"logging_on_close\": True,\n        \"logging_on_update\": True,\n        \"postfix\": False,\n        \"unit\": \"ep\"\n    }\n\n    # --------------------------------------------------------\n    # Total progress bar\n    # --------------------------------------------------------\n    print(''), logging.logbook('')\n    total_progress = create_progressbar(**progressbar_args)\n    print(\"\\n\")\n\n    # --------------------------------------------------------\n    # Remember validation loss\n    # --------------------------------------------------------\n    best_validation_loss = float(\"inf\") if args.validation_key_minimize else -float(\"inf\")\n    store_as_best = False\n\n    for epoch in range(args.start_epoch, args.total_epochs + 1):\n        with logger.LoggingBlock(\"Epoch %i/%i\" % (epoch, args.total_epochs), emph=True):\n\n            # Always report learning rate\n            if lr_scheduler is not None:\n                logging.info(\"lr: %s\" % format_learning_rate(lr_scheduler.get_last_lr()))\n\n            # -------------------------------------------\n            # Create and run a training epoch\n            # -------------------------------------------\n            if train_loader is not None:\n                avg_loss_dict = TrainingEpoch(\n                    args,\n                    desc=\"   Train\",\n                    model_and_loss=model_and_loss,\n                    optimizer=optimizer,\n                    loader=train_loader,\n                    augmentation=training_augmentation).run()\n\n            # -------------------------------------------\n            # Create and run a validation epoch\n            # -------------------------------------------\n            if validation_loader is not None:\n\n                # ---------------------------------------------------\n                # Construct holistic recorder for epoch\n                # ---------------------------------------------------\n                avg_loss_dict = EvaluationEpoch(\n                    args,\n                    desc=\"Validate\",\n                    model_and_loss=model_and_loss,\n                    loader=validation_loader,\n                    augmentation=validation_augmentation).run()\n\n                # ----------------------------------------------------------------\n                # Evaluate whether this is the best validation_loss\n                # ----------------------------------------------------------------\n                validation_loss = avg_loss_dict[args.validation_key]\n                if args.validation_key_minimize:\n                    store_as_best = validation_loss < best_validation_loss\n                else:\n                    store_as_best = validation_loss > best_validation_loss\n                if store_as_best:\n                    best_validation_loss = validation_loss\n\n            # Update standard learning scheduler\n            if lr_scheduler is not None:\n                lr_scheduler.step()\n\n            # ----------------------------------------------------------------\n            # Also show best loss on total_progress\n            # ----------------------------------------------------------------\n            total_progress_stats = {\n                \"best_\" + args.validation_key + \"_avg\": \"%1.4f\" % best_validation_loss\n            }\n            total_progress.set_postfix(total_progress_stats)\n\n            # ----------------------------------------------------------------\n            # Bump total progress\n            # ----------------------------------------------------------------\n            total_progress.update()\n            print('')\n\n            # ----------------------------------------------------------------\n            # Store checkpoint\n            # ----------------------------------------------------------------\n            if checkpoint_saver is not None:\n                checkpoint_saver.save_latest(\n                    directory=args.save,\n                    model_and_loss=model_and_loss,\n                    stats_dict=dict(avg_loss_dict, epoch=epoch),\n                    store_as_best=store_as_best)\n\n            # ----------------------------------------------------------------\n            # Vertical space between epochs\n            # ----------------------------------------------------------------\n            print(''), logging.logbook('')\n            \n    # ----------------------------------------------------------------\n    # Finish\n    # ----------------------------------------------------------------\n    total_progress.close()\n    logging.info(\"Finished.\")\n"
  },
  {
    "path": "scripts/IRR-FlowNet_flyingChairsOcc.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nFLYINGCHAIRS_OCC_HOME=(YOUR PATH)/flow_occ_v5/data\n\n# model and checkpoint\nMODEL=IRR_FlowNet\nEVAL_LOSS=MultiScaleEPE_FlowNet_IRR_Bi_Occ_upsample\nCHECKPOINT=None\nSIZE_OF_BATCH=4\n\n# save path\nTIME=$(date +\"%Y%m%d-%H%M%S\")\nSAVE_PATH=\"$EXPERIMENTS_HOME/$MODEL-$TIME\"\n\n# training configuration\npython ../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--lr_scheduler=MultiStepLR \\\n--lr_scheduler_gamma=0.5 \\\n--lr_scheduler_milestones=\"[54, 72, 90]\" \\\n--model=$MODEL \\\n--num_workers=4 \\\n--num_iters=2 \\\n--optimizer=Adam \\\n--optimizer_lr=1e-4 \\\n--optimizer_weight_decay=4e-4 \\\n--save=$SAVE_PATH \\\n--total_epochs=108 \\\n--training_augmentation=RandomAffineFlowOcc \\\n--training_dataset=FlyingChairsOccTrain \\\n--training_dataset_photometric_augmentations=True \\\n--training_dataset_root=$FLYINGCHAIRS_OCC_HOME \\\n--training_key=total_loss \\\n--training_loss=$EVAL_LOSS \\\n--validation_dataset=FlyingChairsOccValid  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$FLYINGCHAIRS_OCC_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/IRR-PWC_flyingChairsOcc.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nFLYINGCHAIRS_OCC_HOME=(YOUR PATH)/flow_occ_v5/data\n\n# model and checkpoint\nMODEL=IRR_PWC\nEVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample\nCHECKPOINT=None\nSIZE_OF_BATCH=4\n\n# save path\nTIME=$(date +\"%Y%m%d-%H%M%S\")\nSAVE_PATH=\"$EXPERIMENTS_HOME/$MODEL-$TIME\"\n\n# training configuration\npython ../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--lr_scheduler=MultiStepLR \\\n--lr_scheduler_gamma=0.5 \\\n--lr_scheduler_milestones=\"[54, 72, 90]\" \\\n--model=$MODEL \\\n--num_workers=4 \\\n--optimizer=Adam \\\n--optimizer_lr=1e-4 \\\n--optimizer_weight_decay=4e-4 \\\n--save=$SAVE_PATH \\\n--total_epochs=108 \\\n--training_augmentation=RandomAffineFlowOcc \\\n--training_dataset=FlyingChairsOccTrain \\\n--training_dataset_photometric_augmentations=True \\\n--training_dataset_root=$FLYINGCHAIRS_OCC_HOME \\\n--training_key=total_loss \\\n--training_loss=$EVAL_LOSS \\\n--validation_dataset=FlyingChairsOccValid  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$FLYINGCHAIRS_OCC_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/IRR-PWC_kitti_train.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nKITTI_HOME=(YOUR PATH)/KITTI_flow/\n\n# model and checkpoint\nMODEL=IRR_PWC\nEVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample_KITTI\nCHECKPOINT=\"saved_check_point/IRR-PWC_things3d/checkpoint_latest.ckpt\"\nSIZE_OF_BATCH=4\n\n# save path\nTIME=$(date +\"%Y%m%d-%H%M%S\")\nSAVE_PATH=\"$EXPERIMENTS_HOME/$MODEL-$TIME\"\n\n# training configuration\npython ../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=1 \\\n--checkpoint=$CHECKPOINT \\\n--lr_scheduler=MultiStepLR \\\n--lr_scheduler_gamma=0.5 \\\n--lr_scheduler_milestones=\"[730, 984, 1238, 1365, 1397, 1429, 1556, 1683, 1810, 1937]\" \\\n--model=$MODEL \\\n--num_workers=4 \\\n--optimizer=Adam \\\n--optimizer_lr=3e-05 \\\n--optimizer_weight_decay=4e-4 \\\n--save=$SAVE_PATH \\\n--start_epoch=160 \\\n--total_epochs=2064 \\\n--training_augmentation=RandomAffineFlowOccKITTI \\\n--training_augmentation_crop=\"[320,896]\" \\\n--training_dataset=KittiCombTrain \\\n--training_dataset_photometric_augmentations=True \\\n--training_dataset_root=$KITTI_HOME \\\n--training_dataset_preprocessing_crop=True \\\n--training_key=total_loss \\\n--training_loss=$EVAL_LOSS \\\n--validation_dataset=KittiCombVal  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$KITTI_HOME \\\n--validation_dataset_preprocessing_crop=False \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS\n"
  },
  {
    "path": "scripts/IRR-PWC_kitti_train_full.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nKITTI_HOME=(YOUR PATH)/KITTI_flow/\n\n# model and checkpoint\nMODEL=IRR_PWC\nEVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample_KITTI\nCHECKPOINT=\"saved_check_point/IRR-PWC_things3d/checkpoint_latest.ckpt\"\nSIZE_OF_BATCH=4\n\n# save path\nTIME=$(date +\"%Y%m%d-%H%M%S\")\nSAVE_PATH=\"$EXPERIMENTS_HOME/$MODEL-$TIME\"\n\n# training configuration\npython ../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=1 \\\n--checkpoint=$CHECKPOINT \\\n--lr_scheduler=MultiStepLR \\\n--lr_scheduler_gamma=0.5 \\\n--lr_scheduler_milestones=\"[616, 819, 1022, 1123, 1149, 1174, 1276, 1377, 1479, 1580]\" \\\n--model=$MODEL \\\n--num_workers=4 \\\n--optimizer=Adam \\\n--optimizer_lr=3e-05 \\\n--optimizer_weight_decay=4e-4 \\\n--save=$SAVE_PATH \\\n--start_epoch=160 \\\n--total_epochs=710 \\\n--training_augmentation=RandomAffineFlowOccKITTI \\\n--training_augmentation_crop=\"[320,896]\" \\\n--training_dataset=KittiCombFull \\\n--training_dataset_photometric_augmentations=True \\\n--training_dataset_root=$KITTI_HOME \\\n--training_dataset_preprocessing_crop=True \\\n--training_key=total_loss \\\n--training_loss=$EVAL_LOSS \\\n--validation_dataset=KittiCombVal  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$KITTI_HOME \\\n--validation_dataset_preprocessing_crop=False \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/IRR-PWC_sintel_train.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nSINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/\n\n# model and checkpoint\nMODEL=IRR_PWC\nEVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample_Sintel\nCHECKPOINT=\"saved_check_point/IRR-PWC_things3d/checkpoint_latest.ckpt\"\nSIZE_OF_BATCH=4\n\n# save path\nTIME=$(date +\"%Y%m%d-%H%M%S\")\nSAVE_PATH=\"$EXPERIMENTS_HOME/$MODEL-$TIME\"\n\n# training configuration\npython ../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--lr_scheduler=MultiStepLR \\\n--lr_scheduler_gamma=0.5 \\\n--lr_scheduler_milestones=\"[258, 302, 346, 368, 374, 379, 401, 423, 445, 467]\" \\\n--model=$MODEL \\\n--num_workers=4 \\\n--optimizer=Adam \\\n--optimizer_lr=1.5e-05 \\\n--optimizer_weight_decay=4e-4 \\\n--save=$SAVE_PATH \\\n--start_epoch=160 \\\n--total_epochs=489 \\\n--training_augmentation=RandomAffineFlowOccSintel \\\n--training_augmentation_crop=\"[384,768]\" \\\n--training_dataset=SintelTrainingCombTrain \\\n--training_dataset_photometric_augmentations=True \\\n--training_dataset_root=$SINTEL_HOME \\\n--training_key=total_loss \\\n--training_loss=$EVAL_LOSS \\\n--validation_dataset=SintelTrainingCombValid  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$SINTEL_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS\n\n# training configuration\npython ../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--lr_scheduler=MultiStepLR \\\n--lr_scheduler_gamma=0.5 \\\n--lr_scheduler_milestones=\"[687, 775, 863, 908, 919, 930, 974, 1018, 1062, 1106]\" \\\n--model=$MODEL \\\n--num_workers=4 \\\n--optimizer=Adam \\\n--optimizer_lr=1e-05 \\\n--optimizer_weight_decay=4e-4 \\\n--save=$SAVE_PATH \\\n--start_epoch=490 \\\n--total_epochs=1150 \\\n--training_augmentation=RandomAffineFlowOccSintel \\\n--training_augmentation_crop=\"[384,768]\" \\\n--training_dataset=SintelTrainingFinalTrain \\\n--training_dataset_photometric_augmentations=True \\\n--training_dataset_root=$SINTEL_HOME \\\n--training_key=total_loss \\\n--training_loss=$EVAL_LOSS \\\n--validation_dataset=SintelTrainingFinalValid  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$SINTEL_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/IRR-PWC_sintel_train_full.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nSINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/\n\n# model and checkpoint\nMODEL=IRR_PWC\nEVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample_Sintel\nCHECKPOINT=\"saved_check_point/IRR-PWC_things3d/checkpoint_latest.ckpt\"\nSIZE_OF_BATCH=4\n\n# save path\nTIME=$(date +\"%Y%m%d-%H%M%S\")\nSAVE_PATH=\"$EXPERIMENTS_HOME/$MODEL-$TIME\"\n\n# training configuration\npython ../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--lr_scheduler=MultiStepLR \\\n--lr_scheduler_gamma=0.5 \\\n--lr_scheduler_milestones=\"[245, 284, 322, 342, 346, 351, 370, 390, 409, 428]\" \\\n--model=$MODEL \\\n--num_workers=4 \\\n--optimizer=Adam \\\n--optimizer_lr=1.5e-05 \\\n--optimizer_weight_decay=4e-4 \\\n--save=$SAVE_PATH \\\n--start_epoch=160 \\\n--total_epochs=447 \\\n--training_augmentation=RandomAffineFlowOccSintel \\\n--training_augmentation_crop=\"[384,768]\" \\\n--training_dataset=SintelTrainingCombFull \\\n--training_dataset_photometric_augmentations=True \\\n--training_dataset_root=$SINTEL_HOME \\\n--training_key=total_loss \\\n--training_loss=$EVAL_LOSS \\\n--validation_dataset=SintelTrainingCombValid  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$SINTEL_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS\n\n# training configuration\npython ../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--lr_scheduler=MultiStepLR \\\n--lr_scheduler_gamma=0.5 \\\n--lr_scheduler_milestones=\"[620, 697, 774, 812, 822, 831, 870, 908, 947, 985]\" \\\n--model=$MODEL \\\n--num_workers=4 \\\n--optimizer=Adam \\\n--optimizer_lr=1e-05 \\\n--optimizer_weight_decay=4e-4 \\\n--save=$SAVE_PATH \\\n--start_epoch=448 \\\n--total_epochs=591 \\\n--training_augmentation=RandomAffineFlowOccSintel \\\n--training_augmentation_crop=\"[384,768]\" \\\n--training_dataset=SintelTrainingFinalFull \\\n--training_dataset_photometric_augmentations=True \\\n--training_dataset_root=$SINTEL_HOME \\\n--training_key=total_loss \\\n--training_loss=$EVAL_LOSS \\\n--validation_dataset=SintelTrainingFinalValid  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$SINTEL_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/IRR-PWC_things3d.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nFLYINGTHINGS_HOME=(YOUR PATH)/things3d/FlyingThings3D_subset/\nSINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/\n\n# model and checkpoint\nMODEL=IRR_PWC\nEVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample\nCHECKPOINT=\"saved_check_point/IRR-PWC_flyingchairsOcc/checkpoint_latest.ckpt\"\nSIZE_OF_BATCH=4\n\n# save path\nTIME=$(date +\"%Y%m%d-%H%M%S\")\nSAVE_PATH=\"$EXPERIMENTS_HOME/$MODEL-$TIME\"\n\n# training configuration\npython ../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--lr_scheduler=MultiStepLR \\\n--lr_scheduler_gamma=0.5 \\\n--lr_scheduler_milestones=\"[128, 139, 149]\" \\\n--model=$MODEL \\\n--num_workers=4 \\\n--optimizer=Adam \\\n--optimizer_lr=1e-5 \\\n--optimizer_weight_decay=4e-4 \\\n--save=$SAVE_PATH \\\n--start_epoch=109 \\\n--total_epochs=159 \\\n--training_augmentation=RandomAffineFlowOcc \\\n--training_augmentation_crop=\"[384,768]\" \\\n--training_dataset=FlyingThings3dCleanTrain \\\n--training_dataset_photometric_augmentations=True \\\n--training_dataset_root=$FLYINGTHINGS_HOME \\\n--training_key=total_loss \\\n--training_loss=$EVAL_LOSS \\\n--validation_dataset=SintelTrainingCleanFull  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$SINTEL_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/flownet1s.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nFLYINGCHAIRS_OCC_HOME=(YOUR PATH)/flow_occ_v5/data\n\n# model and checkpoint\nMODEL=FlowNet1S\nEVAL_LOSS=MultiScaleEPE_FlowNet\nCHECKPOINT=None\nSIZE_OF_BATCH=8\n\n# save path\nTIME=$(date +\"%Y%m%d-%H%M%S\")\nSAVE_PATH=\"$EXPERIMENTS_HOME/$MODEL-$TIME\"\n\n# training configuration\npython ../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--lr_scheduler=MultiStepLR \\\n--lr_scheduler_gamma=0.5 \\\n--lr_scheduler_milestones=\"[108, 144, 180]\" \\\n--model=$MODEL \\\n--num_workers=4 \\\n--num_iters=1 \\\n--optimizer=Adam \\\n--optimizer_lr=1e-4 \\\n--optimizer_weight_decay=4e-4 \\\n--save=$SAVE_PATH \\\n--total_epochs=216 \\\n--training_augmentation=RandomAffineFlowOcc \\\n--training_dataset=FlyingChairsOccTrain \\\n--training_dataset_photometric_augmentations=True \\\n--training_dataset_root=$FLYINGCHAIRS_OCC_HOME \\\n--training_key=total_loss \\\n--training_loss=$EVAL_LOSS \\\n--validation_dataset=FlyingChairsOccValid  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$FLYINGCHAIRS_OCC_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/flownet1s_irr1.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nFLYINGCHAIRS_OCC_HOME=(YOUR PATH)/flow_occ_v5/data\n\n# model and checkpoint\nMODEL=FlowNet1S_irr\nEVAL_LOSS=MultiScaleEPE_FlowNet_IRR\nCHECKPOINT=None\nSIZE_OF_BATCH=8\n\n# save path\nTIME=$(date +\"%Y%m%d-%H%M%S\")\nSAVE_PATH=\"$EXPERIMENTS_HOME/$MODEL-$TIME\"\n\n# training configuration\npython ../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--lr_scheduler=MultiStepLR \\\n--lr_scheduler_gamma=0.5 \\\n--lr_scheduler_milestones=\"[108, 144, 180]\" \\\n--model=$MODEL \\\n--num_workers=4 \\\n--num_iters=1 \\\n--optimizer=Adam \\\n--optimizer_lr=1e-4 \\\n--optimizer_weight_decay=4e-4 \\\n--save=$SAVE_PATH \\\n--total_epochs=216 \\\n--training_augmentation=RandomAffineFlowOcc \\\n--training_dataset=FlyingChairsOccTrain \\\n--training_dataset_photometric_augmentations=True \\\n--training_dataset_root=$FLYINGCHAIRS_OCC_HOME \\\n--training_key=total_loss \\\n--training_loss=$EVAL_LOSS \\\n--validation_dataset=FlyingChairsOccValid  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$FLYINGCHAIRS_OCC_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/flownet1s_irr2.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nFLYINGCHAIRS_OCC_HOME=(YOUR PATH)/flow_occ_v5/data\n\n# model and checkpoint\nMODEL=FlowNet1S_irr\nEVAL_LOSS=MultiScaleEPE_FlowNet_IRR\nCHECKPOINT=None\nSIZE_OF_BATCH=4\n\n# save path\nTIME=$(date +\"%Y%m%d-%H%M%S\")\nSAVE_PATH=\"$EXPERIMENTS_HOME/$MODEL-$TIME\"\n\n# training configuration\npython ../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--lr_scheduler=MultiStepLR \\\n--lr_scheduler_gamma=0.5 \\\n--lr_scheduler_milestones=\"[54, 72, 90]\" \\\n--model=$MODEL \\\n--num_workers=4 \\\n--num_iters=2 \\\n--optimizer=Adam \\\n--optimizer_lr=1e-4 \\\n--optimizer_weight_decay=4e-4 \\\n--save=$SAVE_PATH \\\n--total_epochs=108 \\\n--training_augmentation=RandomAffineFlowOcc \\\n--training_dataset=FlyingChairsOccTrain \\\n--training_dataset_photometric_augmentations=True \\\n--training_dataset_root=$FLYINGCHAIRS_OCC_HOME \\\n--training_key=total_loss \\\n--training_loss=$EVAL_LOSS \\\n--validation_dataset=FlyingChairsOccValid  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$FLYINGCHAIRS_OCC_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/pwcnet.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nFLYINGCHAIRS_OCC_HOME=(YOUR PATH)/flow_occ_v5/data\n\n# model and checkpoint\nMODEL=PWCNet\nEVAL_LOSS=MultiScaleEPE_PWC\nCHECKPOINT=None\nSIZE_OF_BATCH=8\n\n# save path\nTIME=$(date +\"%Y%m%d-%H%M%S\")\nSAVE_PATH=\"$EXPERIMENTS_HOME/$MODEL-$TIME\"\n\n# training configuration\npython ../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--lr_scheduler=MultiStepLR \\\n--lr_scheduler_gamma=0.5 \\\n--lr_scheduler_milestones=\"[108, 144, 180]\" \\\n--model=$MODEL \\\n--num_workers=4 \\\n--optimizer=Adam \\\n--optimizer_lr=1e-4 \\\n--optimizer_weight_decay=4e-4 \\\n--save=$SAVE_PATH \\\n--total_epochs=216 \\\n--training_augmentation=RandomAffineFlowOcc \\\n--training_dataset=FlyingChairsOccTrain \\\n--training_dataset_photometric_augmentations=True \\\n--training_dataset_root=$FLYINGCHAIRS_OCC_HOME \\\n--training_key=total_loss \\\n--training_loss=$EVAL_LOSS \\\n--validation_dataset=FlyingChairsOccValid  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$FLYINGCHAIRS_OCC_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/pwcnet_irr.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"experiments\"\n\n# datasets\nFLYINGCHAIRS_OCC_HOME=(YOUR PATH)/flow_occ_v5/data\n\n# model and checkpoint\nMODEL=PWCNet_irr\nEVAL_LOSS=MultiScaleEPE_PWC\nCHECKPOINT=None\nSIZE_OF_BATCH=8\n\n# save path\nTIME=$(date +\"%Y%m%d-%H%M%S\")\nSAVE_PATH=\"$EXPERIMENTS_HOME/$MODEL-$TIME\"\n\n# training configuration\npython ../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--lr_scheduler=MultiStepLR \\\n--lr_scheduler_gamma=0.5 \\\n--lr_scheduler_milestones=\"[108, 144, 180]\" \\\n--model=$MODEL \\\n--num_workers=4 \\\n--optimizer=Adam \\\n--optimizer_lr=1e-4 \\\n--optimizer_weight_decay=4e-4 \\\n--save=$SAVE_PATH \\\n--total_epochs=216 \\\n--training_augmentation=RandomAffineFlowOcc \\\n--training_dataset=FlyingChairsOccTrain \\\n--training_dataset_photometric_augmentations=True \\\n--training_dataset_root=$FLYINGCHAIRS_OCC_HOME \\\n--training_key=total_loss \\\n--training_loss=$EVAL_LOSS \\\n--validation_dataset=FlyingChairsOccValid  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$FLYINGCHAIRS_OCC_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/validation/IRR-FlowNet_flyingChairs.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/flownet\"\n\n# datasets\nSINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/\n\n# model and checkpoint\nMODEL=IRR_FlowNet\nCHECKPOINT=\"$EXPERIMENTS_HOME/IRR-FlowNet_flyingChairs/checkpoint_best.ckpt\"\nEVAL_LOSS=MultiScaleEPE_FlowNet_IRR_Bi_Occ_upsample\n\nSIZE_OF_BATCH=4\n\n# validate clean configuration\nSAVE_PATH=\"$EXPERIMENTS_HOME/eval_temp/$MODEL\"\npython ../../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--evaluation=True \\\n--model=$MODEL \\\n--num_workers=4 \\\n--num_iters=2 \\\n--save=$SAVE_PATH \\\n--validation_dataset=SintelTrainingCleanFull  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$SINTEL_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/validation/IRR-PWC_flyingChairs.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/pwcnet\"\n\n# datasets\nSINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/\n\n# model and checkpoint\nMODEL=IRR_PWC\nCHECKPOINT=\"$EXPERIMENTS_HOME/IRR-PWC_flyingchairsOcc/checkpoint_best.ckpt\"\nEVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample\n\nSIZE_OF_BATCH=4\n\n# validate clean configuration\nSAVE_PATH=\"$EXPERIMENTS_HOME/eval_temp/$MODEL\"\npython ../../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--evaluation=True \\\n--model=$MODEL \\\n--num_workers=4 \\\n--save=$SAVE_PATH \\\n--validation_dataset=SintelTrainingCleanFull  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$SINTEL_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/validation/IRR-PWC_kitti.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/pwcnet\"\n\n# datasets\nKITTI_HOME=(YOUR PATH)/KITTI_flow/\n\n# model and checkpoint\nMODEL=IRR_PWC\nCHECKPOINT=\"$EXPERIMENTS_HOME/IRR-PWC_kitti/checkpoint_latest.ckpt\"\nEVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample_KITTI\n\nSIZE_OF_BATCH=1\n\n# validate clean configuration\nSAVE_PATH=\"$EXPERIMENTS_HOME/eval_temp/$MODEL\"\npython ../../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--evaluation=True \\\n--model=$MODEL \\\n--num_workers=4 \\\n--save=$SAVE_PATH \\\n--validation_dataset=KittiCombVal  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$KITTI_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/validation/IRR-PWC_sintel.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/pwcnet\"\n\n# datasets\nSINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/\n\n# model and checkpoint\nMODEL=IRR_PWC\nCHECKPOINT=\"$EXPERIMENTS_HOME/IRR-PWC_sintel/checkpoint_latest.ckpt\"\nEVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample_Sintel\n\nSIZE_OF_BATCH=4\n\n# validate clean configuration\nSAVE_PATH=\"$EXPERIMENTS_HOME/eval_temp/$MODEL\"\npython ../../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--evaluation=True \\\n--model=$MODEL \\\n--num_workers=4 \\\n--save=$SAVE_PATH \\\n--validation_dataset=SintelTrainingFinalValid  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$SINTEL_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/validation/IRR-PWC_things3d.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/pwcnet\"\n\n# datasets\nSINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/\n\n# model and checkpoint\nMODEL=IRR_PWC\nCHECKPOINT=\"$EXPERIMENTS_HOME/IRR-PWC_things3d/checkpoint_latest.ckpt\"\nEVAL_LOSS=MultiScaleEPE_PWC_Bi_Occ_upsample\n\nSIZE_OF_BATCH=4\n\n# validate clean configuration\nSAVE_PATH=\"$EXPERIMENTS_HOME/eval_temp/$MODEL\"\npython ../../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--evaluation=True \\\n--model=$MODEL \\\n--num_workers=4 \\\n--save=$SAVE_PATH \\\n--validation_dataset=SintelTrainingCleanFull  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$SINTEL_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/validation/flownet1s.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/flownet\"\n\n# datasets\nSINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/\n\n# model and checkpoint\nMODEL=FlowNet1S\nCHECKPOINT=\"$EXPERIMENTS_HOME/FlowNet1S/checkpoint_best.ckpt\"\nEVAL_LOSS=MultiScaleEPE_FlowNet\n\nSIZE_OF_BATCH=4\n\n# validate clean configuration\nSAVE_PATH=\"$EXPERIMENTS_HOME/eval_temp/$MODEL\"\npython ../../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--evaluation=True \\\n--model=$MODEL \\\n--num_workers=4 \\\n--num_iters=1 \\\n--save=$SAVE_PATH \\\n--validation_dataset=SintelTrainingCleanFull  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$SINTEL_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/validation/flownet1s_irr1.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/flownet\"\n\n# datasets\nSINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/\n\n# model and checkpoint\nMODEL=FlowNet1S_irr\nCHECKPOINT=\"$EXPERIMENTS_HOME/FlowNet1S-irr1/checkpoint_best.ckpt\"\nEVAL_LOSS=MultiScaleEPE_FlowNet_IRR\n\nSIZE_OF_BATCH=4\n\n# validate clean configuration\nSAVE_PATH=\"$EXPERIMENTS_HOME/eval_temp/$MODEL\"\npython ../../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--evaluation=True \\\n--model=$MODEL \\\n--num_workers=4 \\\n--num_iters=1 \\\n--save=$SAVE_PATH \\\n--validation_dataset=SintelTrainingCleanFull  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$SINTEL_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/validation/flownet1s_irr2.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/flownet\"\n\n# datasets\nSINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/\n\n# model and checkpoint\nMODEL=FlowNet1S_irr\nCHECKPOINT=\"$EXPERIMENTS_HOME/FlowNet1S-irr2/checkpoint_best.ckpt\"\nEVAL_LOSS=MultiScaleEPE_FlowNet_IRR\n\nSIZE_OF_BATCH=4\n\n# validate clean configuration\nSAVE_PATH=\"$EXPERIMENTS_HOME/eval_temp/$MODEL\"\npython ../../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--evaluation=True \\\n--model=$MODEL \\\n--num_workers=4 \\\n--num_iters=2 \\\n--save=$SAVE_PATH \\\n--validation_dataset=SintelTrainingCleanFull  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$SINTEL_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/validation/pwcnet.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/pwcnet\"\n\n# datasets\nSINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/\n\n# model and checkpoint\nMODEL=PWCNet\nCHECKPOINT=\"$EXPERIMENTS_HOME/PWCNet/checkpoint_best.ckpt\"\nEVAL_LOSS=MultiScaleEPE_PWC\n\nSIZE_OF_BATCH=1\n\n# validate clean configuration\nSAVE_PATH=\"$EXPERIMENTS_HOME/eval_temp/$MODEL\"\npython ../../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--evaluation=True \\\n--model=$MODEL \\\n--num_workers=4 \\\n--save=$SAVE_PATH \\\n--validation_dataset=SintelTrainingCleanFull  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$SINTEL_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "scripts/validation/pwcnet_irr.sh",
    "content": "#!/bin/bash\n\n# experiments and datasets meta\nEXPERIMENTS_HOME=\"saved_check_point/pwcnet\"\n\n# datasets\nSINTEL_HOME=(YOUR PATH)/MPI-Sintel-complete/\n\n# model and checkpoint\nMODEL=PWCNet_irr\nCHECKPOINT=\"$EXPERIMENTS_HOME/PWCNet-irr/checkpoint_best.ckpt\"\nEVAL_LOSS=MultiScaleEPE_PWC\n\nSIZE_OF_BATCH=4\n\n# validate clean configuration\nSAVE_PATH=\"$EXPERIMENTS_HOME/eval_temp/$MODEL\"\npython ../../main.py \\\n--batch_size=$SIZE_OF_BATCH \\\n--batch_size_val=$SIZE_OF_BATCH \\\n--checkpoint=$CHECKPOINT \\\n--evaluation=True \\\n--model=$MODEL \\\n--num_workers=4 \\\n--save=$SAVE_PATH \\\n--validation_dataset=SintelTrainingCleanFull  \\\n--validation_dataset_photometric_augmentations=False \\\n--validation_dataset_root=$SINTEL_HOME \\\n--validation_key=epe \\\n--validation_loss=$EVAL_LOSS"
  },
  {
    "path": "tools.py",
    "content": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\nimport os\nimport socket\nimport re\nfrom pytz import timezone\nfrom datetime import datetime\nimport fnmatch\nimport itertools\nimport argparse\nimport sys\nimport six\nimport unicodedata\nimport json\nimport inspect\nimport tqdm\nimport logging\nimport torch\nimport ast\nimport numpy as np\n\n\ndef x2module(module_or_data_parallel):\n    if isinstance(module_or_data_parallel, torch.nn.DataParallel):\n        return module_or_data_parallel.module\n    else:\n        return module_or_data_parallel\n\n\n# ----------------------------------------------------------------------------------------\n# Comprehensively adds a new logging level to the `logging` module and the\n# currently configured logging class.\n# e.g. addLoggingLevel('TRACE', logging.DEBUG - 5)\n# ----------------------------------------------------------------------------------------\ndef addLoggingLevel(level_name, level_num, method_name=None):\n    if not method_name:\n        method_name = level_name.lower()\n    if hasattr(logging, level_name):\n        raise AttributeError('{} already defined in logging module'.format(level_name))\n    if hasattr(logging, method_name):\n        raise AttributeError('{} already defined in logging module'.format(method_name))\n    if hasattr(logging.getLoggerClass(), method_name):\n        raise AttributeError('{} already defined in logger class'.format(method_name))\n\n    # This method was inspired by the answers to Stack Overflow post\n    # http://stackoverflow.com/q/2183233/2988730, especially\n    # http://stackoverflow.com/a/13638084/2988730\n    def logForLevel(self, message, *args, **kwargs):\n        if self.isEnabledFor(level_num):\n            self._log(level_num, message, args, **kwargs)\n\n    def logToRoot(message, *args, **kwargs):\n        logging.log(level_num, message, *args, **kwargs)\n\n    logging.addLevelName(level_num, level_name)\n    setattr(logging, level_name, level_num)\n    setattr(logging.getLoggerClass(), method_name, logForLevel)\n    setattr(logging, method_name, logToRoot)\n\n\n# -------------------------------------------------------------------------------------------------\n# Looks for sub arguments in the argument structure.\n# Retrieve sub arguments for modules such as optimizer_*\n# -------------------------------------------------------------------------------------------------\ndef kwargs_from_args(args, name, exclude=[]):\n    if isinstance(exclude, str):\n        exclude = [exclude]\n    exclude += [\"class\"]\n    args_dict = vars(args)\n    name += \"_\"\n    subargs_dict = {\n        key[len(name):]: value for key, value in args_dict.items()\n        if name in key and all([key != name + x for x in exclude])\n    }\n    return subargs_dict\n\n\n# -------------------------------------------------------------------------------------------------\n# Create class instance from kwargs dictionary.\n# Filters out keys that not in the constructor\n# -------------------------------------------------------------------------------------------------\ndef instance_from_kwargs(class_constructor, kwargs):\n    argspec = inspect.getargspec(class_constructor.__init__)\n    full_args = argspec.args\n    filtered_args = dict([(k,v) for k,v in kwargs.items() if k in full_args])\n    instance = class_constructor(**filtered_args)\n    return instance\n\n\ndef module_classes_to_dict(module, include_classes=\"*\", exclude_classes=()):\n    # -------------------------------------------------------------------------\n    # If arguments are strings, convert them to a list\n    # -------------------------------------------------------------------------\n    if include_classes is not None:\n        if isinstance(include_classes, str):\n            include_classes = [include_classes]\n\n    if exclude_classes is not None:\n        if isinstance(exclude_classes, str):\n            exclude_classes = [exclude_classes]\n\n    # -------------------------------------------------------------------------\n    # Obtain dictionary from given module\n    # -------------------------------------------------------------------------\n    item_dict = dict([(name, getattr(module, name)) for name in dir(module)])\n\n    # -------------------------------------------------------------------------\n    # Filter classes\n    # -------------------------------------------------------------------------\n    item_dict = dict([\n        (name,value) for name, value in item_dict.items() if inspect.isclass(getattr(module, name))\n    ])\n\n    filtered_keys = filter_list_of_strings(\n        item_dict.keys(), include=include_classes, exclude=exclude_classes)\n\n    # -------------------------------------------------------------------------\n    # Construct dictionary from matched results\n    # -------------------------------------------------------------------------\n    result_dict = dict([(name, value) for name, value in item_dict.items() if name in filtered_keys])\n\n    return result_dict\n\n\ndef ensure_dir(file_path):\n    directory = os.path.dirname(file_path)\n    if not os.path.exists(directory):\n        os.makedirs(directory)\n\n\ndef search_and_replace(string, regex, replace):\n    while True:\n        match = re.search(regex, string)\n        if match:\n            string = string.replace(match.group(0), replace)\n        else:\n            break\n    return string\n\n\ndef hostname():\n    name = socket.gethostname()\n    n = name.find('.')\n    if n > 0:\n        name = name[:n]\n    return name\n\n\ndef get_filenames(directory, match='*.*', not_match=()):\n    if match is not None:\n        if isinstance(match, str):\n            match = [match]\n    if not_match is not None:\n        if isinstance(not_match, str):\n            not_match = [not_match]\n\n    result = []\n    for dirpath, _, filenames in os.walk(directory):\n        filtered_matches = list(itertools.chain.from_iterable(\n            [fnmatch.filter(filenames, x) for x in match]))\n        filtered_nomatch = list(itertools.chain.from_iterable(\n            [fnmatch.filter(filenames, x) for x in not_match]))\n        matched = list(set(filtered_matches) - set(filtered_nomatch))\n        result += [os.path.join(dirpath, x) for x in matched]\n    return result\n\n\ndef str2bool(v):\n    if v.lower() in ('yes', 'true', 't', 'y', '1'):\n        return True\n    elif v.lower() in ('no', 'false', 'f', 'n', '0'):\n        return False\n    else:\n        raise argparse.ArgumentTypeError('Boolean value expected.')\n\n\ndef str2str_or_none(v):\n    if v.lower() == \"none\":\n        return None\n    return v\n\n\ndef str2dict(v):\n    return ast.literal_eval(v)\n\n\ndef str2intlist(v):\n    return [int(x.strip()) for x in v.strip()[1:-1].split(',')]\n\n\ndef str2list(v):\n    return [str(x.strip()) for x in v.strip()[1:-1].split(',')]\n\n\ndef read_json(filename):\n\n    def _convert_from_unicode(data):\n        new_data = dict()\n        for name, value in six.iteritems(data):\n            if isinstance(name, six.string_types):\n                name = unicodedata.normalize('NFKD', name).encode(\n                    'ascii', 'ignore')\n            if isinstance(value, six.string_types):\n                value = unicodedata.normalize('NFKD', value).encode(\n                    'ascii', 'ignore')\n            if isinstance(value, dict):\n                value = _convert_from_unicode(value)\n            new_data[name] = value\n        return new_data\n\n    output_dict = None\n    with open(filename, \"r\") as f:\n        lines = f.readlines()\n        try:\n            output_dict = json.loads(''.join(lines), encoding='utf-8')\n        except:\n            raise ValueError('Could not read %s. %s' % (filename, sys.exc_info()[1]))\n        output_dict = _convert_from_unicode(output_dict)\n    return output_dict\n\n\ndef write_json(data_dict, filename):\n    with open(filename, \"w\") as file:\n        json.dump(data_dict, file)\n\n\ndef datestr():\n    pacific = timezone('US/Pacific')\n    now = datetime.now(pacific)\n    return '{}{:02}{:02}_{:02}{:02}'.format(now.year, now.month, now.day, now.hour, now.minute)\n\n\ndef filter_list_of_strings(lst, include=\"*\", exclude=()):\n    filtered_matches = list(itertools.chain.from_iterable([fnmatch.filter(lst, x) for x in include]))\n    filtered_nomatch = list(itertools.chain.from_iterable([fnmatch.filter(lst, x) for x in exclude]))\n    matched = list(set(filtered_matches) - set(filtered_nomatch))\n    return matched\n\n\n# ----------------------------------------------------------------------------\n# Writes all pairs to a filename for book keeping\n# Either .txt or .json\n# ----------------------------------------------------------------------------\ndef write_dictionary_to_file(arguments_dict, filename):\n    # ensure dir\n    d = os.path.dirname(filename)\n    if not os.path.exists(d):\n        os.makedirs(d)\n\n    # check for json extension\n    ext = os.path.splitext(filename)[1]\n    if ext == \".json\":\n\n        def replace_quotes(x):\n            return x.replace(\"\\'\", \"\\\"\")\n\n        with open(filename, 'w') as file:\n            file.write(\"{\\n\")\n            for i, (key, value) in enumerate(arguments_dict):\n                if isinstance(value, tuple):\n                    value = list(value)\n                if value is None:\n                    file.write(\"  \\\"%s\\\": null\" % key)\n                elif isinstance(value, str):\n                    value = value.replace(\"\\'\", \"\\\"\")\n                    file.write(\"  \\\"%s\\\": \\\"%s\\\"\" % (key, replace_quotes(str( value))))\n                elif isinstance(value, bool):\n                    file.write(\"  \\\"%s\\\": %s\" % (key, str(value).lower()))\n                else:\n                    file.write(\"  \\\"%s\\\": %s\" % (key, replace_quotes(str(value))))\n                if i < len(arguments_dict) - 1:\n                    file.write(',\\n')\n                else:\n                    file.write('\\n')\n            file.write(\"}\\n\")\n    else:\n        with open(filename, 'w') as file:\n            for key, value in arguments_dict:\n                file.write('%s: %s\\n' % (key, value))\n\n\nclass MovingAverage:\n    postfix = \"avg\"\n\n    def __init__(self):\n        self._sum = 0.0\n        self._count = 0\n\n    def add_value(self, sigma, addcount=1):\n        self._sum += sigma\n        self._count += addcount\n\n    def add_average(self, avg, addcount):\n        self._sum += avg*addcount\n        self._count += addcount\n\n    def mean(self):\n        return self._sum / self._count\n\n\nclass ExponentialMovingAverage:\n    postfix = \"ema\"\n\n    def __init__(self, alpha=0.7):\n        self._weighted_sum = 0.0\n        self._weighted_count = 0\n        self._alpha = alpha\n\n    def add_value(self, sigma, addcount=1):\n        self._weighted_sum = sigma + (1.0 - self._alpha)*self._weighted_sum\n        self._weighted_count = 1 + (1.0 - self._alpha)*self._weighted_count\n\n    def add_average(self, avg, addcount):\n        self._weighted_sum = avg*addcount + (1.0 - self._alpha)*self._weighted_sum\n        self._weighted_count = addcount + (1.0 - self._alpha)*self._weighted_count\n\n    def mean(self):\n        return self._weighted_sum / self._weighted_count\n\n\n# -----------------------------------------------------------------\n# Subclass tqdm to achieve two things:\n#   1) Output the progress bar into the logbook.\n#   2) Remove the comma before {postfix} because it's annoying.\n# -----------------------------------------------------------------\nclass TqdmToLogger(tqdm.tqdm):\n    def __init__(self, iterable=None, desc=None, total=None, leave=True,\n                 file=None, ncols=None, mininterval=0.1,\n                 maxinterval=10.0, miniters=None, ascii=None, disable=False,\n                 unit='it', unit_scale=False, dynamic_ncols=False,\n                 smoothing=0.3, bar_format=None, initial=0, position=None,\n                 postfix=None,\n                 logging_on_close=True,\n                 logging_on_update=False):\n\n        super(TqdmToLogger, self).__init__(\n            iterable=iterable, desc=desc, total=total, leave=leave,\n            file=file, ncols=ncols, mininterval=mininterval,\n            maxinterval=maxinterval, miniters=miniters, ascii=ascii, disable=disable,\n            unit=unit, unit_scale=unit_scale, dynamic_ncols=dynamic_ncols,\n            smoothing=smoothing, bar_format=bar_format, initial=initial, position=position,\n            postfix=postfix)\n\n        self._logging_on_close = logging_on_close\n        self._logging_on_update = logging_on_update\n        self._closed = False\n\n    @staticmethod\n    def format_meter(n, total, elapsed, ncols=None, prefix='', ascii=False,\n                     unit='it', unit_scale=False, rate=None, bar_format=None,\n                     postfix=None, unit_divisor=1000):\n\n        meter = tqdm.tqdm.format_meter(\n            n=n, total=total, elapsed=elapsed, ncols=ncols, prefix=prefix, ascii=ascii,\n            unit=unit, unit_scale=unit_scale, rate=rate, bar_format=bar_format,\n            postfix=postfix, unit_divisor=unit_divisor)\n\n        # get rid of that stupid comma before the postfix\n        if postfix is not None:\n            postfix_with_comma = \", %s\" % postfix\n            meter = meter.replace(postfix_with_comma, postfix)\n\n        return meter\n\n    def update(self, n=1):\n        if self._logging_on_update:\n            msg = self.__repr__()\n            logging.logbook(msg)\n        return super(TqdmToLogger, self).update(n=n)\n\n    def close(self):\n        if self._logging_on_close and not self._closed:\n            msg = self.__repr__()\n            logging.logbook(msg)\n            self._closed = True\n        return super(TqdmToLogger, self).close()\n\n\ndef tqdm_with_logging(iterable=None, desc=None, total=None, leave=True,\n                      ncols=None, mininterval=0.1,\n                      maxinterval=10.0, miniters=None, ascii=None, disable=False,\n                      unit=\"it\", unit_scale=False, dynamic_ncols=False,\n                      smoothing=0.3, bar_format=None, initial=0, position=None,\n                      postfix=None,\n                      logging_on_close=True,\n                      logging_on_update=False):\n\n    return TqdmToLogger(\n        iterable=iterable, desc=desc, total=total, leave=leave,\n        ncols=ncols, mininterval=mininterval,\n        maxinterval=maxinterval, miniters=miniters, ascii=ascii, disable=disable,\n        unit=unit, unit_scale=unit_scale, dynamic_ncols=dynamic_ncols,\n        smoothing=smoothing, bar_format=bar_format, initial=initial, position=position,\n        postfix=postfix,\n        logging_on_close=logging_on_close,\n        logging_on_update=logging_on_update)\n\n\ndef cd_dotdot(path_or_filename):\n    return os.path.abspath(os.path.join(os.path.dirname(path_or_filename), \"..\"))\n\n\ndef cd_dotdotdot(path_or_filename):\n    return os.path.abspath(os.path.join(os.path.dirname(path_or_filename), \"../..\"))\n\n\ndef cd_dotdotdotdot(path_or_filename):\n    return os.path.abspath(os.path.join(os.path.dirname(path_or_filename), \"../../..\"))\n\n\ndef tensor2numpy(tensor):\n    if isinstance(tensor, np.ndarray):\n        return tensor\n    else:\n        if isinstance(tensor, torch.autograd.Variable):\n            tensor = tensor.data\n        if tensor.dim() == 3:\n            return tensor.cpu().numpy().transpose([1,2,0])\n        else:\n            return tensor.cpu().numpy().transpose([0,2,3,1])\n"
  },
  {
    "path": "utils/__init__.py",
    "content": ""
  },
  {
    "path": "utils/flow.py",
    "content": "from __future__ import absolute_import, division, print_function\n\nimport numpy as np\nimport png\nimport matplotlib.colors as cl\n\nTAG_CHAR = np.array([202021.25], np.float32)\nUNKNOWN_FLOW_THRESH = 1e7\n\n\ndef write_flow(filename, uv, v=None):\n    nBands = 2\n\n    if v is None:\n        assert (uv.ndim == 3)\n        assert (uv.shape[2] == 2)\n        u = uv[:, :, 0]\n        v = uv[:, :, 1]\n    else:\n        u = uv\n\n    assert (u.shape == v.shape)\n    height, width = u.shape\n    f = open(filename, 'wb')\n    # write the header\n    f.write(TAG_CHAR)\n    np.array(width).astype(np.int32).tofile(f)\n    np.array(height).astype(np.int32).tofile(f)\n    # arrange into matrix form\n    tmp = np.zeros((height, width * nBands))\n    tmp[:, np.arange(width) * 2] = u\n    tmp[:, np.arange(width) * 2 + 1] = v\n    tmp.astype(np.float32).tofile(f)\n    f.close()\n\n\ndef write_flow_png(filename, uv, v=None, mask=None):\n\n    if v is None:\n        assert (uv.ndim == 3)\n        assert (uv.shape[2] == 2)\n        u = uv[:, :, 0]\n        v = uv[:, :, 1]\n    else:\n        u = uv\n\n    assert (u.shape == v.shape)\n\n    height_img, width_img = u.shape\n    if mask is None:\n        valid_mask = np.ones([height_img, width_img])\n    else:\n        valid_mask = mask\n\n    flow_u = np.clip((u * 64 + 2 ** 15), 0.0, 65535.0).astype(np.uint16)\n    flow_v = np.clip((v * 64 + 2 ** 15), 0.0, 65535.0).astype(np.uint16)\n    \n    output = np.stack((flow_u, flow_v, valid_mask), axis=-1)\n\n    with open(filename, 'wb') as f:\n        writer = png.Writer(width=width_img, height=height_img, bitdepth=16)\n        writer.write(f, np.reshape(output, (-1, width_img*3)))\n\n\ndef flow_to_png(flow_map, max_value=None):\n    _, h, w = flow_map.shape    \n    rgb_map = np.ones((h, w, 3)).astype(np.float32)\n    if max_value is not None:\n        normalized_flow_map = flow_map / max_value\n    else:\n        normalized_flow_map = flow_map / (np.abs(flow_map).max())\n    rgb_map[:, :, 0] += normalized_flow_map[0]\n    rgb_map[:, :, 1] -= 0.5 * (normalized_flow_map[0] + normalized_flow_map[1])\n    rgb_map[:, :, 2] += normalized_flow_map[1]\n    return rgb_map.clip(0, 1)\n\n\n\ndef compute_color(u, v):\n    \"\"\"\n    compute optical flow color map\n    :param u: optical flow horizontal map\n    :param v: optical flow vertical map\n    :return: optical flow in color code\n    \"\"\"\n    [h, w] = u.shape\n    img = np.zeros([h, w, 3])\n    nanIdx = np.isnan(u) | np.isnan(v)\n    u[nanIdx] = 0\n    v[nanIdx] = 0\n\n    colorwheel = make_color_wheel()\n    ncols = np.size(colorwheel, 0)\n\n    rad = np.sqrt(u ** 2 + v ** 2)\n\n    a = np.arctan2(-v, -u) / np.pi\n\n    fk = (a + 1) / 2 * (ncols - 1) + 1\n\n    k0 = np.floor(fk).astype(int)\n\n    k1 = k0 + 1\n    k1[k1 == ncols + 1] = 1\n    f = fk - k0\n\n    for i in range(0, np.size(colorwheel, 1)):\n        tmp = colorwheel[:, i]\n        col0 = tmp[k0 - 1] / 255\n        col1 = tmp[k1 - 1] / 255\n        col = (1 - f) * col0 + f * col1\n\n        idx = rad <= 1\n        col[idx] = 1 - rad[idx] * (1 - col[idx])\n        notidx = np.logical_not(idx)\n\n        col[notidx] *= 0.75\n        img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))\n\n    return img\n\n\ndef make_color_wheel():\n    \"\"\"\n    Generate color wheel according Middlebury color code\n    :return: Color wheel\n    \"\"\"\n    RY = 15\n    YG = 6\n    GC = 4\n    CB = 11\n    BM = 13\n    MR = 6\n\n    ncols = RY + YG + GC + CB + BM + MR\n\n    colorwheel = np.zeros([ncols, 3])\n\n    col = 0\n\n    # RY\n    colorwheel[0:RY, 0] = 255\n    colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))\n    col += RY\n\n    # YG\n    colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))\n    colorwheel[col:col + YG, 1] = 255\n    col += YG\n\n    # GC\n    colorwheel[col:col + GC, 1] = 255\n    colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))\n    col += GC\n\n    # CB\n    colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))\n    colorwheel[col:col + CB, 2] = 255\n    col += CB\n\n    # BM\n    colorwheel[col:col + BM, 2] = 255\n    colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))\n    col += + BM\n\n    # MR\n    colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))\n    colorwheel[col:col + MR, 0] = 255\n\n    return colorwheel\n\n\ndef flow_to_png_middlebury(flow):\n    \"\"\"\n    Convert flow into middlebury color code image\n    :param flow: optical flow map\n    :return: optical flow image in middlebury color\n    \"\"\"\n\n    flow = flow.transpose([1, 2, 0])\n    u = flow[:, :, 0]\n    v = flow[:, :, 1]\n\n    maxu = -999.\n    maxv = -999.\n    minu = 999.\n    minv = 999.\n\n    idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)\n    u[idxUnknow] = 0\n    v[idxUnknow] = 0\n\n    maxu = max(maxu, np.max(u))\n    minu = min(minu, np.min(u))\n\n    maxv = max(maxv, np.max(v))\n    minv = min(minv, np.min(v))\n\n    rad = np.sqrt(u ** 2 + v ** 2)\n    maxrad = max(-1, np.max(rad))\n\n    u = u / (maxrad + np.finfo(float).eps)\n    v = v / (maxrad + np.finfo(float).eps)\n\n    img = compute_color(u, v)\n\n    idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)\n    img[idx] = 0\n\n    return np.uint8(img)"
  },
  {
    "path": "utils/interpolation.py",
    "content": "## Portions of Code from, copyright 2018 Jochen Gast\n\nfrom __future__ import absolute_import, division, print_function\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as tf\n\n\ndef _bchw2bhwc(tensor):\n    return tensor.transpose(1,2).transpose(2,3)\n\n\ndef _bhwc2bchw(tensor):\n    return tensor.transpose(2,3).transpose(1,2)\n\n\nclass Meshgrid(nn.Module):\n    def __init__(self):\n        super(Meshgrid, self).__init__()\n        self.width = 0\n        self.height = 0\n        self.xx = None\n        self.yy = None\n\n    def _compute_meshgrid(self, width, height):\n        rangex = torch.arange(0, width)\n        rangey = torch.arange(0, height)\n        self.xx = rangex.repeat(height, 1).contiguous()\n        self.yy = rangey.repeat(width, 1).t().contiguous()\n        \n    def forward(self, width, height, device=None, dtype=None):\n        if self.width != width or self.height != height:\n            self._compute_meshgrid(width=width, height=height)\n            self.width = width\n            self.height = height\n        self.xx = self.xx.to(device=device, dtype=dtype)\n        self.yy = self.yy.to(device=device, dtype=dtype)\n        return self.xx, self.yy\n\n\nclass BatchSub2Ind(nn.Module):\n    def __init__(self):\n        super(BatchSub2Ind, self).__init__()\n        self.register_buffer(\"_offsets\", torch.LongTensor())\n\n    def forward(self, shape, row_sub, col_sub, out=None):\n        batch_size = row_sub.size(0)\n        height, width = shape\n        ind = row_sub*width + col_sub\n        torch.arange(batch_size, out=self._offsets)\n        self._offsets *= (height*width)\n\n        if out is None:\n            return torch.add(ind, self._offsets.view(-1,1,1))\n        else:\n            torch.add(ind, self._offsets.view(-1,1,1), out=out)\n\n\nclass Interp2(nn.Module):\n    def __init__(self, clamp=False):\n        super(Interp2, self).__init__()\n        self._clamp = clamp\n        self._batch_sub2ind = BatchSub2Ind()\n        self.register_buffer(\"_x0\", torch.LongTensor())\n        self.register_buffer(\"_x1\", torch.LongTensor())\n        self.register_buffer(\"_y0\", torch.LongTensor())\n        self.register_buffer(\"_y1\", torch.LongTensor())\n        self.register_buffer(\"_i00\", torch.LongTensor())\n        self.register_buffer(\"_i01\", torch.LongTensor())\n        self.register_buffer(\"_i10\", torch.LongTensor())\n        self.register_buffer(\"_i11\", torch.LongTensor())\n        self.register_buffer(\"_v00\", torch.FloatTensor())\n        self.register_buffer(\"_v01\", torch.FloatTensor())\n        self.register_buffer(\"_v10\", torch.FloatTensor())\n        self.register_buffer(\"_v11\", torch.FloatTensor())\n        self.register_buffer(\"_x\", torch.FloatTensor())\n        self.register_buffer(\"_y\", torch.FloatTensor())\n\n    def forward(self, v, xq, yq):\n        batch_size, channels, height, width = v.size()\n\n        # clamp if wanted\n        if self._clamp:\n            xq.clamp_(0, width - 1)\n            yq.clamp_(0, height - 1)\n\n        # ------------------------------------------------------------------\n        # Find neighbors\n        #\n        # x0 = torch.floor(xq).long(),          x0.clamp_(0, width - 1)\n        # x1 = x0 + 1,                          x1.clamp_(0, width - 1)\n        # y0 = torch.floor(yq).long(),          y0.clamp_(0, height - 1)\n        # y1 = y0 + 1,                          y1.clamp_(0, height - 1)\n        #\n        # ------------------------------------------------------------------\n        self._x0 = torch.floor(xq).long().clamp(0, width - 1)\n        self._y0 = torch.floor(yq).long().clamp(0, height - 1)\n\n        self._x1 = torch.add(self._x0, 1).clamp(0, width - 1)\n        self._y1 = torch.add(self._y0, 1).clamp(0, height - 1)\n\n        # batch_sub2ind\n        self._batch_sub2ind([height, width], self._y0, self._x0, out=self._i00)\n        self._batch_sub2ind([height, width], self._y0, self._x1, out=self._i01)\n        self._batch_sub2ind([height, width], self._y1, self._x0, out=self._i10)\n        self._batch_sub2ind([height, width], self._y1, self._x1, out=self._i11)\n\n        # reshape\n        v_flat = _bchw2bhwc(v).contiguous().view(-1, channels)\n        torch.index_select(v_flat, dim=0, index=self._i00.view(-1), out=self._v00)\n        torch.index_select(v_flat, dim=0, index=self._i01.view(-1), out=self._v01)\n        torch.index_select(v_flat, dim=0, index=self._i10.view(-1), out=self._v10)\n        torch.index_select(v_flat, dim=0, index=self._i11.view(-1), out=self._v11)\n\n        # local_coords\n        torch.add(xq, - self._x0.float(), out=self._x)\n        torch.add(yq, - self._y0.float(), out=self._y)\n\n        # weights\n        w00 = torch.unsqueeze((1.0 - self._y) * (1.0 - self._x), dim=1)\n        w01 = torch.unsqueeze((1.0 - self._y) * self._x, dim=1)\n        w10 = torch.unsqueeze(self._y * (1.0 - self._x), dim=1)\n        w11 = torch.unsqueeze(self._y * self._x, dim=1)\n\n        def _reshape(u):\n            return _bhwc2bchw(u.view(batch_size, height, width, channels))\n\n        # values\n        values = _reshape(self._v00)*w00 + _reshape(self._v01)*w01 \\\n            + _reshape(self._v10)*w10 + _reshape(self._v11)*w11\n\n        if self._clamp:\n            return values\n        else:\n            #  find_invalid\n            invalid = ((xq < 0) | (xq >= width) | (yq < 0) | (yq >= height)).unsqueeze(dim=1).float()\n            # maskout invalid\n            transformed = invalid * torch.zeros_like(values) + (1.0 - invalid)*values\n\n        return transformed\n\n\nclass Interp2MaskBinary(nn.Module):\n    def __init__(self, clamp=False):\n        super(Interp2MaskBinary, self).__init__()\n        self._clamp = clamp\n        self._batch_sub2ind = BatchSub2Ind()\n        self.register_buffer(\"_x0\", torch.LongTensor())\n        self.register_buffer(\"_x1\", torch.LongTensor())\n        self.register_buffer(\"_y0\", torch.LongTensor())\n        self.register_buffer(\"_y1\", torch.LongTensor())\n        self.register_buffer(\"_i00\", torch.LongTensor())\n        self.register_buffer(\"_i01\", torch.LongTensor())\n        self.register_buffer(\"_i10\", torch.LongTensor())\n        self.register_buffer(\"_i11\", torch.LongTensor())\n        self.register_buffer(\"_v00\", torch.FloatTensor())\n        self.register_buffer(\"_v01\", torch.FloatTensor())\n        self.register_buffer(\"_v10\", torch.FloatTensor())\n        self.register_buffer(\"_v11\", torch.FloatTensor())\n        self.register_buffer(\"_m00\", torch.FloatTensor())\n        self.register_buffer(\"_m01\", torch.FloatTensor())\n        self.register_buffer(\"_m10\", torch.FloatTensor())\n        self.register_buffer(\"_m11\", torch.FloatTensor())\n        self.register_buffer(\"_x\", torch.FloatTensor())\n        self.register_buffer(\"_y\", torch.FloatTensor())\n\n    def forward(self, v, xq, yq, mask):\n        batch_size, channels, height, width = v.size()\n        _, channels_mask, _, _ = mask.size()\n\n        if channels_mask != channels:\n            mask = mask.repeat(1, int(channels/channels_mask), 1, 1)\n\n        # clamp if wanted\n        if self._clamp:\n            xq.clamp_(0, width - 1)\n            yq.clamp_(0, height - 1)\n\n        # ------------------------------------------------------------------\n        # Find neighbors\n        #\n        # x0 = torch.floor(xq).long(),          x0.clamp_(0, width - 1)\n        # x1 = x0 + 1,                          x1.clamp_(0, width - 1)\n        # y0 = torch.floor(yq).long(),          y0.clamp_(0, height - 1)\n        # y1 = y0 + 1,                          y1.clamp_(0, height - 1)\n        #\n        # ------------------------------------------------------------------\n        self._x0 = torch.floor(xq).long().clamp(0, width - 1)\n        self._y0 = torch.floor(yq).long().clamp(0, height - 1)\n\n        self._x1 = torch.add(self._x0, 1).clamp(0, width - 1)\n        self._y1 = torch.add(self._y0, 1).clamp(0, height - 1)\n\n        # batch_sub2ind\n        self._batch_sub2ind([height, width], self._y0, self._x0, out=self._i00)\n        self._batch_sub2ind([height, width], self._y0, self._x1, out=self._i01)\n        self._batch_sub2ind([height, width], self._y1, self._x0, out=self._i10)\n        self._batch_sub2ind([height, width], self._y1, self._x1, out=self._i11)\n\n        # reshape\n        v_flat = _bchw2bhwc(v).contiguous().view(-1, channels)\n        torch.index_select(v_flat, dim=0, index=self._i00.view(-1), out=self._v00)\n        torch.index_select(v_flat, dim=0, index=self._i01.view(-1), out=self._v01)\n        torch.index_select(v_flat, dim=0, index=self._i10.view(-1), out=self._v10)\n        torch.index_select(v_flat, dim=0, index=self._i11.view(-1), out=self._v11)\n\n        # reshape\n        m_flat = _bchw2bhwc(mask).contiguous().view(-1, channels)\n        torch.index_select(m_flat, dim=0, index=self._i00.view(-1), out=self._m00)\n        torch.index_select(m_flat, dim=0, index=self._i01.view(-1), out=self._m01)\n        torch.index_select(m_flat, dim=0, index=self._i10.view(-1), out=self._m10)\n        torch.index_select(m_flat, dim=0, index=self._i11.view(-1), out=self._m11)\n\n        # local_coords\n        torch.add(xq, - self._x0.float(), out=self._x)\n        torch.add(yq, - self._y0.float(), out=self._y)\n\n        # weights\n        w00 = torch.unsqueeze((1.0 - self._y) * (1.0 - self._x), dim=1)\n        w01 = torch.unsqueeze((1.0 - self._y) * self._x, dim=1)\n        w10 = torch.unsqueeze(self._y * (1.0 - self._x), dim=1)\n        w11 = torch.unsqueeze(self._y * self._x, dim=1)\n\n        def _reshape(u):\n            return _bhwc2bchw(u.view(batch_size, height, width, channels))\n\n        # values\n        values = _reshape(self._m00) * _reshape(self._v00) * w00 + _reshape(self._m01) * _reshape(\n            self._v01) * w01 + _reshape(self._m10) * _reshape(self._v10) * w10 + _reshape(self._m11) * _reshape(\n            self._v11) * w11\n        m_weights = _reshape(self._m00) * w00 + _reshape(self._m01) * w01 + _reshape(self._m10) * w10 + _reshape(\n            self._m11) * w11\n        values = values / (m_weights + 1e-12)\n        invalid_mask = (((1 - m_weights) / (m_weights + 1e-12)) > 0.5)[:, 0:1, :, :]\n\n        if self._clamp:\n            return values\n        else:\n            #  find_invalid\n            invalid = ((xq < 0) | (xq >= width) | (yq < 0) | (yq >= height) | invalid_mask.squeeze(dim=1)).unsqueeze(dim=1).float()\n            transformed = invalid * torch.zeros_like(values) + (1.0 - invalid) * values\n\n        return transformed, (1 - invalid_mask).float()\n\n\ndef resize2D(inputs, size_targets, mode=\"bilinear\"):\n    size_inputs = [inputs.size(2), inputs.size(3)]\n\n    if all([size_inputs == size_targets]):\n        return inputs  # nothing to do\n    elif any([size_targets < size_inputs]):\n        resized = tf.adaptive_avg_pool2d(inputs, size_targets)  # downscaling\n    else:\n        resized = tf.upsample(inputs, size=size_targets, mode=mode)  # upsampling\n\n    # correct scaling\n    return resized\n\n\ndef resize2D_as(inputs, output_as, mode=\"bilinear\"):\n    size_targets = [output_as.size(2), output_as.size(3)]\n    return resize2D(inputs, size_targets, mode=mode)\n"
  }
]