[
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<h1 align='center' style=\"text-align:center; font-weight:bold; font-size:2.0em;letter-spacing:2.0px;\">\nStereo Any Video: <br> Temporally Consistent Stereo Matching<h1>      \n\n<div align=\"center\">\n  <a href=\"https://arxiv.org/abs/2503.05549\" target=\"_blank\" rel=\"external nofollow noopener\">\n  <img src=\"https://img.shields.io/badge/Paper-arXiv-deepgreen\" alt=\"Paper arXiv\"></a>\n  <a href=\"https://tomtomtommi.github.io/StereoAnyVideo/\" target=\"_blank\" rel=\"external nofollow noopener\">\n  <img src=\"https://img.shields.io/badge/Project-Page-9cf\" alt=\"Project Page\"></a>\n</div>\n</p>\n\n![Demo](./assets/stereoanyvideo.gif)\n\n## Installation\n\nInstallation with cuda 12.2\n\n<details>\n  <summary>Setup the root for all source files</summary>\n  <pre><code>\n    git clone https://github.com/tomtomtommi/stereoanyvideo\n    cd stereoanyvideo\n    export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH\n  </code></pre>\n</details>\n\n<details>\n  <summary>Create a conda env</summary>\n  <pre><code>\n    conda create -n sav python=3.10\n    conda activate sav\n  </code></pre>\n</details>\n\n<details>\n  <summary>Install requirements</summary>\n  <pre><code>\n    conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=12.1 -c pytorch -c nvidia\n    pip install pip==24.0\n    pip install pytorch_lightning==1.6.0\n    pip install iopath\n    conda install -c bottler nvidiacub\n    pip install scikit-image matplotlib imageio plotly opencv-python\n    conda install -c fvcore -c conda-forge fvcore\n    pip install black usort flake8 flake8-bugbear flake8-comprehensions\n    conda install pytorch3d -c pytorch3d\n    pip install -r requirements.txt\n    pip install timm\n  </code></pre>\n</details>\n\n<details>\n  <summary>Download VDA checkpoints</summary>\n  <pre><code>\n    cd models/Video-Depth-Anything\n    sh get_weights.sh\n  </code></pre>\n</details>\n\n## Inference a stereo video\n\n```\nsh demo.sh\n```\nBefore running, download the checkpoints on [google drive](https://drive.google.com/drive/folders/1c7L065dcBWhCYYjWYo2edGOG605PnpXv?usp=sharing) . \nCopy the checkpoints to `./checkpoints/`\n\nIn default, left and right camera videos are supposed to be structured like this:\n```none\n./demo_video/\n        ├── left\n            ├── left000000.png\n            ├── left000001.png\n            ├── left000002.png\n            ...\n        ├── right\n            ├── right000000.png\n            ├── right000001.png\n            ├── right000002.png\n            ...\n```\n\nA simple way to run the demo is using SouthKensingtonSV.\n\nTo test on your own data, modify `--path ./demo_video/`. More arguments can be found and modified in ` demo.py`\n\n## Dataset\n\nDownload the following datasets and put in `./data/datasets/`:\n - [SceneFlow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)\n - [Sintel](http://sintel.is.tue.mpg.de/stereo)\n - [Dynamic_Replica](https://dynamic-stereo.github.io/)\n - [KITTI Depth](https://www.cvlibs.net/datasets/kitti/eval_depth_all.php)\n - [Infinigen SV](https://tomtomtommi.github.io/BiDAVideo/)\n - [Virtual KITTI2](https://europe.naverlabs.com/proxy-virtual-worlds-vkitti-2/)\n - [SouthKensington SV](https://tomtomtommi.github.io/BiDAVideo/)\n\n\n## Evaluation\n```\nsh evaluate_stereoanyvideo.sh\n```\n\n## Training\n```\nsh train_stereoanyvideo.sh\n```\n\n## Citation \nIf you use our method in your research, please consider citing:\n```\n@inproceedings{jing2025stereo,\n  title={Stereo any video: Temporally consistent stereo matching},\n  author={Jing, Junpeng and Luo, Weixun and Mao, Ye and Mikolajczyk, Krystian},\n  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},\n  pages={20836--20846},\n  year={2025}\n}\n```\n"
  },
  {
    "path": "assets/1",
    "content": "\n"
  },
  {
    "path": "checkpoints/checkpoints here.txt",
    "content": ""
  },
  {
    "path": "data/datasets/dataset here.txt",
    "content": ""
  },
  {
    "path": "datasets/augmentor.py",
    "content": "import numpy as np\nimport random\nfrom PIL import Image\n\nimport cv2\n\ncv2.setNumThreads(0)\ncv2.ocl.setUseOpenCL(False)\n\nfrom torchvision.transforms import ColorJitter, functional, Compose\n\n\nclass AdjustGamma(object):\n    def __init__(self, gamma_min, gamma_max, gain_min=1.0, gain_max=1.0):\n        self.gamma_min, self.gamma_max, self.gain_min, self.gain_max = (\n            gamma_min,\n            gamma_max,\n            gain_min,\n            gain_max,\n        )\n\n    def __call__(self, sample):\n        gain = random.uniform(self.gain_min, self.gain_max)\n        gamma = random.uniform(self.gamma_min, self.gamma_max)\n        return functional.adjust_gamma(sample, gamma, gain)\n\n    def __repr__(self):\n        return f\"Adjust Gamma {self.gamma_min}, ({self.gamma_max}) and Gain ({self.gain_min}, {self.gain_max})\"\n\n\nclass SequenceDispFlowAugmentor:\n    def __init__(\n        self,\n        crop_size,\n        min_scale=-0.2,\n        max_scale=0.5,\n        do_flip=True,\n        yjitter=False,\n        saturation_range=[0.6, 1.4],\n        gamma=[1, 1, 1, 1],\n    ):\n        # spatial augmentation params\n        self.crop_size = crop_size\n        self.min_scale = min_scale\n        self.max_scale = max_scale\n        self.spatial_aug_prob = 1.0\n        self.stretch_prob = 0.8\n        self.max_stretch = 0.2\n\n        # flip augmentation params\n        self.yjitter = yjitter\n        self.do_flip = do_flip\n        self.h_flip_prob = 0.5\n        self.v_flip_prob = 0.1\n\n        # photometric augmentation params\n        self.photo_aug = Compose(\n            [\n                ColorJitter(\n                    brightness=0.4,\n                    contrast=0.4,\n                    saturation=saturation_range,\n                    hue=0.5 / 3.14,\n                ),\n                AdjustGamma(*gamma),\n            ]\n        )\n        self.asymmetric_color_aug_prob = 0.2\n        self.eraser_aug_prob = 0.5\n\n    def color_transform(self, seq):\n        \"\"\"Photometric augmentation\"\"\"\n\n        # asymmetric\n        if np.random.rand() < self.asymmetric_color_aug_prob:\n            for i in range(len(seq)):\n                for cam in (0, 1):\n                    seq[i][cam] = np.array(\n                        self.photo_aug(Image.fromarray(seq[i][cam])), dtype=np.uint8\n                    )\n        # symmetric\n        else:\n            image_stack = np.concatenate(\n                [seq[i][cam] for i in range(len(seq)) for cam in (0, 1)], axis=0\n            )\n            image_stack = np.array(\n                self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8\n            )\n            split = np.split(image_stack, len(seq) * 2, axis=0)\n            for i in range(len(seq)):\n                seq[i][0] = split[2 * i]\n                seq[i][1] = split[2 * i + 1]\n        return seq\n\n    def eraser_transform(self, seq, bounds=[50, 100]):\n        \"\"\"Occlusion augmentation\"\"\"\n        ht, wd = seq[0][0].shape[:2]\n        for i in range(len(seq)):\n            for cam in (0, 1):\n                if np.random.rand() < self.eraser_aug_prob:\n                    mean_color = np.mean(seq[0][0].reshape(-1, 3), axis=0)\n                    for _ in range(np.random.randint(1, 3)):\n                        x0 = np.random.randint(0, wd)\n                        y0 = np.random.randint(0, ht)\n                        dx = np.random.randint(bounds[0], bounds[1])\n                        dy = np.random.randint(bounds[0], bounds[1])\n                        seq[i][cam][y0 : y0 + dy, x0 : x0 + dx, :] = mean_color\n\n        return seq\n\n    def spatial_transform(self, img, disp):\n        # randomly sample scale\n        ht, wd = img[0][0].shape[:2]\n        min_scale = np.maximum(\n            (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd)\n        )\n\n        scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)\n        scale_x = scale\n        scale_y = scale\n        if np.random.rand() < self.stretch_prob:\n            scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)\n            scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)\n\n        scale_x = np.clip(scale_x, min_scale, None)\n        scale_y = np.clip(scale_y, min_scale, None)\n\n        if np.random.rand() < self.spatial_aug_prob:\n            # rescale the images\n            for i in range(len(img)):\n                for cam in (0, 1):\n                    img[i][cam] = cv2.resize(\n                        img[i][cam],\n                        None,\n                        fx=scale_x,\n                        fy=scale_y,\n                        interpolation=cv2.INTER_LINEAR,\n                    )\n                    if len(disp[i]) > 0:\n                        disp[i][cam] = cv2.resize(\n                            disp[i][cam],\n                            None,\n                            fx=scale_x,\n                            fy=scale_y,\n                            interpolation=cv2.INTER_LINEAR,\n                        )\n                        disp[i][cam] = disp[i][cam] * [scale_x, scale_y]\n\n        if self.yjitter:\n            y0 = np.random.randint(2, img[0][0].shape[0] - self.crop_size[0] - 2)\n            x0 = np.random.randint(2, img[0][0].shape[1] - self.crop_size[1] - 2)\n\n            for i in range(len(img)):\n                y1 = y0 + np.random.randint(-2, 2 + 1)\n                img[i][0] = img[i][0][\n                    y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]\n                ]\n                img[i][1] = img[i][1][\n                    y1 : y1 + self.crop_size[0], x0 : x0 + self.crop_size[1]\n                ]\n                if len(disp[i]) > 0:\n                    disp[i][0] = disp[i][0][\n                        y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]\n                    ]\n                    disp[i][1] = disp[i][1][\n                        y1 : y1 + self.crop_size[0], x0 : x0 + self.crop_size[1]\n                    ]\n        else:\n            y0 = np.random.randint(0, img[0][0].shape[0] - self.crop_size[0])\n            x0 = np.random.randint(0, img[0][0].shape[1] - self.crop_size[1])\n            for i in range(len(img)):\n                for cam in (0, 1):\n                    img[i][cam] = img[i][cam][\n                        y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]\n                    ]\n                    if len(disp[i]) > 0:\n                        disp[i][cam] = disp[i][cam][\n                            y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]\n                        ]\n\n        return img, disp\n\n    def __call__(self, img, disp):\n        img = self.color_transform(img)\n        img = self.eraser_transform(img)\n        img, disp = self.spatial_transform(img, disp)\n\n        for i in range(len(img)):\n            for cam in (0, 1):\n                img[i][cam] = np.ascontiguousarray(img[i][cam])\n                if len(disp[i]) > 0:\n                    disp[i][cam] = np.ascontiguousarray(disp[i][cam])\n\n        return img, disp\n\n\nclass SequenceDispSparseFlowAugmentor:\n    def __init__(\n        self,\n        crop_size,\n        min_scale=-0.2,\n        max_scale=0.5,\n        do_flip=True,\n        yjitter=False,\n        saturation_range=[0.6, 1.4],\n        gamma=[1, 1, 1, 1],\n    ):\n        # spatial augmentation params\n        self.crop_size = crop_size\n        self.min_scale = min_scale\n        self.max_scale = max_scale\n        self.spatial_aug_prob = 1.0\n        self.stretch_prob = 0.8\n        self.max_stretch = 0.2\n\n        # flip augmentation params\n        self.yjitter = yjitter\n        self.do_flip = do_flip\n        self.h_flip_prob = 0.5\n        self.v_flip_prob = 0.1\n\n        # photometric augmentation params\n        self.photo_aug = Compose(\n            [\n                ColorJitter(\n                    brightness=0.4,\n                    contrast=0.4,\n                    saturation=saturation_range,\n                    hue=0.5 / 3.14,\n                ),\n                AdjustGamma(*gamma),\n            ]\n        )\n        self.asymmetric_color_aug_prob = 0.2\n        self.eraser_aug_prob = 0.5\n\n    def color_transform(self, seq):\n        \"\"\"Photometric augmentation\"\"\"\n        # symmetric\n        image_stack = np.concatenate(\n            [seq[i][cam] for i in range(len(seq)) for cam in (0, 1)], axis=0\n        )\n        image_stack = np.array(\n            self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8\n        )\n        split = np.split(image_stack, len(seq) * 2, axis=0)\n        for i in range(len(seq)):\n            seq[i][0] = split[2 * i]\n            seq[i][1] = split[2 * i + 1]\n        return seq\n\n    def eraser_transform(self, seq, bounds=[50, 100]):\n        \"\"\"Occlusion augmentation\"\"\"\n        ht, wd = seq[0][0].shape[:2]\n        for i in range(len(seq)):\n            for cam in (0, 1):\n                if np.random.rand() < self.eraser_aug_prob:\n                    mean_color = np.mean(seq[0][0].reshape(-1, 3), axis=0)\n                    for _ in range(np.random.randint(1, 3)):\n                        x0 = np.random.randint(0, wd)\n                        y0 = np.random.randint(0, ht)\n                        dx = np.random.randint(bounds[0], bounds[1])\n                        dy = np.random.randint(bounds[0], bounds[1])\n                        seq[i][cam][y0 : y0 + dy, x0 : x0 + dx, :] = mean_color\n\n        return seq\n\n    def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):\n        ht, wd = flow.shape[:2]\n        coords = np.meshgrid(np.arange(wd), np.arange(ht))\n        coords = np.stack(coords, axis=-1)\n\n        coords = coords.reshape(-1, 2).astype(np.float32)\n        flow = flow.reshape(-1, 2).astype(np.float32)\n        valid = valid.reshape(-1).astype(np.float32)\n\n        coords0 = coords[valid>=1]\n        flow0 = flow[valid>=1]\n\n        ht1 = int(round(ht * fy))\n        wd1 = int(round(wd * fx))\n\n        coords1 = coords0 * [fx, fy]\n        flow1 = flow0 * [fx, fy]\n\n        xx = np.round(coords1[:,0]).astype(np.int32)\n        yy = np.round(coords1[:,1]).astype(np.int32)\n\n        v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)\n        xx = xx[v]\n        yy = yy[v]\n        flow1 = flow1[v]\n\n        flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)\n        valid_img = np.zeros([ht1, wd1], dtype=np.int32)\n\n        flow_img[yy, xx] = flow1\n        valid_img[yy, xx] = 1\n\n        return flow_img, valid_img\n\n    def spatial_transform(self, img, disp, valid):\n        # randomly sample scale\n        ht, wd = img[0][0].shape[:2]\n        min_scale = np.maximum(\n            (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd)\n        )\n\n        scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)\n        scale_x = scale\n        scale_y = scale\n        if np.random.rand() < self.stretch_prob:\n            scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)\n            scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)\n\n        scale_x = np.clip(scale_x, min_scale, None)\n        scale_y = np.clip(scale_y, min_scale, None)\n\n        if np.random.rand() < self.spatial_aug_prob:\n            # rescale the images\n            for i in range(len(img)):\n                for cam in (0, 1):\n                    img[i][cam] = cv2.resize(\n                        img[i][cam],\n                        None,\n                        fx=scale_x,\n                        fy=scale_y,\n                        interpolation=cv2.INTER_LINEAR,\n                    )\n                    if len(disp[i]) > 0:\n                        disp[i][cam], valid[i][cam] = self.resize_sparse_flow_map(disp[i][cam], valid[i][cam], fx=scale_x, fy=scale_y)\n\n        margin_y = 20\n        margin_x = 50\n\n        y0 = np.random.randint(0, img[0][0].shape[0] - self.crop_size[0])\n        x0 = np.random.randint(0, img[0][0].shape[1] - self.crop_size[1])\n        for i in range(len(img)):\n            for cam in (0, 1):\n                img[i][cam] = img[i][cam][\n                    y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]\n                ]\n                if len(disp[i]) > 0:\n                    disp[i][cam] = disp[i][cam][\n                        y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]\n                    ]\n                    valid[i][cam] = valid[i][cam][\n                        y0: y0 + self.crop_size[0], x0: x0 + self.crop_size[1]\n                    ]\n        return img, disp, valid\n\n    def __call__(self, img, disp, valid):\n        img = self.color_transform(img)\n        img = self.eraser_transform(img)\n        img, disp, valid = self.spatial_transform(img, disp, valid)\n\n        for i in range(len(img)):\n            for cam in (0, 1):\n                img[i][cam] = np.ascontiguousarray(img[i][cam])\n                if len(disp[i]) > 0:\n                    disp[i][cam] = np.ascontiguousarray(disp[i][cam])\n                    valid[i][cam] = np.ascontiguousarray(valid[i][cam])\n\n        return img, disp, valid\n"
  },
  {
    "path": "datasets/frame_utils.py",
    "content": "import numpy as np\nfrom PIL import Image\nfrom os.path import *\nimport re\nimport imageio\nimport cv2\n\ncv2.setNumThreads(0)\ncv2.ocl.setUseOpenCL(False)\n\nTAG_CHAR = np.array([202021.25], np.float32)\n\n\ndef readFlow(fn):\n    \"\"\"Read .flo file in Middlebury format\"\"\"\n    # Code adapted from:\n    # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy\n\n    # WARNING: this will work on little-endian architectures (eg Intel x86) only!\n    # print 'fn = %s'%(fn)\n    with open(fn, \"rb\") as f:\n        magic = np.fromfile(f, np.float32, count=1)\n        if 202021.25 != magic:\n            print(\"Magic number incorrect. Invalid .flo file\")\n            return None\n        else:\n            w = np.fromfile(f, np.int32, count=1)\n            h = np.fromfile(f, np.int32, count=1)\n            # print 'Reading %d x %d flo file\\n' % (w, h)\n            data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))\n            # Reshape data into 3D array (columns, rows, bands)\n            # The reshape here is for visualization, the original code is (w,h,2)\n            return np.resize(data, (int(h), int(w), 2))\n\n\ndef readPFM(file):\n    file = open(file, \"rb\")\n\n    color = None\n    width = None\n    height = None\n    scale = None\n    endian = None\n\n    header = file.readline().rstrip()\n    if header == b\"PF\":\n        color = True\n    elif header == b\"Pf\":\n        color = False\n    else:\n        raise Exception(\"Not a PFM file.\")\n\n    dim_match = re.match(rb\"^(\\d+)\\s(\\d+)\\s$\", file.readline())\n    if dim_match:\n        width, height = map(int, dim_match.groups())\n    else:\n        raise Exception(\"Malformed PFM header.\")\n\n    scale = float(file.readline().rstrip())\n    if scale < 0:  # little-endian\n        endian = \"<\"\n        scale = -scale\n    else:\n        endian = \">\"  # big-endian\n\n    data = np.fromfile(file, endian + \"f\")\n    shape = (height, width, 3) if color else (height, width)\n\n    data = np.reshape(data, shape)\n    data = np.flipud(data)\n    return data\n\n\ndef readDispSintelStereo(file_name):\n    \"\"\"Return disparity read from filename.\"\"\"\n    f_in = np.array(Image.open(file_name))\n    d_r = f_in[:, :, 0].astype(\"float64\")\n    d_g = f_in[:, :, 1].astype(\"float64\")\n    d_b = f_in[:, :, 2].astype(\"float64\")\n\n    disp = d_r * 4 + d_g / (2 ** 6) + d_b / (2 ** 14)\n    mask = np.array(Image.open(file_name.replace(\"disparities\", \"occlusions\")))\n    valid = (mask == 0) & (disp > 0)\n    return disp, valid\n\n\ndef readDispMiddlebury(file_name):\n    assert basename(file_name) == \"disp0GT.pfm\"\n    disp = readPFM(file_name).astype(np.float32)\n    assert len(disp.shape) == 2\n    nocc_pix = file_name.replace(\"disp0GT.pfm\", \"mask0nocc.png\")\n    assert exists(nocc_pix)\n    nocc_pix = imageio.imread(nocc_pix) == 255\n    assert np.any(nocc_pix)\n    return disp, nocc_pix\n\n\ndef read_gen(file_name, pil=False):\n    ext = splitext(file_name)[-1]\n    if ext == \".png\" or ext == \".jpeg\" or ext == \".ppm\" or ext == \".jpg\":\n        return Image.open(file_name)\n    elif ext == \".bin\" or ext == \".raw\":\n        return np.load(file_name)\n    elif ext == \".flo\":\n        return readFlow(file_name).astype(np.float32)\n    elif ext == \".pfm\":\n        flow = readPFM(file_name).astype(np.float32)\n        if len(flow.shape) == 2:\n            return flow\n        else:\n            return flow[:, :, :-1]\n    return []\n"
  },
  {
    "path": "datasets/video_datasets.py",
    "content": "import os\nimport copy\nimport gzip\nimport logging\nimport torch\nimport numpy as np\nimport torch.utils.data as data\nimport torch.nn.functional as F\nimport os.path as osp\nfrom glob import glob\nimport cv2\nimport re\nfrom scipy.spatial.transform import Rotation as R\n\nfrom collections import defaultdict\nfrom PIL import Image\nfrom dataclasses import dataclass\nfrom typing import List, Optional\nfrom pytorch3d.renderer.cameras import PerspectiveCameras\nfrom pytorch3d.implicitron.dataset.types import (\n    FrameAnnotation as ImplicitronFrameAnnotation,\n    load_dataclass,\n)\n\nfrom stereoanyvideo.datasets import frame_utils\nfrom stereoanyvideo.evaluation.utils.eval_utils import depth2disparity_scale\nfrom stereoanyvideo.datasets.augmentor import SequenceDispFlowAugmentor, SequenceDispSparseFlowAugmentor\n\n\n@dataclass\nclass DynamicReplicaFrameAnnotation(ImplicitronFrameAnnotation):\n    \"\"\"A dataclass used to load annotations from json.\"\"\"\n\n    camera_name: Optional[str] = None\n\n\nclass StereoSequenceDataset(data.Dataset):\n    def __init__(self, aug_params=None, sparse=False, reader=None):\n        self.augmentor = None\n        self.sparse = sparse\n        self.img_pad = (\n            aug_params.pop(\"img_pad\", None) if aug_params is not None else None\n        )\n        if aug_params is not None and \"crop_size\" in aug_params:\n            if sparse:\n                self.augmentor = SequenceDispSparseFlowAugmentor(**aug_params)\n            else:\n                self.augmentor = SequenceDispFlowAugmentor(**aug_params)\n\n        if reader is None:\n            self.disparity_reader = frame_utils.read_gen\n        else:\n            self.disparity_reader = reader\n        self.depth_reader = self._load_depth\n        self.is_test = False\n        self.sample_list = []\n        self.extra_info = []\n        self.depth_eps = 1e-5\n\n    def _load_depth(self, depth_path):\n        if depth_path[-3:] == \"npy\":\n            return self._load_npy_depth(depth_path)\n        elif depth_path[-3:] == \"png\":\n            if \"kitti_depth\" in depth_path:\n                return self._load_kitti_depth(depth_path)\n            elif \"vkitti2\" in depth_path:\n                return self._load_vkitti2(depth_path)\n            else:\n                return self._load_16big_png_depth(depth_path)\n        else:\n            raise ValueError(\"Other format depth is not implemented\")\n\n    def _load_npy_depth(self, depth_npy):\n        depth = np.load(depth_npy)\n        return depth\n\n    def _load_vkitti2(self, depth_png):\n        depth_image = cv2.imread(depth_png, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)\n        depth_in_meters = depth_image.astype(np.float32) / 100.0\n        depth_in_meters[depth_image == 0] = -1.\n\n        return depth_in_meters\n\n    def _load_kitti_depth(self, depth_png):\n        # depth_image = cv2.imread(depth_png, cv2.IMREAD_UNCHANGED)\n        # depth_in_meters = depth_image.astype(np.float32) / 256.0\n        depth_image = np.array(Image.open(depth_png), dtype=int)\n        # make sure we have a proper 16bit depth map here.. not 8bit!\n        assert (np.max(depth_image) > 255)\n\n        depth_in_meters = depth_image.astype(np.float32) / 256.\n        depth_in_meters[depth_image == 0] = -1.\n\n        return depth_in_meters\n\n    def _load_16big_png_depth(self, depth_png):\n        with Image.open(depth_png) as depth_pil:\n            # the image is stored with 16-bit depth but PIL reads it as I (32 bit).\n            # we cast it to uint16, then reinterpret as float16, then cast to float32\n            depth = (\n                np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)\n                .astype(np.float32)\n                .reshape((depth_pil.size[1], depth_pil.size[0]))\n            )\n        return depth\n\n    def load_tartanair_pose(self, filepath, index=0):\n        poses = np.loadtxt(filepath)\n        tx, ty, tz, qx, qy, qz, qw = poses[index]\n\n        # Quaternion to rotation matrix\n        r = R.from_quat([qx, qy, qz, qw])\n        R_mat = r.as_matrix()\n\n        # Assemble 4x4 pose matrix\n        T = np.eye(4)\n        T[:3, :3] = R_mat\n        T[:3, 3] = [tx, ty, tz]\n\n        return T\n\n    def parse_txt_file(self, file_path):\n        with open(file_path, 'r') as file:\n            data = file.read()\n\n        # Regex patterns\n        intrinsic_pattern = re.compile(r\"Intrinsic:\\s*\\[\\[([^\\]]+)\\]\\s*\\[\\s*([^\\]]+)\\]\\s*\\[\\s*([^\\]]+)\\]\\]\")\n        frame_pattern = re.compile(r\"Frame (\\d+): Pose: ([\\w\\d]+)\\n([\\s\\S]+?)(?=Frame|\\Z)\")\n\n        # Extract intrinsic matrix (K)\n        intrinsic_match = intrinsic_pattern.search(data)\n        if intrinsic_match:\n            K = np.array([list(map(float, row.split())) for row in intrinsic_match.groups()])\n        else:\n            raise ValueError(\"Intrinsic matrix not found in the file\")\n\n        # Extract frames and compute R and T\n        frames = []\n        for frame_match in frame_pattern.finditer(data):\n            frame_number = int(frame_match.group(1))\n            pose_id = frame_match.group(2)\n            pose_matrix = np.array([list(map(float, row.split())) for row in frame_match.group(3).strip().split('\\n')])\n\n            # Decompose pose matrix into R and T\n            R = pose_matrix[:3, :3]  # The upper-left 3x3 part is the rotation matrix\n            T = pose_matrix[:3, 3]  # The first three elements of the fourth column is the translation vector\n\n            frames.append({\n                'frame_number': frame_number,\n                'pose_id': pose_id,\n                'pose_matrix': pose_matrix,\n                'R': R,\n                'T': T\n            })\n\n        return K, frames\n\n    def _get_pytorch3d_camera(\n        self, entry_viewpoint, image_size, scale: float\n    ) -> PerspectiveCameras:\n        assert entry_viewpoint is not None\n        # principal point and focal length\n        principal_point = torch.tensor(\n            entry_viewpoint.principal_point, dtype=torch.float\n        )\n        focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)\n        half_image_size_wh_orig = (\n            torch.tensor(list(reversed(image_size)), dtype=torch.float) / 2.0\n        )\n\n        # first, we convert from the dataset's NDC convention to pixels\n        format = entry_viewpoint.intrinsics_format\n        if format.lower() == \"ndc_norm_image_bounds\":\n            # this is e.g. currently used in CO3D for storing intrinsics\n            rescale = half_image_size_wh_orig\n        elif format.lower() == \"ndc_isotropic\":\n            rescale = half_image_size_wh_orig.min()\n        else:\n            raise ValueError(f\"Unknown intrinsics format: {format}\")\n\n        # principal point and focal length in pixels\n        principal_point_px = half_image_size_wh_orig - principal_point * rescale\n        focal_length_px = focal_length * rescale\n\n        # now, convert from pixels to PyTorch3D v0.5+ NDC convention\n        # if self.image_height is None or self.image_width is None:\n        out_size = list(reversed(image_size))\n\n        half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0\n        half_min_image_size_output = half_image_size_output.min()\n\n        # rescaled principal point and focal length in ndc\n        principal_point = (\n            half_image_size_output - principal_point_px * scale\n        ) / half_min_image_size_output\n        focal_length = focal_length_px * scale / half_min_image_size_output\n        return PerspectiveCameras(\n            focal_length=focal_length[None],\n            principal_point=principal_point[None],\n            R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],\n            T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],\n        )\n\n    def _get_pytorch3d_camera_from_blender(self, R, T, K, image_size, scale: float) -> PerspectiveCameras:\n        assert R is not None and T is not None and K is not None\n        assert R.shape == (3, 3), f\"Expected R to be 3x3, but got {R.shape}\"\n        assert T.shape == (3,), f\"Expected T to be a 3-element vector, but got {T.shape}\"\n        assert K.shape == (3, 3), f\"Expected K to be 3x3, but got {K.shape}\"\n\n        # Extract principal point and focal length from K\n        fx = K[0, 0]\n        fy = K[1, 1]\n        cx = K[0, 2]\n        cy = K[1, 2]\n\n        principal_point = torch.tensor([cx, cy], dtype=torch.float)\n        focal_length = torch.tensor([fx, fy], dtype=torch.float)\n\n        half_image_size_wh_orig = (\n                torch.tensor(list(reversed(image_size)), dtype=torch.float) / 2.0\n        )\n\n        # Adjust principal point and focal length in pixels\n        principal_point_px = principal_point * scale\n        focal_length_px = focal_length * scale\n\n        # Convert from pixels to PyTorch3D NDC convention\n        principal_point = (principal_point_px - half_image_size_wh_orig) / half_image_size_wh_orig\n        half_min_image_size_output = half_image_size_wh_orig.min()\n        focal_length = focal_length_px / half_min_image_size_output\n\n        R = R.T @ np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]], dtype=np.float64)\n        T = T @ np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]], dtype=np.float64)\n\n        # Convert R and T to PyTorch tensors\n        R_tensor = torch.tensor(R, dtype=torch.float).unsqueeze(0)  # Add batch dimension\n        T_tensor = torch.tensor(T, dtype=torch.float).view(1, 3)  # Ensure T is a 1x3 tensor\n\n        # Return PerspectiveCameras object\n        return PerspectiveCameras(\n            focal_length=focal_length.unsqueeze(0),  # Add batch dimension\n            principal_point=principal_point.unsqueeze(0),  # Add batch dimension\n            R=R_tensor,\n            T=T_tensor,\n        )\n\n    def _get_output_tensor(self, sample):\n        output_tensor = defaultdict(list)\n        sample_size = len(sample[\"image\"][\"left\"])\n        output_tensor_keys = [\"img\", \"disp\", \"valid_disp\", \"mask\"]\n        add_keys = [\"viewpoint\", \"metadata\"]\n        for add_key in add_keys:\n            if add_key in sample:\n                output_tensor_keys.append(add_key)\n\n        for key in output_tensor_keys:\n            output_tensor[key] = [[] for _ in range(sample_size)]\n\n        if \"viewpoint\" in sample:\n            viewpoint_left = self._get_pytorch3d_camera(\n                sample[\"viewpoint\"][\"left\"][0],\n                sample[\"metadata\"][\"left\"][0][1],\n                scale=1.0,\n            )\n            viewpoint_right = self._get_pytorch3d_camera(\n                sample[\"viewpoint\"][\"right\"][0],\n                sample[\"metadata\"][\"right\"][0][1],\n                scale=1.0,\n            )\n            depth2disp_scale = depth2disparity_scale(\n                viewpoint_left,\n                viewpoint_right,\n                torch.Tensor(sample[\"metadata\"][\"left\"][0][1])[None],\n            )\n            output_tensor[\"depth2disp_scale\"] = [depth2disp_scale]\n\n        if \"camera\" in sample:\n            output_tensor[\"viewpoint\"] = [[] for _ in range(sample_size)]\n            # InfinigenSV\n            if sample[\"camera\"][\"left\"][0][-3:] == \"npz\":\n                # Note that the K, R, T is based on Blender world Matrix\n                camera_left = np.load(sample[\"camera\"][\"left\"][0])\n                camera_right = np.load(sample[\"camera\"][\"right\"][0])\n                camera_left_RT = camera_left['T']\n                camera_right_RT = camera_right['T']\n                camera_left_K = camera_left['K']\n                camera_right_K = camera_right['K']\n                camera_left_T = camera_left['T'][:3, 3]\n                camera_left_R = camera_left['T'][:3, :3]\n                fix_baseline = np.linalg.norm(camera_left_RT[:3, 3] - camera_right_RT[:3, 3])\n                focal_length_px = camera_left_K[0][0]\n                depth2disp_scale = focal_length_px * fix_baseline\n\n            # Sintel\n            elif sample[\"camera\"][\"left\"][0][-3:] == \"cam\":\n                TAG_FLOAT = 202021.25\n                f = open(sample[\"camera\"][\"left\"][0], 'rb')\n                check = np.fromfile(f, dtype=np.float32, count=1)[0]\n                assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(\n                    TAG_FLOAT, check)\n                camera_left_K = np.fromfile(f, dtype='float64', count=9).reshape((3, 3))\n                camera_left_RT = np.fromfile(f, dtype='float64', count=12).reshape((3, 4))\n                fix_baseline = 0.1 # From the MPI Sintel dataset website, the baseline of the cameras = 10cm = 0.1m\n                focal_length_px = camera_left_K[0][0]\n                depth2disp_scale = focal_length_px * fix_baseline\n                camera_left_T = camera_left_RT[:3, 3]\n                camera_left_R = camera_left_RT[:3, :3]\n\n            # Spring\n            elif any(filename in path for path in sample[\"camera\"][\"left\"] for filename in [\"focaldistance.txt\", \"extrinsics.txt\", \"intrinsics.txt\"]):\n                for path in sample[\"camera\"][\"left\"]:\n                    if \"intrinsics.txt\" in path:\n                        intrinsics_path = path\n                    elif \"extrinsics.txt\" in path:\n                        extrinsics_path = path\n\n                fx, fy, cx, cy = np.loadtxt(intrinsics_path)[0]\n                # Build the 3x3 intrinsic matrix\n                camera_left_K = np.array([\n                    [fx, 0, cx],\n                    [0, fy, cy],\n                    [0, 0, 1]\n                ])\n                focal_length_px = camera_left_K[0][0]\n                fix_baseline = 0.065  # From the dataset website, the baseline of the cameras = 6.5cm = 0.065m\n                depth2disp_scale = focal_length_px * fix_baseline\n                camera_left_RT = np.loadtxt(extrinsics_path).reshape(-1, 4, 4)[0]\n                camera_left_T = camera_left_RT[:3, 3]\n                camera_left_R = camera_left_RT[:3, :3]\n\n            # TartanAir\n            elif sample[\"camera\"][\"left\"][0][-13:] == \"pose_left.txt\":\n                fx, fy, cx, cy = 320.0, 320.0, 320.0, 240.0\n                # Build the 3x3 intrinsic matrix\n                camera_left_K = np.array([\n                    [fx, 0, cx],\n                    [0, fy, cy],\n                    [0, 0, 1]\n                ])\n                focal_length_px = camera_left_K[0][0]\n                fix_baseline = 0.25\n                depth2disp_scale = focal_length_px * fix_baseline\n                camera_left_RT = self.load_tartanair_pose(sample[\"camera\"][\"left\"][0], index=0)\n                camera_left_T = camera_left_RT[:3, 3]\n                camera_left_R = camera_left_RT[:3, :3]\n\n            # KITTI Depth\n            elif sample[\"camera\"][\"left\"][0][-20:] == \"calib_cam_to_cam.txt\":\n                calib_data = {}\n                with open(sample[\"camera\"][\"left\"][0], 'r') as f:\n                    for line in f:\n                        key, value = line.split(':', 1)\n                        calib_data[key.strip()] = value.strip()\n\n                P_key = 'P_rect_02'\n                if P_key in calib_data:\n                    P_values = np.array([float(x) for x in calib_data[P_key].split()])\n                    P_matrix = P_values.reshape(3, 4)\n                else:\n                    raise KeyError(f\"Projection matrix for camera not found in calibration data\")\n                focal_length_px = P_matrix[0, 0]\n\n                T_key1 = 'T_02'\n                T_key2 = 'T_03'\n                if T_key1 in calib_data and T_key2 in calib_data:\n                    T1 = np.array([float(x) for x in calib_data[T_key1].split()])\n                    T2 = np.array([float(x) for x in calib_data[T_key2].split()])\n                    baseline = np.linalg.norm(T1 - T2)\n                else:\n                    raise KeyError(f\"Translation vectors for cameras not found in calibration data\")\n\n                R_key1 = 'R_rect_02'\n                R_key2 = 'R_rect_03'\n                if R_key1 in calib_data and R_key2 in calib_data:\n                    R1 = np.array([float(x) for x in calib_data[R_key1].split()]).reshape(3, 3)\n                    R2 = np.array([float(x) for x in calib_data[R_key2].split()]).reshape(3, 3)\n                else:\n                    raise KeyError(f\"Rotation vectors for cameras not found in calibration data\")\n\n                depth2disp_scale = focal_length_px * baseline\n                camera_left_K = P_matrix[:, :3]\n                camera_left_T = T1\n                camera_left_R = R1\n\n            # VKITTI2\n            elif sample[\"camera\"][\"left\"][0][-13:] == \"intrinsic.txt\":\n                baseline = 0.532725\n                with open(sample[\"camera\"][\"left\"][0], 'r') as f:\n                    line = f.readlines()[1]\n                    values = line.strip().split()\n                    frame = int(values[0])\n                    camera_id = int(values[1])\n                    fx = float(values[2])\n                    fy = float(values[3])\n                    cx = float(values[4])\n                    cy = float(values[5])\n                    # Construct the intrinsic matrix\n                    camera_left_K = torch.tensor([[fx, 0, cx],\n                                      [0, fy, cy],\n                                      [0, 0, 1]], dtype=torch.float32)\n                depth2disp_scale = camera_left_K[0, 0] * baseline\n\n                with open(sample[\"camera\"][\"left\"][0].replace(\"intrinsic.txt\", \"extrinsic.txt\"), 'r') as f:\n                    line = f.readlines()[1]\n                    values = line.strip().split()\n                    frame = int(values[0])\n                    camera_id = int(values[1])\n                    # Extract rotation (3x3) and translation (3x1)\n                    camera_left_R = np.array([\n                        [float(values[2]), float(values[3]), float(values[4])],\n                        [float(values[6]), float(values[7]), float(values[8])],\n                        [float(values[10]), float(values[11]), float(values[12])]\n                    ], dtype=np.float32)\n\n                    camera_left_T = np.array([\n                        float(values[5]),\n                        float(values[9]),\n                        float(values[13])\n                    ], dtype=np.float32)\n\n            # SouthKensington\n            elif sample[\"camera\"][\"left\"][0][-3:] == \"txt\":\n                camera_left_K, frames = self.parse_txt_file(sample[\"camera\"][\"left\"][0])\n                fix_baseline = 0.12\n                camera_left_R = frames[0]['R']\n                camera_left_T = frames[0]['T']\n                focal_length_px = camera_left_K[0][0]\n                depth2disp_scale = focal_length_px * fix_baseline\n            else:\n                raise ValueError(\"Other format camera is not implemented\")\n\n            output_tensor[\"depth2disp_scale\"] = [depth2disp_scale]\n            output_tensor[\"RTK\"] = [camera_left_R, camera_left_T, camera_left_K]\n\n        for i in range(sample_size):\n            for cam in [\"left\", \"right\"]:\n                if \"mask\" in sample and cam in sample[\"mask\"]:\n                    mask = frame_utils.read_gen(sample[\"mask\"][cam][i])\n                    mask = np.array(mask) / 255.0\n                    output_tensor[\"mask\"][i].append(mask)\n\n                if \"viewpoint\" in sample and cam in sample[\"viewpoint\"]:\n                    viewpoint = self._get_pytorch3d_camera(\n                        sample[\"viewpoint\"][cam][i],\n                        sample[\"metadata\"][cam][i][1],\n                        scale=1.0,\n                    )\n                    output_tensor[\"viewpoint\"][i].append(viewpoint)\n                if \"camera\" in sample:\n                    # InfinigenSV\n                    if sample[\"camera\"][\"left\"][0][-3:] == \"npz\" and cam in sample[\"camera\"]:\n                        # Note that the K, R, T is based on Blender world Matrix\n                        camera = np.load(sample[\"camera\"][cam][i])\n                        camera_K = camera['K']\n                        camera_T = camera['T'][:3, 3]\n                        camera_R = camera['T'][:3, :3]\n                        viewpoint = self._get_pytorch3d_camera_from_blender(\n                            camera_R, camera_T, camera_K,\n                            sample[\"metadata\"][cam][i][1],\n                            scale=1.0,\n                        )\n                        output_tensor[\"viewpoint\"][i].append(viewpoint)\n\n                    # Sintel\n                    elif sample[\"camera\"][\"left\"][0][-3:] == \"cam\" and cam in sample[\"camera\"]:\n                        f = open(sample[\"camera\"][\"left\"][0], 'rb')\n                        check = np.fromfile(f, dtype=np.float32, count=1)[0]\n                        assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(\n                            TAG_FLOAT, check)\n                        camera_K = np.fromfile(f, dtype='float64', count=9).reshape((3, 3))\n                        camera_RT = np.fromfile(f, dtype='float64', count=12).reshape((3, 4))\n                        camera_T = camera_RT[:3, 3]\n                        camera_R = camera_RT[:3, :3]\n                        viewpoint = self._get_pytorch3d_camera_from_blender(\n                            camera_R, camera_T, camera_K,\n                            sample[\"metadata\"][cam][i][1],\n                            scale=1.0,\n                        )\n                        output_tensor[\"viewpoint\"][i].append(viewpoint)\n\n                    # TartanAir\n                    elif sample[\"camera\"][\"left\"][0][-13:] == \"pose_left.txt\":\n                        fx, fy, cx, cy = 320.0, 320.0, 320.0, 240.0\n                        # Build the 3x3 intrinsic matrix\n                        camera_left_K = np.array([\n                            [fx, 0, cx],\n                            [0, fy, cy],\n                            [0, 0, 1]\n                        ])\n                        focal_length_px = camera_left_K[0][0]\n                        fix_baseline = 0.25\n                        depth2disp_scale = focal_length_px * fix_baseline\n                        camera_left_RT = self.load_tartanair_pose(sample[\"camera\"][\"left\"][0], index=i)\n                        camera_left_T = camera_left_RT[:3, 3]\n                        camera_left_R = camera_left_RT[:3, :3]\n\n                    # Spring\n                    elif any(filename in path for path in sample[\"camera\"][\"left\"] for filename\n                             in [\"focaldistance.txt\", \"extrinsics.txt\", \"intrinsics.txt\"]) and cam in sample[\"camera\"]:\n\n                        for path in sample[\"camera\"][\"left\"]:\n                            if \"intrinsics.txt\" in path:\n                                intrinsics_path = path\n                            elif \"extrinsics.txt\" in path:\n                                extrinsics_path = path\n\n                        fx, fy, cx, cy = np.loadtxt(intrinsics_path)[0]\n                        # Build the 3x3 intrinsic matrix\n                        camera_K = np.array([\n                            [fx, 0, cx],\n                            [0, fy, cy],\n                            [0, 0, 1]\n                        ])\n                        focal_length_px = camera_left_K[0][0]\n                        fix_baseline = 0.065  # From the dataset website, the baseline of the cameras = 6.5cm = 0.065m\n                        depth2disp_scale = focal_length_px * fix_baseline\n                        camera_RT = np.loadtxt(extrinsics_path).reshape(-1, 4, 4)[i]\n                        camera_T = camera_RT[:3, 3]\n                        camera_R = camera_RT[:3, :3]\n                        viewpoint = self._get_pytorch3d_camera_from_blender(\n                            camera_R, camera_T, camera_K,\n                            sample[\"metadata\"][cam][i][1],\n                            scale=1.0,\n                        )\n                        output_tensor[\"viewpoint\"][i].append(viewpoint)\n\n                    # KITTI Depth\n                    elif sample[\"camera\"][\"left\"][0][-20:] == \"calib_cam_to_cam.txt\":\n                        calib_data = {}\n                        with open(sample[\"camera\"][\"left\"][0], 'r') as f:\n                            for line in f:\n                                key, value = line.split(':', 1)\n                                calib_data[key.strip()] = value.strip()\n\n                        P_key = 'P_rect_02'\n                        if P_key in calib_data:\n                            P_values = np.array([float(x) for x in calib_data[P_key].split()])\n                            P_matrix = P_values.reshape(3, 4)\n                        else:\n                            raise KeyError(f\"Projection matrix for camera not found in calibration data\")\n                        focal_length_px = P_matrix[0, 0]\n\n                        T_key1 = 'T_02'\n                        T_key2 = 'T_03'\n                        if T_key1 in calib_data and T_key2 in calib_data:\n                            T1 = np.array([float(x) for x in calib_data[T_key1].split()])\n                            T2 = np.array([float(x) for x in calib_data[T_key2].split()])\n                            baseline = np.linalg.norm(T1 - T2)\n                        else:\n                            raise KeyError(f\"Translation vectors for cameras not found in calibration data\")\n\n                        R_key1 = 'R_rect_02'\n                        R_key2 = 'R_rect_03'\n                        if R_key1 in calib_data and R_key2 in calib_data:\n                            R1 = np.array([float(x) for x in calib_data[R_key1].split()]).reshape(3, 3)\n                            R2 = np.array([float(x) for x in calib_data[R_key2].split()]).reshape(3, 3)\n                        else:\n                            raise KeyError(f\"Rotation vectors for cameras not found in calibration data\")\n\n                        depth2disp_scale = focal_length_px * baseline\n                        camera_K = P_matrix[:, :3]\n                        camera_T = T1\n                        camera_R = R1\n                        viewpoint = self._get_pytorch3d_camera_from_blender(\n                            camera_R, camera_T, camera_K,\n                            sample[\"metadata\"][cam][i][1],\n                            scale=1.0,\n                        )\n                        output_tensor[\"viewpoint\"][i].append(viewpoint)\n\n                    # VKITTI2\n                    elif sample[\"camera\"][\"left\"][0][-13:] == \"intrinsic.txt\":\n                        with open(sample[\"camera\"][\"left\"][0], 'r') as f:\n                            line = f.readlines()[1+i]\n                            values = line.strip().split()\n                            frame = int(values[0])\n                            camera_id = int(values[1])\n                            fx = float(values[2])\n                            fy = float(values[3])\n                            cx = float(values[4])\n                            cy = float(values[5])\n                            # Construct the intrinsic matrix\n                            camera_K = torch.tensor([[fx, 0, cx],\n                                                          [0, fy, cy],\n                                                          [0, 0, 1]], dtype=torch.float32)\n                        with open(sample[\"camera\"][\"left\"][0].replace(\"intrinsic.txt\", \"extrinsic.txt\"), 'r') as f:\n                            line = f.readlines()[1+i]\n                            values = line.strip().split()\n                            frame = int(values[0])\n                            camera_id = int(values[1])\n                            # Extract rotation (3x3) and translation (3x1)\n                            camera_R = np.array([\n                                [float(values[2]), float(values[3]), float(values[4])],\n                                [float(values[6]), float(values[7]), float(values[8])],\n                                [float(values[10]), float(values[11]), float(values[12])]\n                            ], dtype=np.float32)\n\n                            camera_T = np.array([\n                                float(values[5]),\n                                float(values[9]),\n                                float(values[13])\n                            ], dtype=np.float32)\n                        viewpoint = self._get_pytorch3d_camera_from_blender(\n                            camera_R, camera_T, camera_K,\n                            sample[\"metadata\"][cam][i][1],\n                            scale=1.0,\n                        )\n                        output_tensor[\"viewpoint\"][i].append(viewpoint)\n\n                    # SouthKensington\n                    elif sample[\"camera\"][\"left\"][0][-3:] == \"txt\" and cam in sample[\"camera\"]:\n                        camera_left_K, frames = self.parse_txt_file(sample[\"camera\"][\"left\"][0])\n\n                        camera_K = camera_left_K\n                        camera_R = frames[i]['R']\n                        camera_T = frames[i]['T']\n                        viewpoint = self._get_pytorch3d_camera_from_blender(\n                            camera_R, camera_T, camera_K,\n                            sample[\"metadata\"][cam][i][1],\n                            scale=1.0,\n                        )\n                        output_tensor[\"viewpoint\"][i].append(viewpoint)\n\n                if \"metadata\" in sample and cam in sample[\"metadata\"]:\n                    metadata = sample[\"metadata\"][cam][i]\n                    output_tensor[\"metadata\"][i].append(metadata)\n\n                if cam in sample[\"image\"]:\n                    img = frame_utils.read_gen(sample[\"image\"][cam][i])\n                    img = np.array(img).astype(np.uint8)\n\n                    # grayscale images\n                    if len(img.shape) == 2:\n                        img = np.tile(img[..., None], (1, 1, 3))\n                    else:\n                        img = img[..., :3]\n                    output_tensor[\"img\"][i].append(img)\n\n                if cam in sample[\"disparity\"]:\n                    disp = self.disparity_reader(sample[\"disparity\"][cam][i])\n                    if isinstance(disp, tuple):\n                        disp, valid_disp = disp\n                    else:\n                        valid_disp = disp < 512\n                    disp = np.array(disp).astype(np.float32)\n                    disp = np.stack([-disp, np.zeros_like(disp)], axis=-1)\n                    # disp = np.stack([disp, np.zeros_like(disp)], axis=-1)\n\n                    output_tensor[\"disp\"][i].append(disp)\n                    output_tensor[\"valid_disp\"][i].append(valid_disp)\n\n                elif \"depth\" in sample and cam in sample[\"depth\"]:\n                    depth = self.depth_reader(sample[\"depth\"][cam][i])\n                    depth_mask = depth < self.depth_eps\n                    depth[depth_mask] = self.depth_eps\n                    disp = depth2disp_scale / depth\n                    disp[depth_mask] = 0\n                    valid_disp = (disp < 512) * (1 - depth_mask)\n                    disp = np.array(disp).astype(np.float32)\n                    disp = np.stack([-disp, np.zeros_like(disp)], axis=-1)\n                    output_tensor[\"disp\"][i].append(disp)\n                    output_tensor[\"valid_disp\"][i].append(valid_disp)\n\n        return output_tensor\n\n    def __getitem__(self, index):\n        im_tensor = {\"img\"}\n        sample = self.sample_list[index]\n        if self.is_test:\n            sample_size = len(sample[\"image\"][\"left\"])\n            im_tensor[\"img\"] = [[] for _ in range(sample_size)]\n            for i in range(sample_size):\n                for cam in [\"left\", \"right\"]:\n                    img = frame_utils.read_gen(sample[\"image\"][cam][i])\n                    img = np.array(img).astype(np.uint8)[..., :3]\n                    img = torch.from_numpy(img).permute(2, 0, 1).float()\n                    im_tensor[\"img\"][i].append(img)\n            im_tensor[\"img\"] = torch.stack(im_tensor[\"img\"])\n            return im_tensor, self.extra_info[index]\n\n        index = index % len(self.sample_list)\n        try:\n            output_tensor = self._get_output_tensor(sample)\n        except:\n            logging.warning(f\"Exception in loading sample {index}!\")\n            index = np.random.randint(len(self.sample_list))\n            logging.info(f\"New index is {index}\")\n            sample = self.sample_list[index]\n            output_tensor = self._get_output_tensor(sample)\n\n        sample_size = len(sample[\"image\"][\"left\"])\n        if self.augmentor is not None:\n            if self.sparse:\n                output_tensor[\"img\"], output_tensor[\"disp\"], output_tensor[\"valid_disp\"] = self.augmentor(\n                    output_tensor[\"img\"], output_tensor[\"disp\"], output_tensor[\"valid_disp\"]\n                )\n            else:\n                output_tensor[\"img\"], output_tensor[\"disp\"] = self.augmentor(\n                    output_tensor[\"img\"], output_tensor[\"disp\"]\n                )\n        for i in range(sample_size):\n            for cam in (0, 1):\n                if cam < len(output_tensor[\"img\"][i]):\n                    img = (\n                        torch.from_numpy(output_tensor[\"img\"][i][cam])\n                        .permute(2, 0, 1)\n                        .float()\n                    )\n                    if self.img_pad is not None:\n                        padH, padW = self.img_pad\n                        img = F.pad(img, [padW] * 2 + [padH] * 2)\n                    output_tensor[\"img\"][i][cam] = img\n\n                if cam < len(output_tensor[\"disp\"][i]):\n                    disp = (\n                        torch.from_numpy(output_tensor[\"disp\"][i][cam])\n                        .permute(2, 0, 1)\n                        .float()\n                    )\n\n                    if self.sparse:\n                        valid_disp = torch.from_numpy(\n                            output_tensor[\"valid_disp\"][i][cam]\n                        )\n                    else:\n                        valid_disp = (\n                            (disp[0].abs() < 512)\n                            & (disp[1].abs() < 512)\n                            & (disp[0].abs() != 0)\n                        )\n                    disp = disp[:1]\n\n                    output_tensor[\"disp\"][i][cam] = disp\n                    output_tensor[\"valid_disp\"][i][cam] = valid_disp.float()\n\n                if \"mask\" in output_tensor and cam < len(output_tensor[\"mask\"][i]):\n                    mask = torch.from_numpy(output_tensor[\"mask\"][i][cam]).float()\n                    output_tensor[\"mask\"][i][cam] = mask\n\n                if \"viewpoint\" in output_tensor and cam < len(\n                    output_tensor[\"viewpoint\"][i]\n                ):\n\n                    viewpoint = output_tensor[\"viewpoint\"][i][cam]\n                    output_tensor[\"viewpoint\"][i][cam] = viewpoint\n\n        res = {}\n        if \"viewpoint\" in output_tensor and self.split != \"train\":\n            res[\"viewpoint\"] = output_tensor[\"viewpoint\"]\n        if \"metadata\" in output_tensor and self.split != \"train\":\n            res[\"metadata\"] = output_tensor[\"metadata\"]\n        if \"depth2disp_scale\" in output_tensor and self.split != \"train\":\n            res[\"depth2disp_scale\"] = output_tensor[\"depth2disp_scale\"]\n        if \"RTK\" in output_tensor and self.split != \"train\":\n            res[\"RTK\"] = output_tensor[\"RTK\"]\n\n        for k, v in output_tensor.items():\n            if k != \"viewpoint\" and k != \"metadata\" and k != \"depth2disp_scale\" and k != \"RTK\":\n                for i in range(len(v)):\n                    if len(v[i]) > 0:\n                        v[i] = torch.stack(v[i])\n                if len(v) > 0 and (len(v[0]) > 0):\n                    res[k] = torch.stack(v)\n        return res\n\n    def __mul__(self, v):\n        copy_of_self = copy.deepcopy(self)\n        copy_of_self.sample_list = v * copy_of_self.sample_list\n        copy_of_self.extra_info = v * copy_of_self.extra_info\n        return copy_of_self\n\n    def __len__(self):\n        return len(self.sample_list)\n\n\nclass DynamicReplicaDataset(StereoSequenceDataset):\n    def __init__(\n        self,\n        aug_params=None,\n        root=\"./data/datasets/dynamic_replica_data\",\n        split=\"train\",\n        sample_len=-1,\n        only_first_n_samples=-1,\n    ):\n        super(DynamicReplicaDataset, self).__init__(aug_params)\n        self.root = root\n        self.sample_len = sample_len\n        self.split = split\n\n        frame_annotations_file = f\"frame_annotations_{split}.jgz\"\n\n        with gzip.open(\n            osp.join(root, split, frame_annotations_file), \"rt\", encoding=\"utf8\"\n        ) as zipfile:\n            frame_annots_list = load_dataclass(\n                zipfile, List[DynamicReplicaFrameAnnotation]\n            )\n        seq_annot = defaultdict(lambda: defaultdict(list))\n        for frame_annot in frame_annots_list:\n            seq_annot[frame_annot.sequence_name][frame_annot.camera_name].append(\n                frame_annot\n            )\n        for seq_name in seq_annot.keys():\n            try:\n                filenames = defaultdict(lambda: defaultdict(list))\n                for cam in [\"left\", \"right\"]:\n                    for framedata in seq_annot[seq_name][cam]:\n                        im_path = osp.join(root, split, framedata.image.path)\n                        depth_path = osp.join(root, split, framedata.depth.path)\n                        mask_path = osp.join(root, split, framedata.mask.path)\n\n                        assert os.path.isfile(im_path), im_path\n                        if self.split == 'train':\n                            assert os.path.isfile(depth_path), depth_path\n                        assert os.path.isfile(mask_path), mask_path\n\n                        filenames[\"image\"][cam].append(im_path)\n                        if os.path.isfile(depth_path):\n                            filenames[\"depth\"][cam].append(depth_path)\n                        filenames[\"mask\"][cam].append(mask_path)\n\n                        filenames[\"viewpoint\"][cam].append(framedata.viewpoint)\n                        filenames[\"metadata\"][cam].append(\n                            [framedata.sequence_name, framedata.image.size]\n                        )\n\n                        for k in filenames.keys():\n                            assert (\n                                len(filenames[k][cam])\n                                == len(filenames[\"image\"][cam])\n                                > 0\n                            ), framedata.sequence_name\n\n                seq_len = len(filenames[\"image\"][cam])\n\n                print(\"seq_len\", seq_name, seq_len)\n                if split == \"train\":\n                    for ref_idx in range(0, seq_len, 3):\n                        step = 1 if self.sample_len == 1 else np.random.randint(1, 6)\n                        if ref_idx + step * self.sample_len < seq_len:\n                            sample = defaultdict(lambda: defaultdict(list))\n                            for cam in [\"left\", \"right\"]:\n                                for idx in range(\n                                    ref_idx, ref_idx + step * self.sample_len, step\n                                ):\n                                    for k in filenames.keys():\n                                        if \"mask\" not in k:\n                                            sample[k][cam].append(\n                                                filenames[k][cam][idx]\n                                            )\n\n                            self.sample_list.append(sample)\n                else:\n                    step = self.sample_len if self.sample_len > 0 else seq_len\n                    counter = 0\n                    for ref_idx in range(0, seq_len, step):\n                        sample = defaultdict(lambda: defaultdict(list))\n                        for cam in [\"left\", \"right\"]:\n                            for idx in range(ref_idx, ref_idx + step):\n                                for k in filenames.keys():\n                                    sample[k][cam].append(filenames[k][cam][idx])\n\n                        self.sample_list.append(sample)\n                        counter += 1\n                        if only_first_n_samples > 0 and counter >= only_first_n_samples:\n                            break\n            except Exception as e:\n                print(e)\n                print(\"Skipping sequence\", seq_name)\n\n        assert len(self.sample_list) > 0, \"No samples found\"\n        print(f\"Added {len(self.sample_list)} from Dynamic Replica {split}\")\n        logging.info(f\"Added {len(self.sample_list)} from Dynamic Replica {split}\")\n\n\nclass InfinigenStereoVideoDataset(StereoSequenceDataset):\n    def __init__(\n        self,\n        aug_params=None,\n        root=\"./data/datasets/InfinigenStereo\",\n        split=\"train\",\n        sample_len=-1,\n        only_first_n_samples=-1,\n    ):\n        super(InfinigenStereoVideoDataset, self).__init__(aug_params)\n        self.root = root\n        self.sample_len = sample_len\n        self.split = split\n\n        sequence = sorted(\n            glob(osp.join(root, self.split, \"*\"))\n        )\n        for i in range(len(sequence)):\n            sequence_name = os.path.basename(sequence[i])\n            try:\n                filenames = defaultdict(lambda: defaultdict(list))\n                for cam in [\"left\", \"right\"]:\n                    suffix = \"camera_0/\" if cam == \"left\" else \"camera_1/\"\n                    im_path_list = sorted(glob(osp.join(sequence[i], \"frames/Image/\", suffix, \"*.png\")))\n                    depth_path_list = sorted(glob(osp.join(sequence[i], \"frames/Depth/\", suffix, \"*.npy\")))\n                    camera_list = sorted(glob(osp.join(sequence[i], \"frames/camview/\", suffix, \"*.npz\")))\n                    for j in range(len(im_path_list)):\n                        im_path = im_path_list[j]\n                        depth_path = depth_path_list[j]\n                        camera_path = camera_list[j]\n                        assert os.path.isfile(im_path), im_path\n                        assert os.path.isfile(depth_path), depth_path\n                        filenames[\"image\"][cam].append(im_path)\n                        filenames[\"depth\"][cam].append(depth_path)\n                        filenames[\"camera\"][cam].append(camera_path)\n                        filenames[\"metadata\"][cam].append([sequence_name , (720,1280)])\n\n                        for k in filenames.keys():\n                            assert (\n                                    len(filenames[k][cam])\n                                    == len(filenames[\"image\"][cam])\n                                    > 0\n                            ), sequence_name\n                seq_len = len(filenames[\"image\"][cam])\n\n                print(\"seq_len\", sequence_name, seq_len)\n                if self.split == \"train\":\n                    for ref_idx in range(0, seq_len, 3):\n                        step = 1 if self.sample_len == 1 else np.random.randint(1, 6)\n                        if ref_idx + step * self.sample_len < seq_len:\n                            sample = defaultdict(lambda: defaultdict(list))\n                            for cam in [\"left\", \"right\"]:\n                                for idx in range(\n                                    ref_idx, ref_idx + step * self.sample_len, step\n                                ):\n                                    for k in filenames.keys():\n                                        if \"mask\" not in k:\n                                            sample[k][cam].append(\n                                                filenames[k][cam][idx]\n                                            )\n\n                            self.sample_list.append(sample)\n                else:\n                    step = self.sample_len if (self.sample_len > 0) and (self.sample_len < seq_len) else seq_len\n                    print(\"sample_step\", step)\n                    counter = 0\n                    for ref_idx in range(0, seq_len, step):\n                        sample = defaultdict(lambda: defaultdict(list))\n                        for cam in [\"left\", \"right\"]:\n                            for idx in range(ref_idx, ref_idx + step):\n                                for k in filenames.keys():\n                                    sample[k][cam].append(filenames[k][cam][idx])\n\n                        self.sample_list.append(sample)\n                        counter += 1\n                        if only_first_n_samples > 0 and counter >= only_first_n_samples:\n                            break\n            except Exception as e:\n                print(e)\n                print(\"Skipping sequence\", sequence_name)\n        assert len(self.sample_list) > 0, \"No samples found\"\n        print(f\"Added {len(self.sample_list)} from Infinigen Stereo Video {split}\")\n        logging.info(f\"Added {len(self.sample_list)} from Infinigen Stereo Video {split}\")\n\n\nclass SouthKensingtonStereoVideoDataset(StereoSequenceDataset):\n    def __init__(\n        self,\n        aug_params=None,\n        root=\"./data/datasets/SouthKensington/data/\",\n        split=\"test\",\n        subroot=\"\",\n        sample_len=-1,\n        only_first_n_samples=-1,\n    ):\n        super(SouthKensingtonStereoVideoDataset, self).__init__(aug_params)\n        self.root = root\n        self.split = split\n        self.sample_len = sample_len\n\n        sequence = sorted(\n            glob(osp.join(root, \"*\"))\n        )\n        for i in range(len(sequence)):\n            sequence_name = os.path.basename(sequence[i])\n            try:\n                filenames = defaultdict(lambda: defaultdict(list))\n                for cam in [\"left\", \"right\"]:\n                    im_path_list = sorted(glob(osp.join(sequence[i], \"images\", cam, \"*.png\")))\n                    camera_path = glob(osp.join(sequence[i], \"*.txt\"))[0]\n\n                    for j in range(len(im_path_list)):\n                        im_path = im_path_list[j]\n                        assert os.path.isfile(im_path), im_path\n                        filenames[\"image\"][cam].append(im_path)\n                        filenames[\"camera\"][cam].append(camera_path)\n                        filenames[\"metadata\"][cam].append([sequence_name , (720,1280)])\n\n                        for k in filenames.keys():\n                            assert (\n                                    len(filenames[k][cam])\n                                    == len(filenames[\"image\"][cam])\n                                    > 0\n                            ), sequence_name\n                seq_len = len(filenames[\"image\"][cam])\n                print(\"seq_len\", sequence_name, seq_len)\n\n                step = self.sample_len if (self.sample_len > 0) and (self.sample_len < seq_len) else seq_len\n                print(\"sample_step\", step)\n                counter = 0\n                for ref_idx in range(0, seq_len, step):\n                    sample = defaultdict(lambda: defaultdict(list))\n                    for cam in [\"left\", \"right\"]:\n                        for idx in range(ref_idx, ref_idx + step):\n                            for k in filenames.keys():\n                                sample[k][cam].append(filenames[k][cam][idx])\n\n                    self.sample_list.append(sample)\n                    counter += 1\n                    if only_first_n_samples > 0 and counter >= only_first_n_samples:\n                        break\n            except Exception as e:\n                print(e)\n                print(\"Skipping sequence\", sequence_name)\n        assert len(self.sample_list) > 0, \"No samples found\"\n        print(f\"Added {len(self.sample_list)} from SouthKensington Stereo Video\")\n        logging.info(f\"Added {len(self.sample_list)} from SouthKensington Stereo Video\")\n\n\nclass KITTIDepthDataset(StereoSequenceDataset):\n    def __init__(\n        self,\n        aug_params=None,\n        root=\"./data/datasets/\",\n        split=\"train\",\n        sample_len=-1,\n        only_first_n_samples=-1,\n    ):\n        super().__init__(aug_params, sparse=True)\n        # super(KITTIDepthDataset, self).__init__(aug_params)\n        image_root = osp.join(root, \"kitti_depth\", \"input\")\n        gt_root = osp.join(root, \"kitti_depth\", \"gt_depth\")\n        self.sample_len = sample_len\n        self.split = split\n        # Following CODD: https://github.com/facebookresearch/CODD\n        val_split = ['2011_10_03_drive_0042_sync']  # 1 scene\n        test_split = ['2011_09_26_drive_0002_sync', '2011_09_26_drive_0005_sync',\n                      '2011_09_26_drive_0013_sync', '2011_09_26_drive_0020_sync',\n                      '2011_09_26_drive_0023_sync', '2011_09_26_drive_0036_sync',\n                      '2011_09_26_drive_0079_sync', '2011_09_26_drive_0095_sync',\n                      '2011_09_26_drive_0113_sync', '2011_09_28_drive_0037_sync',\n                      '2011_09_29_drive_0026_sync', '2011_09_30_drive_0016_sync',\n                      '2011_10_03_drive_0047_sync']  # 13 scenes\n\n        sequence_root = sorted(glob(osp.join(gt_root, \"*\")))\n        train_list = []\n        val_list = []\n        test_list = []\n        for i in range(len(sequence_root)):\n            sequence_name = os.path.basename(os.path.normpath(sequence_root[i]))\n            if sequence_name in test_split:\n                test_list.append(sequence_root[i])\n            elif sequence_name in val_split:\n                val_list.append(sequence_root[i])\n            else:\n                train_list.append(sequence_root[i])\n\n        if self.split == \"train\":\n            sequence_split = train_list\n        elif self.split == \"val\":\n            sequence_split = val_list\n        elif self.split == \"test\":\n            sequence_split = test_list\n        else:\n            raise ValueError(\"Wrong Split: \", self.split)\n\n        for i in range(len(sequence_split)):\n            sequence_name = os.path.basename(os.path.normpath(sequence_split[i]))\n            sequence_camera = osp.join(image_root, sequence_name[:10], \"calib_cam_to_cam.txt\")\n            try:\n                filenames = defaultdict(lambda: defaultdict(list))\n                for cam in [\"left\", \"right\"]:\n                    suffix = \"image_02/\" if cam == \"left\" else \"image_03/\"\n                    depth_path_list = sorted(\n                        glob(osp.join(gt_root, sequence_name, \"proj_depth\", \"groundtruth\", suffix, \"*.png\")))\n                    for j in range(len(depth_path_list)):\n                        depth_path = depth_path_list[j]\n                        assert os.path.isfile(depth_path), depth_path\n                        filenames[\"depth\"][cam].append(depth_path)\n\n                        # find the corresponding images\n                        im_name = os.path.basename(os.path.normpath(depth_path))\n                        im_path = osp.join(image_root, sequence_name[:10], sequence_name, suffix, \"data\", im_name)\n                        assert os.path.isfile(im_path), im_path\n                        filenames[\"image\"][cam].append(im_path)\n                        filenames[\"camera\"][cam].append(sequence_camera)\n                        filenames[\"metadata\"][cam].append([sequence_name, (370,1224)])\n                        for k in filenames.keys():\n                            assert (\n                                    len(filenames[k][cam])\n                                    == len(filenames[\"depth\"][cam])\n                                    > 0\n                            ), sequence_name\n                seq_len = len(filenames[\"image\"][cam])\n                print(\"seq_len\", sequence_name, seq_len)\n                if self.split == \"train\":\n                    for ref_idx in range(0, seq_len, 3):\n                        step = 1 if self.sample_len == 1 else np.random.randint(1, 6)\n                        if ref_idx + step * self.sample_len < seq_len:\n                            sample = defaultdict(lambda: defaultdict(list))\n                            for cam in [\"left\", \"right\"]:\n                                for idx in range(\n                                        ref_idx, ref_idx + step * self.sample_len, step\n                                ):\n                                    for k in filenames.keys():\n                                        if \"mask\" not in k:\n                                            sample[k][cam].append(\n                                                filenames[k][cam][idx]\n                                            )\n                            self.sample_list.append(sample)\n                else:\n                    step = self.sample_len if (self.sample_len > 0) and (self.sample_len < seq_len) else seq_len\n                    print(\"sample_step\", step)\n                    counter = 0\n                    for ref_idx in range(0, seq_len, step):\n                        sample = defaultdict(lambda: defaultdict(list))\n                        for cam in [\"left\", \"right\"]:\n                            for idx in range(ref_idx, ref_idx + step):\n                                for k in filenames.keys():\n                                    sample[k][cam].append(filenames[k][cam][idx])\n\n                        self.sample_list.append(sample)\n                        counter += 1\n                        if only_first_n_samples > 0 and counter >= only_first_n_samples:\n                            break\n            except Exception as e:\n                print(e)\n                print(\"Skipping sequence\", sequence_name)\n        assert len(self.sample_list) > 0, \"No samples found\"\n        print(f\"Added {len(self.sample_list)} from  KITTI Depth {split}\")\n        logging.info(f\"Added {len(self.sample_list)} from KITTI Depth {split}\")\n\n\ndef split_train_valid(path_list, valid_keywords):\n    path_list_init = path_list\n    for kw in valid_keywords:\n        path_list = list(filter(lambda s: kw not in s, path_list))\n    train_path_list = sorted(path_list)\n    valid_path_list = sorted(list(set(path_list_init) - set(train_path_list)))\n    return train_path_list, valid_path_list\n\n\nclass TartanAirDataset(StereoSequenceDataset):\n    def __init__(\n        self,\n        aug_params=None,\n        root=\"./data/datasets/TartanAir/\",\n        split=\"train\",\n        sample_len=-1,\n        only_first_n_samples=-1,\n    ):\n        super().__init__(aug_params, sparse=False)\n        self.sample_len = sample_len\n        self.split = split\n\n        # Each entry is (scene, motion, part)\n        test_entries = [\n            (\"abandonedfactory\", \"Easy\", \"P002\"),\n            (\"abandonedfactory\", \"Hard\", \"P002\"),\n            (\"amusement\", \"Easy\", \"P007\"),\n            (\"amusement\", \"Hard\", \"P007\"),\n            (\"carwelding\", \"Hard\", \"P003\"),\n            (\"endofworld\", \"Easy\", \"P006\"),\n            (\"endofworld\", \"Hard\", \"P006\"),\n            (\"gascola\", \"Easy\", \"P001\"),\n            (\"gascola\", \"Hard\", \"P001\"),\n            (\"hospital\", \"Hard\", \"P042\"),\n            (\"office\", \"Easy\", \"P006\"),\n            (\"office\", \"Hard\", \"P006\"),\n            (\"office2\", \"Easy\", \"P004\"),\n            (\"office2\", \"Hard\", \"P004\"),\n            (\"oldtown\", \"Hard\", \"P006\"),\n            (\"soulcity\", \"Easy\", \"P008\"),\n            (\"soulcity\", \"Hard\", \"P008\"),\n        ]\n\n        scene_root = sorted(glob(osp.join(root, \"*\")))\n\n        sequence_root_list = []\n        test_set = []\n        train_set = []\n        for i in range(len(scene_root)):\n            sequence_root_list += sorted(glob(osp.join(scene_root[i], \"Easy\", \"*\"))) + sorted(glob(osp.join(scene_root[i], \"Hard\", \"*\")))\n\n        for path in sequence_root_list:\n            parts = path.split(\"/\")\n            if len(parts) < 5:\n                continue  # skip malformed paths\n\n            scene, motion, part = parts[-3], parts[-2], parts[-1]\n            if (scene, motion, part) in test_entries:\n                test_set.append(path)\n            else:\n                train_set.append(path)\n\n        if self.split == \"train\":\n            sequence_root_list = train_set\n        elif self.split == \"test\":\n            sequence_root_list = test_set\n        else:\n            raise KeyError(f\"Wrong Split!\")\n\n        for i in range(len(sequence_root_list)):\n            filenames = defaultdict(lambda: defaultdict(list))\n            sequence_root = sequence_root_list[i]\n            parts = os.path.normpath(sequence_root).split(os.sep)\n            sequence_name = \"_\".join(parts[-3:])\n            try:\n                for cam in ['left', 'right']:\n                    depth_path_list = sorted(glob(osp.join(sequence_root, \"depth_left/\", \"*.npy\")))\n                    im_path_list = sorted(glob(osp.join(sequence_root, f\"image_{cam}/\", \"*.png\")))\n                    pose_path = os.path.join(sequence_root, f\"pose_{cam}.txt\")\n                    assert len(depth_path_list) == len(im_path_list), [len(depth_path_list), len(im_path_list)]\n                    for j in range(len(depth_path_list)):\n                        depth_path = depth_path_list[j]\n                        assert os.path.isfile(depth_path), depth_path\n                        filenames[\"depth\"][cam].append(depth_path)\n                        im_path = im_path_list[j]\n                        assert os.path.isfile(im_path), im_path\n                        filenames[\"image\"][cam].append(im_path)\n                        filenames[\"camera\"][cam].append(pose_path)\n                        filenames[\"metadata\"][cam].append([sequence_name, (480,640)])\n                        for k in filenames.keys():\n                            assert (\n                                    len(filenames[k][cam])\n                                    == len(filenames[\"depth\"][cam])\n                                    > 0\n                            ), sequence_name\n                seq_len = len(filenames[\"image\"][cam])\n                print(\"seq_len\", sequence_name, seq_len)\n\n                if self.split == \"train\":\n                    for ref_idx in range(0, seq_len, 3):\n                        step = 1 if self.sample_len == 1 else np.random.randint(1, 6)\n                        if ref_idx + step * self.sample_len < seq_len:\n                            sample = defaultdict(lambda: defaultdict(list))\n                            for cam in [\"left\", \"right\"]:\n                                for idx in range(\n                                        ref_idx, ref_idx + step * self.sample_len, step\n                                ):\n                                    for k in filenames.keys():\n                                        if \"mask\" not in k:\n                                            sample[k][cam].append(\n                                                filenames[k][cam][idx]\n                                            )\n                            self.sample_list.append(sample)\n                else:\n                    step = self.sample_len if (self.sample_len > 0) and (self.sample_len < seq_len) else seq_len\n                    print(\"sample_step\", step)\n                    counter = 0\n                    for ref_idx in range(0, seq_len, step):\n                        sample = defaultdict(lambda: defaultdict(list))\n                        for cam in [\"left\", \"right\"]:\n                            for idx in range(ref_idx, ref_idx + step):\n                                for k in filenames.keys():\n                                    sample[k][cam].append(filenames[k][cam][idx])\n\n                        self.sample_list.append(sample)\n                        counter += 1\n                        if only_first_n_samples > 0 and counter >= only_first_n_samples:\n                            break\n            except Exception as e:\n                print(e)\n                print(\"Skipping sequence\", sequence_name)\n        assert len(self.sample_list) > 0, \"No samples found\"\n        print(f\"Added {len(self.sample_list)} from  TarTanAir  {split}\")\n        logging.info(f\"Added {len(self.sample_list)} from TarTanAir {split}\")\n\n\nclass VKITTI2Dataset(StereoSequenceDataset):\n    def __init__(\n        self,\n        aug_params=None,\n        root=\"./data/datasets/vkitti2/\",\n        split=\"train\",\n        sample_len=-1,\n        only_first_n_samples=-1,\n    ):\n        super().__init__(aug_params, sparse=False)\n        self.sample_len = sample_len\n        self.split = split\n        if self.split == \"train\":\n            sequence_name_list = []\n            scenes = ['Scene01', 'Scene02', 'Scene06', 'Scene18', 'Scene20']\n            variations = ['15-deg-left', '15-deg-right', '30-deg-left', '30-deg-right',\n                          'clone', 'fog', 'morning', 'overcast', 'rain', 'sunset']\n            for scene in scenes:\n                for variation in variations:\n                    sequence_name = f\"{scene}_{variation}\"\n                    sequence_name_list.append(sequence_name)\n        elif self.split == \"test\":\n            sequence_name_list = [\"Scene01_15-deg-left\", \"Scene02_30-deg-right\", \"Scene06_fog\", \"Scene18_morning\", \"Scene20_rain\"]\n        else:\n            raise KeyError(f\"Wrong Split!\")\n\n        for i in range(len(sequence_name_list)):\n            filenames = defaultdict(lambda: defaultdict(list))\n            sequence_name = sequence_name_list[i]\n            scene, variation = sequence_name.split(\"_\")\n            try:\n                for cam in [('left', 0), ('right', 1)]:\n                    depth_path_list = sorted(glob(osp.join(root, f\"{scene}/{variation}/frames/depth/Camera_{cam[1]}/\", \"*.png\")))\n                    im_path_list = sorted(glob(osp.join(root, f\"{scene}/{variation}/frames/rgb/Camera_{cam[1]}/\", \"*.jpg\")))\n                    intrinsic_path = os.path.join(root, f\"{scene}/{variation}/intrinsic.txt\")\n                    assert len(depth_path_list) == len(im_path_list), [len(depth_path_list), len(im_path_list)]\n                    for j in range(len(depth_path_list)):\n                        depth_path = depth_path_list[j]\n                        assert os.path.isfile(depth_path), depth_path\n                        filenames[\"depth\"][cam[0]].append(depth_path)\n                        im_path = im_path_list[j]\n                        assert os.path.isfile(im_path), im_path\n                        filenames[\"image\"][cam[0]].append(im_path)\n                        filenames[\"camera\"][cam[0]].append(intrinsic_path)\n                        filenames[\"metadata\"][cam[0]].append([sequence_name, (375,1242)])\n                        for k in filenames.keys():\n                            assert (\n                                    len(filenames[k][cam[0]])\n                                    == len(filenames[\"depth\"][cam[0]])\n                                    > 0\n                            ), sequence_name\n                seq_len = len(filenames[\"image\"][cam[0]])\n                print(\"seq_len\", sequence_name, seq_len)\n\n                if self.split == \"train\":\n                    for ref_idx in range(0, seq_len, 3):\n                        step = 1 if self.sample_len == 1 else np.random.randint(1, 6)\n                        if ref_idx + step * self.sample_len < seq_len:\n                            sample = defaultdict(lambda: defaultdict(list))\n                            for cam in [\"left\", \"right\"]:\n                                for idx in range(\n                                        ref_idx, ref_idx + step * self.sample_len, step\n                                ):\n                                    for k in filenames.keys():\n                                        if \"mask\" not in k:\n                                            sample[k][cam].append(\n                                                filenames[k][cam][idx]\n                                            )\n                            self.sample_list.append(sample)\n                else:\n                    step = self.sample_len if (self.sample_len > 0) and (self.sample_len < seq_len) else seq_len\n                    print(\"sample_step\", step)\n                    counter = 0\n                    for ref_idx in range(0, seq_len, step):\n                        sample = defaultdict(lambda: defaultdict(list))\n                        for cam in [\"left\", \"right\"]:\n                            for idx in range(ref_idx, ref_idx + step):\n                                for k in filenames.keys():\n                                    sample[k][cam].append(filenames[k][cam][idx])\n\n                        self.sample_list.append(sample)\n                        counter += 1\n                        if only_first_n_samples > 0 and counter >= only_first_n_samples:\n                            break\n            except Exception as e:\n                print(e)\n                print(\"Skipping sequence\", sequence_name)\n        assert len(self.sample_list) > 0, \"No samples found\"\n        print(f\"Added {len(self.sample_list)} from  VKITTI2  {split}\")\n        logging.info(f\"Added {len(self.sample_list)} from VKITTI2 {split}\")\n\n\nclass SequenceSpringDataset(StereoSequenceDataset):\n    def __init__(\n        self,\n        aug_params=None,\n        sample_len=-1,\n        root=\"./data/datasets/Spring\",\n    ):\n        super(SequenceSpringDataset, self).__init__(aug_params)\n        self.split = \"test\"\n        self.sample_len = sample_len\n        original_length = len(self.sample_list)\n        image_paths = defaultdict(list)\n        disparity_paths = defaultdict(list)\n        camera_paths = defaultdict(list)\n\n        for cam in [\"left\", \"right\"]:\n            image_paths[cam] = sorted(\n                glob(osp.join(root, f\"train_frame_{cam}/*\"))\n            )\n\n        cam = \"left\"\n        disparity_paths[cam] = sorted(\n                glob(osp.join(root, f\"train_disp1_{cam}/*\"))\n            )\n\n        camera_paths[cam] = sorted(\n                glob(osp.join(root, \"train_cam_data/*\"))\n            )\n\n        num_seq = len(image_paths[\"left\"])\n        # for each sequence\n        for seq_idx in range(num_seq):\n            sequence_name = os.path.basename(image_paths[cam][seq_idx])\n            sample = defaultdict(lambda: defaultdict(list))\n            for cam in [\"left\", \"right\"]:\n                sample[\"image\"][cam] = sorted(\n                    glob(osp.join(image_paths[cam][seq_idx], f\"frame_{cam}\", \"*.png\"))\n                )[:sample_len]\n                # for _ in range(len(sample[\"image\"][cam])):\n                for _ in range(sample_len):\n                    sample[\"metadata\"][cam].append([sequence_name, (1080, 1920)])\n\n            cam = \"left\"\n            sample[\"disparity\"][cam] = sorted(\n                glob(osp.join(disparity_paths[cam][seq_idx], f\"disp1_{cam}\", \"*.dsp5\"))\n            )[:sample_len]\n            sample[\"camera\"][cam] = sorted(\n                glob(osp.join(camera_paths[cam][seq_idx], \"cam_data\", \"*.txt\"))\n            )\n            self.sample_list.append(sample)\n            seq_len = len(sample[\"image\"][cam])\n            print(\"seq_len\", sequence_name, seq_len)\n        logging.info(\n            f\"Added {len(self.sample_list) - original_length} from Spring Dataset\"\n        )\n\n\nclass SequenceSceneFlowDataset(StereoSequenceDataset):\n    def __init__(\n        self,\n        aug_params=None,\n        root=\"./data/datasets\",\n        dstype=\"frames_cleanpass\",\n        sample_len=1,\n        things_test=False,\n        add_things=True,\n        add_monkaa=True,\n        add_driving=True,\n    ):\n        super(SequenceSceneFlowDataset, self).__init__(aug_params)\n        self.root = root\n        self.dstype = dstype\n        self.sample_len = sample_len\n        if things_test:\n            self._add_things(\"TEST\")\n        else:\n            if add_things:\n                self._add_things(\"TRAIN\")\n            if add_monkaa:\n                self._add_monkaa()\n            if add_driving:\n                self._add_driving()\n\n    def _add_things(self, split=\"TRAIN\"):\n        \"\"\"Add FlyingThings3D data\"\"\"\n\n        original_length = len(self.sample_list)\n        root = osp.join(self.root, \"FlyingThings3D\")\n        image_paths = defaultdict(list)\n        disparity_paths = defaultdict(list)\n\n        for cam in [\"left\", \"right\"]:\n            image_paths[cam] = sorted(\n                glob(osp.join(root, self.dstype, split, f\"*/*/{cam}/\"))\n            )\n            disparity_paths[cam] = [\n                path.replace(self.dstype, \"disparity\") for path in image_paths[cam]\n            ]\n        # Choose a random subset of 400 images for validation\n        # state = np.random.get_state()\n        # np.random.seed(1000)\n        # val_idxs = set(np.random.permutation(len(image_paths[\"left\"]))[:40])\n        # np.random.set_state(state)\n        # np.random.seed(0)\n        num_seq = len(image_paths[\"left\"])\n        num = 0\n        for seq_idx in range(num_seq):\n            # if (split == \"TEST\" and seq_idx in val_idxs) or (\n            #     split == \"TRAIN\" and not seq_idx in val_idxs\n            # ):\n            images, disparities = defaultdict(list), defaultdict(list)\n            for cam in [\"left\", \"right\"]:\n                images[cam] = sorted(\n                    glob(osp.join(image_paths[cam][seq_idx], \"*.png\"))\n                )\n                disparities[cam] = sorted(\n                    glob(osp.join(disparity_paths[cam][seq_idx], \"*.pfm\"))\n                )\n            num = num + len(images[\"left\"])\n            self._append_sample(images, disparities)\n        print(num)\n        assert len(self.sample_list) > 0, \"No samples found\"\n        print(\n            f\"Added {len(self.sample_list) - original_length} from FlyingThings {self.dstype}\"\n        )\n        logging.info(\n            f\"Added {len(self.sample_list) - original_length} from FlyingThings {self.dstype}\"\n        )\n\n    def _add_monkaa(self):\n        \"\"\"Add FlyingThings3D data\"\"\"\n\n        original_length = len(self.sample_list)\n        root = osp.join(self.root, \"Monkaa\")\n        image_paths = defaultdict(list)\n        disparity_paths = defaultdict(list)\n\n        for cam in [\"left\", \"right\"]:\n            image_paths[cam] = sorted(glob(osp.join(root, self.dstype, f\"*/{cam}/\")))\n            disparity_paths[cam] = [\n                path.replace(self.dstype, \"disparity\") for path in image_paths[cam]\n            ]\n\n        num_seq = len(image_paths[\"left\"])\n\n        for seq_idx in range(num_seq):\n            images, disparities = defaultdict(list), defaultdict(list)\n            for cam in [\"left\", \"right\"]:\n                images[cam] = sorted(glob(osp.join(image_paths[cam][seq_idx], \"*.png\")))\n                disparities[cam] = sorted(\n                    glob(osp.join(disparity_paths[cam][seq_idx], \"*.pfm\"))\n                )\n\n            self._append_sample(images, disparities)\n\n        assert len(self.sample_list) > 0, \"No samples found\"\n        print(\n            f\"Added {len(self.sample_list) - original_length} from Monkaa {self.dstype}\"\n        )\n        logging.info(\n            f\"Added {len(self.sample_list) - original_length} from Monkaa {self.dstype}\"\n        )\n\n    def _add_driving(self):\n        \"\"\"Add FlyingThings3D data\"\"\"\n\n        original_length = len(self.sample_list)\n        root = osp.join(self.root, \"Driving\")\n        image_paths = defaultdict(list)\n        disparity_paths = defaultdict(list)\n\n        for cam in [\"left\", \"right\"]:\n            image_paths[cam] = sorted(\n                glob(osp.join(root, self.dstype, f\"*/*/*/{cam}/\"))\n            )\n            disparity_paths[cam] = [\n                path.replace(self.dstype, \"disparity\") for path in image_paths[cam]\n            ]\n\n        num_seq = len(image_paths[\"left\"])\n        for seq_idx in range(num_seq):\n            images, disparities = defaultdict(list), defaultdict(list)\n            for cam in [\"left\", \"right\"]:\n                images[cam] = sorted(glob(osp.join(image_paths[cam][seq_idx], \"*.png\")))\n                disparities[cam] = sorted(\n                    glob(osp.join(disparity_paths[cam][seq_idx], \"*.pfm\"))\n                )\n\n            self._append_sample(images, disparities)\n\n        assert len(self.sample_list) > 0, \"No samples found\"\n        print(\n            f\"Added {len(self.sample_list) - original_length} from Driving {self.dstype}\"\n        )\n        logging.info(\n            f\"Added {len(self.sample_list) - original_length} from Driving {self.dstype}\"\n        )\n\n    def _append_sample(self, images, disparities):\n        seq_len = len(images[\"left\"])\n        for ref_idx in range(0, seq_len - self.sample_len):\n            sample = defaultdict(lambda: defaultdict(list))\n            for cam in [\"left\", \"right\"]:\n                for idx in range(ref_idx, ref_idx + self.sample_len):\n                    sample[\"image\"][cam].append(images[cam][idx])\n                    sample[\"disparity\"][cam].append(disparities[cam][idx])\n            self.sample_list.append(sample)\n\n            sample = defaultdict(lambda: defaultdict(list))\n            for cam in [\"left\", \"right\"]:\n                for idx in range(ref_idx, ref_idx + self.sample_len):\n                    sample[\"image\"][cam].append(images[cam][seq_len - idx - 1])\n                    sample[\"disparity\"][cam].append(disparities[cam][seq_len - idx - 1])\n            self.sample_list.append(sample)\n\n\nclass SequenceSintelStereo(StereoSequenceDataset):\n    def __init__(\n        self,\n        dstype=\"clean\",\n        aug_params=None,\n        root=\"./data/datasets\",\n    ):\n        super().__init__(\n            aug_params, sparse=True, reader=frame_utils.readDispSintelStereo\n        )\n        self.dstype = dstype\n        self.split = \"test\"\n        original_length = len(self.sample_list)\n        image_root = osp.join(root, \"sintel_stereo\", \"training\")\n        image_paths = defaultdict(list)\n        disparity_paths = defaultdict(list)\n        camera_paths = defaultdict(list)\n\n        for cam in [\"left\", \"right\"]:\n            image_paths[cam] = sorted(\n                glob(osp.join(image_root, f\"{self.dstype}_{cam}/*\"))\n            )\n\n        cam = \"left\"\n        disparity_paths[cam] = [\n            path.replace(f\"{self.dstype}_{cam}\", \"disparities\")\n            for path in image_paths[cam]\n        ]\n        camera_paths[cam] = [\n            path.replace(f\"{self.dstype}_{cam}\", \"camdata_left\")\n            for path in image_paths[cam]\n        ]\n\n        num_seq = len(image_paths[\"left\"])\n        # for each sequence\n        for seq_idx in range(num_seq):\n            sequence_name = os.path.basename(image_paths[cam][seq_idx])\n            sample = defaultdict(lambda: defaultdict(list))\n            for cam in [\"left\", \"right\"]:\n                sample[\"image\"][cam] = sorted(\n                    glob(osp.join(image_paths[cam][seq_idx], \"*.png\"))\n                )\n                for _ in range(len(sample[\"image\"][cam])):\n                    sample[\"metadata\"][cam].append([sequence_name, (436, 1024)])\n\n            cam = \"left\"\n            sample[\"disparity\"][cam] = sorted(\n                glob(osp.join(disparity_paths[cam][seq_idx], \"*.png\"))\n            )\n            sample[\"camera\"][cam] = sorted(\n                glob(osp.join(camera_paths[cam][seq_idx], \"*.cam\"))\n            )\n\n            for im1, disp, camera in zip(sample[\"image\"][cam], sample[\"disparity\"][cam], sample[\"camera\"][cam]):\n                assert (\n                    im1.split(\"/\")[-1].split(\".\")[0]\n                    == disp.split(\"/\")[-1].split(\".\")[0]\n                    == camera.split(\"/\")[-1].split(\".\")[0]\n                ), (im1.split(\"/\")[-1].split(\".\")[0], disp.split(\"/\")[-1].split(\".\")[0], camera.split(\"/\")[-1].split(\".\")[0])\n            self.sample_list.append(sample)\n\n        logging.info(\n            f\"Added {len(self.sample_list) - original_length} from SintelStereo {self.dstype}\"\n        )\n\n\ndef fetch_dataloader(args):\n    \"\"\"Create the data loader for the corresponding training set\"\"\"\n\n    aug_params = {\n        \"crop_size\": args.image_size,\n        \"min_scale\": args.spatial_scale[0],\n        \"max_scale\": args.spatial_scale[1],\n        \"do_flip\": False,\n        \"yjitter\": not args.noyjitter,\n    }\n    if hasattr(args, \"saturation_range\") and args.saturation_range is not None:\n        aug_params[\"saturation_range\"] = args.saturation_range\n    if hasattr(args, \"img_gamma\") and args.img_gamma is not None:\n        aug_params[\"gamma\"] = args.img_gamma\n    if hasattr(args, \"do_flip\") and args.do_flip is not None:\n        aug_params[\"do_flip\"] = args.do_flip\n\n    train_dataset = None\n\n    add_monkaa = \"monkaa\" in args.train_datasets\n    add_driving = \"driving\" in args.train_datasets\n    add_things = \"things\" in args.train_datasets\n    add_dynamic_replica = \"dynamic_replica\" in args.train_datasets\n    add_infinigensv = \"infinigen_stereovideo\" in args.train_datasets\n    add_kittidepth = \"kitti_depth\" in args.train_datasets\n    add_vkitti2 = \"vkitti2\" in args.train_datasets\n    add_tartanair = \"tartanair\" in args.train_datasets\n    new_dataset = None\n\n    if add_monkaa or add_driving or add_things:\n        # clean_dataset = SequenceSceneFlowDataset(\n        #     aug_params,\n        #     dstype=\"frames_cleanpass\",\n        #     sample_len=args.sample_len,\n        #     add_monkaa=add_monkaa,\n        #     add_driving=add_driving,\n        #     add_things=add_things,\n        # )\n\n        final_dataset = SequenceSceneFlowDataset(\n            aug_params,\n            dstype=\"frames_finalpass\",\n            sample_len=args.sample_len,\n            add_monkaa=add_monkaa,\n            add_driving=add_driving,\n            add_things=add_things,\n        )\n        # new_dataset = clean_dataset + final_dataset\n\n        new_dataset = final_dataset\n\n    if add_dynamic_replica:\n        dr_dataset = DynamicReplicaDataset(\n            aug_params, split=\"train\", sample_len=args.sample_len\n        )\n        if new_dataset is None:\n            new_dataset = dr_dataset\n        else:\n            new_dataset = new_dataset + dr_dataset\n\n    if add_infinigensv:\n        infinigensv_dataset = InfinigenStereoVideoDataset(\n            aug_params, split=\"train\", sample_len=args.sample_len\n        )\n        if new_dataset is None:\n            new_dataset = infinigensv_dataset\n        else:\n            new_dataset = new_dataset + infinigensv_dataset + infinigensv_dataset + infinigensv_dataset\n\n    if add_kittidepth:\n        kittidepth_dataset = KITTIDepthDataset(\n            aug_params, split=\"train\", sample_len=args.sample_len\n        )\n        if new_dataset is None:\n            new_dataset = kittidepth_dataset\n        else:\n            new_dataset = new_dataset + kittidepth_dataset\n\n    if add_vkitti2:\n        vkitti2_dataset = VKITTI2Dataset(\n            aug_params, split=\"train\", sample_len=args.sample_len\n        )\n        if new_dataset is None:\n            new_dataset = vkitti2_dataset\n        else:\n            new_dataset = new_dataset + vkitti2_dataset\n\n    if add_tartanair:\n        tartanair_dataset = TartanAirDataset(\n            aug_params, split=\"train\", sample_len=args.sample_len\n        )\n        if new_dataset is None:\n            new_dataset = tartanair_dataset\n        else:\n            new_dataset = new_dataset + tartanair_dataset\n\n    logging.info(f\"Adding {len(new_dataset)} samples in total\")\n    train_dataset = (\n        new_dataset if train_dataset is None else train_dataset + new_dataset\n    )\n\n    train_loader = data.DataLoader(\n        train_dataset,\n        batch_size=args.batch_size,\n        pin_memory=True,\n        shuffle=True,\n        num_workers=args.num_workers,\n        drop_last=True,\n    )\n\n    logging.info(\"Training with %d image pairs\" % len(train_dataset))\n    return train_loader\n"
  },
  {
    "path": "demo.py",
    "content": "import sys\n\nimport argparse\nimport os\nimport cv2\nimport glob\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom collections import defaultdict\n\nfrom PIL import Image\nfrom matplotlib import pyplot as plt\nfrom pathlib import Path\n\nDEVICE = 'cuda'\n\n\ndef load_image(imfile):\n    img = np.array(Image.open(imfile).convert('RGB')).astype(np.uint8)\n    img = torch.from_numpy(img).permute(2, 0, 1).float()\n    return img.to(DEVICE)\n\n\ndef viz(img, flo):\n    img = img[0].permute(1, 2, 0).cpu().numpy()\n    flo = flo[0].permute(1, 2, 0).cpu().numpy()\n\n    # map flow to rgb image\n    flo = flow_viz.flow_to_image(flo)\n    img_flo = np.concatenate([img, flo], axis=0)\n\n    cv2.imshow('image', img_flo[:, :, [2, 1, 0]] / 255.0)\n    cv2.waitKey()\n\n\ndef demo(args):\n    from stereoanyvideo.models.stereoanyvideo_model import StereoAnyVideoModel\n    model = StereoAnyVideoModel()\n\n    if args.ckpt is not None:\n        assert args.ckpt.endswith(\".pth\") or args.ckpt.endswith(\n            \".pt\"\n        )\n        strict = True\n        state_dict = torch.load(args.ckpt)\n        if \"model\" in state_dict:\n            state_dict = state_dict[\"model\"]\n        if list(state_dict.keys())[0].startswith(\"module.\"):\n            state_dict = {\n                k.replace(\"module.\", \"\"): v for k, v in state_dict.items()\n            }\n        model.model.load_state_dict(state_dict, strict=strict)\n        print(\"Done loading model checkpoint\", args.ckpt)\n\n    model.to(DEVICE)\n    model.eval()\n\n    output_directory = args.output_path\n    parent_directory = os.path.dirname(output_directory)\n    if not os.path.exists(parent_directory):\n        os.makedirs(parent_directory)\n    if not os.path.isdir(output_directory):\n        os.mkdir(output_directory)\n\n    with torch.no_grad():\n        images_left = sorted(glob.glob(os.path.join(args.path, 'left/*.png')) + glob.glob(os.path.join(args.path, 'left/*.jpg')))\n        images_right = sorted(glob.glob(os.path.join(args.path, 'right/*.png')) + glob.glob(os.path.join(args.path, 'right/*.jpg')))\n        assert len(images_left) == len(images_right), [len(images_left), len(images_right)]\n        assert len(images_left) > 0, args.path\n        print(f\"Found {len(images_left)} frames. Saving files to {args.output_path}\")\n\n        num_frames = len(images_left)\n        frame_size = args.frame_size\n\n        disparities_ori_all = []\n\n        for start_idx in range(0, num_frames, frame_size):\n            end_idx = min(start_idx + frame_size, num_frames)\n\n            image_left_list = []\n            image_right_list = []\n\n            for imfile1, imfile2 in zip(images_left[start_idx:end_idx], images_right[start_idx:end_idx]):\n                image_left = load_image(imfile1)\n                image_right = load_image(imfile2)\n                image_left = F.interpolate(image_left[None], size=args.resize, mode=\"bilinear\", align_corners=True)\n                image_right = F.interpolate(image_right[None], size=args.resize, mode=\"bilinear\", align_corners=True)\n                image_left_list.append(image_left[0])\n                image_right_list.append(image_right[0])\n\n            video_left = torch.stack(image_left_list, dim=0)\n            video_right = torch.stack(image_right_list, dim=0)\n\n            batch_dict = defaultdict(list)\n            batch_dict[\"stereo_video\"] = torch.stack([video_left, video_right], dim=1)\n\n            predictions = model(batch_dict)\n\n            assert \"disparity\" in predictions\n            disparities = predictions[\"disparity\"][:, :1].clone().data.cpu().abs().numpy()\n            disparities_ori = disparities.astype(np.uint8)\n            disparities_ori_all.extend(disparities_ori)\n\n        disparities_ori_all = np.array(disparities_ori_all)\n\n        epsilon = 1e-5  # Smallest allowable disparity\n        disparities_ori_all[disparities_ori_all < epsilon] = epsilon\n\n        disparities_all = ((disparities_ori_all - disparities_ori_all.min()) / (disparities_ori_all.max() - disparities_ori_all.min()) * 255).astype(np.uint8)\n\n        video_ori_disparity = cv2.VideoWriter(\n            os.path.join(args.output_path, \"disparity.mp4\"),\n            cv2.VideoWriter_fourcc(*\"mp4v\"),\n            fps=args.fps,\n            frameSize=(disparities_all.shape[3], disparities_all.shape[2]),\n            isColor=True,\n        )\n        video_disparity = cv2.VideoWriter(\n            os.path.join(args.output_path, \"disparity_norm.mp4\"),\n            cv2.VideoWriter_fourcc(*\"mp4v\"),\n            fps=args.fps,\n            frameSize=(disparities_all.shape[3], disparities_all.shape[2]),\n            isColor=True,\n        )\n\n        for i in range(num_frames):\n            imfile1 = images_left[i]\n\n            disparity_norm = disparities_all[i]\n            disparity_norm = disparity_norm.transpose(1, 2, 0)\n            disparity_norm_vis = cv2.applyColorMap(disparity_norm, cv2.COLORMAP_INFERNO)\n            video_disparity.write(disparity_norm_vis)\n\n            disparity_ori = disparities_ori_all[i]\n            disparity_ori = disparity_ori.transpose(1, 2, 0)\n            disparity_ori_vis = cv2.applyColorMap(disparity_ori, cv2.COLORMAP_INFERNO)\n            video_ori_disparity.write(disparity_ori_vis)\n\n            if args.save_png:\n                filename_temp = args.output_path + '/disparity_norm_' + str(i).zfill(3) + '.png'\n                cv2.imwrite(filename_temp, disparity_norm_vis)\n                filename_temp = args.output_path + '/disparity_ori_' + str(i).zfill(3) + '.png'\n                cv2.imwrite(filename_temp, disparity_ori_vis)\n\n        video_ori_disparity.release()\n        video_disparity.release()\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--model_name', default=\"stereoanyvideo\", help=\"name to specify model\")\n    parser.add_argument('--ckpt', default=None, help=\"checkpoint of stereo model\")\n    parser.add_argument('--resize', default=(720, 1280), help=\"image size input to the model\")\n    parser.add_argument(\"--fps\", type=int, default=30, help=\"frame rate for video visualization\")\n    parser.add_argument('--path', help=\"dataset for evaluation\")\n    parser.add_argument(\"--save_png\", action=\"store_true\")\n    parser.add_argument(\"--frame_size\", type=int, default=150, help=\"number of updates in each forward pass.\")\n    parser.add_argument(\"--iters\",type=int, default=20, help=\"number of updates in each forward pass.\")\n    parser.add_argument(\"--kernel_size\", type=int, default=20, help=\"number of frames in each forward pass.\")\n    parser.add_argument('--output_path', help=\"directory to save output\", default=\"demo_output\")\n    args = parser.parse_args()\n\n    demo(args)\n"
  },
  {
    "path": "demo.sh",
    "content": "#!/bin/bash\n\nexport PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH\n\npython demo.py --ckpt ./checkpoints/StereoAnyVideo_MIX.pth --path ./demo_video/ --output_path ./demo_output/ --save_png"
  },
  {
    "path": "evaluate_stereoanyvideo.sh",
    "content": "#!/bin/bash\n\nexport PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH\n\n# evaluate on [sintel, dynamicreplica， infinigensv, vkitti2] using sceneflow checkpoint\n\npython ./evaluation/evaluate.py --config-name eval_sintel_final \\\nMODEL.model_name=StereoAnyVideoModel \\\nMODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_SF.pth\n\npython ./evaluation/evaluate.py --config-name eval_dynamic_replica \\\nMODEL.model_name=StereoAnyVideoModel \\\nMODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_SF.pth\n\npython ./evaluation/evaluate.py --config-name eval_infinigensv \\\nMODEL.model_name=StereoAnyVideoModel \\\nMODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_SF.pth\n\npython ./evaluation/evaluate.py --config-name eval_vkitti2 \\\nMODEL.model_name=StereoAnyVideoModel \\\nMODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_SF.pth\n\n\n\n# evaluate on [sintel, kittidepth, southkensingtonSV] using mixed checkpoint\n\npython ./evaluation/evaluate.py --config-name eval_sintel_final \\\nMODEL.model_name=StereoAnyVideoModel \\\nMODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_MIX.pth\n\npython ./evaluation/evaluate.py --config-name eval_kittidepth \\\nMODEL.model_name=StereoAnyVideoModel \\\nMODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_MIX.pth\n\npython ./evaluation/evaluate.py --config-name eval_southkensington \\\nMODEL.model_name=StereoAnyVideoModel \\\nMODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_SF.pth"
  },
  {
    "path": "evaluation/configs/eval_dynamic_replica.yaml",
    "content": "defaults:\n  - default_config_eval\nvisualize_interval: -1\nexp_dir: ./outputs/stereoanyvideo_DynamicReplica\nsample_len: 150\nMODEL:\n    model_name: StereoAnyVideoModel\n"
  },
  {
    "path": "evaluation/configs/eval_infinigensv.yaml",
    "content": "defaults:\n  - default_config_eval\nvisualize_interval: -1\nrender_bin_size: 0\nexp_dir: ./outputs/stereoanyvideo_InfinigenSV\nsample_len: 150\ndataset_name: infinigensv\nMODEL:\n    model_name: StereoAnyVideoModel\n"
  },
  {
    "path": "evaluation/configs/eval_kittidepth.yaml",
    "content": "defaults:\n  - default_config_eval\nvisualize_interval: -1\nrender_bin_size: 0\nexp_dir: ./outputs/stereoanyvideo_KITTIDepth\nsample_len: 300\ndataset_name: kitti_depth\nMODEL:\n    model_name: StereoAnyVideoModel\n"
  },
  {
    "path": "evaluation/configs/eval_sintel_clean.yaml",
    "content": "defaults:\n  - default_config_eval\nvisualize_interval: -1\nrender_bin_size: 0\nexp_dir: ./outputs/stereoanyvideo_sintel_clean\ndataset_name: sintel\ndstype: clean\nMODEL:\n    model_name: StereoAnyVideoModel\n"
  },
  {
    "path": "evaluation/configs/eval_sintel_final.yaml",
    "content": "defaults:\n  - default_config_eval\nvisualize_interval: -1\nrender_bin_size: 0\nexp_dir: ./outputs/stereoanyvideo_sintel_final\ndataset_name: sintel\ndstype: final\nMODEL:\n    model_name: StereoAnyVideoModel"
  },
  {
    "path": "evaluation/configs/eval_southkensington.yaml",
    "content": "defaults:\n  - default_config_eval\nvisualize_interval: 1\nexp_dir: ./outputs/stereoanyvideo_SouthKensingtonIndoor\nsample_len: 300\ndataset_name: southkensingtonsv\nMODEL:\n    model_name: StereoAnyVideoModel\n"
  },
  {
    "path": "evaluation/configs/eval_vkitti2.yaml",
    "content": "defaults:\n  - default_config_eval\nvisualize_interval: -1\nrender_bin_size: 0\nexp_dir: ./outputs/stereoanyvideo_VKITTI2\nsample_len: 300\ndataset_name: vkitti2\nMODEL:\n    model_name: StereoAnyVideoModel\n"
  },
  {
    "path": "evaluation/core/evaluator.py",
    "content": "import os\nimport numpy as np\nimport cv2\nfrom collections import defaultdict\nimport torch.nn.functional as F\nimport torch\nimport matplotlib.pyplot as plt\nfrom tqdm import tqdm\nfrom omegaconf import DictConfig\nfrom pytorch3d.implicitron.tools.config import Configurable\nfrom stereoanyvideo.evaluation.utils.eval_utils import depth2disparity_scale, eval_batch\nfrom stereoanyvideo.evaluation.utils.utils import (\n    PerceptionPrediction,\n    pretty_print_perception_metrics,\n    visualize_batch,\n)\n\n\ndef depth_to_colormap(depth, colormap='jet', eps=1e-3, scale_vmin=1.0):\n    valid = (depth > eps) & (depth < 1e4)\n    vmin = depth[valid].min() * scale_vmin\n    vmax = depth[valid].max()\n    if colormap=='jet':\n        cmap = plt.cm.jet\n    else:\n        cmap = plt.cm.inferno\n    norm = plt.Normalize(vmin=vmin, vmax=vmax)\n    depth = cmap(norm(depth))\n    depth[~valid] = 1\n    return np.ascontiguousarray(depth[...,:3] * 255, dtype=np.uint8)\n\n\nclass Evaluator(Configurable):\n    \"\"\"\n    A class defining the DynamicStereo evaluator.\n\n    Args:\n        eps: Threshold for converting disparity to depth.\n    \"\"\"\n\n    eps = 1e-5\n\n    def setup_visualization(self, cfg: DictConfig) -> None:\n        # Visualization\n        self.visualize_interval = cfg.visualize_interval\n        self.render_bin_size = cfg.render_bin_size\n        self.exp_dir = cfg.exp_dir\n        if self.visualize_interval > 0:\n            self.visualize_dir = os.path.join(cfg.exp_dir, \"visualisations\")\n\n    @torch.no_grad()\n    def evaluate_sequence(\n        self,\n        model,\n        model_stabilizer,\n        test_dataloader: torch.utils.data.DataLoader,\n        is_real_data: bool = False,\n        step=None,\n        writer=None,\n        train_mode=False,\n        interp_shape=None,\n        exp_dir=None,\n    ):\n        model.eval()\n        per_batch_eval_results = []\n\n        if self.visualize_interval > 0:\n            os.makedirs(self.visualize_dir, exist_ok=True)\n\n        for batch_idx, sequence in enumerate(tqdm(test_dataloader)):\n            batch_dict = defaultdict(list)\n            batch_dict[\"stereo_video\"] = sequence[\"img\"]\n            if not is_real_data:\n                batch_dict[\"disparity\"] = sequence[\"disp\"][:, 0].abs()\n                batch_dict[\"disparity_mask\"] = sequence[\"valid_disp\"][:, :1]\n\n                if \"mask\" in sequence:\n                    batch_dict[\"fg_mask\"] = sequence[\"mask\"][:, :1]\n                else:\n                    batch_dict[\"fg_mask\"] = torch.ones_like(\n                        batch_dict[\"disparity_mask\"]\n                    )\n\n            elif interp_shape is not None:\n                left_video = batch_dict[\"stereo_video\"][:, 0]\n                left_video = F.interpolate(\n                    left_video, tuple(interp_shape), mode=\"bilinear\"\n                )\n                right_video = batch_dict[\"stereo_video\"][:, 1]\n                right_video = F.interpolate(\n                    right_video, tuple(interp_shape), mode=\"bilinear\"\n                )\n                batch_dict[\"stereo_video\"] = torch.stack([left_video, right_video], 1)\n\n            if model_stabilizer is not None:\n                predictions = model.forward_stabilizer(batch_dict, model_stabilizer)\n            elif train_mode:\n                predictions = model.forward_batch_test(batch_dict)\n            else:\n                predictions = model(batch_dict)\n\n            assert \"disparity\" in predictions\n            predictions[\"disparity\"] = predictions[\"disparity\"][:, :1].clone().cpu()\n            if not is_real_data:\n                predictions[\"disparity\"] = predictions[\"disparity\"] * (\n                    batch_dict[\"disparity_mask\"].round()\n                )\n\n                batch_eval_result, seq_length = eval_batch(batch_dict, predictions, sequence[\"depth2disp_scale\"][0])\n                per_batch_eval_results.append((batch_eval_result, seq_length))\n                pretty_print_perception_metrics(batch_eval_result)\n\n            if (self.visualize_interval > 0) and (\n                batch_idx % self.visualize_interval == 0\n            ):\n                perception_prediction = PerceptionPrediction()\n\n                pred_disp = predictions[\"disparity\"]\n                pred_disp[pred_disp < self.eps] = self.eps\n\n                scale = sequence[\"depth2disp_scale\"][0]\n                perception_prediction.depth_map = (scale / pred_disp).cuda()\n\n                perspective_cameras = []\n                if \"viewpoint\" in sequence:\n                    for cam in sequence[\"viewpoint\"]:\n                        perspective_cameras.append(cam[0])\n                        perception_prediction.perspective_cameras = perspective_cameras\n\n                if \"stereo_original_video\" in batch_dict:\n                    batch_dict[\"stereo_video\"] = batch_dict[\n                        \"stereo_original_video\"\n                    ].clone()\n\n                for k, v in batch_dict.items():\n                    if isinstance(v, torch.Tensor):\n                        batch_dict[k] = v.cuda()\n\n                visualize_batch(\n                    batch_dict,\n                    perception_prediction,\n                    output_dir=self.visualize_dir,\n                    sequence_name=sequence[\"metadata\"][0][0][0],\n                    step=step,\n                    writer=writer,\n                    render_bin_size=self.render_bin_size,\n                )\n                filename = os.path.join(self.visualize_dir, sequence[\"metadata\"][0][0][0])\n                if not os.path.isdir(filename):\n                    os.mkdir(filename)\n                disparity_list = pred_disp.data.cpu().numpy()\n                depth_list = perception_prediction.depth_map.data.cpu().numpy()\n                np.save(f\"{filename}_depth_list.npy\", depth_list)\n\n                video_disparity = cv2.VideoWriter(\n                    f\"{filename}_disparity.mp4\",\n                    cv2.VideoWriter_fourcc(*\"mp4v\"),\n                    fps=30,\n                    frameSize=(\n                        batch_dict[\"stereo_video\"][:, 0][0].shape[2], batch_dict[\"stereo_video\"][:, 0][0].shape[1]),\n                    isColor=True,\n                )\n\n                disparity_vis = depth_to_colormap(disparity_list[:, 0], eps=self.eps, colormap='inferno')\n                for i in range(disparity_list.shape[0]):\n                    filename_temp = filename + '/disparity_' + str(i).zfill(3) + '.png'\n                    disparity_vis[i] = cv2.cvtColor(disparity_vis[i], cv2.COLOR_RGB2BGR)\n                    cv2.imwrite(filename_temp, disparity_vis[i])\n                    video_disparity.write(disparity_vis[i])\n                video_disparity.release()\n\n        return per_batch_eval_results\n"
  },
  {
    "path": "evaluation/evaluate.py",
    "content": "import json\nimport os\nfrom dataclasses import dataclass, field\nfrom typing import Any, Dict, Optional\n\nimport hydra\nimport numpy as np\n\nimport torch\nfrom omegaconf import OmegaConf\n\nfrom stereoanyvideo.evaluation.utils.utils import aggregate_and_print_results\n\nimport stereoanyvideo.datasets.video_datasets as datasets\n\nfrom stereoanyvideo.models.core.model_zoo import (\n    get_all_model_default_configs,\n    model_zoo,\n)\nfrom pytorch3d.implicitron.tools.config import get_default_args_field\nfrom stereoanyvideo.evaluation.core.evaluator import Evaluator\n\n\n@dataclass(eq=False)\nclass DefaultConfig:\n    exp_dir: str = \"./outputs\"\n    stabilizer_ckpt: Optional[str] = None\n\n    # one of [sintel, dynamicreplica, things, kitti_depth, infinigensv, southkensingtonsv]\n    dataset_name: str = \"dynamicreplica\"\n\n    sample_len: int = -1\n    dstype: Optional[str] = None\n    # clean, final\n    MODEL: Dict[str, Any] = field(\n        default_factory=lambda: get_all_model_default_configs()\n    )\n    EVALUATOR: Dict[str, Any] = get_default_args_field(Evaluator)\n\n    seed: int = 42\n    gpu_idx: int = 0\n\n    visualize_interval: int = 1  # Use 0 for no visualization\n\n    render_bin_size: Optional[int] = None\n\n    # Override hydra's working directory to current working dir,\n    # also disable storing the .hydra logs:\n    hydra: dict = field(\n        default_factory=lambda: {\n            \"run\": {\"dir\": \".\"},\n            \"output_subdir\": None,\n        }\n    )\n\n\ndef run_eval(cfg: DefaultConfig):\n    \"\"\"\n    Evaluates new view synthesis metrics of a specified model\n    on a benchmark dataset.\n    \"\"\"\n    # make the experiment directory\n    os.makedirs(cfg.exp_dir, exist_ok=True)\n\n    # dump the exp cofig to the exp_dir\n    cfg_file = os.path.join(cfg.exp_dir, \"expconfig.yaml\")\n    with open(cfg_file, \"w\") as f:\n        OmegaConf.save(config=cfg, f=f)\n\n    torch.manual_seed(cfg.seed)\n    np.random.seed(cfg.seed)\n    evaluator = Evaluator(**cfg.EVALUATOR)\n\n    model = model_zoo(**cfg.MODEL)\n    model.cuda(0)\n    evaluator.setup_visualization(cfg)\n\n    if cfg.dataset_name == \"dynamicreplica\":\n        test_dataloader = datasets.DynamicReplicaDataset(\n            split=\"test\", sample_len=cfg.sample_len, only_first_n_samples=1\n        )\n    elif cfg.dataset_name == \"infinigensv\":\n        test_dataloader = datasets.InfinigenStereoVideoDataset(\n            split=\"test\", sample_len=cfg.sample_len, only_first_n_samples=1\n        )\n    elif cfg.dataset_name == \"southkensingtonsv\":\n        test_dataloader = datasets.SouthKensingtonStereoVideoDataset(\n            sample_len=cfg.sample_len, only_first_n_samples=1\n        )\n        evaluator.evaluate_sequence(\n            model,\n            None,\n            test_dataloader,\n            is_real_data=True,\n            exp_dir=cfg.exp_dir\n        )\n        return\n\n    elif cfg.dataset_name == \"kitti_depth\":\n        test_dataloader = datasets.KITTIDepthDataset(\n            split=\"test\", sample_len=cfg.sample_len, only_first_n_samples=1\n        )\n    elif cfg.dataset_name == \"vkitti2\":\n        test_dataloader = datasets.VKITTI2Dataset(\n            split=\"test\", sample_len=cfg.sample_len, only_first_n_samples=1\n        )\n    elif cfg.dataset_name == \"sintel\":\n        test_dataloader = datasets.SequenceSintelStereo(dstype=cfg.dstype)\n    elif cfg.dataset_name == \"things\":\n        test_dataloader = datasets.SequenceSceneFlowDatasets(\n            {},\n            dstype=cfg.dstype,\n            sample_len=cfg.sample_len,\n            add_monkaa=False,\n            add_driving=False,\n            things_test=True,\n        )\n\n    evaluate_result = evaluator.evaluate_sequence(\n        model,\n        None,\n        test_dataloader,\n        is_real_data=False,\n        exp_dir=cfg.exp_dir\n    )\n\n    aggreegate_result = aggregate_and_print_results(evaluate_result)\n\n    result_file = os.path.join(cfg.exp_dir, f\"result_eval.json\")\n\n    print(f\"Dumping eval results to {result_file}.\")\n    with open(result_file, \"w\") as f:\n        json.dump(aggreegate_result, f)\n\n\ncs = hydra.core.config_store.ConfigStore.instance()\ncs.store(name=\"default_config_eval\", node=DefaultConfig)\n\n\n@hydra.main(config_path=\"./configs/\", config_name=\"default_config_eval\")\ndef evaluate(cfg: DefaultConfig) -> None:\n    os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(cfg.gpu_idx)\n    run_eval(cfg)\n\n\nif __name__ == \"__main__\":\n    evaluate()\n"
  },
  {
    "path": "evaluation/utils/eval_utils.py",
    "content": "from dataclasses import dataclass\nfrom typing import Dict, Optional, Union\nfrom stereoanyvideo.evaluation.utils.ssim import SSIM\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nimport math\nimport cv2\nfrom pytorch3d.utils import opencv_from_cameras_projection\nfrom stereoanyvideo.models.raft_model import RAFTModel\n\n\n@dataclass(eq=True, frozen=True)\nclass PerceptionMetric:\n    metric: str\n    depth_scaling_norm: Optional[str] = None\n    suffix: str = \"\"\n    index: str = \"\"\n\n    def __str__(self):\n        return (\n            self.metric\n            + self.index\n            + (\n                (\"_norm_\" + self.depth_scaling_norm)\n                if self.depth_scaling_norm is not None\n                else \"\"\n            )\n            + self.suffix\n        )\n\n\ndef compute_flow(seq, is_seq=True):\n    raft = RAFTModel().cuda()\n    raft.eval()\n    if is_seq:\n        t, c, h, w = seq.size()\n        flows_forward = []\n        for i in range(t-1):\n            flow_forward = raft.forward_fullres(seq[i][None], seq[i+1][None], iters=20)\n            flows_forward.append(flow_forward)\n        flows_forward = torch.cat(flows_forward, dim=0)\n        return flows_forward\n\n    else:\n        img1, img2 = seq\n        flow_forward = raft.forward_fullres(img1, img2, iters=20)\n        return flow_forward\n\ndef flow_warp(x, flow):\n    if flow.size(3) != 2:  # [B, H, W, 2]\n        flow = flow.permute(0, 2, 3, 1)\n    if x.size()[-2:] != flow.size()[1:3]:\n        raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '\n                         f'flow ({flow.size()[1:3]}) are not the same.')\n    _, _, h, w = x.size()\n    # create mesh grid\n    grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))\n    grid = torch.stack((grid_x, grid_y), 2).type_as(x)  # (h, w, 2)\n    grid.requires_grad = False\n\n    grid_flow = grid + flow\n    # scale grid_flow to [-1,1]\n    grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0\n    grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0\n    grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)\n    output = F.grid_sample(\n        x,\n        grid_flow,\n        mode='bilinear',\n        padding_mode='zeros',\n        align_corners=True)\n    return output\n\ndef eval_endpoint_error_sequence(\n    x: torch.Tensor,\n    y: torch.Tensor,\n    mask: torch.Tensor,\n    crop: int = 0,\n    mask_thr: float = 0.5,\n    clamp_thr: float = 1e-5,\n) -> Dict[str, torch.Tensor]:\n\n    assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (\n        x.shape,\n        y.shape,\n        mask.shape,\n    )\n    assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)\n\n    # chuck out the border\n    if crop > 0:\n        if crop > min(y.shape[2:]) - crop:\n            raise ValueError(\"Incorrect crop size.\")\n        y = y[:, :, crop:-crop, crop:-crop]\n        x = x[:, :, crop:-crop, crop:-crop]\n        mask = mask[:, :, crop:-crop, crop:-crop]\n\n    y = y * (mask > mask_thr).float()\n    x = x * (mask > mask_thr).float()\n    y[torch.isnan(y)] = 0\n\n    results = {}\n    for epe_name in (\"epe\", \"temp_epe\"):\n        if epe_name == \"epe\":\n            endpoint_error = (mask * (x - y) ** 2).sum(dim=1).sqrt()\n        elif epe_name == \"temp_epe\":\n            delta_mask = mask[:-1] * mask[1:]\n            # endpoint_error = (\n            #     (delta_mask * ((x[:-1] - x[1:]) - (y[:-1] - y[1:])) ** 2)\n            #     .sum(dim=1)\n            #     .sqrt()\n            # )\n            endpoint_error = (\n                (delta_mask * ((x[:-1] - x[1:]).abs() - (y[:-1] - y[1:]).abs()) ** 2)\n                .sum(dim=1)\n                .sqrt()\n            )\n\n        # epe_nonzero = endpoint_error != 0\n        nonzero = torch.count_nonzero(endpoint_error)\n        epe_mean = endpoint_error.sum() / torch.clamp(\n            nonzero, clamp_thr\n        )  # average error for all the sequence pixels\n        epe_inv_accuracy_05px = (endpoint_error > 0.5).sum() / torch.clamp(\n            nonzero, clamp_thr\n        )\n        epe_inv_accuracy_1px = (endpoint_error > 1).sum() / torch.clamp(\n            nonzero, clamp_thr\n        )\n        epe_inv_accuracy_2px = (endpoint_error > 2).sum() / torch.clamp(\n            nonzero, clamp_thr\n        )\n        epe_inv_accuracy_3px = (endpoint_error > 3).sum() / torch.clamp(\n            nonzero, clamp_thr\n        )\n\n        results[f\"{epe_name}_mean\"] = epe_mean[None]\n        results[f\"{epe_name}_bad_0.5px\"] = epe_inv_accuracy_05px[None] * 100\n        results[f\"{epe_name}_bad_1px\"] = epe_inv_accuracy_1px[None] * 100\n        results[f\"{epe_name}_bad_2px\"] = epe_inv_accuracy_2px[None] * 100\n        results[f\"{epe_name}_bad_3px\"] = epe_inv_accuracy_3px[None] * 100\n    return results\n\n\ndef eval_TCC_sequence(\n    x: torch.Tensor,\n    y: torch.Tensor,\n    mask: torch.Tensor,\n    crop: int = 0,\n    mask_thr: float = 0.5,\n) -> Dict[str, torch.Tensor]:\n\n    assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (\n        x.shape,\n        y.shape,\n        mask.shape,\n    )\n    assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)\n    t, c, h, w = x.shape\n    # chuck out the border\n    if crop > 0:\n        if crop > min(y.shape[2:]) - crop:\n            raise ValueError(\"Incorrect crop size.\")\n        y = y[:, :, crop:-crop, crop:-crop]\n        x = x[:, :, crop:-crop, crop:-crop]\n        mask = mask[:, :, crop:-crop, crop:-crop]\n\n    y = y * (mask > mask_thr).float()\n    x = x * (mask > mask_thr).float()\n    x[torch.isnan(x)] = 0\n    y[torch.isnan(y)] = 0\n\n    ssim_loss = SSIM(1.0, nonnegative_ssim=True)\n    delta_mask = mask[:-1] * mask[1:]\n\n    tcc = 0\n    for i in range(t-1):\n        tcc += ssim_loss((torch.abs(x[i][None] - x[i+1][None]) * delta_mask[i]).expand(-1, 3, -1, -1),\n                          (torch.abs(y[i][None] - y[i+1][None]) * delta_mask[i]).expand(-1, 3, -1, -1))\n    tcc = tcc / (t-1)\n\n    return tcc\n\ndef eval_TCM_sequence(\n    x: torch.Tensor,\n    y: torch.Tensor,\n    mask: torch.Tensor,\n    crop: int = 0,\n    mask_thr: float = 0.5,\n) -> Dict[str, torch.Tensor]:\n\n    assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (\n        x.shape,\n        y.shape,\n        mask.shape,\n    )\n    assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)\n\n    t, c, h, w = x.shape\n    # chuck out the border\n    if crop > 0:\n        if crop > min(y.shape[2:]) - crop:\n            raise ValueError(\"Incorrect crop size.\")\n        y = y[:, :, crop:-crop, crop:-crop]\n        x = x[:, :, crop:-crop, crop:-crop]\n        mask = mask[:, :, crop:-crop, crop:-crop]\n\n    y = y * (mask > mask_thr).float()\n    x = x * (mask > mask_thr).float()\n    y[torch.isnan(y)] = 0\n\n    ssim_loss = SSIM(1.0, nonnegative_ssim=True, size_average=False)\n    delta_mask = mask[:-1] * mask[1:]\n\n    tcm = 0\n    for i in range(t-1):\n        dmax = torch.max(y[i][None].view(1, -1), -1)[0].view(1, 1, 1, 1).expand(-1, 3, -1, -1)\n        dmin = torch.min(y[i][None].view(1, -1), -1)[0].view(1, 1, 1, 1).expand(-1, 3, -1, -1)\n\n        x_norm = (x[i][None].expand(-1, 3, -1, -1) - dmin) / (dmax - dmin) * 255.\n        x2_norm = (x[i+1][None].expand(-1, 3, -1, -1) - dmin) / (dmax - dmin) * 255.\n        x_flow = compute_flow([x_norm.cuda(), x2_norm.cuda()], is_seq=False).cpu()\n\n        y_norm = (y[i][None].expand(-1, 3, -1, -1) - dmin) / (dmax - dmin) * 255.\n        y2_norm = (y[i+1][None].expand(-1, 3, -1, -1) - dmin) / (dmax - dmin) * 255.\n        y_flow = compute_flow([y_norm.cuda(), y2_norm.cuda()], is_seq=False).cpu()\n\n        flow_mask = torch.sum(y_flow > 250, 1, keepdim=True) == 0\n\n        mask = delta_mask[i][None] * flow_mask\n        mask = mask.expand(-1, 3, -1, -1)\n        if torch.sum(mask) > 0:\n            tcm += torch.mean(ssim_loss(\n                torch.cat((x_flow, torch.ones_like(x_flow[:, 0, None, ...])), 1) * mask,\n                torch.cat((y_flow, torch.ones_like(x_flow[:, 0, None, ...])), 1) * mask)[:, :2])\n    tcm = tcm / (t-1)\n\n    return tcm\n\n\ndef eval_OPW_sequence(\n    img: torch.Tensor,\n    x: torch.Tensor,\n    y: torch.Tensor,\n    mask: torch.Tensor,\n    crop: int = 0,\n    mask_thr: float = 0.5,\n    clamp_thr: float = 1e-5,\n) -> Dict[str, torch.Tensor]:\n\n    assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (\n        x.shape,\n        y.shape,\n        mask.shape,\n    ) # T, 1, H, W\n    assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)\n\n    t, c, h, w = img[:, 0].shape\n    # chuck out the border\n    if crop > 0:\n        if crop > min(y.shape[2:]) - crop:\n            raise ValueError(\"Incorrect crop size.\")\n        y = y[:, :, crop:-crop, crop:-crop]\n        x = x[:, :, crop:-crop, crop:-crop]\n        mask = mask[:, :, crop:-crop, crop:-crop]\n\n    y = y * (mask > mask_thr).float()\n    x = x * (mask > mask_thr).float()\n    y[torch.isnan(y)] = 0\n    delta_mask = mask[:-1] * mask[1:]\n    depth_mask_30 = torch.sum(y > 30, 1, keepdim=True) == 0\n    depth_mask_30 = depth_mask_30[:-1] * depth_mask_30[1:]\n    depth_mask_50 = torch.sum(y > 50, 1, keepdim=True) == 0\n    depth_mask_50 = depth_mask_50[:-1] * depth_mask_50[1:]\n    depth_mask_100 = torch.sum(y > 100, 1, keepdim=True) == 0\n    depth_mask_100 = depth_mask_100[:-1] * depth_mask_100[1:]\n\n    flow = compute_flow(img[:, 0].cuda()).cpu()\n    warped_disp = flow_warp(x[1:], flow)\n    warped_img = flow_warp(img[:, 0][1:].float(), flow)\n\n    flow_mask = torch.sum(flow > 250, 1, keepdim=True) == 0\n\n    delta_mask = delta_mask * torch.exp(-50. * torch.sqrt(\n        ((warped_img / 255. - img[:, 0][:-1].float() / 255.) ** 2).sum(dim=1, keepdim=True))) * flow_mask * (\n                          warped_disp > 0) > 1e-2\n    opw_err = torch.abs(warped_disp - x[:-1]) * delta_mask\n    opw_err_30 = torch.abs(warped_disp - x[:-1]) * delta_mask * depth_mask_30\n    opw_err_50 = torch.abs(warped_disp - x[:-1]) * delta_mask * depth_mask_50\n    opw_err_100 = torch.abs(warped_disp - x[:-1]) * delta_mask * depth_mask_100\n\n    opw = 0\n    opw_30 = 0\n    opw_50 = 0\n    opw_100 = 0\n    for i in range(t-1):\n        if torch.sum(delta_mask[i]) > 0:\n            opw += torch.sum(opw_err[i]) / torch.sum(delta_mask[i])\n        if torch.sum(delta_mask[i] * depth_mask_30[i]) > 0:\n            opw_30 += torch.sum(opw_err_30[i]) / torch.sum(delta_mask[i] * depth_mask_30[i])\n        if torch.sum(delta_mask[i] * depth_mask_50[i]) > 0:\n            opw_50 += torch.sum(opw_err_50[i]) / torch.sum(delta_mask[i] * depth_mask_50[i])\n        if torch.sum(delta_mask[i] * depth_mask_100[i]) > 0:\n            opw_100 += torch.sum(opw_err_100[i]) / torch.sum(delta_mask[i] * depth_mask_100[i])\n    opw = opw / (t - 1)\n    opw_30 = opw_30 / (t - 1)\n    opw_50 = opw_50 / (t - 1)\n    opw_100 = opw_100 / (t - 1)\n    return opw, opw_30, opw_50, opw_100\n\n\ndef eval_RTC_sequence(\n    img: torch.Tensor,\n    x: torch.Tensor,\n    y: torch.Tensor,\n    mask: torch.Tensor,\n    crop: int = 0,\n    mask_thr: float = 0.5,\n    clamp_thr: float = 1e-5,\n) -> Dict[str, torch.Tensor]:\n\n    assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (\n        x.shape,\n        y.shape,\n        mask.shape,\n    ) # T, 1, H, W\n    assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)\n\n    t, c, h, w = img[:, 0].shape\n    # chuck out the border\n    if crop > 0:\n        if crop > min(y.shape[2:]) - crop:\n            raise ValueError(\"Incorrect crop size.\")\n        y = y[:, :, crop:-crop, crop:-crop]\n        x = x[:, :, crop:-crop, crop:-crop]\n        mask = mask[:, :, crop:-crop, crop:-crop]\n\n    y = y * (mask > mask_thr).float()\n    x = x * (mask > mask_thr).float()\n    y[torch.isnan(y)] = 0\n\n    flow = compute_flow(img[:, 0].cuda()).cpu()\n    delta_mask = mask[:-1] * mask[1:]\n\n    warped_disp = flow_warp(x[1:], flow)\n    warped_img = flow_warp(img[:, 0][1:], flow)\n\n    flow_mask = torch.sum(flow > 250, 1, keepdim=True) == 0\n    depth_mask = torch.sum(y > 30, 1, keepdim=True) == 0\n    depth_mask = depth_mask[:-1] * depth_mask[1:]\n\n    delta_mask = delta_mask * torch.exp(-50. * torch.sqrt(\n        ((warped_img / 255. - img[:, 0][:-1] / 255.) ** 2).sum(dim=1, keepdim=True))) * flow_mask * (\n                         warped_disp > 0) > 1e-2\n    tau = 1.01\n\n    x1 = x[:-1]  / warped_disp\n    x2 = warped_disp / x[:-1]\n\n    x1[torch.isinf(x1)] = -1e10\n    x2[torch.isinf(x2)] = -1e10\n    x = torch.max(torch.cat((x1, x2), 1), 1)[0] < tau\n\n    rtc_err = x[:, None] * delta_mask\n    rtc_err_30 = x[:, None] * delta_mask * depth_mask\n    rtc = 0\n    rtc_30 = 0\n    for i in range(t-1):\n        if torch.sum(delta_mask[i]) > 0:\n            rtc += torch.sum(rtc_err[i]) / torch.sum(delta_mask[i])\n        if torch.sum(delta_mask[i] * depth_mask[i]) > 0:\n            rtc_30 += torch.sum(rtc_err_30[i]) / torch.sum(delta_mask[i] * depth_mask[i])\n    rtc = rtc / (t-1)\n    rtc_30 = rtc_30 / (t - 1)\n    return rtc, rtc_30\n\n\ndef depth2disparity_scale(left_camera, right_camera, image_size_tensor):\n    # # opencv camera matrices\n    (_, T1, K1), (_, T2, _) = [\n        opencv_from_cameras_projection(\n            f,\n            image_size_tensor,\n        )\n        for f in (left_camera, right_camera)\n    ]\n    fix_baseline = T1[0][0] - T2[0][0]\n    focal_length_px = K1[0][0][0]\n    # following this https://github.com/princeton-vl/RAFT-Stereo#converting-disparity-to-depth\n    return focal_length_px * fix_baseline\n\n\ndef depth_to_pcd(\n    depth_map,\n    img,\n    focal_length,\n    cx,\n    cy,\n    step: int = None,\n    inv_extrinsic=None,\n    mask=None,\n    filter=False,\n):\n    __, w, __ = img.shape\n    if step is None:\n        step = int(w / 100)\n    Z = depth_map[::step, ::step]\n    colors = img[::step, ::step, :]\n\n    Pixels_Y = torch.arange(Z.shape[0]).to(Z.device) * step\n    Pixels_X = torch.arange(Z.shape[1]).to(Z.device) * step\n\n    X = (Pixels_X[None, :] - cx) * Z / focal_length\n    Y = (Pixels_Y[:, None] - cy) * Z / focal_length\n\n    inds = Z > 0\n\n    if mask is not None:\n        inds = inds * (mask[::step, ::step] > 0)\n\n    X = X[inds].reshape(-1)\n    Y = Y[inds].reshape(-1)\n    Z = Z[inds].reshape(-1)\n    colors = colors[inds]\n    pcd = torch.stack([X, Y, Z]).T\n\n    if inv_extrinsic is not None:\n        pcd_ext = torch.vstack([pcd.T, torch.ones((1, pcd.shape[0])).to(Z.device)])\n        pcd = (inv_extrinsic @ pcd_ext)[:3, :].T\n\n    if filter:\n        pcd, filt_inds = filter_outliers(pcd)\n        colors = colors[filt_inds]\n    return pcd, colors\n\n\ndef filter_outliers(pcd, sigma=3):\n    mean = pcd.mean(0)\n    std = pcd.std(0)\n    inds = ((pcd - mean).abs() < sigma * std)[:, 2]\n    pcd = pcd[inds]\n    return pcd, inds\n\n\ndef eval_batch(batch_dict, predictions, scale) -> Dict[str, Union[float, torch.Tensor]]:\n    \"\"\"\n    Produce performance metrics for a single batch of perception\n    predictions.\n    Args:\n        frame_data: A PixarFrameData object containing the input to the new view\n            synthesis method.\n        preds: A PerceptionPrediction object with the predicted data.\n    Returns:\n        results: A dictionary holding evaluation metrics.\n    \"\"\"\n    results = {}\n\n    assert \"disparity\" in predictions\n    mask_now = torch.ones_like(batch_dict[\"fg_mask\"])\n\n    mask_now = mask_now * batch_dict[\"disparity_mask\"]\n\n    eval_flow_traj_output = eval_endpoint_error_sequence(\n        predictions[\"disparity\"], batch_dict[\"disparity\"], mask_now\n    )\n    for epe_name in (\"epe\", \"temp_epe\"):\n        results[PerceptionMetric(f\"disp_{epe_name}_mean\")] = eval_flow_traj_output[\n            f\"{epe_name}_mean\"\n        ]\n\n        results[PerceptionMetric(f\"disp_{epe_name}_bad_3px\")] = eval_flow_traj_output[\n            f\"{epe_name}_bad_3px\"\n        ]\n\n        results[PerceptionMetric(f\"disp_{epe_name}_bad_2px\")] = eval_flow_traj_output[\n            f\"{epe_name}_bad_2px\"\n        ]\n\n        results[PerceptionMetric(f\"disp_{epe_name}_bad_1px\")] = eval_flow_traj_output[\n            f\"{epe_name}_bad_1px\"\n        ]\n\n        results[PerceptionMetric(f\"disp_{epe_name}_bad_0.5px\")] = eval_flow_traj_output[\n            f\"{epe_name}_bad_0.5px\"\n        ]\n    if \"endpoint_error_per_pixel\" in eval_flow_traj_output:\n        results[\"disp_endpoint_error_per_pixel\"] = eval_flow_traj_output[\n            \"endpoint_error_per_pixel\"\n        ]\n\n    # disparity to depth\n    depth = scale / predictions[\"disparity\"].clamp(min=1e-10)\n\n    eval_TCC_output = eval_TCC_sequence(\n        depth, scale / batch_dict[\"disparity\"].clamp(min=1e-10), mask_now\n    )\n    results[PerceptionMetric(\"disp_TCC\")] = eval_TCC_output[None]\n\n    eval_TCM_output = eval_TCM_sequence(\n        depth, scale / batch_dict[\"disparity\"].clamp(min=1e-10), mask_now\n    )\n    results[PerceptionMetric(\"disp_TCM\")] = eval_TCM_output[None]\n\n    eval_OPW_output, eval_OPW_30_output, eval_OPW_50_output, eval_OPW_100_output = eval_OPW_sequence(\n        batch_dict[\"stereo_video\"], depth, scale / batch_dict[\"disparity\"].clamp(min=1e-10), mask_now\n    )\n    results[PerceptionMetric(\"disp_OPW\")] = eval_OPW_output[None]\n    results[PerceptionMetric(\"disp_OPW_100\")] = eval_OPW_100_output[None]\n    results[PerceptionMetric(\"disp_OPW_50\")] = eval_OPW_50_output[None]\n    if eval_OPW_30_output > 0:\n        results[PerceptionMetric(\"disp_OPW_30\")] = eval_OPW_30_output[None]\n    else:\n        results[PerceptionMetric(\"disp_OPW_30\")] = torch.tensor([0.0])\n\n    eval_RTC_output, eval_RTC_30_output = eval_RTC_sequence(\n        batch_dict[\"stereo_video\"], depth, scale / batch_dict[\"disparity\"].clamp(min=1e-10), mask_now\n    )\n    results[PerceptionMetric(\"disp_RTC\")] = eval_RTC_output[None]\n    if eval_RTC_30_output > 0:\n        results[PerceptionMetric(\"disp_RTC_30\")] = eval_RTC_30_output[None]\n    else:\n        results[PerceptionMetric(\"disp_RTC_30\")] = torch.tensor([0.0])\n\n    return (results, len(predictions[\"disparity\"]))\n"
  },
  {
    "path": "evaluation/utils/ssim.py",
    "content": "# Copyright 2020 by Gongfan Fang, Zhejiang University.\n# All rights reserved.\nimport warnings\n\nimport torch\nimport torch.nn.functional as F\n\n\ndef _fspecial_gauss_1d(size, sigma):\n    r\"\"\"Create 1-D gauss kernel\n    Args:\n        size (int): the size of gauss kernel\n        sigma (float): sigma of normal distribution\n    Returns:\n        torch.Tensor: 1D kernel (1 x 1 x size)\n    \"\"\"\n    coords = torch.arange(size, dtype=torch.float)\n    coords -= size // 2\n\n    g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))\n    g /= g.sum()\n\n    return g.unsqueeze(0).unsqueeze(0)\n\n\ndef gaussian_filter(input, win):\n    r\"\"\" Blur input with 1-D kernel\n    Args:\n        input (torch.Tensor): a batch of tensors to be blurred\n        window (torch.Tensor): 1-D gauss kernel\n    Returns:\n        torch.Tensor: blurred tensors\n    \"\"\"\n    assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape\n    if len(input.shape) == 4:\n        conv = F.conv2d\n    elif len(input.shape) == 5:\n        conv = F.conv3d\n    else:\n        raise NotImplementedError(input.shape)\n\n    C = input.shape[1]\n    out = input\n    for i, s in enumerate(input.shape[2:]):\n        if s >= win.shape[-1]:\n            out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C)\n        else:\n            warnings.warn(\n                f\"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}\"\n            )\n    return out\n\n\ndef _ssim(X, Y, data_range, win, size_average=True, K=(0.01, 0.03)):\n\n    r\"\"\" Calculate ssim index for X and Y\n    Args:\n        X (torch.Tensor): images\n        Y (torch.Tensor): images\n        win (torch.Tensor): 1-D gauss kernel\n        data_range (float or int, optional): value range of input images. (usually 1.0 or 255)\n        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar\n    Returns:\n        torch.Tensor: ssim results.\n    \"\"\"\n    K1, K2 = K\n    # batch, channel, [depth,] height, width = X.shape\n    compensation = 1.0\n\n    C1 = (K1 * data_range) ** 2\n    C2 = (K2 * data_range) ** 2\n\n    win = win.to(X.device, dtype=X.dtype)\n\n    mu1 = gaussian_filter(X, win)\n    mu2 = gaussian_filter(Y, win)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq)\n    sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq)\n    sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2)\n\n    cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)  # set alpha=beta=gamma=1\n    ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map\n\n    ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1)\n    cs = torch.flatten(cs_map, 2).mean(-1)\n    return ssim_per_channel, cs\n\n\ndef ssim(\n    X,\n    Y,\n    data_range=255,\n    size_average=True,\n    win_size=11,\n    win_sigma=1.5,\n    win=None,\n    K=(0.01, 0.03),\n    nonnegative_ssim=False,\n):\n    r\"\"\" interface of ssim\n    Args:\n        X (torch.Tensor): a batch of images, (N,C,H,W)\n        Y (torch.Tensor): a batch of images, (N,C,H,W)\n        data_range (float or int, optional): value range of input images. (usually 1.0 or 255)\n        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar\n        win_size: (int, optional): the size of gauss kernel\n        win_sigma: (float, optional): sigma of normal distribution\n        win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma\n        K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.\n        nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu\n    Returns:\n        torch.Tensor: ssim results\n    \"\"\"\n    if not X.shape == Y.shape:\n        raise ValueError(f\"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.\")\n\n    for d in range(len(X.shape) - 1, 1, -1):\n        X = X.squeeze(dim=d)\n        Y = Y.squeeze(dim=d)\n\n    if len(X.shape) not in (4, 5):\n        raise ValueError(f\"Input images should be 4-d or 5-d tensors, but got {X.shape}\")\n\n    if not X.type() == Y.type():\n        raise ValueError(f\"Input images should have the same dtype, but got {X.type()} and {Y.type()}.\")\n\n    if win is not None:  # set win_size\n        win_size = win.shape[-1]\n\n    if not (win_size % 2 == 1):\n        raise ValueError(\"Window size should be odd.\")\n\n    if win is None:\n        win = _fspecial_gauss_1d(win_size, win_sigma)\n        win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))\n\n    ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K)\n    if nonnegative_ssim:\n        ssim_per_channel = torch.relu(ssim_per_channel)\n\n    if size_average:\n        return ssim_per_channel.mean()\n    else:\n        return ssim_per_channel #.mean(1)\n\n\ndef ms_ssim(\n    X, Y, data_range=255, size_average=True, win_size=11, win_sigma=1.5, win=None, weights=None, K=(0.01, 0.03)\n):\n\n    r\"\"\" interface of ms-ssim\n    Args:\n        X (torch.Tensor): a batch of images, (N,C,[T,]H,W)\n        Y (torch.Tensor): a batch of images, (N,C,[T,]H,W)\n        data_range (float or int, optional): value range of input images. (usually 1.0 or 255)\n        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar\n        win_size: (int, optional): the size of gauss kernel\n        win_sigma: (float, optional): sigma of normal distribution\n        win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma\n        weights (list, optional): weights for different levels\n        K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.\n    Returns:\n        torch.Tensor: ms-ssim results\n    \"\"\"\n    if not X.shape == Y.shape:\n        raise ValueError(f\"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.\")\n\n    for d in range(len(X.shape) - 1, 1, -1):\n        X = X.squeeze(dim=d)\n        Y = Y.squeeze(dim=d)\n\n    if not X.type() == Y.type():\n        raise ValueError(f\"Input images should have the same dtype, but got {X.type()} and {Y.type()}.\")\n\n    if len(X.shape) == 4:\n        avg_pool = F.avg_pool2d\n    elif len(X.shape) == 5:\n        avg_pool = F.avg_pool3d\n    else:\n        raise ValueError(f\"Input images should be 4-d or 5-d tensors, but got {X.shape}\")\n\n    if win is not None:  # set win_size\n        win_size = win.shape[-1]\n\n    if not (win_size % 2 == 1):\n        raise ValueError(\"Window size should be odd.\")\n\n    smaller_side = min(X.shape[-2:])\n    assert smaller_side > (win_size - 1) * (\n        2 ** 4\n    ), \"Image size should be larger than %d due to the 4 downsamplings in ms-ssim\" % ((win_size - 1) * (2 ** 4))\n\n    if weights is None:\n        weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]\n    weights = X.new_tensor(weights)\n\n    if win is None:\n        win = _fspecial_gauss_1d(win_size, win_sigma)\n        win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))\n\n    levels = weights.shape[0]\n    mcs = []\n    for i in range(levels):\n        ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K)\n\n        if i < levels - 1:\n            mcs.append(torch.relu(cs))\n            padding = [s % 2 for s in X.shape[2:]]\n            X = avg_pool(X, kernel_size=2, padding=padding)\n            Y = avg_pool(Y, kernel_size=2, padding=padding)\n\n    ssim_per_channel = torch.relu(ssim_per_channel)  # (batch, channel)\n    mcs_and_ssim = torch.stack(mcs + [ssim_per_channel], dim=0)  # (level, batch, channel)\n    ms_ssim_val = torch.prod(mcs_and_ssim ** weights.view(-1, 1, 1), dim=0)\n\n    if size_average:\n        return ms_ssim_val.mean()\n    else:\n        return ms_ssim_val.mean(1)\n\n\nclass SSIM(torch.nn.Module):\n    def __init__(\n        self,\n        data_range=255,\n        size_average=True,\n        win_size=11,\n        win_sigma=1.5,\n        channel=3,\n        spatial_dims=2,\n        K=(0.01, 0.03),\n        nonnegative_ssim=False,\n    ):\n        r\"\"\" class for ssim\n        Args:\n            data_range (float or int, optional): value range of input images. (usually 1.0 or 255)\n            size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar\n            win_size: (int, optional): the size of gauss kernel\n            win_sigma: (float, optional): sigma of normal distribution\n            channel (int, optional): input channels (default: 3)\n            K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.\n            nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu.\n        \"\"\"\n\n        super(SSIM, self).__init__()\n        self.win_size = win_size\n        self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)\n        self.size_average = size_average\n        self.data_range = data_range\n        self.K = K\n        self.nonnegative_ssim = nonnegative_ssim\n\n    def forward(self, X, Y):\n        return ssim(\n            X,\n            Y,\n            data_range=self.data_range,\n            size_average=self.size_average,\n            win=self.win,\n            K=self.K,\n            nonnegative_ssim=self.nonnegative_ssim,\n        )\n\n\nclass MS_SSIM(torch.nn.Module):\n    def __init__(\n        self,\n        data_range=255,\n        size_average=True,\n        win_size=11,\n        win_sigma=1.5,\n        channel=3,\n        spatial_dims=2,\n        weights=None,\n        K=(0.01, 0.03),\n    ):\n        r\"\"\" class for ms-ssim\n        Args:\n            data_range (float or int, optional): value range of input images. (usually 1.0 or 255)\n            size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar\n            win_size: (int, optional): the size of gauss kernel\n            win_sigma: (float, optional): sigma of normal distribution\n            channel (int, optional): input channels (default: 3)\n            weights (list, optional): weights for different levels\n            K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.\n        \"\"\"\n\n        super(MS_SSIM, self).__init__()\n        self.win_size = win_size\n        self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)\n        self.size_average = size_average\n        self.data_range = data_range\n        self.weights = weights\n        self.K = K\n\n    def forward(self, X, Y):\n        return ms_ssim(\n            X,\n            Y,\n            data_range=self.data_range,\n            size_average=self.size_average,\n            win=self.win,\n            weights=self.weights,\n            K=self.K,\n        )\n"
  },
  {
    "path": "evaluation/utils/utils.py",
    "content": "from collections import defaultdict\nimport configparser\nimport os\nimport math\nfrom typing import Optional, List\nimport torch\nimport cv2\nimport numpy as np\nfrom dataclasses import dataclass\nfrom tabulate import tabulate\nimport logging\n\nfrom pytorch3d.structures import Pointclouds\nfrom pytorch3d.transforms import RotateAxisAngle\nfrom pytorch3d.utils import (\n    opencv_from_cameras_projection,\n)\nfrom pytorch3d.renderer import (\n    AlphaCompositor,\n    PointsRasterizationSettings,\n    PointsRasterizer,\n    PointsRenderer,\n)\nfrom stereoanyvideo.evaluation.utils.eval_utils import depth_to_pcd\n\n\n@dataclass\nclass PerceptionPrediction:\n    \"\"\"\n    Holds the tensors that describe a result of any perception module.\n    \"\"\"\n\n    depth_map: Optional[torch.Tensor] = None\n    disparity: Optional[torch.Tensor] = None\n    image_rgb: Optional[torch.Tensor] = None\n    fg_probability: Optional[torch.Tensor] = None\n\n\ndef aggregate_eval_results(per_batch_eval_results, reduction=\"mean\"):\n\n    total_length = 0\n    aggregate_results = defaultdict(list)\n    for result in per_batch_eval_results:\n        if isinstance(result, tuple):\n            reduction = \"sum\"\n            length = result[1]\n            total_length += length\n            result = result[0]\n        for metric, val in result.items():\n            if reduction == \"sum\":\n                aggregate_results[metric].append(val * length)\n\n    if reduction == \"mean\":\n        return {k: torch.cat(v).mean().item() for k, v in aggregate_results.items()}\n    elif reduction == \"sum\":\n        return {\n            k: torch.cat(v).sum().item() / float(total_length)\n            for k, v in aggregate_results.items()\n        }\n\n\ndef aggregate_and_print_results(\n    per_batch_eval_results: List[dict],\n):\n    print(\"\")\n    result = aggregate_eval_results(\n        per_batch_eval_results,\n    )\n    pretty_print_perception_metrics(result)\n    result = {str(k): v for k, v in result.items()}\n\n    print(\"\")\n    return result\n\n\ndef pretty_print_perception_metrics(results):\n\n    metrics = sorted(list(results.keys()), key=lambda x: x.metric)\n\n    print(\"===== Perception results =====\")\n    print(\n        tabulate(\n            [[metric, results[metric]] for metric in metrics],\n        )\n    )\n    logging.info(\"===== Perception results =====\")\n    logging.info(tabulate(\n            [[metric, results[metric]] for metric in metrics],\n        ))\n\n\ndef read_calibration(calibration_file, resolution_string):\n    # ported from https://github.com/stereolabs/zed-open-capture/\n    # blob/dfa0aee51ccd2297782230a05ca59e697df496b2/examples/include/calibration.hpp#L4172\n\n    zed_resolutions = {\n        \"2K\": (1242, 2208),\n        \"FHD\": (1080, 1920),\n        \"HD\": (720, 1280),\n        # \"qHD\": (540, 960),\n        \"VGA\": (376, 672),\n    }\n    assert resolution_string in zed_resolutions.keys()\n    image_height, image_width = zed_resolutions[resolution_string]\n\n    # Open camera configuration file\n    assert os.path.isfile(calibration_file)\n    calib = configparser.ConfigParser()\n    calib.read(calibration_file)\n\n    # Get translations\n    T = np.zeros((3, 1))\n    T[0] = float(calib[\"STEREO\"][\"baseline\"])\n    T[1] = float(calib[\"STEREO\"][\"ty\"])\n    T[2] = float(calib[\"STEREO\"][\"tz\"])\n\n    baseline = T[0]\n\n    # Get left parameters\n    left_cam_cx = float(calib[f\"LEFT_CAM_{resolution_string}\"][\"cx\"])\n    left_cam_cy = float(calib[f\"LEFT_CAM_{resolution_string}\"][\"cy\"])\n    left_cam_fx = float(calib[f\"LEFT_CAM_{resolution_string}\"][\"fx\"])\n    left_cam_fy = float(calib[f\"LEFT_CAM_{resolution_string}\"][\"fy\"])\n    left_cam_k1 = float(calib[f\"LEFT_CAM_{resolution_string}\"][\"k1\"])\n    left_cam_k2 = float(calib[f\"LEFT_CAM_{resolution_string}\"][\"k2\"])\n    left_cam_p1 = float(calib[f\"LEFT_CAM_{resolution_string}\"][\"p1\"])\n    left_cam_p2 = float(calib[f\"LEFT_CAM_{resolution_string}\"][\"p2\"])\n    left_cam_k3 = float(calib[f\"LEFT_CAM_{resolution_string}\"][\"k3\"])\n\n    # Get right parameters\n    right_cam_cx = float(calib[f\"RIGHT_CAM_{resolution_string}\"][\"cx\"])\n    right_cam_cy = float(calib[f\"RIGHT_CAM_{resolution_string}\"][\"cy\"])\n    right_cam_fx = float(calib[f\"RIGHT_CAM_{resolution_string}\"][\"fx\"])\n    right_cam_fy = float(calib[f\"RIGHT_CAM_{resolution_string}\"][\"fy\"])\n    right_cam_k1 = float(calib[f\"RIGHT_CAM_{resolution_string}\"][\"k1\"])\n    right_cam_k2 = float(calib[f\"RIGHT_CAM_{resolution_string}\"][\"k2\"])\n    right_cam_p1 = float(calib[f\"RIGHT_CAM_{resolution_string}\"][\"p1\"])\n    right_cam_p2 = float(calib[f\"RIGHT_CAM_{resolution_string}\"][\"p2\"])\n    right_cam_k3 = float(calib[f\"RIGHT_CAM_{resolution_string}\"][\"k3\"])\n\n    # Get rotations\n    R_zed = np.zeros(3)\n    R_zed[0] = float(calib[\"STEREO\"][f\"rx_{resolution_string.lower()}\"])\n    R_zed[1] = float(calib[\"STEREO\"][f\"cv_{resolution_string.lower()}\"])\n    R_zed[2] = float(calib[\"STEREO\"][f\"rz_{resolution_string.lower()}\"])\n\n    R = cv2.Rodrigues(R_zed)[0]\n\n    # Left\n    cameraMatrix_left = np.array(\n        [[left_cam_fx, 0, left_cam_cx], [0, left_cam_fy, left_cam_cy], [0, 0, 1]]\n    )\n    distCoeffs_left = np.array(\n        [left_cam_k1, left_cam_k2, left_cam_p1, left_cam_p2, left_cam_k3]\n    )\n\n    # Right\n    cameraMatrix_right = np.array(\n        [\n            [right_cam_fx, 0, right_cam_cx],\n            [0, right_cam_fy, right_cam_cy],\n            [0, 0, 1],\n        ]\n    )\n    distCoeffs_right = np.array(\n        [right_cam_k1, right_cam_k2, right_cam_p1, right_cam_p2, right_cam_k3]\n    )\n\n    # Stereo\n    R1, R2, P1, P2, Q = cv2.stereoRectify(\n        cameraMatrix1=cameraMatrix_left,\n        distCoeffs1=distCoeffs_left,\n        cameraMatrix2=cameraMatrix_right,\n        distCoeffs2=distCoeffs_right,\n        imageSize=(image_width, image_height),\n        R=R,\n        T=T,\n        flags=cv2.CALIB_ZERO_DISPARITY,\n        newImageSize=(image_width, image_height),\n        alpha=0,\n    )[:5]\n\n    # Precompute maps for cv::remap()\n    map_left_x, map_left_y = cv2.initUndistortRectifyMap(\n        cameraMatrix_left,\n        distCoeffs_left,\n        R1,\n        P1,\n        (image_width, image_height),\n        cv2.CV_32FC1,\n    )\n    map_right_x, map_right_y = cv2.initUndistortRectifyMap(\n        cameraMatrix_right,\n        distCoeffs_right,\n        R2,\n        P2,\n        (image_width, image_height),\n        cv2.CV_32FC1,\n    )\n\n    zed_calib = {\n        \"map_left_x\": map_left_x,\n        \"map_left_y\": map_left_y,\n        \"map_right_x\": map_right_x,\n        \"map_right_y\": map_right_y,\n        \"pose_left\": P1,\n        \"pose_right\": P2,\n        \"baseline\": baseline,\n        \"image_width\": image_width,\n        \"image_height\": image_height,\n    }\n\n    return zed_calib\n\n\ndef filter_depth_discontinuities(pcd, depth_map, threshold=5):\n    \"\"\"\n    Removes points that belong to high-depth discontinuity regions.\n\n    Args:\n        pcd (torch.Tensor): Nx3 point cloud tensor.\n        depth_map (torch.Tensor): HxW depth map.\n        threshold (float): Depth change threshold.\n\n    Returns:\n        torch.Tensor: Filtered point cloud.\n    \"\"\"\n    # Compute depth differences in x and y directions\n    depth_diff_x = torch.abs(depth_map[:, 1:] - depth_map[:, :-1])  # Shape (H, W-1)\n    depth_diff_y = torch.abs(depth_map[1:, :] - depth_map[:-1, :])  # Shape (H-1, W)\n\n    # Initialize mask with all True (valid points)\n    mask = torch.ones_like(depth_map, dtype=torch.bool)  # Shape (H, W)\n\n    # Apply filtering: set False where depth difference is too large\n    mask[:, :-1] &= depth_diff_x <= threshold  # X-direction filtering\n    mask[:-1, :] &= depth_diff_y <= threshold  # Y-direction filtering\n\n    # Flatten mask to match point cloud size\n    mask_flat = mask.flatten()[: pcd.shape[0]]\n\n    return pcd[mask_flat]  # Return only valid points\n\n\n\ndef visualize_batch(\n    batch_dict: dict,\n    preds: PerceptionPrediction,\n    output_dir: str,\n    ref_frame: int = 0,\n    only_foreground=False,\n    step=0,\n    sequence_name=None,\n    writer=None,\n    render_bin_size=None\n):\n    os.makedirs(output_dir, exist_ok=True)\n\n    outputs = {}\n\n    if preds.depth_map is not None:\n        device = preds.depth_map.device\n\n        pcd_global_seq = []\n        H, W = batch_dict[\"stereo_video\"].shape[3:]\n\n        for i in range(len(batch_dict[\"stereo_video\"])):\n            if hasattr(preds, 'perspective_cameras'):\n                R, T, K = opencv_from_cameras_projection(\n                    preds.perspective_cameras[i],\n                    torch.tensor([H, W])[None].to(device),\n                )  # 1x3x3, 1x3, 1x3x3\n            else:\n                raise KeyError(f\"R T K not found!\")\n            extrinsic_3x4_0 = torch.cat([R[0], T[0, :, None]], dim=1)\n            extr_matrix = torch.cat(\n                [\n                    extrinsic_3x4_0,\n                    torch.Tensor([[0, 0, 0, 1]]).to(extrinsic_3x4_0.device),\n                ],\n                dim=0,\n            )\n            inv_extr_matrix = extr_matrix.inverse().to(device)\n            pcd, colors = depth_to_pcd(\n                preds.depth_map[i, 0],\n                batch_dict[\"stereo_video\"][i][0].permute(1, 2, 0),\n                K[0][0][0],\n                K[0][0][2],\n                K[0][1][2],\n                step=1,\n                inv_extrinsic=inv_extr_matrix,\n                mask=batch_dict[\"fg_mask\"][i, 0] if only_foreground else None,\n                filter=False,\n            )\n            R, T = inv_extr_matrix[None, :3, :3], inv_extr_matrix[None, :3, 3]\n            pcd_global_seq.append((pcd, colors, (R, T, preds.perspective_cameras[i])))\n\n        raster_settings = PointsRasterizationSettings(\n            image_size=[H, W],\n            radius=0.003,\n            points_per_pixel=10,\n        )\n        R, T, cam_ = pcd_global_seq[ref_frame][2]\n        median_depth = preds.depth_map.median()\n        cam_.cuda()\n\n        for mode in [\"angle_15\", \"angle_-15\", \"angle_0\", \"changing_angle\"]:\n            res = []\n            for t, (pcd, color, __) in enumerate(pcd_global_seq):\n\n                if mode == \"changing_angle\":\n                    angle = math.cos((math.pi) * (t / 60)) * 15\n                elif mode == \"angle_15\":\n                    angle = 15\n                elif mode == \"angle_-15\":\n                    angle = -15\n                elif mode == \"angle_0\":\n                    angle = 0\n\n                delta_x = median_depth * math.sin(math.radians(angle))\n                delta_z = median_depth * (1 - math.cos(math.radians(angle)))\n\n                cam = cam_.clone()\n                cam.R = torch.bmm(\n                    cam.R,\n                    RotateAxisAngle(angle=angle, axis=\"Y\", device=device).get_matrix()[\n                        :, :3, :3\n                    ],\n                )\n                cam.T[0, 0] = cam.T[0, 0] - delta_x\n                cam.T[0, 2] = cam.T[0, 2] - delta_z + median_depth / 2.0\n\n                rasterizer = PointsRasterizer(\n                    cameras=cam, raster_settings=raster_settings\n                )\n                renderer = PointsRenderer(\n                    rasterizer=rasterizer,\n                    compositor=AlphaCompositor(background_color=(1, 1, 1)),\n                )\n                pcd_copy = pcd.clone()\n                point_cloud = Pointclouds(points=[pcd_copy], features=[color / 255.0])\n\n                images = renderer(point_cloud)\n                res.append(images[0, ..., :3].cpu())\n            res = torch.stack(res)\n            video = (res * 255).numpy().astype(np.uint8)\n            save_name = f\"{sequence_name}_reconstruction_{step}_mode_{mode}_\"\n\n            if writer is None:\n                outputs[mode] = video\n            if only_foreground:\n                save_name += \"fg_only\"\n            else:\n                save_name += \"full_scene\"\n            video_out = cv2.VideoWriter(\n                os.path.join(\n                    output_dir,\n                    f\"{save_name}.mp4\",\n                ),\n                cv2.VideoWriter_fourcc(*\"mp4v\"),\n                fps=30,\n                frameSize=(res.shape[2], res.shape[1]),\n                isColor=True,\n            )\n            filename = os.path.join(output_dir, sequence_name + '_img_')\n            if not os.path.isdir(filename + str(mode)):\n                os.mkdir(filename + str(mode))\n            for i in range(len(video)):\n                filename_temp = filename + str(mode) + '/' + str(i).zfill(3) + '.png'\n                cv2.imwrite(filename_temp, cv2.cvtColor(video[i], cv2.COLOR_BGR2RGB))\n                video_out.write(cv2.cvtColor(video[i], cv2.COLOR_BGR2RGB))\n            video_out.release()\n\n            if writer is not None:\n                writer.add_video(\n                    f\"{sequence_name}_reconstruction_mode_{mode}\",\n                    (res * 255).permute(0, 3, 1, 2).to(torch.uint8)[None],\n                    global_step=step,\n                    fps=30,\n                )\n\n    return outputs\n"
  },
  {
    "path": "models/Video-Depth-Anything/app.py",
    "content": "# Copyright (2025) Bytedance Ltd. and/or its affiliates \n\n# Licensed under the Apache License, Version 2.0 (the \"License\"); \n# you may not use this file except in compliance with the License. \n# You may obtain a copy of the License at \n\n#     http://www.apache.org/licenses/LICENSE-2.0 \n\n# Unless required by applicable law or agreed to in writing, software \n# distributed under the License is distributed on an \"AS IS\" BASIS, \n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n# See the License for the specific language governing permissions and \n# limitations under the License. \nimport gradio as gr\n\n\nimport numpy as np\nimport os\nimport torch\n\nfrom video_depth_anything.video_depth import VideoDepthAnything\nfrom utils.dc_utils import read_video_frames, vis_sequence_depth, save_video\n\nexamples = [\n    ['assets/example_videos/davis_rollercoaster.mp4'],\n]\n\nmodel_configs = {\n    'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},\n    'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},\n}\n\nencoder='vitl'\n\nvideo_depth_anything = VideoDepthAnything(**model_configs[encoder])\nvideo_depth_anything.load_state_dict(torch.load(f'./checkpoints/video_depth_anything_{encoder}.pth', map_location='cpu'), strict=True)\nvideo_depth_anything = video_depth_anything.to('cuda').eval()\n\n\ndef infer_video_depth(\n    input_video: str,\n    max_len: int = -1,\n    target_fps: int = -1,\n    max_res: int = 1280,\n    output_dir: str = './outputs',\n    input_size: int = 518,\n):\n    frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)\n    depth_list, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device='cuda')\n    depth_list = np.stack(depth_list, axis=0)\n    vis = vis_sequence_depth(depth_list)\n    video_name = os.path.basename(input_video)\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_src.mp4')\n    depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')\n    save_video(frames, processed_video_path, fps=fps)\n    save_video(vis, depth_vis_path, fps=fps)\n\n    return [processed_video_path, depth_vis_path]\n\n\ndef construct_demo():\n    with gr.Blocks(analytics_enabled=False) as demo:\n        gr.Markdown(\n            f\"\"\"\n            blablabla\n            \"\"\"\n        )\n\n        with gr.Row(equal_height=True):\n            with gr.Column(scale=1):\n                input_video = gr.Video(label=\"Input Video\")\n\n            # with gr.Tab(label=\"Output\"):\n            with gr.Column(scale=2):\n                with gr.Row(equal_height=True):\n                    processed_video = gr.Video(\n                        label=\"Preprocessed video\",\n                        interactive=False,\n                        autoplay=True,\n                        loop=True,\n                        show_share_button=True,\n                        scale=5,\n                    )\n                    depth_vis_video = gr.Video(\n                        label=\"Generated Depth Video\",\n                        interactive=False,\n                        autoplay=True,\n                        loop=True,\n                        show_share_button=True,\n                        scale=5,\n                    )\n\n        with gr.Row(equal_height=True):\n            with gr.Column(scale=1):\n                with gr.Row(equal_height=False):\n                    with gr.Accordion(\"Advanced Settings\", open=False):\n                        max_len = gr.Slider(\n                            label=\"max process length\",\n                            minimum=-1,\n                            maximum=1000,\n                            value=-1,\n                            step=1,\n                        )\n                        target_fps = gr.Slider(\n                            label=\"target FPS\",\n                            minimum=-1,\n                            maximum=30,\n                            value=15,\n                            step=1,\n                        )\n                        max_res = gr.Slider(\n                            label=\"max side resolution\",\n                            minimum=480,\n                            maximum=1920,\n                            value=1280,\n                            step=1,\n                        )\n                    generate_btn = gr.Button(\"Generate\")\n            with gr.Column(scale=2):\n                pass\n\n        gr.Examples(\n            examples=examples,\n            inputs=[\n                input_video,\n                max_len,\n                target_fps,\n                max_res\n            ],\n            outputs=[processed_video, depth_vis_video],\n            fn=infer_video_depth,\n            cache_examples=\"lazy\",\n        )\n\n        generate_btn.click(\n            fn=infer_video_depth,\n            inputs=[\n                input_video,\n                max_len,\n                target_fps,\n                max_res\n            ],\n            outputs=[processed_video, depth_vis_video],\n        )\n\n    return demo\n\nif __name__ == \"__main__\":\n    demo = construct_demo()\n    demo.queue()\n    demo.launch(server_name=\"0.0.0.0\")"
  },
  {
    "path": "models/Video-Depth-Anything/get_weights.sh",
    "content": "#!/bin/bash\n\nmkdir checkpoints\ncd checkpoints\nwget https://huggingface.co/depth-anything/Video-Depth-Anything-Small/resolve/main/video_depth_anything_vits.pth\nwget https://huggingface.co/depth-anything/Video-Depth-Anything-Large/resolve/main/video_depth_anything_vitl.pth"
  },
  {
    "path": "models/Video-Depth-Anything/run.py",
    "content": "# Copyright (2025) Bytedance Ltd. and/or its affiliates \n\n# Licensed under the Apache License, Version 2.0 (the \"License\"); \n# you may not use this file except in compliance with the License. \n# You may obtain a copy of the License at \n\n#     http://www.apache.org/licenses/LICENSE-2.0 \n\n# Unless required by applicable law or agreed to in writing, software \n# distributed under the License is distributed on an \"AS IS\" BASIS, \n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n# See the License for the specific language governing permissions and \n# limitations under the License. \nimport argparse\nimport numpy as np\nimport os\nimport torch\n\nfrom video_depth_anything.video_depth import VideoDepthAnything\nfrom utils.dc_utils import read_video_frames, vis_sequence_depth, save_video\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Video Depth Anything')\n    parser.add_argument('--input_video', type=str, default='./assets/example_videos/davis_rollercoaster.mp4')\n    parser.add_argument('--output_dir', type=str, default='./outputs')\n    parser.add_argument('--input_size', type=int, default=518)\n    parser.add_argument('--max_res', type=int, default=1280)\n    parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitl'])\n    parser.add_argument('--max_len', type=int, default=-1, help='maximum length of the input video, -1 means no limit')\n    parser.add_argument('--target_fps', type=int, default=-1, help='target fps of the input video, -1 means the original fps')\n\n    args = parser.parse_args()\n\n    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n\n    model_configs = {\n        'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},\n        'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},\n    }\n\n    video_depth_anything = VideoDepthAnything(**model_configs[args.encoder])\n    video_depth_anything.load_state_dict(torch.load(f'./checkpoints/video_depth_anything_{args.encoder}.pth', map_location='cpu'), strict=True)\n    video_depth_anything = video_depth_anything.to(DEVICE).eval()\n\n    frames, target_fps = read_video_frames(args.input_video, args.max_len, args.target_fps, args.max_res)\n    depth_list, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=args.input_size, device=DEVICE)\n    depth_list = np.stack(depth_list, axis=0)\n    vis = vis_sequence_depth(depth_list)\n    video_name = os.path.basename(args.input_video)\n    if not os.path.exists(args.output_dir):\n        os.makedirs(args.output_dir)\n\n    processed_video_path = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_src.mp4')\n    depth_vis_path = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')\n    save_video(frames, processed_video_path, fps=fps)\n    save_video(vis, depth_vis_path, fps=fps)\n\n    \n\n\n"
  },
  {
    "path": "models/Video-Depth-Anything/utils/dc_utils.py",
    "content": "# This file is originally from DepthCrafter/depthcrafter/utils.py at main · Tencent/DepthCrafter\n# SPDX-License-Identifier: MIT License license\n#\n# This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]\n# Original file is released under [ MIT License license], with the full license text available at [https://github.com/Tencent/DepthCrafter?tab=License-1-ov-file].\nfrom typing import Union, List\nimport tempfile\nimport numpy as np\nimport PIL.Image\nimport matplotlib.cm as cm\nimport mediapy\nimport torch\ntry:\n    from decord import VideoReader, cpu\n    DECORD_AVAILABLE = True\nexcept:\n    import cv2\n    DECORD_AVAILABLE = False\n\n\ndef read_video_frames(video_path, process_length, target_fps=-1, max_res=-1, dataset=\"open\"):\n    if DECORD_AVAILABLE:\n        vid = VideoReader(video_path, ctx=cpu(0))\n        original_height, original_width = vid.get_batch([0]).shape[1:3]\n        height = original_height\n        width = original_width\n        if max_res > 0 and max(height, width) > max_res:\n            scale = max_res / max(original_height, original_width)\n            height = round(original_height * scale)\n            width = round(original_width * scale)\n\n        vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)\n\n        fps = vid.get_avg_fps() if target_fps == -1 else target_fps\n        stride = round(vid.get_avg_fps() / fps)\n        stride = max(stride, 1)\n        frames_idx = list(range(0, len(vid), stride))\n        if process_length != -1 and process_length < len(frames_idx):\n            frames_idx = frames_idx[:process_length]\n        frames = vid.get_batch(frames_idx).asnumpy()\n    else:\n        cap = cv2.VideoCapture(video_path)\n        original_fps = cap.get(cv2.CAP_PROP_FPS)\n        original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n        original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n\n        if max_res > 0 and max(original_height, original_width) > max_res:\n            scale = max_res / max(original_height, original_width)\n            height = round(original_height * scale)\n            width = round(original_width * scale)\n\n        fps = original_fps if target_fps < 0 else target_fps\n\n        stride = max(round(original_fps / fps), 1)\n\n        frames = []\n        frame_count = 0\n        while cap.isOpened():\n            ret, frame = cap.read()\n            if not ret or (process_length > 0 and frame_count >= process_length):\n                break\n            if frame_count % stride == 0:\n                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB\n                if max_res > 0 and max(original_height, original_width) > max_res:\n                    frame = cv2.resize(frame, (width, height))  # Resize frame\n                frames.append(frame)\n            frame_count += 1\n        cap.release()\n        frames = np.stack(frames, axis=0)\n\n    return frames, fps\n\n\ndef save_video(\n    video_frames: Union[List[np.ndarray], List[PIL.Image.Image]],\n    output_video_path: str = None,\n    fps: int = 10,\n    crf: int = 18,\n) -> str:\n    if output_video_path is None:\n        output_video_path = tempfile.NamedTemporaryFile(suffix=\".mp4\").name\n\n    if isinstance(video_frames[0], np.ndarray):\n        video_frames = [frame.astype(np.uint8) for frame in video_frames]\n\n    elif isinstance(video_frames[0], PIL.Image.Image):\n        video_frames = [np.array(frame) for frame in video_frames]\n    mediapy.write_video(output_video_path, video_frames, fps=fps, crf=crf)\n    return output_video_path\n\n\nclass ColorMapper:\n    # a color mapper to map depth values to a certain colormap\n    def __init__(self, colormap: str = \"inferno\"):\n        self.colormap = torch.tensor(cm.get_cmap(colormap).colors)\n\n    def apply(self, image: torch.Tensor, v_min=None, v_max=None):\n        # assert len(image.shape) == 2\n        if v_min is None:\n            v_min = image.min()\n        if v_max is None:\n            v_max = image.max()\n        image = (image - v_min) / (v_max - v_min)\n        image = (image * 255).long()\n        image = self.colormap[image] * 255\n        return image\n\n\ndef vis_sequence_depth(depths: np.ndarray, v_min=None, v_max=None):\n    visualizer = ColorMapper()\n    if v_min is None:\n        v_min = depths.min()\n    if v_max is None:\n        v_max = depths.max()\n    res = visualizer.apply(torch.tensor(depths), v_min=v_min, v_max=v_max).numpy()\n    return res\n"
  },
  {
    "path": "models/Video-Depth-Anything/utils/util.py",
    "content": "# Copyright (2025) Bytedance Ltd. and/or its affiliates \n\n# Licensed under the Apache License, Version 2.0 (the \"License\"); \n# you may not use this file except in compliance with the License. \n# You may obtain a copy of the License at \n\n#     http://www.apache.org/licenses/LICENSE-2.0 \n\n# Unless required by applicable law or agreed to in writing, software \n# distributed under the License is distributed on an \"AS IS\" BASIS, \n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n# See the License for the specific language governing permissions and \n# limitations under the License. \nimport numpy as np\n\ndef compute_scale_and_shift(prediction, target, mask, scale_only=False):\n    if scale_only:\n        return compute_scale(prediction, target, mask), 0\n    else:\n        return compute_scale_and_shift_full(prediction, target, mask)\n\n\ndef compute_scale(prediction, target, mask):\n    # system matrix: A = [[a_00, a_01], [a_10, a_11]]\n    prediction = prediction.astype(np.float32)\n    target = target.astype(np.float32)\n    mask = mask.astype(np.float32)\n\n    a_00 = np.sum(mask * prediction * prediction)\n    a_01 = np.sum(mask * prediction)\n    a_11 = np.sum(mask)\n\n    # right hand side: b = [b_0, b_1]\n    b_0 = np.sum(mask * prediction * target)\n\n    x_0 = b_0 / (a_00 + 1e-6)\n\n    return x_0\n\ndef compute_scale_and_shift_full(prediction, target, mask):\n    # system matrix: A = [[a_00, a_01], [a_10, a_11]]\n    prediction = prediction.astype(np.float32)\n    target = target.astype(np.float32)\n    mask = mask.astype(np.float32)\n\n    a_00 = np.sum(mask * prediction * prediction)\n    a_01 = np.sum(mask * prediction)\n    a_11 = np.sum(mask)\n\n    b_0 = np.sum(mask * prediction * target)\n    b_1 = np.sum(mask * target)\n\n    x_0 = 1\n    x_1 = 0\n\n    det = a_00 * a_11 - a_01 * a_01\n\n    if det != 0:\n        x_0 = (a_11 * b_0 - a_01 * b_1) / det\n        x_1 = (-a_01 * b_0 + a_00 * b_1) / det\n\n    return x_0, x_1\n\n\ndef get_interpolate_frames(frame_list_pre, frame_list_post):\n    assert len(frame_list_pre) == len(frame_list_post)\n    min_w = 0.0\n    max_w = 1.0\n    step = (max_w - min_w) / (len(frame_list_pre)-1)\n    post_w_list = [min_w] + [i * step for i in range(1,len(frame_list_pre)-1)] + [max_w]\n    interpolated_frames = []\n    for i in range(len(frame_list_pre)):\n        interpolated_frames.append(frame_list_pre[i] * (1-post_w_list[i]) + frame_list_post[i] * post_w_list[i])\n    return interpolated_frames"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/main/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py\n\nfrom functools import partial\nimport math\nimport logging\nfrom typing import Sequence, Tuple, Union, Callable\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint\nfrom torch.nn.init import trunc_normal_\n\nfrom .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block\n\n\nlogger = logging.getLogger(\"dinov2\")\n\n\ndef named_apply(fn: Callable, module: nn.Module, name=\"\", depth_first=True, include_root=False) -> nn.Module:\n    if not depth_first and include_root:\n        fn(module=module, name=name)\n    for child_name, child_module in module.named_children():\n        child_name = \".\".join((name, child_name)) if name else child_name\n        named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)\n    if depth_first and include_root:\n        fn(module=module, name=name)\n    return module\n\n\nclass BlockChunk(nn.ModuleList):\n    def forward(self, x):\n        for b in self:\n            x = b(x)\n        return x\n\n\nclass DinoVisionTransformer(nn.Module):\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=16,\n        in_chans=3,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        ffn_bias=True,\n        proj_bias=True,\n        drop_path_rate=0.0,\n        drop_path_uniform=False,\n        init_values=None,  # for layerscale: None or 0 => no layerscale\n        embed_layer=PatchEmbed,\n        act_layer=nn.GELU,\n        block_fn=Block,\n        ffn_layer=\"mlp\",\n        block_chunks=1,\n        num_register_tokens=0,\n        interpolate_antialias=False,\n        interpolate_offset=0.1,\n    ):\n        \"\"\"\n        Args:\n            img_size (int, tuple): input image size\n            patch_size (int, tuple): patch size\n            in_chans (int): number of input channels\n            embed_dim (int): embedding dimension\n            depth (int): depth of transformer\n            num_heads (int): number of attention heads\n            mlp_ratio (int): ratio of mlp hidden dim to embedding dim\n            qkv_bias (bool): enable bias for qkv if True\n            proj_bias (bool): enable bias for proj in attn if True\n            ffn_bias (bool): enable bias for ffn if True\n            drop_path_rate (float): stochastic depth rate\n            drop_path_uniform (bool): apply uniform drop rate across blocks\n            weight_init (str): weight init scheme\n            init_values (float): layer-scale init values\n            embed_layer (nn.Module): patch embedding layer\n            act_layer (nn.Module): MLP activation layer\n            block_fn (nn.Module): transformer block class\n            ffn_layer (str): \"mlp\", \"swiglu\", \"swiglufused\" or \"identity\"\n            block_chunks: (int) split block sequence into block_chunks units for FSDP wrap\n            num_register_tokens: (int) number of extra cls tokens (so-called \"registers\")\n            interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings\n            interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings\n        \"\"\"\n        super().__init__()\n        norm_layer = partial(nn.LayerNorm, eps=1e-6)\n\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n        self.num_tokens = 1\n        self.n_blocks = depth\n        self.num_heads = num_heads\n        self.patch_size = patch_size\n        self.num_register_tokens = num_register_tokens\n        self.interpolate_antialias = interpolate_antialias\n        self.interpolate_offset = interpolate_offset\n\n        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))\n        assert num_register_tokens >= 0\n        self.register_tokens = (\n            nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None\n        )\n\n        if drop_path_uniform is True:\n            dpr = [drop_path_rate] * depth\n        else:\n            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n\n        if ffn_layer == \"mlp\":\n            logger.info(\"using MLP layer as FFN\")\n            ffn_layer = Mlp\n        elif ffn_layer == \"swiglufused\" or ffn_layer == \"swiglu\":\n            logger.info(\"using SwiGLU layer as FFN\")\n            ffn_layer = SwiGLUFFNFused\n        elif ffn_layer == \"identity\":\n            logger.info(\"using Identity layer as FFN\")\n\n            def f(*args, **kwargs):\n                return nn.Identity()\n\n            ffn_layer = f\n        else:\n            raise NotImplementedError\n\n        blocks_list = [\n            block_fn(\n                dim=embed_dim,\n                num_heads=num_heads,\n                mlp_ratio=mlp_ratio,\n                qkv_bias=qkv_bias,\n                proj_bias=proj_bias,\n                ffn_bias=ffn_bias,\n                drop_path=dpr[i],\n                norm_layer=norm_layer,\n                act_layer=act_layer,\n                ffn_layer=ffn_layer,\n                init_values=init_values,\n            )\n            for i in range(depth)\n        ]\n        if block_chunks > 0:\n            self.chunked_blocks = True\n            chunked_blocks = []\n            chunksize = depth // block_chunks\n            for i in range(0, depth, chunksize):\n                # this is to keep the block index consistent if we chunk the block list\n                chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])\n            self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])\n        else:\n            self.chunked_blocks = False\n            self.blocks = nn.ModuleList(blocks_list)\n\n        self.norm = norm_layer(embed_dim)\n        self.head = nn.Identity()\n\n        self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))\n\n        self.init_weights()\n\n    def init_weights(self):\n        trunc_normal_(self.pos_embed, std=0.02)\n        nn.init.normal_(self.cls_token, std=1e-6)\n        if self.register_tokens is not None:\n            nn.init.normal_(self.register_tokens, std=1e-6)\n        named_apply(init_weights_vit_timm, self)\n\n    def interpolate_pos_encoding(self, x, w, h):\n        previous_dtype = x.dtype\n        npatch = x.shape[1] - 1\n        N = self.pos_embed.shape[1] - 1\n        if npatch == N and w == h:\n            return self.pos_embed\n        pos_embed = self.pos_embed.float()\n        class_pos_embed = pos_embed[:, 0]\n        patch_pos_embed = pos_embed[:, 1:]\n        dim = x.shape[-1]\n        w0 = w // self.patch_size\n        h0 = h // self.patch_size\n        # we add a small number to avoid floating point error in the interpolation\n        # see discussion at https://github.com/facebookresearch/dino/issues/8\n        # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0\n        w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset\n        # w0, h0 = w0 + 0.1, h0 + 0.1\n        \n        sqrt_N = math.sqrt(N)\n        sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N\n        patch_pos_embed = nn.functional.interpolate(\n            patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),\n            scale_factor=(sx, sy),\n            # (int(w0), int(h0)), # to solve the upsampling shape issue\n            mode=\"bicubic\",\n            antialias=self.interpolate_antialias\n        )\n        \n        assert int(w0) == patch_pos_embed.shape[-2]\n        assert int(h0) == patch_pos_embed.shape[-1]\n        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)\n\n    def prepare_tokens_with_masks(self, x, masks=None):\n        B, nc, w, h = x.shape\n        x = self.patch_embed(x)\n        if masks is not None:\n            x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)\n\n        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)\n        x = x + self.interpolate_pos_encoding(x, w, h)\n\n        if self.register_tokens is not None:\n            x = torch.cat(\n                (\n                    x[:, :1],\n                    self.register_tokens.expand(x.shape[0], -1, -1),\n                    x[:, 1:],\n                ),\n                dim=1,\n            )\n\n        return x\n\n    def forward_features_list(self, x_list, masks_list):\n        x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]\n        for blk in self.blocks:\n            x = blk(x)\n\n        all_x = x\n        output = []\n        for x, masks in zip(all_x, masks_list):\n            x_norm = self.norm(x)\n            output.append(\n                {\n                    \"x_norm_clstoken\": x_norm[:, 0],\n                    \"x_norm_regtokens\": x_norm[:, 1 : self.num_register_tokens + 1],\n                    \"x_norm_patchtokens\": x_norm[:, self.num_register_tokens + 1 :],\n                    \"x_prenorm\": x,\n                    \"masks\": masks,\n                }\n            )\n        return output\n\n    def forward_features(self, x, masks=None):\n        if isinstance(x, list):\n            return self.forward_features_list(x, masks)\n\n        x = self.prepare_tokens_with_masks(x, masks)\n\n        for blk in self.blocks:\n            x = blk(x)\n\n        x_norm = self.norm(x)\n        return {\n            \"x_norm_clstoken\": x_norm[:, 0],\n            \"x_norm_regtokens\": x_norm[:, 1 : self.num_register_tokens + 1],\n            \"x_norm_patchtokens\": x_norm[:, self.num_register_tokens + 1 :],\n            \"x_prenorm\": x,\n            \"masks\": masks,\n        }\n\n    def _get_intermediate_layers_not_chunked(self, x, n=1):\n        x = self.prepare_tokens_with_masks(x)\n        # If n is an int, take the n last blocks. If it's a list, take them\n        output, total_block_len = [], len(self.blocks)\n        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n\n        for i, blk in enumerate(self.blocks):\n            x = blk(x)\n            if i in blocks_to_take:\n                output.append(x)\n        assert len(output) == len(blocks_to_take), f\"only {len(output)} / {len(blocks_to_take)} blocks found\"\n        return output\n\n    def _get_intermediate_layers_chunked(self, x, n=1):\n        x = self.prepare_tokens_with_masks(x)\n        output, i, total_block_len = [], 0, len(self.blocks[-1])\n        # If n is an int, take the n last blocks. If it's a list, take them\n        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n\n        for block_chunk in self.blocks:\n            for blk in block_chunk[i:]:  # Passing the nn.Identity()\n                x = blk(x)\n                if i in blocks_to_take:\n                    output.append(x)\n                i += 1\n        assert len(output) == len(blocks_to_take), f\"only {len(output)} / {len(blocks_to_take)} blocks found\"\n        return output\n\n    def get_intermediate_layers(\n        self,\n        x: torch.Tensor,\n        n: Union[int, Sequence] = 1,  # Layers or n last layers to take\n        reshape: bool = False,\n        return_class_token: bool = False,\n        norm=True\n    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:\n        if self.chunked_blocks:\n            outputs = self._get_intermediate_layers_chunked(x, n)\n        else:\n            outputs = self._get_intermediate_layers_not_chunked(x, n)\n        if norm:\n            outputs = [self.norm(out) for out in outputs]\n        class_tokens = [out[:, 0] for out in outputs]\n        outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]\n        if reshape:\n            B, _, w, h = x.shape\n            outputs = [\n                out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()\n                for out in outputs\n            ]\n        if return_class_token:\n            return tuple(zip(outputs, class_tokens))\n        return tuple(outputs)\n\n    def forward(self, *args, is_training=False, **kwargs):\n        ret = self.forward_features(*args, **kwargs)\n        if is_training:\n            return ret\n        else:\n            return self.head(ret[\"x_norm_clstoken\"])\n\n\ndef init_weights_vit_timm(module: nn.Module, name: str = \"\"):\n    \"\"\"ViT weight initialization, original timm impl (for reproducibility)\"\"\"\n    if isinstance(module, nn.Linear):\n        trunc_normal_(module.weight, std=0.02)\n        if module.bias is not None:\n            nn.init.zeros_(module.bias)\n\n\ndef vit_small(patch_size=16, num_register_tokens=0, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=384,\n        depth=12,\n        num_heads=6,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        num_register_tokens=num_register_tokens,\n        **kwargs,\n    )\n    return model\n\n\ndef vit_base(patch_size=16, num_register_tokens=0, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        num_register_tokens=num_register_tokens,\n        **kwargs,\n    )\n    return model\n\n\ndef vit_large(patch_size=16, num_register_tokens=0, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=1024,\n        depth=24,\n        num_heads=16,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        num_register_tokens=num_register_tokens,\n        **kwargs,\n    )\n    return model\n\n\ndef vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):\n    \"\"\"\n    Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64\n    \"\"\"\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=1536,\n        depth=40,\n        num_heads=24,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        num_register_tokens=num_register_tokens,\n        **kwargs,\n    )\n    return model\n\n\ndef DINOv2(model_name):\n    model_zoo = {\n        \"vits\": vit_small, \n        \"vitb\": vit_base, \n        \"vitl\": vit_large, \n        \"vitg\": vit_giant2\n    }\n    \n    return model_zoo[model_name](\n        img_size=518,\n        patch_size=14,\n        init_values=1.0,\n        ffn_layer=\"mlp\" if model_name != \"vitg\" else \"swiglufused\",\n        block_chunks=0,\n        num_register_tokens=0,\n        interpolate_antialias=False,\n        interpolate_offset=0.1\n    )\n"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2_layers/__init__.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nfrom .mlp import Mlp\nfrom .patch_embed import PatchEmbed\nfrom .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused\nfrom .block import NestedTensorBlock\nfrom .attention import MemEffAttention\n"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2_layers/attention.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py\n\nimport logging\n\nfrom torch import Tensor\nfrom torch import nn\n\n\nlogger = logging.getLogger(\"dinov2\")\n\n\ntry:\n    from xformers.ops import memory_efficient_attention, unbind, fmha\n\n    XFORMERS_AVAILABLE = True\nexcept ImportError:\n    logger.warning(\"xFormers not available\")\n    XFORMERS_AVAILABLE = False\n\n\nclass Attention(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int = 8,\n        qkv_bias: bool = False,\n        proj_bias: bool = True,\n        attn_drop: float = 0.0,\n        proj_drop: float = 0.0,\n    ) -> None:\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = head_dim**-0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim, bias=proj_bias)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x: Tensor) -> Tensor:\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n\n        q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]\n        attn = q @ k.transpose(-2, -1)\n\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass MemEffAttention(Attention):\n    def forward(self, x: Tensor, attn_bias=None) -> Tensor:\n        if not XFORMERS_AVAILABLE:\n            assert attn_bias is None, \"xFormers is required for nested tensors usage\"\n            return super().forward(x)\n\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)\n\n        q, k, v = unbind(qkv, 2)\n\n        x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)\n        x = x.reshape([B, N, C])\n\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n        "
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2_layers/block.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py\n\nimport logging\nfrom typing import Callable, List, Any, Tuple, Dict\n\nimport torch\nfrom torch import nn, Tensor\n\nfrom .attention import Attention, MemEffAttention\nfrom .drop_path import DropPath\nfrom .layer_scale import LayerScale\nfrom .mlp import Mlp\n\n\nlogger = logging.getLogger(\"dinov2\")\n\n\ntry:\n    from xformers.ops import fmha\n    from xformers.ops import scaled_index_add, index_select_cat\n\n    XFORMERS_AVAILABLE = True\nexcept ImportError:\n    logger.warning(\"xFormers not available\")\n    XFORMERS_AVAILABLE = False\n\n\nclass Block(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        mlp_ratio: float = 4.0,\n        qkv_bias: bool = False,\n        proj_bias: bool = True,\n        ffn_bias: bool = True,\n        drop: float = 0.0,\n        attn_drop: float = 0.0,\n        init_values=None,\n        drop_path: float = 0.0,\n        act_layer: Callable[..., nn.Module] = nn.GELU,\n        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,\n        attn_class: Callable[..., nn.Module] = Attention,\n        ffn_layer: Callable[..., nn.Module] = Mlp,\n    ) -> None:\n        super().__init__()\n        # print(f\"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}\")\n        self.norm1 = norm_layer(dim)\n        self.attn = attn_class(\n            dim,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            proj_bias=proj_bias,\n            attn_drop=attn_drop,\n            proj_drop=drop,\n        )\n        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()\n        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = ffn_layer(\n            in_features=dim,\n            hidden_features=mlp_hidden_dim,\n            act_layer=act_layer,\n            drop=drop,\n            bias=ffn_bias,\n        )\n        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()\n        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n        self.sample_drop_ratio = drop_path\n\n    def forward(self, x: Tensor) -> Tensor:\n        def attn_residual_func(x: Tensor) -> Tensor:\n            return self.ls1(self.attn(self.norm1(x)))\n\n        def ffn_residual_func(x: Tensor) -> Tensor:\n            return self.ls2(self.mlp(self.norm2(x)))\n\n        if self.training and self.sample_drop_ratio > 0.1:\n            # the overhead is compensated only for a drop path rate larger than 0.1\n            x = drop_add_residual_stochastic_depth(\n                x,\n                residual_func=attn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n            )\n            x = drop_add_residual_stochastic_depth(\n                x,\n                residual_func=ffn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n            )\n        elif self.training and self.sample_drop_ratio > 0.0:\n            x = x + self.drop_path1(attn_residual_func(x))\n            x = x + self.drop_path1(ffn_residual_func(x))  # FIXME: drop_path2\n        else:\n            x = x + attn_residual_func(x)\n            x = x + ffn_residual_func(x)\n        return x\n\n\ndef drop_add_residual_stochastic_depth(\n    x: Tensor,\n    residual_func: Callable[[Tensor], Tensor],\n    sample_drop_ratio: float = 0.0,\n) -> Tensor:\n    # 1) extract subset using permutation\n    b, n, d = x.shape\n    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)\n    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]\n    x_subset = x[brange]\n\n    # 2) apply residual_func to get residual\n    residual = residual_func(x_subset)\n\n    x_flat = x.flatten(1)\n    residual = residual.flatten(1)\n\n    residual_scale_factor = b / sample_subset_size\n\n    # 3) add the residual\n    x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)\n    return x_plus_residual.view_as(x)\n\n\ndef get_branges_scales(x, sample_drop_ratio=0.0):\n    b, n, d = x.shape\n    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)\n    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]\n    residual_scale_factor = b / sample_subset_size\n    return brange, residual_scale_factor\n\n\ndef add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):\n    if scaling_vector is None:\n        x_flat = x.flatten(1)\n        residual = residual.flatten(1)\n        x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)\n    else:\n        x_plus_residual = scaled_index_add(\n            x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor\n        )\n    return x_plus_residual\n\n\nattn_bias_cache: Dict[Tuple, Any] = {}\n\n\ndef get_attn_bias_and_cat(x_list, branges=None):\n    \"\"\"\n    this will perform the index select, cat the tensors, and provide the attn_bias from cache\n    \"\"\"\n    batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]\n    all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))\n    if all_shapes not in attn_bias_cache.keys():\n        seqlens = []\n        for b, x in zip(batch_sizes, x_list):\n            for _ in range(b):\n                seqlens.append(x.shape[1])\n        attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)\n        attn_bias._batch_sizes = batch_sizes\n        attn_bias_cache[all_shapes] = attn_bias\n\n    if branges is not None:\n        cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])\n    else:\n        tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)\n        cat_tensors = torch.cat(tensors_bs1, dim=1)\n\n    return attn_bias_cache[all_shapes], cat_tensors\n\n\ndef drop_add_residual_stochastic_depth_list(\n    x_list: List[Tensor],\n    residual_func: Callable[[Tensor, Any], Tensor],\n    sample_drop_ratio: float = 0.0,\n    scaling_vector=None,\n) -> Tensor:\n    # 1) generate random set of indices for dropping samples in the batch\n    branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]\n    branges = [s[0] for s in branges_scales]\n    residual_scale_factors = [s[1] for s in branges_scales]\n\n    # 2) get attention bias and index+concat the tensors\n    attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)\n\n    # 3) apply residual_func to get residual, and split the result\n    residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))  # type: ignore\n\n    outputs = []\n    for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):\n        outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))\n    return outputs\n\n\nclass NestedTensorBlock(Block):\n    def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:\n        \"\"\"\n        x_list contains a list of tensors to nest together and run\n        \"\"\"\n        assert isinstance(self.attn, MemEffAttention)\n\n        if self.training and self.sample_drop_ratio > 0.0:\n\n            def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:\n                return self.attn(self.norm1(x), attn_bias=attn_bias)\n\n            def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:\n                return self.mlp(self.norm2(x))\n\n            x_list = drop_add_residual_stochastic_depth_list(\n                x_list,\n                residual_func=attn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n                scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,\n            )\n            x_list = drop_add_residual_stochastic_depth_list(\n                x_list,\n                residual_func=ffn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n                scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,\n            )\n            return x_list\n        else:\n\n            def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:\n                return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))\n\n            def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:\n                return self.ls2(self.mlp(self.norm2(x)))\n\n            attn_bias, x = get_attn_bias_and_cat(x_list)\n            x = x + attn_residual_func(x, attn_bias=attn_bias)\n            x = x + ffn_residual_func(x)\n            return attn_bias.split(x)\n\n    def forward(self, x_or_x_list):\n        if isinstance(x_or_x_list, Tensor):\n            return super().forward(x_or_x_list)\n        elif isinstance(x_or_x_list, list):\n            assert XFORMERS_AVAILABLE, \"Please install xFormers for nested tensors usage\"\n            return self.forward_nested(x_or_x_list)\n        else:\n            raise AssertionError\n"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2_layers/drop_path.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py\n\n\nfrom torch import nn\n\n\ndef drop_path(x, drop_prob: float = 0.0, training: bool = False):\n    if drop_prob == 0.0 or not training:\n        return x\n    keep_prob = 1 - drop_prob\n    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets\n    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n    if keep_prob > 0.0:\n        random_tensor.div_(keep_prob)\n    output = x * random_tensor\n    return output\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\"\"\"\n\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2_layers/layer_scale.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110\n\nfrom typing import Union\n\nimport torch\nfrom torch import Tensor\nfrom torch import nn\n\n\nclass LayerScale(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        init_values: Union[float, Tensor] = 1e-5,\n        inplace: bool = False,\n    ) -> None:\n        super().__init__()\n        self.inplace = inplace\n        self.gamma = nn.Parameter(init_values * torch.ones(dim))\n\n    def forward(self, x: Tensor) -> Tensor:\n        return x.mul_(self.gamma) if self.inplace else x * self.gamma\n"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2_layers/mlp.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py\n\n\nfrom typing import Callable, Optional\n\nfrom torch import Tensor, nn\n\n\nclass Mlp(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: Callable[..., nn.Module] = nn.GELU,\n        drop: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2_layers/patch_embed.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py\n\nfrom typing import Callable, Optional, Tuple, Union\n\nfrom torch import Tensor\nimport torch.nn as nn\n\n\ndef make_2tuple(x):\n    if isinstance(x, tuple):\n        assert len(x) == 2\n        return x\n\n    assert isinstance(x, int)\n    return (x, x)\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"\n    2D image to patch embedding: (B,C,H,W) -> (B,N,D)\n\n    Args:\n        img_size: Image size.\n        patch_size: Patch token size.\n        in_chans: Number of input image channels.\n        embed_dim: Number of linear projection output channels.\n        norm_layer: Normalization layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        img_size: Union[int, Tuple[int, int]] = 224,\n        patch_size: Union[int, Tuple[int, int]] = 16,\n        in_chans: int = 3,\n        embed_dim: int = 768,\n        norm_layer: Optional[Callable] = None,\n        flatten_embedding: bool = True,\n    ) -> None:\n        super().__init__()\n\n        image_HW = make_2tuple(img_size)\n        patch_HW = make_2tuple(patch_size)\n        patch_grid_size = (\n            image_HW[0] // patch_HW[0],\n            image_HW[1] // patch_HW[1],\n        )\n\n        self.img_size = image_HW\n        self.patch_size = patch_HW\n        self.patches_resolution = patch_grid_size\n        self.num_patches = patch_grid_size[0] * patch_grid_size[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.flatten_embedding = flatten_embedding\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n    def forward(self, x: Tensor) -> Tensor:\n        _, _, H, W = x.shape\n        patch_H, patch_W = self.patch_size\n\n        assert H % patch_H == 0, f\"Input image height {H} is not a multiple of patch height {patch_H}\"\n        assert W % patch_W == 0, f\"Input image width {W} is not a multiple of patch width: {patch_W}\"\n\n        x = self.proj(x)  # B C H W\n        H, W = x.size(2), x.size(3)\n        x = x.flatten(2).transpose(1, 2)  # B HW C\n        x = self.norm(x)\n        if not self.flatten_embedding:\n            x = x.reshape(-1, H, W, self.embed_dim)  # B H W C\n        return x\n\n    def flops(self) -> float:\n        Ho, Wo = self.patches_resolution\n        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2_layers/swiglu_ffn.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nfrom typing import Callable, Optional\n\nfrom torch import Tensor, nn\nimport torch.nn.functional as F\n\n\nclass SwiGLUFFN(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: Callable[..., nn.Module] = None,\n        drop: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)\n        self.w3 = nn.Linear(hidden_features, out_features, bias=bias)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x12 = self.w12(x)\n        x1, x2 = x12.chunk(2, dim=-1)\n        hidden = F.silu(x1) * x2\n        return self.w3(hidden)\n\n\ntry:\n    from xformers.ops import SwiGLU\n\n    XFORMERS_AVAILABLE = True\nexcept ImportError:\n    SwiGLU = SwiGLUFFN\n    XFORMERS_AVAILABLE = False\n\n\nclass SwiGLUFFNFused(SwiGLU):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: Callable[..., nn.Module] = None,\n        drop: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8\n        super().__init__(\n            in_features=in_features,\n            hidden_features=hidden_features,\n            out_features=out_features,\n            bias=bias,\n        )\n"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dpt.py",
    "content": "# Copyright (2025) Bytedance Ltd. and/or its affiliates \n\n# Licensed under the Apache License, Version 2.0 (the \"License\"); \n# you may not use this file except in compliance with the License. \n# You may obtain a copy of the License at \n\n#     http://www.apache.org/licenses/LICENSE-2.0 \n\n# Unless required by applicable law or agreed to in writing, software \n# distributed under the License is distributed on an \"AS IS\" BASIS, \n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n# See the License for the specific language governing permissions and \n# limitations under the License. \nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .util.blocks import FeatureFusionBlock, _make_scratch\n\n\ndef _make_fusion_block(features, use_bn, size=None):\n    return FeatureFusionBlock(\n        features,\n        nn.ReLU(False),\n        deconv=False,\n        bn=use_bn,\n        expand=False,\n        align_corners=True,\n        size=size,\n    )\n\n\nclass ConvBlock(nn.Module):\n    def __init__(self, in_feature, out_feature):\n        super().__init__()\n        \n        self.conv_block = nn.Sequential(\n            nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),\n            nn.BatchNorm2d(out_feature),\n            nn.ReLU(True)\n        )\n    \n    def forward(self, x):\n        return self.conv_block(x)\n\n\nclass DPTHead(nn.Module):\n    def __init__(\n        self, \n        in_channels, \n        features=256, \n        use_bn=False, \n        out_channels=[256, 512, 1024, 1024], \n        use_clstoken=False\n    ):\n        super(DPTHead, self).__init__()\n        \n        self.use_clstoken = use_clstoken\n        \n        self.projects = nn.ModuleList([\n            nn.Conv2d(\n                in_channels=in_channels,\n                out_channels=out_channel,\n                kernel_size=1,\n                stride=1,\n                padding=0,\n            ) for out_channel in out_channels\n        ])\n        \n        self.resize_layers = nn.ModuleList([\n            nn.ConvTranspose2d(\n                in_channels=out_channels[0],\n                out_channels=out_channels[0],\n                kernel_size=4,\n                stride=4,\n                padding=0),\n            nn.ConvTranspose2d(\n                in_channels=out_channels[1],\n                out_channels=out_channels[1],\n                kernel_size=2,\n                stride=2,\n                padding=0),\n            nn.Identity(),\n            nn.Conv2d(\n                in_channels=out_channels[3],\n                out_channels=out_channels[3],\n                kernel_size=3,\n                stride=2,\n                padding=1)\n        ])\n        \n        if use_clstoken:\n            self.readout_projects = nn.ModuleList()\n            for _ in range(len(self.projects)):\n                self.readout_projects.append(\n                    nn.Sequential(\n                        nn.Linear(2 * in_channels, in_channels),\n                        nn.GELU()))\n        \n        self.scratch = _make_scratch(\n            out_channels,\n            features,\n            groups=1,\n            expand=False,\n        )\n        \n        self.scratch.stem_transpose = None\n        \n        self.scratch.refinenet1 = _make_fusion_block(features, use_bn)\n        self.scratch.refinenet2 = _make_fusion_block(features, use_bn)\n        self.scratch.refinenet3 = _make_fusion_block(features, use_bn)\n        self.scratch.refinenet4 = _make_fusion_block(features, use_bn)\n        \n        head_features_1 = features\n        head_features_2 = 32\n        \n        self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)\n        self.scratch.output_conv2 = nn.Sequential(\n            nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),\n            nn.ReLU(True),\n            nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),\n            nn.ReLU(True),\n            nn.Identity(),\n        )\n    \n    def forward(self, out_features, patch_h, patch_w):\n        out = []\n        for i, x in enumerate(out_features):\n            if self.use_clstoken:\n                x, cls_token = x[0], x[1]\n                readout = cls_token.unsqueeze(1).expand_as(x)\n                x = self.readout_projects[i](torch.cat((x, readout), -1))\n            else:\n                x = x[0]\n            \n            x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))\n            \n            x = self.projects[i](x)\n            x = self.resize_layers[i](x)\n            \n            out.append(x)\n        \n        layer_1, layer_2, layer_3, layer_4 = out\n        \n        layer_1_rn = self.scratch.layer1_rn(layer_1)\n        layer_2_rn = self.scratch.layer2_rn(layer_2)\n        layer_3_rn = self.scratch.layer3_rn(layer_3)\n        layer_4_rn = self.scratch.layer4_rn(layer_4)\n        \n        path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])        \n        path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])\n        path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])\n        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)\n        \n        out = self.scratch.output_conv1(path_1)\n        out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode=\"bilinear\", align_corners=True)\n        out = self.scratch.output_conv2(out)\n        \n        return out\n        "
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dpt_temporal.py",
    "content": "# Copyright (2025) Bytedance Ltd. and/or its affiliates \n\n# Licensed under the Apache License, Version 2.0 (the \"License\"); \n# you may not use this file except in compliance with the License. \n# You may obtain a copy of the License at \n\n#     http://www.apache.org/licenses/LICENSE-2.0 \n\n# Unless required by applicable law or agreed to in writing, software \n# distributed under the License is distributed on an \"AS IS\" BASIS, \n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n# See the License for the specific language governing permissions and \n# limitations under the License. \nimport torch\nimport torch.nn.functional as F\nimport torch.nn as nn\nfrom .dpt import DPTHead\nfrom .motion_module.motion_module import TemporalModule\nfrom easydict import EasyDict\n\n\nclass DPTHeadTemporal(DPTHead):\n    def __init__(self, \n        in_channels, \n        features=256, \n        use_bn=False, \n        out_channels=[256, 512, 1024, 1024], \n        use_clstoken=False,\n        num_frames=32,\n        pe='ape'\n    ):\n        super().__init__(in_channels, features, use_bn, out_channels, use_clstoken)\n\n        assert num_frames > 0\n        motion_module_kwargs = EasyDict(num_attention_heads                = 8,\n                                        num_transformer_block              = 1,\n                                        num_attention_blocks               = 2,\n                                        temporal_max_len                   = num_frames,\n                                        zero_initialize                    = True,\n                                        pos_embedding_type                 = pe)\n\n        self.motion_modules = nn.ModuleList([\n            TemporalModule(in_channels=out_channels[2], \n                           **motion_module_kwargs),\n            TemporalModule(in_channels=out_channels[3],\n                           **motion_module_kwargs),\n            TemporalModule(in_channels=features,\n                           **motion_module_kwargs),\n            TemporalModule(in_channels=features,\n                           **motion_module_kwargs)\n        ])\n\n    def forward(self, out_features, patch_h, patch_w, frame_length):\n        out = []\n        for i, x in enumerate(out_features):\n            if self.use_clstoken:\n                x, cls_token = x[0], x[1]\n                readout = cls_token.unsqueeze(1).expand_as(x)\n                x = self.readout_projects[i](torch.cat((x, readout), -1))\n            else:\n                x = x[0]\n\n            x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)).contiguous()\n\n            B, T = x.shape[0] // frame_length, frame_length\n            x = self.projects[i](x)\n            x = self.resize_layers[i](x)\n\n            out.append(x)\n\n        layer_1, layer_2, layer_3, layer_4 = out\n\n        B, T = layer_1.shape[0] // frame_length, frame_length\n\n        layer_3 = self.motion_modules[0](layer_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)\n        layer_4 = self.motion_modules[1](layer_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)\n\n        layer_1_rn = self.scratch.layer1_rn(layer_1)\n        layer_2_rn = self.scratch.layer2_rn(layer_2)\n        layer_3_rn = self.scratch.layer3_rn(layer_3)\n        layer_4_rn = self.scratch.layer4_rn(layer_4)\n\n        path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])\n        path_4 = self.motion_modules[2](path_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)\n        path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])\n        path_3 = self.motion_modules[3](path_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)\n        path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])\n        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)\n\n        out = self.scratch.output_conv1(path_1)\n        out = F.interpolate(\n            out, (int(patch_h * 14), int(patch_w * 14)), mode=\"bilinear\", align_corners=True\n        )\n        # out = self.scratch.output_conv2(out)\n        return out"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/motion_module/attention.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\ntry:\n    import xformers\n    import xformers.ops\n\n    XFORMERS_AVAILABLE = True\nexcept ImportError:\n    print(\"xFormers not available\")\n    XFORMERS_AVAILABLE = False\n\n\nclass CrossAttention(nn.Module):\n    r\"\"\"\n    A cross attention layer.\n\n    Parameters:\n        query_dim (`int`): The number of channels in the query.\n        cross_attention_dim (`int`, *optional*):\n            The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.\n        heads (`int`,  *optional*, defaults to 8): The number of heads to use for multi-head attention.\n        dim_head (`int`,  *optional*, defaults to 64): The number of channels in each head.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n        bias (`bool`, *optional*, defaults to False):\n            Set to `True` for the query, key, and value linear layers to contain a bias parameter.\n    \"\"\"\n\n    def __init__(\n        self,\n        query_dim: int,\n        cross_attention_dim: Optional[int] = None,\n        heads: int = 8,\n        dim_head: int = 64,\n        dropout: float = 0.0,\n        bias=False,\n        upcast_attention: bool = False,\n        upcast_softmax: bool = False,\n        added_kv_proj_dim: Optional[int] = None,\n        norm_num_groups: Optional[int] = None,\n    ):\n        super().__init__()\n        inner_dim = dim_head * heads\n        cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim\n        self.upcast_attention = upcast_attention\n        self.upcast_softmax = upcast_softmax\n        self.upcast_efficient_attention = False\n\n        self.scale = dim_head**-0.5\n\n        self.heads = heads\n        # for slice_size > 0 the attention score computation\n        # is split across the batch axis to save memory\n        # You can set slice_size with `set_attention_slice`\n        self.sliceable_head_dim = heads\n        self._slice_size = None\n        self._use_memory_efficient_attention_xformers = False\n        self.added_kv_proj_dim = added_kv_proj_dim\n\n        if norm_num_groups is not None:\n            self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)\n        else:\n            self.group_norm = None\n\n        self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)\n        self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)\n        self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)\n\n        if self.added_kv_proj_dim is not None:\n            self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)\n            self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)\n\n        self.to_out = nn.ModuleList([])\n        self.to_out.append(nn.Linear(inner_dim, query_dim))\n        self.to_out.append(nn.Dropout(dropout))\n\n    def reshape_heads_to_batch_dim(self, tensor):\n        batch_size, seq_len, dim = tensor.shape\n        head_size = self.heads\n        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous()\n        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size).contiguous()\n        return tensor\n\n    def reshape_heads_to_4d(self, tensor):\n        batch_size, seq_len, dim = tensor.shape\n        head_size = self.heads\n        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous()\n        return tensor\n\n    def reshape_batch_dim_to_heads(self, tensor):\n        batch_size, seq_len, dim = tensor.shape\n        head_size = self.heads\n        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim).contiguous()\n        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size).contiguous()\n        return tensor\n\n    def reshape_4d_to_heads(self, tensor):\n        batch_size, seq_len, head_size, dim = tensor.shape\n        head_size = self.heads\n        tensor = tensor.reshape(batch_size, seq_len, dim * head_size).contiguous()\n        return tensor\n\n    def set_attention_slice(self, slice_size):\n        if slice_size is not None and slice_size > self.sliceable_head_dim:\n            raise ValueError(f\"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.\")\n\n        self._slice_size = slice_size\n\n    def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):\n        batch_size, sequence_length, _ = hidden_states.shape\n\n        encoder_hidden_states = encoder_hidden_states\n\n        if self.group_norm is not None:\n            hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = self.to_q(hidden_states)\n        dim = query.shape[-1]\n        query = self.reshape_heads_to_batch_dim(query)\n\n        if self.added_kv_proj_dim is not None:\n            key = self.to_k(hidden_states)\n            value = self.to_v(hidden_states)\n            encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)\n            encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)\n\n            key = self.reshape_heads_to_batch_dim(key)\n            value = self.reshape_heads_to_batch_dim(value)\n            encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)\n            encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)\n\n            key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)\n            value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)\n        else:\n            encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states\n            key = self.to_k(encoder_hidden_states)\n            value = self.to_v(encoder_hidden_states)\n\n            key = self.reshape_heads_to_batch_dim(key)\n            value = self.reshape_heads_to_batch_dim(value)\n\n        if attention_mask is not None:\n            if attention_mask.shape[-1] != query.shape[1]:\n                target_length = query.shape[1]\n                attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)\n                attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)\n\n        # attention, what we cannot get enough of\n        if XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers:\n            hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)\n            # Some versions of xformers return output in fp32, cast it back to the dtype of the input\n            hidden_states = hidden_states.to(query.dtype)\n        else:\n            if self._slice_size is None or query.shape[0] // self._slice_size == 1:\n                hidden_states = self._attention(query, key, value, attention_mask)\n            else:\n                hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)\n\n        # linear proj\n        hidden_states = self.to_out[0](hidden_states)\n\n        # dropout\n        hidden_states = self.to_out[1](hidden_states)\n        return hidden_states\n\n    def _attention(self, query, key, value, attention_mask=None):\n        if self.upcast_attention:\n            query = query.float()\n            key = key.float()\n\n        attention_scores = torch.baddbmm(\n            torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),\n            query,\n            key.transpose(-1, -2),\n            beta=0,\n            alpha=self.scale,\n        )\n\n        if attention_mask is not None:\n            attention_scores = attention_scores + attention_mask\n\n        if self.upcast_softmax:\n            attention_scores = attention_scores.float()\n\n        attention_probs = attention_scores.softmax(dim=-1)\n\n        # cast back to the original dtype\n        attention_probs = attention_probs.to(value.dtype)\n\n        # compute attention output\n        hidden_states = torch.bmm(attention_probs, value)\n\n        # reshape hidden_states\n        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)\n        return hidden_states\n\n    def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):\n        batch_size_attention = query.shape[0]\n        hidden_states = torch.zeros(\n            (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype\n        )\n        slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]\n        for i in range(hidden_states.shape[0] // slice_size):\n            start_idx = i * slice_size\n            end_idx = (i + 1) * slice_size\n\n            query_slice = query[start_idx:end_idx]\n            key_slice = key[start_idx:end_idx]\n\n            if self.upcast_attention:\n                query_slice = query_slice.float()\n                key_slice = key_slice.float()\n\n            attn_slice = torch.baddbmm(\n                torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),\n                query_slice,\n                key_slice.transpose(-1, -2),\n                beta=0,\n                alpha=self.scale,\n            )\n\n            if attention_mask is not None:\n                attn_slice = attn_slice + attention_mask[start_idx:end_idx]\n\n            if self.upcast_softmax:\n                attn_slice = attn_slice.float()\n\n            attn_slice = attn_slice.softmax(dim=-1)\n\n            # cast back to the original dtype\n            attn_slice = attn_slice.to(value.dtype)\n            attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])\n\n            hidden_states[start_idx:end_idx] = attn_slice\n\n        # reshape hidden_states\n        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)\n        return hidden_states\n\n    def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):\n        if self.upcast_efficient_attention:\n            org_dtype = query.dtype\n            query = query.float()\n            key = key.float()\n            value = value.float()\n            if attention_mask is not None:\n                attention_mask = attention_mask.float()\n        hidden_states = self._memory_efficient_attention_split(query, key, value, attention_mask)\n\n        if self.upcast_efficient_attention:\n            hidden_states = hidden_states.to(org_dtype)\n\n        hidden_states = self.reshape_4d_to_heads(hidden_states)\n        return hidden_states\n\n        # print(\"Errror: no xformers\")\n        # raise NotImplementedError\n\n    def _memory_efficient_attention_split(self, query, key, value, attention_mask):\n        batch_size = query.shape[0]\n        max_batch_size = 65535\n        num_batches = (batch_size + max_batch_size - 1) // max_batch_size\n        results = []\n        for i in range(num_batches):\n            start_idx = i * max_batch_size\n            end_idx = min((i + 1) * max_batch_size, batch_size)\n            query_batch = query[start_idx:end_idx]\n            key_batch = key[start_idx:end_idx]\n            value_batch = value[start_idx:end_idx]\n            if attention_mask is not None:\n                attention_mask_batch = attention_mask[start_idx:end_idx]\n            else:\n                attention_mask_batch = None\n            result = xformers.ops.memory_efficient_attention(query_batch, key_batch, value_batch, attn_bias=attention_mask_batch)\n            results.append(result)\n        full_result = torch.cat(results, dim=0)\n        return full_result\n\n\nclass FeedForward(nn.Module):\n    r\"\"\"\n    A feed-forward layer.\n\n    Parameters:\n        dim (`int`): The number of channels in the input.\n        dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.\n        mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n        activation_fn (`str`, *optional*, defaults to `\"geglu\"`): Activation function to be used in feed-forward.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        dim_out: Optional[int] = None,\n        mult: int = 4,\n        dropout: float = 0.0,\n        activation_fn: str = \"geglu\",\n    ):\n        super().__init__()\n        inner_dim = int(dim * mult)\n        dim_out = dim_out if dim_out is not None else dim\n\n        if activation_fn == \"gelu\":\n            act_fn = GELU(dim, inner_dim)\n        elif activation_fn == \"geglu\":\n            act_fn = GEGLU(dim, inner_dim)\n        elif activation_fn == \"geglu-approximate\":\n            act_fn = ApproximateGELU(dim, inner_dim)\n\n        self.net = nn.ModuleList([])\n        # project in\n        self.net.append(act_fn)\n        # project dropout\n        self.net.append(nn.Dropout(dropout))\n        # project out\n        self.net.append(nn.Linear(inner_dim, dim_out))\n\n    def forward(self, hidden_states):\n        for module in self.net:\n            hidden_states = module(hidden_states)\n        return hidden_states\n\n\nclass GELU(nn.Module):\n    r\"\"\"\n    GELU activation function\n    \"\"\"\n\n    def __init__(self, dim_in: int, dim_out: int):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out)\n\n    def gelu(self, gate):\n        if gate.device.type != \"mps\":\n            return F.gelu(gate)\n        # mps: gelu is not implemented for float16\n        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)\n\n    def forward(self, hidden_states):\n        hidden_states = self.proj(hidden_states)\n        hidden_states = self.gelu(hidden_states)\n        return hidden_states\n\n\n# feedforward\nclass GEGLU(nn.Module):\n    r\"\"\"\n    A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.\n\n    Parameters:\n        dim_in (`int`): The number of channels in the input.\n        dim_out (`int`): The number of channels in the output.\n    \"\"\"\n\n    def __init__(self, dim_in: int, dim_out: int):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out * 2)\n\n    def gelu(self, gate):\n        if gate.device.type != \"mps\":\n            return F.gelu(gate)\n        # mps: gelu is not implemented for float16\n        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)\n\n    def forward(self, hidden_states):\n        hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)\n        return hidden_states * self.gelu(gate)\n\n\nclass ApproximateGELU(nn.Module):\n    \"\"\"\n    The approximate form of Gaussian Error Linear Unit (GELU)\n\n    For more details, see section 2: https://arxiv.org/abs/1606.08415\n    \"\"\"\n\n    def __init__(self, dim_in: int, dim_out: int):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out)\n\n    def forward(self, x):\n        x = self.proj(x)\n        return x * torch.sigmoid(1.702 * x)\n\n\ndef precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):\n    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))\n    t = torch.arange(end, device=freqs.device, dtype=torch.float32)\n    freqs = torch.outer(t, freqs)\n    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64\n    return freqs_cis\n\n\ndef reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):\n    ndim = x.ndim\n    assert 0 <= 1 < ndim\n    assert freqs_cis.shape == (x.shape[1], x.shape[-1])\n    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]\n    return freqs_cis.view(*shape)\n\n\ndef apply_rotary_emb(\n    xq: torch.Tensor,\n    xk: torch.Tensor,\n    freqs_cis: torch.Tensor,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2).contiguous())\n    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2).contiguous())\n    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)\n    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)\n    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)\n    return xq_out.type_as(xq), xk_out.type_as(xk)\n"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/motion_module/motion_module.py",
    "content": "# This file is originally from AnimateDiff/animatediff/models/motion_module.py at main · guoyww/AnimateDiff\n# SPDX-License-Identifier: Apache-2.0 license\n#\n# This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]\n# Original file was released under [ Apache-2.0 license], with the full license text available at [https://github.com/guoyww/AnimateDiff?tab=Apache-2.0-1-ov-file#readme].\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom .attention import CrossAttention, FeedForward, apply_rotary_emb, precompute_freqs_cis\n\nfrom einops import rearrange, repeat\nimport math\n\ntry:\n    import xformers\n    import xformers.ops\n\n    XFORMERS_AVAILABLE = True\nexcept ImportError:\n    print(\"xFormers not available\")\n    XFORMERS_AVAILABLE = False\n\n\ndef zero_module(module):\n    # Zero out the parameters of a module and return it.\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\nclass TemporalModule(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        num_attention_heads                = 8,\n        num_transformer_block              = 2,\n        num_attention_blocks               = 2,\n        norm_num_groups                    = 32,\n        temporal_max_len                   = 32,\n        zero_initialize                    = True,\n        pos_embedding_type                 = \"ape\",\n    ):\n        super().__init__()\n\n        self.temporal_transformer = TemporalTransformer3DModel(\n            in_channels=in_channels,\n            num_attention_heads=num_attention_heads,\n            attention_head_dim=in_channels // num_attention_heads,\n            num_layers=num_transformer_block,\n            num_attention_blocks=num_attention_blocks,\n            norm_num_groups=norm_num_groups,\n            temporal_max_len=temporal_max_len,\n            pos_embedding_type=pos_embedding_type,\n        )\n\n        if zero_initialize:\n            self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)\n\n    def forward(self, input_tensor, encoder_hidden_states, attention_mask=None):\n        hidden_states = input_tensor\n        hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)\n\n        output = hidden_states\n        return output\n\n\nclass TemporalTransformer3DModel(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        num_attention_heads,\n        attention_head_dim,\n        num_layers,\n        num_attention_blocks               = 2,\n        norm_num_groups                    = 32,\n        temporal_max_len                   = 32,\n        pos_embedding_type                 = \"ape\",\n    ):\n        super().__init__()\n\n        inner_dim = num_attention_heads * attention_head_dim\n\n        self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)\n        self.proj_in = nn.Linear(in_channels, inner_dim)\n\n        self.transformer_blocks = nn.ModuleList(\n            [\n                TemporalTransformerBlock(\n                    dim=inner_dim,\n                    num_attention_heads=num_attention_heads,\n                    attention_head_dim=attention_head_dim,\n                    num_attention_blocks=num_attention_blocks,\n                    temporal_max_len=temporal_max_len,\n                    pos_embedding_type=pos_embedding_type,\n                )\n                for d in range(num_layers)\n            ]\n        )\n        self.proj_out = nn.Linear(inner_dim, in_channels)\n\n    def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):\n        assert hidden_states.dim() == 5, f\"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}.\"\n        video_length = hidden_states.shape[2]\n        hidden_states = rearrange(hidden_states, \"b c f h w -> (b f) c h w\")\n\n        batch, channel, height, width = hidden_states.shape\n        residual = hidden_states\n\n        hidden_states = self.norm(hidden_states)\n        inner_dim = hidden_states.shape[1]\n        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim).contiguous()\n        hidden_states = self.proj_in(hidden_states)\n\n        # Transformer Blocks\n        for block in self.transformer_blocks:\n            hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, attention_mask=attention_mask)\n\n        # output\n        hidden_states = self.proj_out(hidden_states)\n        hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()\n\n        output = hidden_states + residual\n        output = rearrange(output, \"(b f) c h w -> b c f h w\", f=video_length)\n\n        return output\n\n\nclass TemporalTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        dim,\n        num_attention_heads,\n        attention_head_dim,\n        num_attention_blocks               = 2,\n        temporal_max_len                   = 32,\n        pos_embedding_type                 = \"ape\",\n    ):\n        super().__init__()\n\n        self.attention_blocks = nn.ModuleList(\n            [\n                TemporalAttention(\n                        query_dim=dim,\n                        heads=num_attention_heads,\n                        dim_head=attention_head_dim,\n                        temporal_max_len=temporal_max_len,\n                        pos_embedding_type=pos_embedding_type,\n                )\n                for i in range(num_attention_blocks)\n            ]\n        )\n        self.norms = nn.ModuleList(\n            [\n                nn.LayerNorm(dim)\n                for i in range(num_attention_blocks)\n            ]\n        )\n\n        self.ff = FeedForward(dim, dropout=0.0, activation_fn=\"geglu\")\n        self.ff_norm = nn.LayerNorm(dim)\n\n\n    def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):\n        for attention_block, norm in zip(self.attention_blocks, self.norms):\n            norm_hidden_states = norm(hidden_states)\n            hidden_states = attention_block(\n                norm_hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                video_length=video_length,\n                attention_mask=attention_mask,\n            ) + hidden_states\n\n        hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states\n\n        output = hidden_states\n        return output\n\n\nclass PositionalEncoding(nn.Module):\n    def __init__(\n        self,\n        d_model,\n        dropout = 0.,\n        max_len = 32\n    ):\n        super().__init__()\n        self.dropout = nn.Dropout(p=dropout)\n        position = torch.arange(max_len).unsqueeze(1)\n        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n        pe = torch.zeros(1, max_len, d_model)\n        pe[0, :, 0::2] = torch.sin(position * div_term)\n        pe[0, :, 1::2] = torch.cos(position * div_term)\n        self.register_buffer('pe', pe)\n\n    def forward(self, x):\n        x = x + self.pe[:, :x.size(1)].to(x.dtype)\n        return self.dropout(x)\n\nclass TemporalAttention(CrossAttention):\n    def __init__(\n            self,\n            temporal_max_len                   = 32,\n            pos_embedding_type                 = \"ape\",\n            *args, **kwargs\n        ):\n        super().__init__(*args, **kwargs)\n\n        self.pos_embedding_type = pos_embedding_type\n        self._use_memory_efficient_attention_xformers = True\n\n        self.pos_encoder = None\n        self.freqs_cis = None\n        if self.pos_embedding_type == \"ape\":\n            self.pos_encoder = PositionalEncoding(\n                kwargs[\"query_dim\"],\n                dropout=0.,\n                max_len=temporal_max_len\n            )\n\n        elif self.pos_embedding_type == \"rope\":\n            self.freqs_cis = precompute_freqs_cis(\n                kwargs[\"query_dim\"],\n                temporal_max_len\n            )\n\n        else:\n            raise NotImplementedError\n\n    def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):\n        d = hidden_states.shape[1]\n        hidden_states = rearrange(hidden_states, \"(b f) d c -> (b d) f c\", f=video_length)\n\n        if self.pos_encoder is not None:\n            hidden_states = self.pos_encoder(hidden_states)\n\n        encoder_hidden_states = repeat(encoder_hidden_states, \"b n c -> (b d) n c\", d=d) if encoder_hidden_states is not None else encoder_hidden_states\n\n        if self.group_norm is not None:\n            hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = self.to_q(hidden_states)\n        dim = query.shape[-1]\n\n        if self.added_kv_proj_dim is not None:\n            raise NotImplementedError\n\n        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states\n        key = self.to_k(encoder_hidden_states)\n        value = self.to_v(encoder_hidden_states)\n\n        if self.freqs_cis is not None:\n            seq_len = query.shape[1]\n            freqs_cis = self.freqs_cis[:seq_len].to(query.device)\n            query, key = apply_rotary_emb(query, key, freqs_cis)\n\n        if attention_mask is not None:\n            if attention_mask.shape[-1] != query.shape[1]:\n                target_length = query.shape[1]\n                attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)\n                attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)\n\n\n        use_memory_efficient = XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers\n        if use_memory_efficient and (dim // self.heads) % 8 != 0:\n            # print('Warning: the dim {} cannot be divided by 8. Fall into normal attention'.format(dim // self.heads))\n            use_memory_efficient = False\n\n        # attention, what we cannot get enough of\n        if use_memory_efficient:\n            query = self.reshape_heads_to_4d(query)\n            key = self.reshape_heads_to_4d(key)\n            value = self.reshape_heads_to_4d(value)\n\n            hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)\n            # Some versions of xformers return output in fp32, cast it back to the dtype of the input\n            hidden_states = hidden_states.to(query.dtype)\n        else:\n            query = self.reshape_heads_to_batch_dim(query)\n            key = self.reshape_heads_to_batch_dim(key)\n            value = self.reshape_heads_to_batch_dim(value)\n\n            if self._slice_size is None or query.shape[0] // self._slice_size == 1:\n                hidden_states = self._attention(query, key, value, attention_mask)\n            else:\n                raise NotImplementedError\n                # hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)\n\n        # linear proj\n        hidden_states = self.to_out[0](hidden_states)\n\n        # dropout\n        hidden_states = self.to_out[1](hidden_states)\n\n        hidden_states = rearrange(hidden_states, \"(b d) f c -> (b f) d c\", d=d)\n\n        return hidden_states\n"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/util/blocks.py",
    "content": "import torch.nn as nn\n\n\ndef _make_scratch(in_shape, out_shape, groups=1, expand=False):\n    scratch = nn.Module()\n\n    out_shape1 = out_shape\n    out_shape2 = out_shape\n    out_shape3 = out_shape\n    if len(in_shape) >= 4:\n        out_shape4 = out_shape\n\n    if expand:\n        out_shape1 = out_shape\n        out_shape2 = out_shape * 2\n        out_shape3 = out_shape * 4\n        if len(in_shape) >= 4:\n            out_shape4 = out_shape * 8\n\n    scratch.layer1_rn = nn.Conv2d(\n        in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups\n    )\n    scratch.layer2_rn = nn.Conv2d(\n        in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups\n    )\n    scratch.layer3_rn = nn.Conv2d(\n        in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups\n    )\n    if len(in_shape) >= 4:\n        scratch.layer4_rn = nn.Conv2d(\n            in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups\n        )\n\n    return scratch\n\n\nclass ResidualConvUnit(nn.Module):\n    \"\"\"Residual convolution module.\"\"\"\n\n    def __init__(self, features, activation, bn):\n        \"\"\"Init.\n\n        Args:\n            features (int): number of features\n        \"\"\"\n        super().__init__()\n\n        self.bn = bn\n\n        self.groups = 1\n\n        self.conv1 = nn.Conv2d(\n            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups\n        )\n\n        self.conv2 = nn.Conv2d(\n            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups\n        )\n\n        if self.bn is True:\n            self.bn1 = nn.BatchNorm2d(features)\n            self.bn2 = nn.BatchNorm2d(features)\n\n        self.activation = activation\n\n        self.skip_add = nn.quantized.FloatFunctional()\n\n    def forward(self, x):\n        \"\"\"Forward pass.\n\n        Args:\n            x (tensor): input\n\n        Returns:\n            tensor: output\n        \"\"\"\n\n        out = self.activation(x)\n        out = self.conv1(out)\n        if self.bn is True:\n            out = self.bn1(out)\n\n        out = self.activation(out)\n        out = self.conv2(out)\n        if self.bn is True:\n            out = self.bn2(out)\n\n        if self.groups > 1:\n            out = self.conv_merge(out)\n\n        return self.skip_add.add(out, x)\n\n\nclass FeatureFusionBlock(nn.Module):\n    \"\"\"Feature fusion block.\"\"\"\n\n    def __init__(\n        self,\n        features,\n        activation,\n        deconv=False,\n        bn=False,\n        expand=False,\n        align_corners=True,\n        size=None,\n    ):\n        \"\"\"Init.\n\n        Args:\n            features (int): number of features\n        \"\"\"\n        super().__init__()\n\n        self.deconv = deconv\n        self.align_corners = align_corners\n\n        self.groups = 1\n\n        self.expand = expand\n        out_features = features\n        if self.expand is True:\n            out_features = features // 2\n\n        self.out_conv = nn.Conv2d(\n            features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1\n        )\n\n        self.resConfUnit1 = ResidualConvUnit(features, activation, bn)\n        self.resConfUnit2 = ResidualConvUnit(features, activation, bn)\n\n        self.skip_add = nn.quantized.FloatFunctional()\n\n        self.size = size\n\n    def forward(self, *xs, size=None):\n        \"\"\"Forward pass.\n\n        Returns:\n            tensor: output\n        \"\"\"\n        output = xs[0]\n\n        if len(xs) == 2:\n            res = self.resConfUnit1(xs[1])\n            output = self.skip_add.add(output, res)\n\n        output = self.resConfUnit2(output)\n\n        if (size is None) and (self.size is None):\n            modifier = {\"scale_factor\": 2}\n        elif size is None:\n            modifier = {\"size\": self.size}\n        else:\n            modifier = {\"size\": size}\n\n        output = nn.functional.interpolate(\n            output.contiguous(), **modifier, mode=\"bilinear\", align_corners=self.align_corners\n        )\n\n        output = self.out_conv(output)\n\n        return output\n"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/util/transform.py",
    "content": "import numpy as np\nimport cv2\n\n\nclass Resize(object):\n    \"\"\"Resize sample to given size (width, height).\n    \"\"\"\n\n    def __init__(\n        self,\n        width,\n        height,\n        resize_target=True,\n        keep_aspect_ratio=False,\n        ensure_multiple_of=1,\n        resize_method=\"lower_bound\",\n        image_interpolation_method=cv2.INTER_AREA,\n    ):\n        \"\"\"Init.\n\n        Args:\n            width (int): desired output width\n            height (int): desired output height\n            resize_target (bool, optional):\n                True: Resize the full sample (image, mask, target).\n                False: Resize image only.\n                Defaults to True.\n            keep_aspect_ratio (bool, optional):\n                True: Keep the aspect ratio of the input sample.\n                Output sample might not have the given width and height, and\n                resize behaviour depends on the parameter 'resize_method'.\n                Defaults to False.\n            ensure_multiple_of (int, optional):\n                Output width and height is constrained to be multiple of this parameter.\n                Defaults to 1.\n            resize_method (str, optional):\n                \"lower_bound\": Output will be at least as large as the given size.\n                \"upper_bound\": Output will be at max as large as the given size. (Output size might be smaller than given size.)\n                \"minimal\": Scale as least as possible.  (Output size might be smaller than given size.)\n                Defaults to \"lower_bound\".\n        \"\"\"\n        self.__width = width\n        self.__height = height\n\n        self.__resize_target = resize_target\n        self.__keep_aspect_ratio = keep_aspect_ratio\n        self.__multiple_of = ensure_multiple_of\n        self.__resize_method = resize_method\n        self.__image_interpolation_method = image_interpolation_method\n\n    def constrain_to_multiple_of(self, x, min_val=0, max_val=None):\n        y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)\n\n        if max_val is not None and y > max_val:\n            y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)\n\n        if y < min_val:\n            y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)\n\n        return y\n\n    def get_size(self, width, height):\n        # determine new height and width\n        scale_height = self.__height / height\n        scale_width = self.__width / width\n\n        if self.__keep_aspect_ratio:\n            if self.__resize_method == \"lower_bound\":\n                # scale such that output size is lower bound\n                if scale_width > scale_height:\n                    # fit width\n                    scale_height = scale_width\n                else:\n                    # fit height\n                    scale_width = scale_height\n            elif self.__resize_method == \"upper_bound\":\n                # scale such that output size is upper bound\n                if scale_width < scale_height:\n                    # fit width\n                    scale_height = scale_width\n                else:\n                    # fit height\n                    scale_width = scale_height\n            elif self.__resize_method == \"minimal\":\n                # scale as least as possbile\n                if abs(1 - scale_width) < abs(1 - scale_height):\n                    # fit width\n                    scale_height = scale_width\n                else:\n                    # fit height\n                    scale_width = scale_height\n            else:\n                raise ValueError(f\"resize_method {self.__resize_method} not implemented\")\n\n        if self.__resize_method == \"lower_bound\":\n            new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)\n            new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)\n        elif self.__resize_method == \"upper_bound\":\n            new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)\n            new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)\n        elif self.__resize_method == \"minimal\":\n            new_height = self.constrain_to_multiple_of(scale_height * height)\n            new_width = self.constrain_to_multiple_of(scale_width * width)\n        else:\n            raise ValueError(f\"resize_method {self.__resize_method} not implemented\")\n\n        return (new_width, new_height)\n\n    def __call__(self, sample):\n        width, height = self.get_size(sample[\"image\"].shape[1], sample[\"image\"].shape[0])\n        \n        # resize sample\n        sample[\"image\"] = cv2.resize(sample[\"image\"], (width, height), interpolation=self.__image_interpolation_method)\n\n        if self.__resize_target:\n            if \"depth\" in sample:\n                sample[\"depth\"] = cv2.resize(sample[\"depth\"], (width, height), interpolation=cv2.INTER_NEAREST)\n                \n            if \"mask\" in sample:\n                sample[\"mask\"] = cv2.resize(sample[\"mask\"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)\n        \n        return sample\n\n\nclass NormalizeImage(object):\n    \"\"\"Normlize image by given mean and std.\n    \"\"\"\n\n    def __init__(self, mean, std):\n        self.__mean = mean\n        self.__std = std\n\n    def __call__(self, sample):\n        sample[\"image\"] = (sample[\"image\"] - self.__mean) / self.__std\n\n        return sample\n\n\nclass PrepareForNet(object):\n    \"\"\"Prepare sample for usage as network input.\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    def __call__(self, sample):\n        image = np.transpose(sample[\"image\"], (2, 0, 1))\n        sample[\"image\"] = np.ascontiguousarray(image).astype(np.float32)\n\n        if \"depth\" in sample:\n            depth = sample[\"depth\"].astype(np.float32)\n            sample[\"depth\"] = np.ascontiguousarray(depth)\n        \n        if \"mask\" in sample:\n            sample[\"mask\"] = sample[\"mask\"].astype(np.float32)\n            sample[\"mask\"] = np.ascontiguousarray(sample[\"mask\"])\n        \n        return sample"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/video_depth.py",
    "content": "# Copyright (2025) Bytedance Ltd. and/or its affiliates \n\n# Licensed under the Apache License, Version 2.0 (the \"License\"); \n# you may not use this file except in compliance with the License. \n# You may obtain a copy of the License at \n\n#     http://www.apache.org/licenses/LICENSE-2.0 \n\n# Unless required by applicable law or agreed to in writing, software \n# distributed under the License is distributed on an \"AS IS\" BASIS, \n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n# See the License for the specific language governing permissions and \n# limitations under the License. \nimport torch\nimport torch.nn.functional as F\nimport torch.nn as nn\nfrom torchvision.transforms import Compose\nimport cv2\nfrom tqdm import tqdm\nimport numpy as np\nimport gc\n\nfrom .dinov2 import DINOv2\nfrom .dpt_temporal import DPTHeadTemporal\nfrom .util.transform import Resize, NormalizeImage, PrepareForNet\n\nfrom ..utils.util import compute_scale_and_shift, get_interpolate_frames\n\n# infer settings, do not change\nINFER_LEN = 32\nOVERLAP = 10\nKEYFRAMES = [0,12,24,25,26,27,28,29,30,31]\nINTERP_LEN = 8\n\nclass VideoDepthAnything(nn.Module):\n    def __init__(\n        self,\n        encoder='vitl',\n        features=256, \n        out_channels=[256, 512, 1024, 1024], \n        use_bn=False, \n        use_clstoken=False,\n        num_frames=32,\n        pe='ape'\n    ):\n        super(VideoDepthAnything, self).__init__()\n\n        self.intermediate_layer_idx = {\n            'vits': [2, 5, 8, 11],\n            'vitl': [4, 11, 17, 23]\n        }\n        \n        self.encoder = encoder\n        self.pretrained = DINOv2(model_name=encoder)\n        self.head = DPTHeadTemporal(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken, num_frames=num_frames, pe=pe)\n\n    def forward(self, x):\n        B, T, C, H, W = x.shape\n        patch_h, patch_w = H // 14, W // 14\n        features = self.pretrained.get_intermediate_layers(x.flatten(0,1), self.intermediate_layer_idx[self.encoder], return_class_token=True)\n        depth = self.head(features, patch_h, patch_w, T)\n        # depth = F.interpolate(depth, size=(H, W), mode=\"bilinear\", align_corners=True)\n        # depth = F.relu(depth)\n        return depth\n    \n    def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda'):\n        frame_height, frame_width = frames[0].shape[:2]\n        ratio = max(frame_height, frame_width) / min(frame_height, frame_width)\n        if ratio > 1.78:  # we recommend to process video with ratio smaller than 16:9 due to memory limitation\n            input_size = int(input_size * 1.777 / ratio)\n            input_size = round(input_size / 14) * 14\n\n        transform = Compose([\n            Resize(\n                width=input_size,\n                height=input_size,\n                resize_target=False,\n                keep_aspect_ratio=True,\n                ensure_multiple_of=14,\n                resize_method='lower_bound',\n                image_interpolation_method=cv2.INTER_CUBIC,\n            ),\n            NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n            PrepareForNet(),\n        ])\n\n        frame_list = [frames[i] for i in range(frames.shape[0])]\n        frame_step = INFER_LEN - OVERLAP\n        org_video_len = len(frame_list)\n        append_frame_len = (frame_step - (org_video_len % frame_step)) % frame_step + (INFER_LEN - frame_step)\n        frame_list = frame_list + [frame_list[-1].copy()] * append_frame_len\n        \n        depth_list = []\n        pre_input = None\n        for frame_id in tqdm(range(0, org_video_len, frame_step)):\n            cur_list = []\n            for i in range(INFER_LEN):\n                cur_list.append(torch.from_numpy(transform({'image': frame_list[frame_id+i].astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0))\n            cur_input = torch.cat(cur_list, dim=1).to(device)\n            if pre_input is not None:\n                cur_input[:, :OVERLAP, ...] = pre_input[:, KEYFRAMES, ...]\n\n            with torch.no_grad():\n                depth = self.forward(cur_input) # depth shape: [1, T, H, W]\n\n            depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True)\n            depth_list += [depth[i][0].cpu().numpy() for i in range(depth.shape[0])]\n\n            pre_input = cur_input\n\n        del frame_list\n        gc.collect()\n\n        depth_list_aligned = []\n        ref_align = []\n        align_len = OVERLAP - INTERP_LEN\n        kf_align_list = KEYFRAMES[:align_len]\n\n        for frame_id in range(0, len(depth_list), INFER_LEN):\n            if len(depth_list_aligned) == 0:\n                depth_list_aligned += depth_list[:INFER_LEN]\n                for kf_id in kf_align_list:\n                    ref_align.append(depth_list[frame_id+kf_id])\n            else:\n                curr_align = []\n                for i in range(len(kf_align_list)):\n                    curr_align.append(depth_list[frame_id+i])\n                scale, shift = compute_scale_and_shift(np.concatenate(curr_align),\n                                                       np.concatenate(ref_align),\n                                                       np.concatenate(np.ones_like(ref_align)==1))\n\n                pre_depth_list = depth_list_aligned[-INTERP_LEN:]\n                post_depth_list = depth_list[frame_id+align_len:frame_id+OVERLAP]\n                for i in range(len(post_depth_list)):\n                    post_depth_list[i] = post_depth_list[i] * scale + shift\n                    post_depth_list[i][post_depth_list[i]<0] = 0\n                depth_list_aligned[-INTERP_LEN:] = get_interpolate_frames(pre_depth_list, post_depth_list)\n\n                for i in range(OVERLAP, INFER_LEN):\n                    new_depth = depth_list[frame_id+i] * scale + shift\n                    new_depth[new_depth<0] = 0\n                    depth_list_aligned.append(new_depth)\n\n                ref_align = ref_align[:1]\n                for kf_id in kf_align_list[1:]:\n                    new_depth = depth_list[frame_id+kf_id] * scale + shift\n                    new_depth[new_depth<0] = 0\n                    ref_align.append(new_depth)\n            \n        depth_list = depth_list_aligned\n            \n        return depth_list[:org_video_len], target_fps\n        "
  },
  {
    "path": "models/core/attention.py",
    "content": "import math\nimport copy\nimport torch\nimport torch.nn as nn\nfrom torch.nn import Module, Dropout\n\n\"\"\"\nLinear Transformer proposed in \"Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention\"\nModified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py\n\"\"\"\n\n\ndef elu_feature_map(x):\n    return torch.nn.functional.elu(x) + 1\n\n\nclass PositionEncodingSine(nn.Module):\n    \"\"\"\n    This is a sinusoidal position encoding that generalized to 2-dimensional images\n    \"\"\"\n\n    def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True):\n        \"\"\"\n        Args:\n            max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels\n            temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),\n                the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact\n                on the final performance. For now, we keep both impls for backward compatability.\n                We will remove the buggy impl after re-training all variants of our released models.\n        \"\"\"\n        super().__init__()\n        pe = torch.zeros((d_model, *max_shape))\n        y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)\n        x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)\n        if temp_bug_fix:\n            div_term = torch.exp(\n                torch.arange(0, d_model // 2, 2).float()\n                * (-math.log(10000.0) / (d_model // 2))\n            )\n        else:  # a buggy implementation (for backward compatability only)\n            div_term = torch.exp(\n                torch.arange(0, d_model // 2, 2).float()\n                * (-math.log(10000.0) / d_model // 2)\n            )\n        div_term = div_term[:, None, None]  # [C//4, 1, 1]\n        pe[0::4, :, :] = torch.sin(x_position * div_term)\n        pe[1::4, :, :] = torch.cos(x_position * div_term)\n        pe[2::4, :, :] = torch.sin(y_position * div_term)\n        pe[3::4, :, :] = torch.cos(y_position * div_term)\n\n        self.register_buffer(\"pe\", pe.unsqueeze(0), persistent=False)  # [1, C, H, W]\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x: [N, C, H, W]\n        \"\"\"\n        return x + self.pe[:, :, : x.size(2), : x.size(3)].to(x.device)\n\n\nclass LinearAttention(Module):\n    def __init__(self, eps=1e-6):\n        super().__init__()\n        self.feature_map = elu_feature_map\n        self.eps = eps\n\n    def forward(self, queries, keys, values, q_mask=None, kv_mask=None):\n        \"\"\"Multi-Head linear attention proposed in \"Transformers are RNNs\"\n        Args:\n            queries: [N, L, H, D]\n            keys: [N, S, H, D]\n            values: [N, S, H, D]\n            q_mask: [N, L]\n            kv_mask: [N, S]\n        Returns:\n            queried_values: (N, L, H, D)\n        \"\"\"\n        Q = self.feature_map(queries)\n        K = self.feature_map(keys)\n\n        # set padded position to zero\n        if q_mask is not None:\n            Q = Q * q_mask[:, :, None, None]\n        if kv_mask is not None:\n            K = K * kv_mask[:, :, None, None]\n            values = values * kv_mask[:, :, None, None]\n\n        v_length = values.size(1)\n        values = values / v_length  # prevent fp16 overflow\n        KV = torch.einsum(\"nshd,nshv->nhdv\", K, values)  # (S,D)' @ S,V\n        Z = 1 / (torch.einsum(\"nlhd,nhd->nlh\", Q, K.sum(dim=1)) + self.eps)\n        queried_values = torch.einsum(\"nlhd,nhdv,nlh->nlhv\", Q, KV, Z) * v_length\n\n        return queried_values.contiguous()\n\n\nclass FullAttention(Module):\n    def __init__(self, use_dropout=False, attention_dropout=0.1):\n        super().__init__()\n        self.use_dropout = use_dropout\n        self.dropout = Dropout(attention_dropout)\n\n    def forward(self, queries, keys, values, q_mask=None, kv_mask=None):\n        \"\"\"Multi-head scaled dot-product attention, a.k.a full attention.\n        Args:\n            queries: [N, L, H, D]\n            keys: [N, S, H, D]\n            values: [N, S, H, D]\n            q_mask: [N, L]\n            kv_mask: [N, S]\n        Returns:\n            queried_values: (N, L, H, D)\n        \"\"\"\n\n        # Compute the unnormalized attention and apply the masks\n        QK = torch.einsum(\"nlhd,nshd->nlsh\", queries, keys)\n        if kv_mask is not None:\n            QK.masked_fill_(\n                ~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float(\"-inf\")\n            )\n\n        # Compute the attention and the weighted average\n        softmax_temp = 1.0 / queries.size(3) ** 0.5  # sqrt(D)\n        A = torch.softmax(softmax_temp * QK, dim=2)\n        if self.use_dropout:\n            A = self.dropout(A)\n\n        queried_values = torch.einsum(\"nlsh,nshd->nlhd\", A, values)\n\n        return queried_values.contiguous()\n\n\n# Ref: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py\nclass LoFTREncoderLayer(nn.Module):\n    def __init__(self, d_model, nhead, attention=\"linear\"):\n        super(LoFTREncoderLayer, self).__init__()\n\n        self.dim = d_model // nhead\n        self.nhead = nhead\n\n        # multi-head attention\n        self.q_proj = nn.Linear(d_model, d_model, bias=False)\n        self.k_proj = nn.Linear(d_model, d_model, bias=False)\n        self.v_proj = nn.Linear(d_model, d_model, bias=False)\n        self.attention = LinearAttention() if attention == \"linear\" else FullAttention()\n        self.merge = nn.Linear(d_model, d_model, bias=False)\n\n        # feed-forward network\n        self.mlp = nn.Sequential(\n            nn.Linear(d_model * 2, d_model * 2, bias=False),\n            nn.ReLU(),\n            nn.Linear(d_model * 2, d_model, bias=False),\n        )\n\n        # norm and dropout\n        self.norm1 = nn.LayerNorm(d_model)\n        self.norm2 = nn.LayerNorm(d_model)\n\n    def forward(self, x, source, x_mask=None, source_mask=None):\n        \"\"\"\n        Args:\n            x (torch.Tensor): [N, L, C]\n            source (torch.Tensor): [N, S, C]\n            x_mask (torch.Tensor): [N, L] (optional)\n            source_mask (torch.Tensor): [N, S] (optional)\n        \"\"\"\n        bs = x.size(0)\n        query, key, value = x, source, source\n\n        # multi-head attention\n        query = self.q_proj(query).view(bs, -1, self.nhead, self.dim)  # [N, L, (H, D)]\n        key = self.k_proj(key).view(bs, -1, self.nhead, self.dim)  # [N, S, (H, D)]\n        value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)\n        message = self.attention(\n            query, key, value, q_mask=x_mask, kv_mask=source_mask\n        )  # [N, L, (H, D)]\n        message = self.merge(message.view(bs, -1, self.nhead * self.dim))  # [N, L, C]\n        message = self.norm1(message)\n\n        # feed-forward network\n        message = self.mlp(torch.cat([x, message], dim=2))\n        message = self.norm2(message)\n\n        return x + message\n\n\nclass LocalFeatureTransformer(nn.Module):\n    \"\"\"A Local Feature Transformer (LoFTR) module.\"\"\"\n\n    def __init__(self, d_model, nhead, layer_names, attention):\n        super(LocalFeatureTransformer, self).__init__()\n\n        self.d_model = d_model\n        self.nhead = nhead\n        self.layer_names = layer_names\n        encoder_layer = LoFTREncoderLayer(d_model, nhead, attention)\n        self.layers = nn.ModuleList(\n            [copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]\n        )\n        self._reset_parameters()\n\n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n\n    def forward(self, feat0, feat1, mask0=None, mask1=None):\n        \"\"\"\n        Args:\n            feat0 (torch.Tensor): [N, L, C]\n            feat1 (torch.Tensor): [N, S, C]\n            mask0 (torch.Tensor): [N, L] (optional)\n            mask1 (torch.Tensor): [N, S] (optional)\n        \"\"\"\n        assert self.d_model == feat0.size(\n            2\n        ), \"the feature number of src and transformer must be equal\"\n\n        for layer, name in zip(self.layers, self.layer_names):\n\n            if name == \"self\":\n                feat0 = layer(feat0, feat0, mask0, mask0)\n                feat1 = layer(feat1, feat1, mask1, mask1)\n            elif name == \"cross\":\n                feat0 = layer(feat0, feat1, mask0, mask1)\n                feat1 = layer(feat1, feat0, mask1, mask0)\n            else:\n                raise KeyError\n\n        return feat0, feat1\n"
  },
  {
    "path": "models/core/corr.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom einops import rearrange\n\n\ndef bilinear_sampler(img, coords, mode=\"bilinear\", mask=False):\n    \"\"\"Wrapper for grid_sample, uses pixel coordinates\"\"\"\n    H, W = img.shape[-2:]\n    xgrid, ygrid = coords.split([1, 1], dim=-1)\n    xgrid = 2 * xgrid / (W - 1) - 1\n    if H > 1:\n        ygrid = 2 * ygrid/(H - 1) - 1\n    img = img.contiguous()\n    grid = torch.cat([xgrid, ygrid], dim=-1).contiguous()\n    img = F.grid_sample(img, grid, align_corners=True)\n\n    if mask:\n        mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)\n        return img, mask.float()\n\n    return img\n\n\ndef coords_grid(batch, ht, wd, device):\n    coords = torch.meshgrid(\n        torch.arange(ht, device=device), torch.arange(wd, device=device), indexing=\"ij\"\n    )\n    coords = torch.stack(coords[::-1], dim=0).float()\n    return coords[None].repeat(batch, 1, 1, 1)\n\n\nclass AAPC:\n    \"\"\"\n    Implementation of All-in-All-Pair Correlation.\n    \"\"\"\n    def __init__(self, fmap1, fmap2, att=None):\n        self.fmap1 = fmap1\n        self.fmap2 = fmap2\n\n        self.att = att\n        self.coords = coords_grid(fmap1.shape[0], fmap1.shape[2], fmap1.shape[3], fmap1.device)\n\n    def __call__(self, flow, extra_offset, small_patch=False):\n\n        corr = self.correlation(self.fmap1, self.fmap2, flow, small_patch)\n\n        return corr\n\n    def correlation(self, left_feature, right_feature, flow, small_patch):\n        flow[:, 1:] = 0\n        coords = self.coords - flow\n        coords = coords.permute(0, 2, 3, 1)\n        right_feature = bilinear_sampler(right_feature, coords)\n\n        if small_patch:\n            psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]\n            dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]\n        else:\n            psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]\n            dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]\n\n        N, C, H, W = left_feature.size()\n        lefts = torch.split(left_feature, [C // 4] * 4, dim=1)\n        rights = torch.split(right_feature, [C // 4] * 4, dim=1)\n        corrs = []\n        for i in range(len(psize_list)):\n            corr = self.get_correlation(lefts[i], rights[i], psize_list[i], dilate_list[i])\n            corrs.append(corr)\n\n        final_corr = torch.cat(corrs, dim=1)\n        return final_corr\n\n    def get_correlation(self, left_feature, right_feature, psize=(3, 3), dilate=(1, 1)):\n\n        N, C, H, W = left_feature.size()\n\n        di_y, di_x = dilate[0], dilate[1]\n        pady, padx = psize[0] // 2 * di_y, psize[1] // 2 * di_x\n\n        left_pad = F.pad(left_feature, [padx, padx, pady, pady], mode='replicate')\n        right_pad = F.pad(right_feature, [padx, padx, pady, pady], mode='replicate')\n\n        corr_list = []\n        for dy1 in range(0, pady * 2 + 1, di_y):\n            for dx1 in range(0, padx * 2 + 1, di_x):\n                left_crop = left_pad[:, :, dy1:dy1 + H, dx1:dx1 + W]\n\n                for dy2 in range(0, pady * 2 + 1, di_y):\n                    for dx2 in range(0, padx * 2 + 1, di_x):\n                        right_crop = right_pad[:, :, dy2:dy2 + H, dx2:dx2 + W]\n                        assert right_crop.size() == left_crop.size()\n                        corr = (left_crop * right_crop).sum(dim=1, keepdim=True)  # Sum over channels\n                        corr_list.append(corr)\n\n        corr_final = torch.cat(corr_list, dim=1)\n\n        return corr_final"
  },
  {
    "path": "models/core/extractor.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport os\nimport sys\nimport importlib\nimport timm\nfrom einops import rearrange\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, in_planes, planes, norm_fn=\"group\", stride=1):\n        super(ResidualBlock, self).__init__()\n\n        self.conv1 = nn.Conv2d(\n            in_planes, planes, kernel_size=3, padding=1, stride=stride\n        )\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)\n        self.relu = nn.ReLU(inplace=True)\n\n        num_groups = planes // 8\n\n        if norm_fn == \"group\":\n            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n\n        elif norm_fn == \"batch\":\n            self.norm1 = nn.BatchNorm2d(planes)\n            self.norm2 = nn.BatchNorm2d(planes)\n            self.norm3 = nn.BatchNorm2d(planes)\n\n        elif norm_fn == \"instance\":\n            self.norm1 = nn.InstanceNorm2d(planes, affine=False)\n            self.norm2 = nn.InstanceNorm2d(planes, affine=False)\n            self.norm3 = nn.InstanceNorm2d(planes, affine=False)\n\n        elif norm_fn == \"none\":\n            self.norm1 = nn.Sequential()\n            self.norm2 = nn.Sequential()\n            self.norm3 = nn.Sequential()\n\n        self.downsample = nn.Sequential(\n            nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3\n        )\n\n    def forward(self, x):\n        y = x\n        y = self.relu(self.norm1(self.conv1(y)))\n        y = self.relu(self.norm2(self.conv2(y)))\n\n        x = self.downsample(x)\n\n        return self.relu(x + y)\n\n\nclass BasicEncoder(nn.Module):\n    def __init__(self, input_dim=3, output_dim=128, norm_fn=\"batch\", dropout=0.0):\n        super(BasicEncoder, self).__init__()\n        self.norm_fn = norm_fn\n\n        if self.norm_fn == \"group\":\n            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)\n\n        elif self.norm_fn == \"batch\":\n            self.norm1 = nn.BatchNorm2d(64)\n\n        elif self.norm_fn == \"instance\":\n            self.norm1 = nn.InstanceNorm2d(64, affine=False)\n\n        elif self.norm_fn == \"none\":\n            self.norm1 = nn.Sequential()\n\n        self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3)\n        self.relu1 = nn.ReLU(inplace=True)\n\n        self.in_planes = 64\n        self.layer1 = self._make_layer(64, stride=1)\n        self.layer2 = self._make_layer(96, stride=2)\n        self.layer3 = self._make_layer(128, stride=1)\n\n        # output convolution\n        self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)\n\n        self.dropout = None\n        if dropout > 0:\n            self.dropout = nn.Dropout2d(p=dropout)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, dim, stride=1):\n        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)\n        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)\n        layers = (layer1, layer2)\n\n        self.in_planes = dim\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n\n        # if input is list, combine batch dimension\n        is_list = isinstance(x, tuple) or isinstance(x, list)\n        if is_list:\n            batch_dim = x[0].shape[0]\n            x = torch.cat(x, dim=0)\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu1(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n\n        x = self.conv2(x)\n\n        if self.dropout is not None:\n            x = self.dropout(x)\n\n        if is_list:\n            x = torch.split(x, x.shape[0] // 2, dim=0)\n\n        return x\n\n\nclass MultiBasicEncoder(nn.Module):\n    def __init__(self, output_dim=[128], norm_fn='batch', dropout=0.0, downsample=3):\n        super(MultiBasicEncoder, self).__init__()\n        self.norm_fn = norm_fn\n        self.downsample = downsample\n\n        if self.norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)\n\n        elif self.norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(64)\n\n        elif self.norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(64)\n\n        elif self.norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3)\n        self.relu1 = nn.ReLU(inplace=True)\n\n        self.in_planes = 64\n        self.layer1 = self._make_layer(64, stride=1)\n        self.layer2 = self._make_layer(96, stride=1 + (downsample > 1))\n        self.layer3 = self._make_layer(128, stride=1 + (downsample > 0))\n        self.layer4 = self._make_layer(128, stride=2)\n        self.layer5 = self._make_layer(128, stride=2)\n\n        output_list = []\n        for dim in output_dim:\n            conv_out = nn.Sequential(\n                ResidualBlock(128, 128, self.norm_fn, stride=1),\n                nn.Conv2d(128, dim[2], 3, padding=1))\n            output_list.append(conv_out)\n\n        self.outputs08 = nn.ModuleList(output_list)\n\n        output_list = []\n        for dim in output_dim:\n            conv_out = nn.Sequential(\n                ResidualBlock(128, 128, self.norm_fn, stride=1),\n                nn.Conv2d(128, dim[1], 3, padding=1))\n            output_list.append(conv_out)\n\n        self.outputs16 = nn.ModuleList(output_list)\n\n        output_list = []\n        for dim in output_dim:\n            conv_out = nn.Conv2d(128, dim[0], 3, padding=1)\n            output_list.append(conv_out)\n\n        self.outputs32 = nn.ModuleList(output_list)\n\n        if dropout > 0:\n            self.dropout = nn.Dropout2d(p=dropout)\n        else:\n            self.dropout = None\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, dim, stride=1):\n        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)\n        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)\n        layers = (layer1, layer2)\n\n        self.in_planes = dim\n        return nn.Sequential(*layers)\n\n    def forward(self, x, dual_inp=False, num_layers=3):\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu1(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        if dual_inp:\n            v = x\n            x = x[:(x.shape[0]//2)]\n\n        outputs08 = [f(x) for f in self.outputs08]\n        if num_layers == 1:\n            return (outputs08, v) if dual_inp else (outputs08,)\n\n        y = self.layer4(x)\n        outputs16 = [f(y) for f in self.outputs16]\n\n        if num_layers == 2:\n            return (outputs08, outputs16, v) if dual_inp else (outputs08, outputs16)\n\n        z = self.layer5(y)\n        outputs32 = [f(z) for f in self.outputs32]\n\n        return (outputs08, outputs16, outputs32, v) if dual_inp else (outputs08, outputs16, outputs32)\n\n\nclass DepthExtractor(nn.Module):\n    def __init__(self):\n        super(DepthExtractor, self).__init__()\n\n        thirdparty_path = os.path.abspath(os.path.join(os.path.dirname(__file__), \"./models/Video-Depth-Anything\"))\n        sys.path.append(thirdparty_path)\n        videodepthanything_ppl = importlib.import_module(\n            \"stereoanyvideo.models.Video-Depth-Anything.video_depth_anything.video_depth\"\n        )\n        model_configs = {\n            'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},\n            'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},\n        }\n        encoder = 'vits'  # or 'vitl',\n\n        self.depthanything = videodepthanything_ppl.VideoDepthAnything(**model_configs[encoder])\n        self.depthanything.load_state_dict(torch.load(f'./models/Video-Depth-Anything/checkpoints/video_depth_anything_{encoder}.pth'))\n        self.depthanything.eval()\n        self.conv = nn.Conv2d(32, 32, kernel_size=4, stride=4)\n\n    def forward(self, x):\n        # Store original height and width\n        B, T, C, orig_h, orig_w = x.shape\n\n        # Calculate new height and width divisible by 14\n        new_h = (orig_h // 14) * 14\n        new_w = (orig_w // 14) * 14\n\n        # Resize input to be divisible by 14 for depthanything\n        resized_input = F.interpolate(\n            x.flatten(0, 1),\n            size=(new_h, new_w),\n            mode='bilinear',\n            align_corners=False\n        ).unflatten(0, (B, T))\n\n        # Pass through depthanything\n        depth_features_resized = self.depthanything(resized_input).contiguous()\n\n        # Resize depth features back to the original resolution\n        depth_features = F.interpolate(\n            depth_features_resized,\n            size=(orig_h, orig_w),\n            mode='bilinear',\n            align_corners=False\n        )\n\n        # Apply convolution to the depth features\n        depth_features = self.conv(depth_features).unflatten(0, (B, T))\n\n        return depth_features\n"
  },
  {
    "path": "models/core/model_zoo.py",
    "content": "import copy\nfrom pytorch3d.implicitron.tools.config import get_default_args\nfrom stereoanyvideo.models.stereoanyvideo_model import StereoAnyVideoModel\n\nMODELS = [StereoAnyVideoModel]\n\n\n_MODEL_NAME_TO_MODEL = {model_cls.__name__: model_cls for model_cls in MODELS}\n_MODEL_CONFIG_NAME_TO_DEFAULT_CONFIG = {}\nfor model_cls in MODELS:\n    _MODEL_CONFIG_NAME_TO_DEFAULT_CONFIG[\n        model_cls.MODEL_CONFIG_NAME\n    ] = get_default_args(model_cls)\n\nMODEL_NAME_NONE = \"NONE\"\n\n\ndef model_zoo(model_name: str, **kwargs):\n    if model_name.upper() == MODEL_NAME_NONE:\n        return None\n\n    model_cls = _MODEL_NAME_TO_MODEL.get(model_name)\n\n    if model_cls is None:\n        raise ValueError(f\"No such model name: {model_name}\")\n\n    model_cls_params = {}\n    if \"model_zoo\" in getattr(model_cls, \"__dataclass_fields__\", []):\n        model_cls_params[\"model_zoo\"] = model_zoo\n    print(\n        f\"{model_cls.MODEL_CONFIG_NAME} model configs:\",\n        kwargs.get(model_cls.MODEL_CONFIG_NAME),\n    )\n    return model_cls(**model_cls_params, **kwargs.get(model_cls.MODEL_CONFIG_NAME, {}))\n\n\ndef get_all_model_default_configs():\n    return copy.deepcopy(_MODEL_CONFIG_NAME_TO_DEFAULT_CONFIG)\n"
  },
  {
    "path": "models/core/stereoanyvideo.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom typing import Dict, List\nfrom einops import rearrange\nimport collections\nfrom collections import defaultdict\nfrom itertools import repeat\nimport unfoldNd\n\nfrom stereoanyvideo.models.core.update import SequenceUpdateBlock3D\nfrom stereoanyvideo.models.core.extractor import BasicEncoder, MultiBasicEncoder, DepthExtractor\nfrom stereoanyvideo.models.core.corr import AAPC\nfrom stereoanyvideo.models.core.utils.utils import InputPadder, interp\n\nautocast = torch.cuda.amp.autocast\n\ndef _ntuple(n):\n    def parse(x):\n        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):\n            return tuple(x)\n        return tuple(repeat(x, n))\n\n    return parse\n\n\ndef exists(val):\n    return val is not None\n\n\ndef default(val, d):\n    return val if exists(val) else d\n\n\nto_2tuple = _ntuple(2)\n\nclass Mlp(nn.Module):\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.GELU,\n        norm_layer=None,\n        bias=True,\n        drop=0.0,\n        use_conv=False,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        bias = to_2tuple(bias)\n        drop_probs = to_2tuple(drop)\n        linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear\n\n        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])\n        self.act = act_layer()\n        self.drop1 = nn.Dropout(drop_probs[0])\n        self.norm = (\n            norm_layer(hidden_features) if norm_layer is not None else nn.Identity()\n        )\n        self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])\n        self.drop2 = nn.Dropout(drop_probs[1])\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop1(x)\n        x = self.fc2(x)\n        x = self.drop2(x)\n        return x\n\n\nclass StereoAnyVideo(nn.Module):\n    def __init__(self, mixed_precision=False):\n        super(StereoAnyVideo, self).__init__()\n\n        self.mixed_precision = mixed_precision\n\n        self.hidden_dim = 128\n        self.context_dim = 128\n        self.dropout = 0\n\n        # feature network and update block\n        self.cnet = BasicEncoder(output_dim=96, norm_fn='instance', dropout=self.dropout)\n        self.fnet = BasicEncoder(output_dim=96, norm_fn='instance', dropout=self.dropout)\n        self.depthnet = DepthExtractor()\n        self.corr_mlp = Mlp(in_features=4 * 9 * 9, hidden_features=256, out_features=128)\n        self.update_block = SequenceUpdateBlock3D(hidden_dim=self.hidden_dim, cor_planes=128, mask_size=4)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {\"time_embed\"}\n\n    def freeze_bn(self):\n        for m in self.modules():\n            if isinstance(m, nn.BatchNorm2d):\n                m.eval()\n\n    def convex_upsample(self, flow, mask, rate=4):\n        \"\"\" Upsample flow field [H/rate, W/rate, 2] -> [H, W, 2] using convex combination \"\"\"\n        N, _, H, W = flow.shape\n        mask = mask.view(N, 1, 9, rate, rate, H, W)\n        mask = torch.softmax(mask, dim=2)\n\n        up_flow = F.unfold(rate * flow, [3, 3], padding=1)\n        up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)\n\n        up_flow = torch.sum(mask * up_flow, dim=2)\n        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)\n        return up_flow.reshape(N, 2, rate * H, rate * W)\n\n    def convex_upsample_3D(self, flow, mask, b, T, rate=4):\n        \"\"\"Upsample flow field from [T, H/rate, W/rate, 2] to [T, H, W, 2] using convex combination.\n\n        unfoldNd repo: https://github.com/f-dangel/unfoldNd\n        Run: pip install --user unfoldNd\n\n        Args:\n            flow: (N*T, C_flow, H, W)\n            mask: (N*T, C_mask, H, W) or (N, 1, 27, 1, rate, rate, T, H, W)\n            rate: int\n        \"\"\"\n        flow = rearrange(flow, \"(b t) c h w -> b c t h w\", b=b, t=T)\n        mask = rearrange(mask, \"(b t) c h w -> b c t h w\", b=b, t=T)\n\n        N, _, T, H, W = flow.shape\n\n        mask = mask.view(N, 1, 27, 1, rate, rate, T, H, W)  # (N, 1, 27, rate, rate, rate, T, H, W) if upsample T\n        mask = torch.softmax(mask, dim=2)\n\n        upsample = unfoldNd.UnfoldNd([3, 3, 3], padding=1)\n        flow_upsampled = upsample(rate * flow)\n        flow_upsampled = flow_upsampled.view(N, 2, 27, 1, 1, 1, T, H, W)\n        flow_upsampled = torch.sum(mask * flow_upsampled, dim=2)\n        flow_upsampled = flow_upsampled.permute(0, 1, 5, 2, 6, 3, 7, 4)\n        flow_upsampled = flow_upsampled.reshape(N, 2, T, rate * H,\n                                                rate * W)  # (N, 2, rate*T, rate*H, rate*W) if upsample T\n        up_flow = rearrange(flow_upsampled, \"b c t h w -> (b t) c h w\")\n\n        return up_flow\n\n    def zero_init(self, fmap):\n        N, C, H, W = fmap.shape\n        flow_u = torch.zeros([N, 1, H, W], dtype=torch.float)\n        flow_v = torch.zeros([N, 1, H, W], dtype=torch.float)\n        flow = torch.cat([flow_u, flow_v], dim=1).to(fmap.device)\n        return flow\n\n    def forward_batch_test(\n        self, batch_dict, iters = 24, flow_init=None,\n    ):\n        kernel_size = 20\n        stride = kernel_size // 2\n        predictions = defaultdict(list)\n\n        disp_preds = []\n        video = batch_dict[\"stereo_video\"]\n\n        num_ims = len(video)\n        print(\"video\", video.shape)\n\n        for i in range(0, num_ims, stride):\n            left_ims = video[i : min(i + kernel_size, num_ims), 0]\n            padder = InputPadder(left_ims.shape, divis_by=32)\n            right_ims = video[i : min(i + kernel_size, num_ims), 1]\n            left_ims, right_ims = padder.pad(left_ims, right_ims)\n            if flow_init is not None:\n                flow_init_ims = flow_init[i: min(i + kernel_size, num_ims)]\n                flow_init_ims = padder.pad(flow_init_ims)[0]\n                with autocast(enabled=self.mixed_precision):\n                    disparities_forw = self.forward(\n                        left_ims[None].cuda(),\n                        right_ims[None].cuda(),\n                        flow_init=flow_init_ims,\n                        iters=iters,\n                        test_mode=True,\n                    )\n            else:\n                with autocast(enabled=self.mixed_precision):\n                    disparities_forw = self.forward(\n                        left_ims[None].cuda(),\n                        right_ims[None].cuda(),\n                        iters=iters,\n                        test_mode=True,\n                    )\n\n            disparities_forw = padder.unpad(disparities_forw[:, 0])[:, None].cpu()\n\n            if len(disp_preds) > 0 and len(disparities_forw) >= stride:\n\n                if len(disparities_forw) < kernel_size:\n                    disp_preds.append(disparities_forw[stride // 2 :])\n                else:\n                    disp_preds.append(disparities_forw[stride // 2 : -stride // 2])\n\n            elif len(disp_preds) == 0:\n                disp_preds.append(disparities_forw[: -stride // 2])\n\n        predictions[\"disparity\"] = (torch.cat(disp_preds).squeeze(1).abs())[:, :1]\n\n        return predictions\n\n    def forward(self, image1, image2, flow_init=None, iters=12, test_mode=False):\n        b, T, c, h, w = image1.shape\n\n        image1 = image1 / 255.0\n        image2 = image2 / 255.0\n\n        # Normalize using mean and std for ImageNet pre-trained models\n        mean = torch.tensor([0.485, 0.456, 0.406], device=image1.device).view(1, 1, 3, 1, 1)\n        std = torch.tensor([0.229, 0.224, 0.225], device=image1.device).view(1, 1, 3, 1, 1)\n\n        image1 = (image1 - mean) / std\n        image2 = (image2 - mean) / std\n        image1 = image1.float()\n        image2 = image2.float()\n\n        # feature network\n        with autocast(enabled=self.mixed_precision):\n            fmap1_depth_feature = self.depthnet(image1)\n            fmap2_depth_feature = self.depthnet(image2)\n            fmap1_cnet_feature = self.cnet(image1.flatten(0, 1)).unflatten(0, (b, T))\n            fmap1_fnet_feature = self.fnet(image1.flatten(0, 1)).unflatten(0, (b, T))\n            fmap2_fnet_feature = self.fnet(image2.flatten(0, 1)).unflatten(0, (b, T))\n\n        fmap1 = torch.cat((fmap1_depth_feature, fmap1_fnet_feature), dim=2).flatten(0, 1)\n        fmap2 = torch.cat((fmap2_depth_feature, fmap2_fnet_feature), dim=2).flatten(0, 1)\n\n        context = torch.cat((fmap1_depth_feature, fmap1_cnet_feature), dim=2).flatten(0, 1)\n\n        with autocast(enabled=self.mixed_precision):\n            net = torch.tanh(context)\n            inp = torch.relu(context)\n\n            s_net = F.avg_pool2d(net, 2, stride=2)\n            s_inp = F.avg_pool2d(inp, 2, stride=2)\n\n            # 1/4 -> 1/8\n            # feature\n            s_fmap1 = F.avg_pool2d(fmap1, 2, stride=2)\n            s_fmap2 = F.avg_pool2d(fmap2, 2, stride=2)\n\n            # 1/4 -> 1/16\n            # feature\n            ss_fmap1 = F.avg_pool2d(fmap1, 4, stride=4)\n            ss_fmap2 = F.avg_pool2d(fmap2, 4, stride=4)\n\n            ss_net = F.avg_pool2d(net, 4, stride=4)\n            ss_inp = F.avg_pool2d(inp, 4, stride=4)\n\n        # Correlation\n        corr_fn = AAPC(fmap1, fmap2)\n        s_corr_fn = AAPC(s_fmap1, s_fmap2)\n        ss_corr_fn = AAPC(ss_fmap1, ss_fmap2)\n\n        # cascaded refinement (1/16 + 1/8 + 1/4)\n        flow_predictions = []\n        flow = None\n        flow_up = None\n\n        if flow_init is not None:\n            flow_init = flow_init.cuda()\n            scale = fmap1.shape[2] / flow_init.shape[2]\n            flow = scale * interp(flow_init, size=(fmap1.shape[2], fmap1.shape[3]))\n        else:\n            # init flow\n            ss_flow = self.zero_init(ss_fmap1)\n\n            # 1/16\n            for itr in range(iters // 2):\n                if itr % 2 == 0:\n                    small_patch = False\n                else:\n                    small_patch = True\n\n                ss_flow = ss_flow.detach()\n                out_corrs = ss_corr_fn(ss_flow, None, small_patch=small_patch)  # 36 * H/16 * W/16\n                out_corrs = self.corr_mlp(out_corrs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)\n                with autocast(enabled=self.mixed_precision):\n                    ss_net, up_mask, delta_flow = self.update_block(ss_net, ss_inp, out_corrs, ss_flow, t=T)\n\n                ss_flow = ss_flow + delta_flow\n                flow = self.convex_upsample_3D(ss_flow, up_mask, b, T, rate=4)  # 2 * H/4 * W/4\n                flow_up = 4 * F.interpolate(flow, size=(4 * flow.shape[2], 4 * flow.shape[3]), mode='bilinear',\n                                            align_corners=True)  # 2 * H/2 * W/2\n                flow_predictions.append(flow_up[:, :1])\n\n            scale = s_fmap1.shape[2] / flow.shape[2]\n            s_flow = scale * interp(flow, size=(s_fmap1.shape[2], s_fmap1.shape[3]))\n\n            # 1/8\n            for itr in range(iters // 2):\n                if itr % 2 == 0:\n                    small_patch = False\n                else:\n                    small_patch = True\n\n                s_flow = s_flow.detach()\n                out_corrs = s_corr_fn(s_flow, None, small_patch=small_patch)\n                out_corrs = self.corr_mlp(out_corrs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)\n                with autocast(enabled=self.mixed_precision):\n                    s_net, up_mask, delta_flow = self.update_block(s_net, s_inp, out_corrs, s_flow, t=T)\n\n                s_flow = s_flow + delta_flow\n                flow = self.convex_upsample_3D(s_flow, up_mask, b, T, rate=4)\n                flow_up = 2 * F.interpolate(flow, size=(2 * flow.shape[2], 2 * flow.shape[3]), mode='bilinear',\n                                            align_corners=True)\n                flow_predictions.append(flow_up[:, :1])\n\n            scale = fmap1.shape[2] / flow.shape[2]\n            flow = scale * interp(flow, size=(fmap1.shape[2], fmap1.shape[3]))\n\n        # 1/4\n        for itr in range(iters):\n            if itr % 2 == 0:\n                small_patch = False\n            else:\n                small_patch = True\n\n            flow = flow.detach()\n            out_corrs = corr_fn(flow, None, small_patch=small_patch)\n            out_corrs = self.corr_mlp(out_corrs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)\n            with autocast(enabled=self.mixed_precision):\n                net, up_mask, delta_flow = self.update_block(net, inp, out_corrs, flow, t=T)\n\n            flow = flow + delta_flow\n            flow_up = self.convex_upsample_3D(flow, up_mask, b, T, rate=4)\n            flow_predictions.append(flow_up[:, :1])\n\n        predictions = torch.stack(flow_predictions)\n        predictions = rearrange(predictions, \"d (b t) c h w -> d t b c h w\", b=b, t=T)\n        flow_up = predictions[-1]\n\n        if test_mode:\n            return flow_up\n\n        return predictions\n\n\n"
  },
  {
    "path": "models/core/update.py",
    "content": "from einops import rearrange\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom opt_einsum import contract\nfrom stereoanyvideo.models.core.attention import LoFTREncoderLayer\n\n\ndef pool2x(x):\n    return F.avg_pool2d(x, 3, stride=2, padding=1)\n\ndef pool4x(x):\n    return F.avg_pool2d(x, 5, stride=4, padding=1)\n\ndef interp(x, dest):\n    interp_args = {'mode': 'bilinear', 'align_corners': True}\n    return F.interpolate(x, dest.shape[2:], **interp_args)\n\n\nclass FlowHead(nn.Module):\n    def __init__(self, input_dim=128, hidden_dim=256, output_dim=2):\n        super(FlowHead, self).__init__()\n        self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)\n        self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1)\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        return self.conv2(self.relu(self.conv1(x)))\n\n\nclass FlowHead3D(nn.Module):\n    def __init__(self, input_dim=128, hidden_dim=256):\n        super(FlowHead3D, self).__init__()\n        self.conv1 = nn.Conv3d(input_dim, hidden_dim, 3, padding=1)\n        self.conv2 = nn.Conv3d(hidden_dim, 2, 3, padding=1)\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        return self.conv2(self.relu(self.conv1(x)))\n\n\nclass ConvGRU(nn.Module):\n    def __init__(self, hidden_dim, input_dim, kernel_size=3):\n        super(ConvGRU, self).__init__()\n        self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)\n        self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)\n        self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)\n\n    def forward(self, h, cz, cr, cq, *x_list):\n        x = torch.cat(x_list, dim=1)\n        hx = torch.cat([h, x], dim=1)\n\n        z = torch.sigmoid(self.convz(hx) + cz)\n        r = torch.sigmoid(self.convr(hx) + cr)\n        q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)) + cq)\n\n        h = (1-z) * h + z * q\n        return h\n\n\nclass SepConvGRU(nn.Module):\n    def __init__(self, hidden_dim=128, input_dim=192+128):\n        super(SepConvGRU, self).__init__()\n        self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))\n        self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))\n        self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))\n\n        self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))\n        self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))\n        self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))\n\n\n    def forward(self, h, *x):\n        # horizontal\n        x = torch.cat(x, dim=1)\n        hx = torch.cat([h, x], dim=1)\n        z = torch.sigmoid(self.convz1(hx))\n        r = torch.sigmoid(self.convr1(hx))\n        q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))\n        h = (1-z) * h + z * q\n\n        # vertical\n        hx = torch.cat([h, x], dim=1)\n        z = torch.sigmoid(self.convz2(hx))\n        r = torch.sigmoid(self.convr2(hx))\n        q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))\n        h = (1-z) * h + z * q\n\n        return h\n\n\nclass BasicMotionEncoder(nn.Module):\n    def __init__(self, cor_planes):\n        super(BasicMotionEncoder, self).__init__()\n\n        self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)\n        self.convc2 = nn.Conv2d(256, 192, 3, padding=1)\n        self.convf1 = nn.Conv2d(2, 128, 7, padding=3)\n        self.convf2 = nn.Conv2d(128, 64, 3, padding=1)\n        self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)\n\n    def forward(self, flow, corr):\n        cor = F.relu(self.convc1(corr))\n        cor = F.relu(self.convc2(cor))\n        flo = F.relu(self.convf1(flow))\n        flo = F.relu(self.convf2(flo))\n\n        cor_flo = torch.cat([cor, flo], dim=1)\n        out = F.relu(self.conv(cor_flo))\n        return torch.cat([out, flow], dim=1)\n\n\nclass BasicMotionEncoder3D(nn.Module):\n    def __init__(self, cor_planes):\n        super(BasicMotionEncoder3D, self).__init__()\n\n        self.convc1 = nn.Conv3d(cor_planes, 256, 1, padding=0)\n        self.convc2 = nn.Conv3d(256, 192, 3, padding=1)\n        self.convf1 = nn.Conv3d(2, 128, 5, padding=2)\n        self.convf2 = nn.Conv3d(128, 64, 3, padding=1)\n        self.conv = nn.Conv3d(64 + 192, 128 - 2, 3, padding=1)\n\n    def forward(self, flow, corr):\n        cor = F.relu(self.convc1(corr))\n        cor = F.relu(self.convc2(cor))\n        flo = F.relu(self.convf1(flow))\n        flo = F.relu(self.convf2(flo))\n\n        cor_flo = torch.cat([cor, flo], dim=1)\n        out = F.relu(self.conv(cor_flo))\n        return torch.cat([out, flow], dim=1)\n\n\nclass SepConvGRU3D(nn.Module):\n    def __init__(self, hidden_dim=128, input_dim=192 + 128):\n        super(SepConvGRU3D, self).__init__()\n        self.convz1 = nn.Conv3d(\n            hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)\n        )\n        self.convr1 = nn.Conv3d(\n            hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)\n        )\n        self.convq1 = nn.Conv3d(\n            hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)\n        )\n\n        self.convz2 = nn.Conv3d(\n            hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)\n        )\n        self.convr2 = nn.Conv3d(\n            hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)\n        )\n        self.convq2 = nn.Conv3d(\n            hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)\n        )\n\n        self.convz3 = nn.Conv3d(\n            hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)\n        )\n        self.convr3 = nn.Conv3d(\n            hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)\n        )\n        self.convq3 = nn.Conv3d(\n            hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)\n        )\n\n    def forward(self, h, x):\n        hx = torch.cat([h, x], dim=1)\n        z = torch.sigmoid(self.convz1(hx))\n        r = torch.sigmoid(self.convr1(hx))\n        q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))\n        h = (1 - z) * h + z * q\n\n        # vertical\n        hx = torch.cat([h, x], dim=1)\n        z = torch.sigmoid(self.convz2(hx))\n        r = torch.sigmoid(self.convr2(hx))\n        q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))\n        h = (1 - z) * h + z * q\n\n        # time\n        hx = torch.cat([h, x], dim=1)\n        z = torch.sigmoid(self.convz3(hx))\n        r = torch.sigmoid(self.convr3(hx))\n        q = torch.tanh(self.convq3(torch.cat([r * h, x], dim=1)))\n        h = (1 - z) * h + z * q\n\n        return h\n\n\nclass SKSepConvGRU3D(nn.Module):\n    def __init__(self, hidden_dim=128, input_dim=192 + 128):\n        super(SKSepConvGRU3D, self).__init__()\n        self.convz1 = nn.Sequential(\n            nn.Conv3d(hidden_dim+input_dim, hidden_dim, (1, 1, 15), padding=(0, 0, 7)),\n            nn.GELU(),\n            nn.Conv3d(hidden_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)),\n        )\n        self.convr1 = nn.Sequential(\n            nn.Conv3d(hidden_dim+input_dim, hidden_dim, (1, 1, 15), padding=(0, 0, 7)),\n            nn.GELU(),\n            nn.Conv3d(hidden_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)),\n        )\n        self.convq1 = nn.Conv3d(\n            hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)\n        )\n\n        self.convz2 = nn.Conv3d(\n            hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)\n        )\n        self.convr2 = nn.Conv3d(\n            hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)\n        )\n        self.convq2 = nn.Conv3d(\n            hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)\n        )\n\n        self.convz3 = nn.Conv3d(\n            hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)\n        )\n        self.convr3 = nn.Conv3d(\n            hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)\n        )\n        self.convq3 = nn.Conv3d(\n            hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)\n        )\n\n    def forward(self, h, x):\n        hx = torch.cat([h, x], dim=1)\n        z = torch.sigmoid(self.convz1(hx))\n        r = torch.sigmoid(self.convr1(hx))\n        q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))\n        h = (1 - z) * h + z * q\n\n        # vertical\n        hx = torch.cat([h, x], dim=1)\n        z = torch.sigmoid(self.convz2(hx))\n        r = torch.sigmoid(self.convr2(hx))\n        q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))\n        h = (1 - z) * h + z * q\n\n        # time\n        hx = torch.cat([h, x], dim=1)\n        z = torch.sigmoid(self.convz3(hx))\n        r = torch.sigmoid(self.convr3(hx))\n        q = torch.tanh(self.convq3(torch.cat([r * h, x], dim=1)))\n        h = (1 - z) * h + z * q\n\n        return h\n\n\nclass BasicUpdateBlock(nn.Module):\n    def __init__(self, hidden_dim, cor_planes, mask_size=8, attention_type=None):\n        super(BasicUpdateBlock, self).__init__()\n        self.attention_type = attention_type\n        if attention_type is not None:\n            if \"update_time\" in attention_type:\n                self.time_attn = TimeAttnBlock(dim=256, num_heads=8)\n\n            if \"update_space\" in attention_type:\n                self.space_attn = SpaceAttnBlock(dim=256, num_heads=8)\n\n        self.encoder = BasicMotionEncoder(cor_planes)\n        self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)\n        self.flow_head = FlowHead(hidden_dim, hidden_dim=256)\n\n        self.mask = nn.Sequential(\n            nn.Conv2d(128, 256, 3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(256, mask_size ** 2 * 9, 1, padding=0),\n        )\n\n    def forward(self, net, inp, corr, flow, upsample=True, t=1):\n        motion_features = self.encoder(flow, corr)\n        inp = torch.cat((inp, motion_features), dim=1)\n        if self.attention_type is not None:\n            if \"update_time\" in self.attention_type:\n                inp = self.time_attn(inp, T=t)\n            if \"update_space\" in self.attention_type:\n                inp = self.space_attn(inp, T=t)\n        net = self.gru(net, inp)\n        delta_flow = self.flow_head(net)\n\n        # scale mask to balence gradients\n        mask = 0.25 * self.mask(net)\n        return net, mask, delta_flow\n\n\nclass Attention(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.proj = nn.Linear(dim, dim)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n        q, k, v = qkv, qkv, qkv\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n\n        attn = attn.softmax(dim=-1)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C).contiguous()\n        x = self.proj(x)\n        return x\n\n\nclass TimeAttnBlock(nn.Module):\n    def __init__(self, dim=256, num_heads=8):\n        super(TimeAttnBlock, self).__init__()\n        self.temporal_attn = Attention(dim, num_heads=8, qkv_bias=False, qk_scale=None)\n        self.temporal_fc = nn.Linear(dim, dim)\n        self.temporal_norm1 = nn.LayerNorm(dim)\n\n        nn.init.constant_(self.temporal_fc.weight, 0)\n        nn.init.constant_(self.temporal_fc.bias, 0)\n\n    def forward(self, x, T=1):\n        _, _, h, w = x.shape\n\n        x = rearrange(x, \"(b t) m h w -> (b h w) t m\", h=h, w=w, t=T)\n        res_temporal1 = self.temporal_attn(self.temporal_norm1(x))\n        res_temporal1 = rearrange(\n            res_temporal1, \"(b h w) t m -> b (h w t) m\", h=h, w=w, t=T\n        )\n        res_temporal1 = self.temporal_fc(res_temporal1)\n        res_temporal1 = rearrange(\n            res_temporal1, \" b (h w t) m -> b t m h w\", h=h, w=w, t=T\n        )\n        x = rearrange(x, \"(b h w) t m -> b t m h w\", h=h, w=w, t=T)\n        x = x + res_temporal1\n        x = rearrange(x, \"b t m h w -> (b t) m h w\", h=h, w=w, t=T)\n        return x\n\n\nclass SpaceAttnBlock(nn.Module):\n    def __init__(self, dim=256, num_heads=8):\n        super(SpaceAttnBlock, self).__init__()\n        self.encoder_layer = LoFTREncoderLayer(dim, nhead=num_heads, attention=\"linear\")\n\n    def forward(self, x, T=1):\n        _, _, h, w = x.shape\n        x = rearrange(x, \"(b t) m h w -> (b t) (h w) m\", h=h, w=w, t=T)\n        x = self.encoder_layer(x, x)\n        x = rearrange(x, \"(b t) (h w) m -> (b t) m h w\", h=h, w=w, t=T)\n        return x\n\n\nclass SequenceUpdateBlock3D(nn.Module):\n    def __init__(self, hidden_dim, cor_planes, mask_size=8):\n        super(SequenceUpdateBlock3D, self).__init__()\n\n        self.encoder = BasicMotionEncoder(cor_planes)\n        self.gru = SKSepConvGRU3D(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)\n        self.flow_head = FlowHead3D(hidden_dim, hidden_dim=256)\n        self.mask3d = nn.Sequential(\n            nn.Conv3d(hidden_dim, hidden_dim + 128, 3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv3d(hidden_dim + 128, (mask_size ** 2) * (3 * 3 * 3), 1, padding=0),\n        )\n        self.time_attn = TimeAttnBlock(dim=256, num_heads=8)\n        self.space_attn = SpaceAttnBlock(dim=256, num_heads=8)\n\n    def forward(self, net, inp, corrs, flows, t):\n        motion_features = self.encoder(flows, corrs)\n        inp_tensor = torch.cat([inp, motion_features], dim=1)\n\n        inp_tensor = self.time_attn(inp_tensor, T=t)\n        inp_tensor = self.space_attn(inp_tensor, T=t)\n\n        net = rearrange(net, \"(b t) c h w -> b c t h w\", t=t)\n        inp_tensor = rearrange(inp_tensor, \"(b t) c h w -> b c t h w\", t=t)\n\n        net = self.gru(net, inp_tensor)\n\n        delta_flow = self.flow_head(net)\n\n        # scale mask to balance gradients\n        mask = 0.25 * self.mask3d(net)\n        net = rearrange(net, \" b c t h w -> (b t) c h w\")\n        mask =rearrange(mask, \" b c t h w -> (b t) c h w\")\n        delta_flow = rearrange(delta_flow, \" b c t h w -> (b t) c h w\")\n        return net, mask, delta_flow"
  },
  {
    "path": "models/core/utils/config.py",
    "content": "import dataclasses\nimport inspect\nimport itertools\nimport sys\nimport warnings\nfrom collections import Counter, defaultdict\nfrom enum import Enum\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union\n\nfrom omegaconf import DictConfig, OmegaConf, open_dict\nfrom pytorch3d.common.datatypes import get_args, get_origin\n\n\n\"\"\"\nThis functionality allows a configurable system to be determined in a dataclass-type\nway. It is a generalization of omegaconf's \"structured\", in the dataclass case.\nCore functionality:\n\n- Configurable -- A base class used to label a class as being one which uses this\n                    system. Uses class members and __post_init__ like a dataclass.\n\n- expand_args_fields -- Expands a class like `dataclasses.dataclass`. Runs automatically.\n\n- get_default_args -- gets an omegaconf.DictConfig for initializing a given class.\n\n- run_auto_creation -- Initialises nested members. To be called in __post_init__.\n\n\nIn addition, a Configurable may contain members whose type is decided at runtime.\n\n- ReplaceableBase -- As a base instead of Configurable, labels a class to say that\n                     any child class can be used instead.\n\n- registry -- A global store of named child classes of  ReplaceableBase classes.\n              Used as `@registry.register` decorator on class definition.\n\n\nAdditional utility functions:\n\n- remove_unused_components -- used for simplifying a DictConfig instance.\n- get_default_args_field -- default for DictConfig member of another configurable.\n- enable_get_default_args -- Allows get_default_args on a function or plain class.\n\n\n1. The simplest usage of this functionality is as follows. First a schema is defined\nin dataclass style.\n\n    class A(Configurable):\n        n: int = 9\n\n    class B(Configurable):\n        a: A\n\n        def __post_init__(self):\n            run_auto_creation(self)\n\nThen it can be used like\n\n    b_args = get_default_args(B)\n    b = B(**b_args)\n\nIn this case, get_default_args(B) returns an omegaconf.DictConfig with the right\nmembers {\"a_args\": {\"n\": 9}}. It also modifies the definitions of the classes to\nsomething like the following. (The modification itself is done by the function\n`expand_args_fields`, which is called inside `get_default_args`.)\n\n    @dataclasses.dataclass\n    class A:\n        n: int = 9\n\n    @dataclasses.dataclass\n    class B:\n        a_args: DictConfig = dataclasses.field(default_factory=lambda: DictConfig({\"n\": 9}))\n\n        def __post_init__(self):\n            self.a = A(**self.a_args)\n\n2. Pluggability. Instead of a dataclass-style member being given a concrete class,\nit can be given a base class and the implementation will be looked up by name in the\nglobal `registry` in this module. E.g.\n\n    class A(ReplaceableBase):\n        k: int = 1\n\n    @registry.register\n    class A1(A):\n        m: int = 3\n\n    @registry.register\n    class A2(A):\n        n: str = \"2\"\n\n    class B(Configurable):\n        a: A\n        a_class_type: str = \"A2\"\n        b: Optional[A]\n        b_class_type: Optional[str] = \"A2\"\n\n        def __post_init__(self):\n            run_auto_creation(self)\n\nwill expand to\n\n    @dataclasses.dataclass\n    class A:\n        k: int = 1\n\n    @dataclasses.dataclass\n    class A1(A):\n        m: int = 3\n\n    @dataclasses.dataclass\n    class A2(A):\n        n: str = \"2\"\n\n    @dataclasses.dataclass\n    class B:\n        a_class_type: str = \"A2\"\n        a_A1_args: DictConfig = dataclasses.field(\n            default_factory=lambda: DictConfig({\"k\": 1, \"m\": 3}\n        )\n        a_A2_args: DictConfig = dataclasses.field(\n            default_factory=lambda: DictConfig({\"k\": 1, \"n\": 2}\n        )\n        b_class_type: Optional[str] = \"A2\"\n        b_A1_args: DictConfig = dataclasses.field(\n            default_factory=lambda: DictConfig({\"k\": 1, \"m\": 3}\n        )\n        b_A2_args: DictConfig = dataclasses.field(\n            default_factory=lambda: DictConfig({\"k\": 1, \"n\": 2}\n        )\n\n        def __post_init__(self):\n            if self.a_class_type == \"A1\":\n                self.a = A1(**self.a_A1_args)\n            elif self.a_class_type == \"A2\":\n                self.a = A2(**self.a_A2_args)\n            else:\n                raise ValueError(...)\n\n            if self.b_class_type is None:\n                self.b = None\n            elif self.b_class_type == \"A1\":\n                self.b = A1(**self.b_A1_args)\n            elif self.b_class_type == \"A2\":\n                self.b = A2(**self.b_A2_args)\n            else:\n                raise ValueError(...)\n\n3. Aside from these classes, the members of these classes should be things\nwhich DictConfig is happy with: e.g. (bool, int, str, None, float) and what\ncan be built from them with `DictConfig`s and lists of them.\n\nIn addition, you can call `get_default_args` on a function or class to get\nthe `DictConfig` of its defaulted arguments, assuming those are all things\nwhich `DictConfig` is happy with, so long as you add a call to\n`enable_get_default_args` after its definition. If you want to use such a\nthing as the default for a member of another configured class,\n`get_default_args_field` is a helper.\n\"\"\"\n\n\n_unprocessed_warning: str = (\n    \" must be processed before it can be used.\"\n    + \" This is done by calling expand_args_fields \"\n    + \"or get_default_args on it.\"\n)\n\nTYPE_SUFFIX: str = \"_class_type\"\nARGS_SUFFIX: str = \"_args\"\nENABLED_SUFFIX: str = \"_enabled\"\n\n\nclass ReplaceableBase:\n    \"\"\"\n    Base class for dataclass-style classes which\n    can be stored in the registry.\n    \"\"\"\n\n    def __new__(cls, *args, **kwargs):\n        \"\"\"\n        This function only exists to raise a\n        warning if class construction is attempted\n        without processing.\n        \"\"\"\n        obj = super().__new__(cls)\n        if cls is not ReplaceableBase and not _is_actually_dataclass(cls):\n            warnings.warn(cls.__name__ + _unprocessed_warning)\n        return obj\n\n\nclass Configurable:\n    \"\"\"\n    This indicates a class which is not ReplaceableBase\n    but still needs to be\n    expanded into a dataclass with expand_args_fields.\n    This expansion is delayed.\n    \"\"\"\n\n    def __new__(cls, *args, **kwargs):\n        \"\"\"\n        This function only exists to raise a\n        warning if class construction is attempted\n        without processing.\n        \"\"\"\n        obj = super().__new__(cls)\n        if cls is not Configurable and not _is_actually_dataclass(cls):\n            warnings.warn(cls.__name__ + _unprocessed_warning)\n        return obj\n\n\n_X = TypeVar(\"X\", bound=ReplaceableBase)\n\n\nclass _Registry:\n    \"\"\"\n    Register from names to classes. In particular, we say that direct subclasses of\n    ReplaceableBase are \"base classes\" and we register subclasses of each base class\n    in a separate namespace.\n    \"\"\"\n\n    def __init__(self) -> None:\n        self._mapping: Dict[\n            Type[ReplaceableBase], Dict[str, Type[ReplaceableBase]]\n        ] = defaultdict(dict)\n\n    def register(self, some_class: Type[_X]) -> Type[_X]:\n        \"\"\"\n        A class decorator, to register a class in self.\n        \"\"\"\n        name = some_class.__name__\n        self._register(some_class, name=name)\n        return some_class\n\n    def _register(\n        self,\n        some_class: Type[ReplaceableBase],\n        *,\n        base_class: Optional[Type[ReplaceableBase]] = None,\n        name: str,\n    ) -> None:\n        \"\"\"\n        Register a new member.\n\n        Args:\n            cls: the new member\n            base_class: (optional) what the new member is a type for\n            name: name for the new member\n        \"\"\"\n        if base_class is None:\n            base_class = self._base_class_from_class(some_class)\n            if base_class is None:\n                raise ValueError(\n                    f\"Cannot register {some_class}. Cannot tell what it is.\"\n                )\n        if some_class is base_class:\n            raise ValueError(f\"Attempted to register the base class {some_class}\")\n        self._mapping[base_class][name] = some_class\n\n    def get(\n        self, base_class_wanted: Type[ReplaceableBase], name: str\n    ) -> Type[ReplaceableBase]:\n        \"\"\"\n        Retrieve a class from the registry by name\n\n        Args:\n            base_class_wanted: parent type of type we are looking for.\n                        It determines the namespace.\n                        This will typically be a direct subclass of ReplaceableBase.\n            name: what to look for\n\n        Returns:\n            class type\n        \"\"\"\n        if self._is_base_class(base_class_wanted):\n            base_class = base_class_wanted\n        else:\n            base_class = self._base_class_from_class(base_class_wanted)\n            if base_class is None:\n                raise ValueError(\n                    f\"Cannot look up {base_class_wanted}. Cannot tell what it is.\"\n                )\n        result = self._mapping[base_class].get(name)\n        if result is None:\n            raise ValueError(f\"{name} has not been registered.\")\n        if not issubclass(result, base_class_wanted):\n            raise ValueError(\n                f\"{name} resolves to {result} which does not subclass {base_class_wanted}\"\n            )\n        return result\n\n    def get_all(\n        self, base_class_wanted: Type[ReplaceableBase]\n    ) -> List[Type[ReplaceableBase]]:\n        \"\"\"\n        Retrieve all registered implementations from the registry\n\n        Args:\n            base_class_wanted: parent type of type we are looking for.\n                        It determines the namespace.\n                        This will typically be a direct subclass of ReplaceableBase.\n        Returns:\n            list of class types\n        \"\"\"\n        if self._is_base_class(base_class_wanted):\n            return list(self._mapping[base_class_wanted].values())\n\n        base_class = self._base_class_from_class(base_class_wanted)\n        if base_class is None:\n            raise ValueError(\n                f\"Cannot look up {base_class_wanted}. Cannot tell what it is.\"\n            )\n        return [\n            class_\n            for class_ in self._mapping[base_class].values()\n            if issubclass(class_, base_class_wanted) and class_ is not base_class_wanted\n        ]\n\n    @staticmethod\n    def _is_base_class(some_class: Type[ReplaceableBase]) -> bool:\n        \"\"\"\n        Return whether the given type is a direct subclass of ReplaceableBase\n        and so gets used as a namespace.\n        \"\"\"\n        return ReplaceableBase in some_class.__bases__\n\n    @staticmethod\n    def _base_class_from_class(\n        some_class: Type[ReplaceableBase],\n    ) -> Optional[Type[ReplaceableBase]]:\n        \"\"\"\n        Find the parent class of some_class which inherits ReplaceableBase, or None\n        \"\"\"\n        for base in some_class.mro()[-3::-1]:\n            if base is not ReplaceableBase and issubclass(base, ReplaceableBase):\n                return base\n        return None\n\n\n# Global instance of the registry\nregistry = _Registry()\n\n\nclass _ProcessType(Enum):\n    \"\"\"\n    Type of member which gets rewritten by expand_args_fields.\n    \"\"\"\n\n    CONFIGURABLE = 1\n    REPLACEABLE = 2\n    OPTIONAL_CONFIGURABLE = 3\n    OPTIONAL_REPLACEABLE = 4\n\n\ndef _default_create(\n    name: str, type_: Type, process_type: _ProcessType\n) -> Callable[[Any], None]:\n    \"\"\"\n    Return the default creation function for a member. This is a function which\n    could be called in __post_init__ to initialise the member, and will be called\n    from run_auto_creation.\n\n    Args:\n        name: name of the member\n        type_: type of the member (with any Optional removed)\n        process_type: Shows whether member's declared type inherits ReplaceableBase,\n                    in which case the actual type to be created is decided at\n                    runtime.\n\n    Returns:\n        Function taking one argument, the object whose member should be\n            initialized.\n    \"\"\"\n\n    def inner(self):\n        expand_args_fields(type_)\n        args = getattr(self, name + ARGS_SUFFIX)\n        setattr(self, name, type_(**args))\n\n    def inner_optional(self):\n        expand_args_fields(type_)\n        enabled = getattr(self, name + ENABLED_SUFFIX)\n        if enabled:\n            args = getattr(self, name + ARGS_SUFFIX)\n            setattr(self, name, type_(**args))\n        else:\n            setattr(self, name, None)\n\n    def inner_pluggable(self):\n        type_name = getattr(self, name + TYPE_SUFFIX)\n        if type_name is None:\n            setattr(self, name, None)\n            return\n\n        chosen_class = registry.get(type_, type_name)\n        if self._known_implementations.get(type_name, chosen_class) is not chosen_class:\n            # If this warning is raised, it means that a new definition of\n            # the chosen class has been registered since our class was processed\n            # (i.e. expanded). A DictConfig which comes from our get_default_args\n            # (which might have triggered the processing) will contain the old default\n            # values for the members of the chosen class. Changes to those defaults which\n            # were made in the redefinition will not be reflected here.\n            warnings.warn(f\"New implementation of {type_name} is being chosen.\")\n        expand_args_fields(chosen_class)\n        args = getattr(self, f\"{name}_{type_name}{ARGS_SUFFIX}\")\n        setattr(self, name, chosen_class(**args))\n\n    if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:\n        return inner_optional\n    return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable\n\n\ndef run_auto_creation(self: Any) -> None:\n    \"\"\"\n    Run all the functions named in self._creation_functions.\n    \"\"\"\n    for create_function in self._creation_functions:\n        getattr(self, create_function)()\n\n\ndef _is_configurable_class(C) -> bool:\n    return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase))\n\n\ndef get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig:\n    \"\"\"\n    Get the DictConfig corresponding to the defaults in a dataclass or\n    configurable. Normal use is to provide a dataclass can be provided as C.\n    If enable_get_default_args has been called on a function or plain class,\n    then that function or class can be provided as C.\n\n    If C is a subclass of Configurable or ReplaceableBase, we make sure\n    it has been processed with expand_args_fields.\n\n    Args:\n        C: the class or function to be processed\n        _do_not_process: (internal use) When this function is called from\n                    expand_args_fields, we specify any class currently being\n                    processed, to make sure we don't try to process a class\n                    while it is already being processed.\n\n    Returns:\n        new DictConfig object, which is typed.\n    \"\"\"\n    if C is None:\n        return DictConfig({})\n\n    if _is_configurable_class(C):\n        if C in _do_not_process:\n            raise ValueError(\n                f\"Internal recursion error. Need processed {C},\"\n                f\" but cannot get it. _do_not_process={_do_not_process}\"\n            )\n        # This is safe to run multiple times. It will return\n        # straight away if C has already been processed.\n        expand_args_fields(C, _do_not_process=_do_not_process)\n\n    if dataclasses.is_dataclass(C):\n        # Note that if get_default_args_field is used somewhere in C,\n        # this call is recursive. No special care is needed,\n        # because in practice get_default_args_field is used for\n        # separate types than the outer type.\n\n        out: DictConfig = OmegaConf.structured(C)\n        exclude = getattr(C, \"_processed_members\", ())\n        with open_dict(out):\n            for field in exclude:\n                out.pop(field, None)\n        return out\n\n    if _is_configurable_class(C):\n        raise ValueError(f\"Failed to process {C}\")\n\n    if not inspect.isfunction(C) and not inspect.isclass(C):\n        raise ValueError(f\"Unexpected {C}\")\n\n    dataclass_name = _dataclass_name_for_function(C)\n    dataclass = getattr(sys.modules[C.__module__], dataclass_name, None)\n    if dataclass is None:\n        raise ValueError(\n            f\"Cannot get args for {C}. Was enable_get_default_args forgotten?\"\n        )\n\n    return OmegaConf.structured(dataclass)\n\n\ndef _dataclass_name_for_function(C: Any) -> str:\n    \"\"\"\n    Returns the name of the dataclass which enable_get_default_args(C)\n    creates.\n    \"\"\"\n    name = f\"_{C.__name__}_default_args_\"\n    return name\n\n\ndef enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:\n    \"\"\"\n    If C is a function or a plain class with an __init__ function,\n    and you want get_default_args(C) to work, then add\n    `enable_get_default_args(C)` straight after the definition of C.\n    This makes a dataclass corresponding to the default arguments of C\n    and stores it in the same module as C.\n\n    Args:\n        C: a function, or a class with an __init__ function. Must\n            have types for all its defaulted args.\n        overwrite: whether to allow calling this a second time on\n            the same function.\n    \"\"\"\n    if not inspect.isfunction(C) and not inspect.isclass(C):\n        raise ValueError(f\"Unexpected {C}\")\n\n    field_annotations = []\n    for pname, defval in _params_iter(C):\n        default = defval.default\n        if default == inspect.Parameter.empty:\n            # we do not have a default value for the parameter\n            continue\n\n        if defval.annotation == inspect._empty:\n            raise ValueError(\n                \"All arguments of the input callable have to be typed.\"\n                + f\" Argument '{pname}' does not have a type annotation.\"\n            )\n\n        _, annotation = _resolve_optional(defval.annotation)\n\n        if isinstance(default, set):  # force OmegaConf to convert it to ListConfig\n            default = tuple(default)\n\n        if isinstance(default, (list, dict)):\n            # OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value\n            field_ = dataclasses.field(default_factory=lambda default=default: default)\n        elif not _is_immutable_type(annotation, default):\n            continue\n        else:\n            # we can use a simple default argument for dataclass.field\n            field_ = dataclasses.field(default=default)\n        field_annotations.append((pname, defval.annotation, field_))\n\n    name = _dataclass_name_for_function(C)\n    module = sys.modules[C.__module__]\n    if hasattr(module, name):\n        if overwrite:\n            warnings.warn(f\"Overwriting {name} in {C.__module__}.\")\n        else:\n            raise ValueError(f\"Cannot overwrite {name} in {C.__module__}.\")\n    dc = dataclasses.make_dataclass(name, field_annotations)\n    dc.__module__ = C.__module__\n    setattr(module, name, dc)\n\n\ndef _params_iter(C):\n    \"\"\"Returns dict of keyword args of a class or function C.\"\"\"\n    if inspect.isclass(C):\n        return itertools.islice(  # exclude `self`\n            inspect.signature(C.__init__).parameters.items(), 1, None\n        )\n\n    return inspect.signature(C).parameters.items()\n\n\ndef _is_immutable_type(type_: Type, val: Any) -> bool:\n    PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple)\n    # sometimes type can be too relaxed (e.g. Any), so we also check values\n    if isinstance(val, PRIMITIVE_TYPES):\n        return True\n\n    return type_ in PRIMITIVE_TYPES or (\n        inspect.isclass(type_) and issubclass(type_, Enum)\n    )\n\n\n# copied from OmegaConf\ndef _resolve_optional(type_: Any) -> Tuple[bool, Any]:\n    \"\"\"Check whether `type_` is equivalent to `typing.Optional[T]` for some T.\"\"\"\n    if get_origin(type_) is Union:\n        args = get_args(type_)\n        if len(args) == 2 and args[1] == type(None):  # noqa E721\n            return True, args[0]\n    if type_ is Any:\n        return True, Any\n\n    return False, type_\n\n\ndef _is_actually_dataclass(some_class) -> bool:\n    # Return whether the class some_class has been processed with\n    # the dataclass annotation. This is more specific than\n    # dataclasses.is_dataclass which returns True on anything\n    # deriving from a dataclass.\n\n    # Checking for __init__ would also work for our purpose.\n    return \"__dataclass_fields__\" in some_class.__dict__\n\n\ndef expand_args_fields(\n    some_class: Type[_X], *, _do_not_process: Tuple[type, ...] = ()\n) -> Type[_X]:\n    \"\"\"\n    This expands a class which inherits Configurable or ReplaceableBase classes,\n    including dataclass processing. some_class is modified in place by this function.\n    For classes of type ReplaceableBase, you can add some_class to the registry before\n    or after calling this function. But potential inner classes need to be registered\n    before this function is run on the outer class.\n\n    The transformations this function makes, before the concluding\n    dataclasses.dataclass, are as follows.  if X is a base class with registered\n    subclasses Y and Z, replace a class member\n\n        x: X\n\n    and optionally\n\n        x_class_type: str = \"Y\"\n        def create_x(self):...\n\n    with\n\n        x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))\n        x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))\n        def create_x(self):\n            self.x = registry.get(X, self.x_class_type)(\n                **self.getattr(f\"x_{self.x_class_type}_args)\n            )\n        x_class_type: str = \"UNDEFAULTED\"\n\n    without adding the optional attributes if they are already there.\n\n    Similarly, replace\n\n        x: Optional[X]\n\n    and optionally\n\n        x_class_type: Optional[str] = \"Y\"\n        def create_x(self):...\n\n    with\n\n        x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))\n        x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))\n        def create_x(self):\n            if self.x_class_type is None:\n                self.x = None\n                return\n\n            self.x = registry.get(X, self.x_class_type)(\n                **self.getattr(f\"x_{self.x_class_type}_args)\n            )\n        x_class_type: Optional[str] = \"UNDEFAULTED\"\n\n    without adding the optional attributes if they are already there.\n\n    Similarly, if X is a subclass of Configurable,\n\n        x: X\n\n    and optionally\n\n        def create_x(self):...\n\n    will be replaced with\n\n        x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))\n        def create_x(self):\n            self.x = X(self.x_args)\n\n    Similarly, replace,\n\n        x: Optional[X]\n\n    and optionally\n\n        def create_x(self):...\n        x_enabled: bool = ...\n\n    with\n\n        x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))\n        x_enabled: bool = False\n        def create_x(self):\n            if self.x_enabled:\n                self.x = X(self.x_args)\n            else:\n                self.x = None\n\n\n    Also adds the following class members, unannotated so that dataclass\n    ignores them.\n        - _creation_functions: Tuple[str] of all the create_ functions,\n            including those from base classes.\n        - _known_implementations: Dict[str, Type] containing the classes which\n            have been found from the registry.\n            (used only to raise a warning if it one has been overwritten)\n        - _processed_members: a Dict[str, Any] of all the members which have been\n            transformed, with values giving the types they were declared to have.\n            (E.g. {\"x\": X} or {\"x\": Optional[X]} in the cases above.)\n\n    Args:\n        some_class: the class to be processed\n        _do_not_process: Internal use for get_default_args: Because get_default_args calls\n                        and is called by this function, we let it specify any class currently\n                        being processed, to make sure we don't try to process a class while\n                        it is already being processed.\n\n\n    Returns:\n        some_class itself, which has been modified in place. This\n        allows this function to be used as a class decorator.\n    \"\"\"\n    if _is_actually_dataclass(some_class):\n        return some_class\n\n    # The functions this class's run_auto_creation will run.\n    creation_functions: List[str] = []\n    # The classes which this type knows about from the registry\n    # We could use a weakref.WeakValueDictionary here which would mean\n    # that we don't warn if the class we should have expected is elsewhere\n    # unused.\n    known_implementations: Dict[str, Type] = {}\n    # Names of members which have been processed.\n    processed_members: Dict[str, Any] = {}\n\n    # For all bases except ReplaceableBase and Configurable and object,\n    # we need to process them before our own processing. This is\n    # because dataclasses expect to inherit dataclasses and not unprocessed\n    # dataclasses.\n    for base in some_class.mro()[-3:0:-1]:\n        if base is ReplaceableBase:\n            continue\n        if base is Configurable:\n            continue\n        if not issubclass(base, (Configurable, ReplaceableBase)):\n            continue\n        expand_args_fields(base, _do_not_process=_do_not_process)\n        if \"_creation_functions\" in base.__dict__:\n            creation_functions.extend(base._creation_functions)\n        if \"_known_implementations\" in base.__dict__:\n            known_implementations.update(base._known_implementations)\n        if \"_processed_members\" in base.__dict__:\n            processed_members.update(base._processed_members)\n\n    to_process: List[Tuple[str, Type, _ProcessType]] = []\n    if \"__annotations__\" in some_class.__dict__:\n        for name, type_ in some_class.__annotations__.items():\n            underlying_and_process_type = _get_type_to_process(type_)\n            if underlying_and_process_type is None:\n                continue\n            underlying_type, process_type = underlying_and_process_type\n            to_process.append((name, underlying_type, process_type))\n\n    for name, underlying_type, process_type in to_process:\n        processed_members[name] = some_class.__annotations__[name]\n        _process_member(\n            name=name,\n            type_=underlying_type,\n            process_type=process_type,\n            some_class=some_class,\n            creation_functions=creation_functions,\n            _do_not_process=_do_not_process,\n            known_implementations=known_implementations,\n        )\n\n    for key, count in Counter(creation_functions).items():\n        if count > 1:\n            warnings.warn(f\"Clash with {key} in a base class.\")\n    some_class._creation_functions = tuple(creation_functions)\n    some_class._processed_members = processed_members\n    some_class._known_implementations = known_implementations\n\n    dataclasses.dataclass(eq=False)(some_class)\n    return some_class\n\n\ndef get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):\n    \"\"\"\n    Get a dataclass field which defaults to get_default_args(...)\n\n    Args:\n        As for get_default_args.\n\n    Returns:\n        function to return new DictConfig object\n    \"\"\"\n\n    def create():\n        return get_default_args(C, _do_not_process=_do_not_process)\n\n    return dataclasses.field(default_factory=create)\n\n\ndef _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:\n    \"\"\"\n    If a member is annotated as `type_`, and that should expanded in\n    expand_args_fields, return how it should be expanded.\n    \"\"\"\n    if get_origin(type_) == Union:\n        # We look for Optional[X] which is a Union of X with None.\n        args = get_args(type_)\n        if len(args) != 2 or all(a is not type(None) for a in args):  # noqa: E721\n            return\n        underlying = args[0] if args[1] is type(None) else args[1]  # noqa: E721\n        if (\n            isinstance(underlying, type)\n            and issubclass(underlying, ReplaceableBase)\n            and ReplaceableBase in underlying.__bases__\n        ):\n            return underlying, _ProcessType.OPTIONAL_REPLACEABLE\n\n        if isinstance(underlying, type) and issubclass(underlying, Configurable):\n            return underlying, _ProcessType.OPTIONAL_CONFIGURABLE\n\n    if not isinstance(type_, type):\n        # e.g. any other Union or Tuple\n        return\n\n    if issubclass(type_, ReplaceableBase) and ReplaceableBase in type_.__bases__:\n        return type_, _ProcessType.REPLACEABLE\n\n    if issubclass(type_, Configurable):\n        return type_, _ProcessType.CONFIGURABLE\n\n\ndef _process_member(\n    *,\n    name: str,\n    type_: Type,\n    process_type: _ProcessType,\n    some_class: Type,\n    creation_functions: List[str],\n    _do_not_process: Tuple[type, ...],\n    known_implementations: Dict[str, Type],\n) -> None:\n    \"\"\"\n    Make the modification (of expand_args_fields) to some_class for a single member.\n\n    Args:\n        name: member name\n        type_: member type (with Optional removed if needed)\n        process_type: whether member has dynamic type\n        some_class: (MODIFIED IN PLACE) the class being processed\n        creation_functions: (MODIFIED IN PLACE) the names of the create functions\n        _do_not_process: as for expand_args_fields.\n        known_implementations: (MODIFIED IN PLACE) known types from the registry\n    \"\"\"\n    # Because we are adding defaultable members, make\n    # sure they go at the end of __annotations__ in case\n    # there are non-defaulted standard class members.\n    del some_class.__annotations__[name]\n\n    if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE):\n        type_name = name + TYPE_SUFFIX\n        if type_name not in some_class.__annotations__:\n            if process_type == _ProcessType.OPTIONAL_REPLACEABLE:\n                some_class.__annotations__[type_name] = Optional[str]\n            else:\n                some_class.__annotations__[type_name] = str\n            setattr(some_class, type_name, \"UNDEFAULTED\")\n\n        for derived_type in registry.get_all(type_):\n            if derived_type in _do_not_process:\n                continue\n            if issubclass(derived_type, some_class):\n                # When derived_type is some_class we have a simple\n                # recursion to avoid. When it's a strict subclass the\n                # situation is even worse.\n                continue\n            known_implementations[derived_type.__name__] = derived_type\n            args_name = f\"{name}_{derived_type.__name__}{ARGS_SUFFIX}\"\n            if args_name in some_class.__annotations__:\n                raise ValueError(\n                    f\"Cannot generate {args_name} because it is already present.\"\n                )\n            some_class.__annotations__[args_name] = DictConfig\n            setattr(\n                some_class,\n                args_name,\n                get_default_args_field(\n                    derived_type, _do_not_process=_do_not_process + (some_class,)\n                ),\n            )\n    else:\n        args_name = name + ARGS_SUFFIX\n        if args_name in some_class.__annotations__:\n            raise ValueError(\n                f\"Cannot generate {args_name} because it is already present.\"\n            )\n        if issubclass(type_, some_class) or type_ in _do_not_process:\n            raise ValueError(f\"Cannot process {type_} inside {some_class}\")\n\n        some_class.__annotations__[args_name] = DictConfig\n        setattr(\n            some_class,\n            args_name,\n            get_default_args_field(\n                type_,\n                _do_not_process=_do_not_process + (some_class,),\n            ),\n        )\n        if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:\n            enabled_name = name + ENABLED_SUFFIX\n            if enabled_name not in some_class.__annotations__:\n                some_class.__annotations__[enabled_name] = bool\n                setattr(some_class, enabled_name, False)\n\n    creation_function_name = f\"create_{name}\"\n    if not hasattr(some_class, creation_function_name):\n        setattr(\n            some_class,\n            creation_function_name,\n            _default_create(name, type_, process_type),\n        )\n    creation_functions.append(creation_function_name)\n\n\ndef remove_unused_components(dict_: DictConfig) -> None:\n    \"\"\"\n    Assuming dict_ represents the state of a configurable,\n    modify it to remove all the portions corresponding to\n    pluggable parts which are not in use.\n    For example, if renderer_class_type is SignedDistanceFunctionRenderer,\n    the renderer_MultiPassEmissionAbsorptionRenderer_args will be\n    removed. Also, if chocolate_enabled is False, then chocolate_args will\n    be removed.\n\n    Args:\n        dict_: (MODIFIED IN PLACE) a DictConfig instance\n    \"\"\"\n    keys = [key for key in dict_ if isinstance(key, str)]\n    suffix_length = len(TYPE_SUFFIX)\n    replaceables = [key[:-suffix_length] for key in keys if key.endswith(TYPE_SUFFIX)]\n    args_keys = [key for key in keys if key.endswith(ARGS_SUFFIX)]\n    for replaceable in replaceables:\n        selected_type = dict_[replaceable + TYPE_SUFFIX]\n        if selected_type is None:\n            expect = \"\"\n        else:\n            expect = replaceable + \"_\" + selected_type + ARGS_SUFFIX\n        with open_dict(dict_):\n            for key in args_keys:\n                if key.startswith(replaceable + \"_\") and key != expect:\n                    del dict_[key]\n\n    suffix_length = len(ENABLED_SUFFIX)\n    enableables = [key[:-suffix_length] for key in keys if key.endswith(ENABLED_SUFFIX)]\n    for enableable in enableables:\n        enabled = dict_[enableable + ENABLED_SUFFIX]\n        if not enabled:\n            with open_dict(dict_):\n                dict_.pop(enableable + ARGS_SUFFIX, None)\n\n    for key in dict_:\n        if isinstance(dict_.get(key), DictConfig):\n            remove_unused_components(dict_[key])\n"
  },
  {
    "path": "models/core/utils/utils.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom scipy import interpolate\n\n\ndef interp(tensor, size):\n    return F.interpolate(\n        tensor,\n        size=size,\n        mode=\"bilinear\",\n        align_corners=True,\n    )\n\n\nclass InputPadder:\n    \"\"\"Pads images such that dimensions are divisible by 8\"\"\"\n\n    def __init__(self, dims, mode=\"sintel\", divis_by=8):\n        self.ht, self.wd = dims[-2:]\n        pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by\n        pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by\n        if mode == \"sintel\":\n            self._pad = [\n                pad_wd // 2,\n                pad_wd - pad_wd // 2,\n                pad_ht // 2,\n                pad_ht - pad_ht // 2,\n            ]\n        else:\n            self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]\n\n    def pad(self, *inputs):\n        assert all((x.ndim == 4) for x in inputs)\n        return [F.pad(x, self._pad, mode=\"replicate\") for x in inputs]\n\n    def unpad(self, x):\n        assert x.ndim == 4\n        ht, wd = x.shape[-2:]\n        c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]\n        return x[..., c[0] : c[1], c[2] : c[3]]\n\n\ndef coords_grid(batch, ht, wd):\n    coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))\n    coords = torch.stack(coords[::-1], dim=0).float()\n    return coords[None].repeat(batch, 1, 1, 1)\n\n\ndef upflow8(flow, mode='bilinear'):\n    new_size = (8 * flow.shape[2], 8 * flow.shape[3])\n    return  8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)"
  },
  {
    "path": "models/raft_model.py",
    "content": "from types import SimpleNamespace\nfrom typing import ClassVar\nimport torch.nn.functional as F\n\nfrom pytorch3d.implicitron.tools.config import Configurable\nimport torch\nimport importlib\nimport sys\nimport os\n\nautocast = torch.cuda.amp.autocast\n\nclass RAFTModel(Configurable, torch.nn.Module):\n    MODEL_CONFIG_NAME: ClassVar[str] = \"RAFTModel\"\n\n    def __post_init__(self):\n        super().__init__()\n        thirdparty_raft_path = os.path.abspath(os.path.join(os.path.dirname(__file__), \"../third_party/RAFT\"))\n        sys.path.append(thirdparty_raft_path)\n        raft = importlib.import_module(\n            \"stereoanyvideo.third_party.RAFT.core.raft\"\n        )\n        self.raft_utils = importlib.import_module(\n            \"stereoanyvideo.third_party.RAFT.core.utils.utils\"\n        )\n\n        self.model_weights: str = \"./third_party/RAFT/models/raft-things.pth\"\n\n        model_args = SimpleNamespace(\n            mixed_precision=False,\n            small=False,\n            dropout=0.0,\n        )\n        self.args = model_args\n        self.model = raft.RAFT(model_args).cuda()\n\n        state_dict = torch.load(self.model_weights, map_location=\"cpu\")\n        weight_dict = {}\n        for k,v in state_dict.items():\n            temp_k = k.replace('module.', '') if 'module' in k else k\n            weight_dict[temp_k] = v\n        self.model.load_state_dict(weight_dict, strict=True)\n\n\n    def forward(self, image1, image2, iters=10):\n        left_image_rgb = image1.cuda()\n        right_image_rgb = image2.cuda()\n        padder = self.raft_utils.InputPadder(left_image_rgb.shape)\n        left_image_rgb, right_image_rgb = padder.pad(\n            left_image_rgb, right_image_rgb\n        )\n        with autocast(enabled=self.args.mixed_precision):\n            flow, flow_up = self.model(left_image_rgb, right_image_rgb, iters=iters, test_mode=True)\n\n        flow_up = padder.unpad(flow_up)\n        return 0.25 * F.interpolate(flow_up, size=(flow_up.shape[2] // 4, flow_up.shape[3] // 4), mode=\"bilinear\",\n        align_corners=True)\n\n    def forward_fullres(self, image1, image2, iters=20):\n        left_image_rgb = image1.cuda()\n        right_image_rgb = image2.cuda()\n        padder = self.raft_utils.InputPadder(left_image_rgb.shape)\n        left_image_rgb, right_image_rgb = padder.pad(\n            left_image_rgb, right_image_rgb\n        )\n        with autocast(enabled=self.args.mixed_precision):\n            flow, flow_up = self.model(left_image_rgb.contiguous(), right_image_rgb.contiguous(), iters=iters, test_mode=True)\n\n        flow_up = padder.unpad(flow_up)\n        return flow_up"
  },
  {
    "path": "models/stereoanyvideo_model.py",
    "content": "from typing import ClassVar\n\nimport torch\nimport torch.nn.functional as F\nfrom pytorch3d.implicitron.tools.config import Configurable\nfrom stereoanyvideo.models.core.stereoanyvideo import StereoAnyVideo\n\n\nclass StereoAnyVideoModel(Configurable, torch.nn.Module):\n\n    MODEL_CONFIG_NAME: ClassVar[str] = \"StereoAnyVideoModel\"\n    model_weights: str = \"./checkpoints/StereoAnyVideo_MIX.pth\"\n\n    def __post_init__(self):\n        super().__init__()\n\n        self.mixed_precision = False\n        model = StereoAnyVideo(mixed_precision=self.mixed_precision)\n\n        state_dict = torch.load(self.model_weights, map_location=\"cpu\")\n        if \"model\" in state_dict:\n            state_dict = state_dict[\"model\"]\n        if \"state_dict\" in state_dict:\n            state_dict = state_dict[\"state_dict\"]\n            state_dict = {\"module.\" + k: v for k, v in state_dict.items()}\n        model.load_state_dict(state_dict, strict=True)\n\n        self.model = model\n        self.model.to(\"cuda\")\n        self.model.eval()\n\n    def forward(self, batch_dict, iters=20):\n\n        return self.model.forward_batch_test(batch_dict, iters=iters)"
  },
  {
    "path": "requirements.txt",
    "content": "hydra-core==1.1\nnumpy==1.23.5\nmunch==2.5.0\nomegaconf==2.1.0\nflow_vis==0.1\neinops==0.4.1\nopt_einsum==3.3.0\nrequests\nmoviepy\n"
  },
  {
    "path": "third_party/RAFT/LICENSE",
    "content": "BSD 3-Clause License\n\nCopyright (c) 2020, princeton-vl\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n\n* Redistributions in binary form must reproduce the above copyright notice,\n  this list of conditions and the following disclaimer in the documentation\n  and/or other materials provided with the distribution.\n\n* Neither the name of the copyright holder nor the names of its\n  contributors may be used to endorse or promote products derived from\n  this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "third_party/RAFT/README.md",
    "content": "# RAFT\nThis repository contains the source code for our paper:\n\n[RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)<br/>\nECCV 2020 <br/>\nZachary Teed and Jia Deng<br/>\n\n<img src=\"RAFT.png\">\n\n## Requirements\nThe code has been tested with PyTorch 1.6 and Cuda 10.1.\n```Shell\nconda create --name raft\nconda activate raft\nconda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch\n```\n\n## Demos\nPretrained models can be downloaded by running\n```Shell\n./download_models.sh\n```\nor downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing)\n\nYou can demo a trained model on a sequence of frames\n```Shell\npython demo.py --model=models/raft-things.pth --path=demo-frames\n```\n\n## Required Data\nTo evaluate/train RAFT, you will need to download the required datasets. \n* [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)\n* [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)\n* [Sintel](http://sintel.is.tue.mpg.de/)\n* [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)\n* [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional)\n\n\nBy default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder\n\n```Shell\n├── datasets\n    ├── Sintel\n        ├── test\n        ├── training\n    ├── KITTI\n        ├── testing\n        ├── training\n        ├── devkit\n    ├── FlyingChairs_release\n        ├── data\n    ├── FlyingThings3D\n        ├── frames_cleanpass\n        ├── frames_finalpass\n        ├── optical_flow\n```\n\n## Evaluation\nYou can evaluate a trained model using `evaluate.py`\n```Shell\npython evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision\n```\n\n## Training\nWe used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard\n```Shell\n./train_standard.sh\n```\n\nIf you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)\n```Shell\n./train_mixed.sh\n```\n\n## (Optional) Efficent Implementation\nYou can optionally use our alternate (efficent) implementation by compiling the provided cuda extension\n```Shell\ncd alt_cuda_corr && python setup.py install && cd ..\n```\nand running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.\n"
  },
  {
    "path": "third_party/RAFT/alt_cuda_corr/correlation.cpp",
    "content": "#include <torch/extension.h>\n#include <vector>\n\n// CUDA forward declarations\nstd::vector<torch::Tensor> corr_cuda_forward(\n    torch::Tensor fmap1,\n    torch::Tensor fmap2,\n    torch::Tensor coords,\n    int radius);\n\nstd::vector<torch::Tensor> corr_cuda_backward(\n  torch::Tensor fmap1,\n  torch::Tensor fmap2,\n  torch::Tensor coords,\n  torch::Tensor corr_grad,\n  int radius);\n\n// C++ interface\n#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor> corr_forward(\n    torch::Tensor fmap1,\n    torch::Tensor fmap2,\n    torch::Tensor coords,\n    int radius) {\n  CHECK_INPUT(fmap1);\n  CHECK_INPUT(fmap2);\n  CHECK_INPUT(coords);\n\n  return corr_cuda_forward(fmap1, fmap2, coords, radius);\n}\n\n\nstd::vector<torch::Tensor> corr_backward(\n    torch::Tensor fmap1,\n    torch::Tensor fmap2,\n    torch::Tensor coords,\n    torch::Tensor corr_grad,\n    int radius) {\n  CHECK_INPUT(fmap1);\n  CHECK_INPUT(fmap2);\n  CHECK_INPUT(coords);\n  CHECK_INPUT(corr_grad);\n\n  return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);\n}\n\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &corr_forward, \"CORR forward\");\n  m.def(\"backward\", &corr_backward, \"CORR backward\");\n}"
  },
  {
    "path": "third_party/RAFT/alt_cuda_corr/correlation_kernel.cu",
    "content": "#include <torch/extension.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <vector>\n\n\n#define BLOCK_H 4\n#define BLOCK_W 8\n#define BLOCK_HW BLOCK_H * BLOCK_W\n#define CHANNEL_STRIDE 32\n\n\n__forceinline__ __device__\nbool within_bounds(int h, int w, int H, int W) {\n  return h >= 0 && h < H && w >= 0 && w < W;\n}\n\ntemplate <typename scalar_t>\n__global__ void corr_forward_kernel(\n    const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,\n    const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,\n    const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,\n    torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr,\n    int r)\n{\n  const int b = blockIdx.x;\n  const int h0 = blockIdx.y * blockDim.x;\n  const int w0 = blockIdx.z * blockDim.y;\n  const int tid = threadIdx.x * blockDim.y + threadIdx.y;\n\n  const int H1 = fmap1.size(1);\n  const int W1 = fmap1.size(2);\n  const int H2 = fmap2.size(1);\n  const int W2 = fmap2.size(2);\n  const int N = coords.size(1);\n  const int C = fmap1.size(3);\n\n  __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];\n  __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];\n  __shared__ scalar_t x2s[BLOCK_HW];\n  __shared__ scalar_t y2s[BLOCK_HW];\n\n  for (int c=0; c<C; c+=CHANNEL_STRIDE) {\n    for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {\n      int k1 = k + tid / CHANNEL_STRIDE;\n      int h1 = h0 + k1 / BLOCK_W;\n      int w1 = w0 + k1 % BLOCK_W;\n      int c1 = tid % CHANNEL_STRIDE;\n\n      auto fptr = fmap1[b][h1][w1];\n      if (within_bounds(h1, w1, H1, W1))\n        f1[c1][k1] = fptr[c+c1];\n      else\n        f1[c1][k1] = 0.0;\n    }\n\n    __syncthreads();\n\n    for (int n=0; n<N; n++) {\n      int h1 = h0 + threadIdx.x;\n      int w1 = w0 + threadIdx.y;\n      if (within_bounds(h1, w1, H1, W1)) {\n        x2s[tid] = coords[b][n][h1][w1][0];\n        y2s[tid] = coords[b][n][h1][w1][1];\n      }\n\n      scalar_t dx = x2s[tid] - floor(x2s[tid]);\n      scalar_t dy = y2s[tid] - floor(y2s[tid]);\n\n      int rd = 2*r + 1;\n      for (int iy=0; iy<rd+1; iy++) {\n        for (int ix=0; ix<rd+1; ix++) {\n          for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {\n            int k1 = k + tid / CHANNEL_STRIDE;\n            int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;\n            int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;\n            int c2 = tid % CHANNEL_STRIDE;\n\n            auto fptr = fmap2[b][h2][w2];\n            if (within_bounds(h2, w2, H2, W2))\n              f2[c2][k1] = fptr[c+c2];\n            else\n              f2[c2][k1] = 0.0;\n          }\n\n          __syncthreads();\n      \n          scalar_t s = 0.0;\n          for (int k=0; k<CHANNEL_STRIDE; k++)\n            s += f1[k][tid] * f2[k][tid];\n\n          int ix_nw = H1*W1*((iy-1) + rd*(ix-1));\n          int ix_ne = H1*W1*((iy-1) + rd*ix);\n          int ix_sw = H1*W1*(iy + rd*(ix-1));\n          int ix_se = H1*W1*(iy + rd*ix);\n\n          scalar_t nw = s * (dy) * (dx);\n          scalar_t ne = s * (dy) * (1-dx);\n          scalar_t sw = s * (1-dy) * (dx);\n          scalar_t se = s * (1-dy) * (1-dx);\n\n          scalar_t* corr_ptr = &corr[b][n][0][h1][w1];\n\n          if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))\n            *(corr_ptr + ix_nw) += nw;\n\n          if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))\n            *(corr_ptr + ix_ne) += ne;\n\n          if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))\n            *(corr_ptr + ix_sw) += sw;\n\n          if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))\n            *(corr_ptr + ix_se) += se;\n        }\n      } \n    }\n  }\n}\n\n\ntemplate <typename scalar_t>\n__global__ void corr_backward_kernel(\n    const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,\n    const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,\n    const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,\n    const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr_grad,\n    torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1_grad,\n    torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2_grad,\n    torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords_grad,\n    int r)\n{\n\n  const int b = blockIdx.x;\n  const int h0 = blockIdx.y * blockDim.x;\n  const int w0 = blockIdx.z * blockDim.y;\n  const int tid = threadIdx.x * blockDim.y + threadIdx.y;\n\n  const int H1 = fmap1.size(1);\n  const int W1 = fmap1.size(2);\n  const int H2 = fmap2.size(1);\n  const int W2 = fmap2.size(2);\n  const int N = coords.size(1);\n  const int C = fmap1.size(3);\n\n  __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];\n  __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];\n\n  __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1];\n  __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1];\n\n  __shared__ scalar_t x2s[BLOCK_HW];\n  __shared__ scalar_t y2s[BLOCK_HW];\n\n  for (int c=0; c<C; c+=CHANNEL_STRIDE) {\n\n    for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {\n      int k1 = k + tid / CHANNEL_STRIDE;\n      int h1 = h0 + k1 / BLOCK_W;\n      int w1 = w0 + k1 % BLOCK_W;\n      int c1 = tid % CHANNEL_STRIDE;\n\n      auto fptr = fmap1[b][h1][w1];\n      if (within_bounds(h1, w1, H1, W1))\n        f1[c1][k1] = fptr[c+c1];\n      else\n        f1[c1][k1] = 0.0;\n\n      f1_grad[c1][k1] = 0.0;\n    }\n\n    __syncthreads();\n\n    int h1 = h0 + threadIdx.x;\n    int w1 = w0 + threadIdx.y;\n\n    for (int n=0; n<N; n++) {  \n      x2s[tid] = coords[b][n][h1][w1][0];\n      y2s[tid] = coords[b][n][h1][w1][1];\n\n      scalar_t dx = x2s[tid] - floor(x2s[tid]);\n      scalar_t dy = y2s[tid] - floor(y2s[tid]);\n\n      int rd = 2*r + 1;\n      for (int iy=0; iy<rd+1; iy++) {\n        for (int ix=0; ix<rd+1; ix++) {\n          for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {\n            int k1 = k + tid / CHANNEL_STRIDE;\n            int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;\n            int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;\n            int c2 = tid % CHANNEL_STRIDE;\n\n            auto fptr = fmap2[b][h2][w2];\n            if (within_bounds(h2, w2, H2, W2))\n              f2[c2][k1] = fptr[c+c2];\n            else\n              f2[c2][k1] = 0.0;\n\n            f2_grad[c2][k1] = 0.0;\n          }\n\n          __syncthreads();\n      \n          const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1];\n          scalar_t g = 0.0;\n\n          int ix_nw = H1*W1*((iy-1) + rd*(ix-1));\n          int ix_ne = H1*W1*((iy-1) + rd*ix);\n          int ix_sw = H1*W1*(iy + rd*(ix-1));\n          int ix_se = H1*W1*(iy + rd*ix);\n\n          if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))\n            g +=  *(grad_ptr + ix_nw) * dy * dx;\n\n          if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))\n            g += *(grad_ptr + ix_ne) * dy * (1-dx);\n\n          if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))\n            g += *(grad_ptr + ix_sw) * (1-dy) * dx;\n\n          if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))\n            g += *(grad_ptr + ix_se) * (1-dy) * (1-dx);\n            \n          for (int k=0; k<CHANNEL_STRIDE; k++) {\n            f1_grad[k][tid] += g * f2[k][tid];\n            f2_grad[k][tid] += g * f1[k][tid];\n          }\n\n          for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {\n            int k1 = k + tid / CHANNEL_STRIDE;\n            int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;\n            int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;\n            int c2 = tid % CHANNEL_STRIDE;\n\n            scalar_t* fptr = &fmap2_grad[b][h2][w2][0];\n            if (within_bounds(h2, w2, H2, W2))\n              atomicAdd(fptr+c+c2, f2_grad[c2][k1]);\n          }\n        }\n      } \n    }\n    __syncthreads();\n\n\n    for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {\n      int k1 = k + tid / CHANNEL_STRIDE;\n      int h1 = h0 + k1 / BLOCK_W;\n      int w1 = w0 + k1 % BLOCK_W;\n      int c1 = tid % CHANNEL_STRIDE;\n\n      scalar_t* fptr = &fmap1_grad[b][h1][w1][0];\n      if (within_bounds(h1, w1, H1, W1))\n        fptr[c+c1] += f1_grad[c1][k1];\n    }\n  }\n}\n\n\n\nstd::vector<torch::Tensor> corr_cuda_forward(\n  torch::Tensor fmap1,\n  torch::Tensor fmap2,\n  torch::Tensor coords,\n  int radius)\n{\n  const auto B = coords.size(0);\n  const auto N = coords.size(1);\n  const auto H = coords.size(2);\n  const auto W = coords.size(3);\n\n  const auto rd = 2 * radius + 1;\n  auto opts = fmap1.options();\n  auto corr = torch::zeros({B, N, rd*rd, H, W}, opts);\n  \n  const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W);\n  const dim3 threads(BLOCK_H, BLOCK_W);\n\n  corr_forward_kernel<float><<<blocks, threads>>>(\n    fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),\n    fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),\n    coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),\n    corr.packed_accessor32<float,5,torch::RestrictPtrTraits>(),\n    radius);\n\n  return {corr};\n}\n\nstd::vector<torch::Tensor> corr_cuda_backward(\n  torch::Tensor fmap1,\n  torch::Tensor fmap2,\n  torch::Tensor coords,\n  torch::Tensor corr_grad,\n  int radius)\n{\n  const auto B = coords.size(0);\n  const auto N = coords.size(1);\n\n  const auto H1 = fmap1.size(1);\n  const auto W1 = fmap1.size(2);\n  const auto H2 = fmap2.size(1);\n  const auto W2 = fmap2.size(2);\n  const auto C = fmap1.size(3);\n\n  auto opts = fmap1.options();\n  auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts);\n  auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts);\n  auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts);\n    \n  const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W);\n  const dim3 threads(BLOCK_H, BLOCK_W);\n\n\n  corr_backward_kernel<float><<<blocks, threads>>>(\n    fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),\n    fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),\n    coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),\n    corr_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),\n    fmap1_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),\n    fmap2_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),\n    coords_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),\n    radius);\n\n  return {fmap1_grad, fmap2_grad, coords_grad};\n}"
  },
  {
    "path": "third_party/RAFT/alt_cuda_corr/setup.py",
    "content": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n\nsetup(\n    name='correlation',\n    ext_modules=[\n        CUDAExtension('alt_cuda_corr',\n            sources=['correlation.cpp', 'correlation_kernel.cu'],\n            extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),\n    ],\n    cmdclass={\n        'build_ext': BuildExtension\n    })\n\n"
  },
  {
    "path": "third_party/RAFT/chairs_split.txt",
    "content": "1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n2\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n2\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n2\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1"
  },
  {
    "path": "third_party/RAFT/core/__init__.py",
    "content": ""
  },
  {
    "path": "third_party/RAFT/core/corr.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom .utils.utils import bilinear_sampler, coords_grid\n\ntry:\n    import alt_cuda_corr\nexcept:\n    # alt_cuda_corr is not compiled\n    pass\n\n\nclass CorrBlock:\n    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):\n        self.num_levels = num_levels\n        self.radius = radius\n        self.corr_pyramid = []\n\n        # all pairs correlation\n        corr = CorrBlock.corr(fmap1, fmap2)\n\n        batch, h1, w1, dim, h2, w2 = corr.shape\n        corr = corr.reshape(batch*h1*w1, dim, h2, w2)\n        \n        self.corr_pyramid.append(corr)\n        for i in range(self.num_levels-1):\n            corr = F.avg_pool2d(corr, 2, stride=2)\n            self.corr_pyramid.append(corr)\n\n    def __call__(self, coords):\n        r = self.radius\n        coords = coords.permute(0, 2, 3, 1)\n        batch, h1, w1, _ = coords.shape\n\n        out_pyramid = []\n        for i in range(self.num_levels):\n            corr = self.corr_pyramid[i]\n            dx = torch.linspace(-r, r, 2*r+1, device=coords.device)\n            dy = torch.linspace(-r, r, 2*r+1, device=coords.device)\n            delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)\n\n            centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i\n            delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)\n            coords_lvl = centroid_lvl + delta_lvl\n\n            corr = bilinear_sampler(corr, coords_lvl)\n            corr = corr.view(batch, h1, w1, -1)\n            out_pyramid.append(corr)\n\n        out = torch.cat(out_pyramid, dim=-1)\n        return out.permute(0, 3, 1, 2).contiguous().float()\n\n    @staticmethod\n    def corr(fmap1, fmap2):\n        batch, dim, ht, wd = fmap1.shape\n        fmap1 = fmap1.view(batch, dim, ht*wd)\n        fmap2 = fmap2.view(batch, dim, ht*wd) \n        \n        corr = torch.matmul(fmap1.transpose(1,2), fmap2)\n        corr = corr.view(batch, ht, wd, 1, ht, wd)\n        return corr  / torch.sqrt(torch.tensor(dim).float())\n\n\nclass AlternateCorrBlock:\n    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):\n        self.num_levels = num_levels\n        self.radius = radius\n\n        self.pyramid = [(fmap1, fmap2)]\n        for i in range(self.num_levels):\n            fmap1 = F.avg_pool2d(fmap1, 2, stride=2)\n            fmap2 = F.avg_pool2d(fmap2, 2, stride=2)\n            self.pyramid.append((fmap1, fmap2))\n\n    def __call__(self, coords):\n        coords = coords.permute(0, 2, 3, 1)\n        B, H, W, _ = coords.shape\n        dim = self.pyramid[0][0].shape[1]\n\n        corr_list = []\n        for i in range(self.num_levels):\n            r = self.radius\n            fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()\n            fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()\n\n            coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()\n            corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)\n            corr_list.append(corr.squeeze(1))\n\n        corr = torch.stack(corr_list, dim=1)\n        corr = corr.reshape(B, -1, H, W)\n        return corr / torch.sqrt(torch.tensor(dim).float())\n"
  },
  {
    "path": "third_party/RAFT/core/datasets.py",
    "content": "# Data loading based on https://github.com/NVIDIA/flownet2-pytorch\n\nimport numpy as np\nimport torch\nimport torch.utils.data as data\nimport torch.nn.functional as F\n\nimport os\nimport math\nimport random\nfrom glob import glob\nimport os.path as osp\n\nfrom utils import frame_utils\nfrom utils.augmentor import FlowAugmentor, SparseFlowAugmentor\n\n\nclass FlowDataset(data.Dataset):\n    def __init__(self, aug_params=None, sparse=False):\n        self.augmentor = None\n        self.sparse = sparse\n        if aug_params is not None:\n            if sparse:\n                self.augmentor = SparseFlowAugmentor(**aug_params)\n            else:\n                self.augmentor = FlowAugmentor(**aug_params)\n\n        self.is_test = False\n        self.init_seed = False\n        self.flow_list = []\n        self.image_list = []\n        self.extra_info = []\n\n    def __getitem__(self, index):\n\n        if self.is_test:\n            img1 = frame_utils.read_gen(self.image_list[index][0])\n            img2 = frame_utils.read_gen(self.image_list[index][1])\n            img1 = np.array(img1).astype(np.uint8)[..., :3]\n            img2 = np.array(img2).astype(np.uint8)[..., :3]\n            img1 = torch.from_numpy(img1).permute(2, 0, 1).float()\n            img2 = torch.from_numpy(img2).permute(2, 0, 1).float()\n            return img1, img2, self.extra_info[index]\n\n        if not self.init_seed:\n            worker_info = torch.utils.data.get_worker_info()\n            if worker_info is not None:\n                torch.manual_seed(worker_info.id)\n                np.random.seed(worker_info.id)\n                random.seed(worker_info.id)\n                self.init_seed = True\n\n        index = index % len(self.image_list)\n        valid = None\n        if self.sparse:\n            flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])\n        else:\n            flow = frame_utils.read_gen(self.flow_list[index])\n\n        img1 = frame_utils.read_gen(self.image_list[index][0])\n        img2 = frame_utils.read_gen(self.image_list[index][1])\n\n        flow = np.array(flow).astype(np.float32)\n        img1 = np.array(img1).astype(np.uint8)\n        img2 = np.array(img2).astype(np.uint8)\n\n        # grayscale images\n        if len(img1.shape) == 2:\n            img1 = np.tile(img1[...,None], (1, 1, 3))\n            img2 = np.tile(img2[...,None], (1, 1, 3))\n        else:\n            img1 = img1[..., :3]\n            img2 = img2[..., :3]\n\n        if self.augmentor is not None:\n            if self.sparse:\n                img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)\n            else:\n                img1, img2, flow = self.augmentor(img1, img2, flow)\n\n        img1 = torch.from_numpy(img1).permute(2, 0, 1).float()\n        img2 = torch.from_numpy(img2).permute(2, 0, 1).float()\n        flow = torch.from_numpy(flow).permute(2, 0, 1).float()\n\n        if valid is not None:\n            valid = torch.from_numpy(valid)\n        else:\n            valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)\n\n        return img1, img2, flow, valid.float()\n\n\n    def __rmul__(self, v):\n        self.flow_list = v * self.flow_list\n        self.image_list = v * self.image_list\n        return self\n        \n    def __len__(self):\n        return len(self.image_list)\n        \n\nclass MpiSintel(FlowDataset):\n    def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):\n        super(MpiSintel, self).__init__(aug_params)\n        flow_root = osp.join(root, split, 'flow')\n        image_root = osp.join(root, split, dstype)\n\n        if split == 'test':\n            self.is_test = True\n\n        for scene in os.listdir(image_root):\n            image_list = sorted(glob(osp.join(image_root, scene, '*.png')))\n            for i in range(len(image_list)-1):\n                self.image_list += [ [image_list[i], image_list[i+1]] ]\n                self.extra_info += [ (scene, i) ] # scene and frame_id\n\n            if split != 'test':\n                self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))\n\n\nclass FlyingChairs(FlowDataset):\n    def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):\n        super(FlyingChairs, self).__init__(aug_params)\n\n        images = sorted(glob(osp.join(root, '*.ppm')))\n        flows = sorted(glob(osp.join(root, '*.flo')))\n        assert (len(images)//2 == len(flows))\n\n        split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)\n        for i in range(len(flows)):\n            xid = split_list[i]\n            if (split=='training' and xid==1) or (split=='validation' and xid==2):\n                self.flow_list += [ flows[i] ]\n                self.image_list += [ [images[2*i], images[2*i+1]] ]\n\n\nclass FlyingThings3D(FlowDataset):\n    def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):\n        super(FlyingThings3D, self).__init__(aug_params)\n\n        for cam in ['left']:\n            for direction in ['into_future', 'into_past']:\n                image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))\n                image_dirs = sorted([osp.join(f, cam) for f in image_dirs])\n\n                flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))\n                flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])\n\n                for idir, fdir in zip(image_dirs, flow_dirs):\n                    images = sorted(glob(osp.join(idir, '*.png')) )\n                    flows = sorted(glob(osp.join(fdir, '*.pfm')) )\n                    for i in range(len(flows)-1):\n                        if direction == 'into_future':\n                            self.image_list += [ [images[i], images[i+1]] ]\n                            self.flow_list += [ flows[i] ]\n                        elif direction == 'into_past':\n                            self.image_list += [ [images[i+1], images[i]] ]\n                            self.flow_list += [ flows[i+1] ]\n      \n\nclass KITTI(FlowDataset):\n    def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):\n        super(KITTI, self).__init__(aug_params, sparse=True)\n        if split == 'testing':\n            self.is_test = True\n\n        root = osp.join(root, split)\n        images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))\n        images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))\n\n        for img1, img2 in zip(images1, images2):\n            frame_id = img1.split('/')[-1]\n            self.extra_info += [ [frame_id] ]\n            self.image_list += [ [img1, img2] ]\n\n        if split == 'training':\n            self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))\n\n\nclass HD1K(FlowDataset):\n    def __init__(self, aug_params=None, root='datasets/HD1k'):\n        super(HD1K, self).__init__(aug_params, sparse=True)\n\n        seq_ix = 0\n        while 1:\n            flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))\n            images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))\n\n            if len(flows) == 0:\n                break\n\n            for i in range(len(flows)-1):\n                self.flow_list += [flows[i]]\n                self.image_list += [ [images[i], images[i+1]] ]\n\n            seq_ix += 1\n\n\ndef fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):\n    \"\"\" Create the data loader for the corresponding trainign set \"\"\"\n\n    if args.stage == 'chairs':\n        aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}\n        train_dataset = FlyingChairs(aug_params, split='training')\n    \n    elif args.stage == 'things':\n        aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}\n        clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')\n        final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')\n        train_dataset = clean_dataset + final_dataset\n\n    elif args.stage == 'sintel':\n        aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}\n        things = FlyingThings3D(aug_params, dstype='frames_cleanpass')\n        sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')\n        sintel_final = MpiSintel(aug_params, split='training', dstype='final')        \n\n        if TRAIN_DS == 'C+T+K+S+H':\n            kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})\n            hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})\n            train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things\n\n        elif TRAIN_DS == 'C+T+K/S':\n            train_dataset = 100*sintel_clean + 100*sintel_final + things\n\n    elif args.stage == 'kitti':\n        aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}\n        train_dataset = KITTI(aug_params, split='training')\n\n    train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, \n        pin_memory=False, shuffle=True, num_workers=4, drop_last=True)\n\n    print('Training with %d image pairs' % len(train_dataset))\n    return train_loader\n\n"
  },
  {
    "path": "third_party/RAFT/core/extractor.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, in_planes, planes, norm_fn='group', stride=1):\n        super(ResidualBlock, self).__init__()\n  \n        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)\n        self.relu = nn.ReLU(inplace=True)\n\n        num_groups = planes // 8\n\n        if norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            if not stride == 1:\n                self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n        \n        elif norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(planes)\n            self.norm2 = nn.BatchNorm2d(planes)\n            if not stride == 1:\n                self.norm3 = nn.BatchNorm2d(planes)\n        \n        elif norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(planes)\n            self.norm2 = nn.InstanceNorm2d(planes)\n            if not stride == 1:\n                self.norm3 = nn.InstanceNorm2d(planes)\n\n        elif norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n            self.norm2 = nn.Sequential()\n            if not stride == 1:\n                self.norm3 = nn.Sequential()\n\n        if stride == 1:\n            self.downsample = None\n        \n        else:    \n            self.downsample = nn.Sequential(\n                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)\n\n\n    def forward(self, x):\n        y = x\n        y = self.relu(self.norm1(self.conv1(y)))\n        y = self.relu(self.norm2(self.conv2(y)))\n\n        if self.downsample is not None:\n            x = self.downsample(x)\n\n        return self.relu(x+y)\n\n\n\nclass BottleneckBlock(nn.Module):\n    def __init__(self, in_planes, planes, norm_fn='group', stride=1):\n        super(BottleneckBlock, self).__init__()\n  \n        self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)\n        self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)\n        self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)\n        self.relu = nn.ReLU(inplace=True)\n\n        num_groups = planes // 8\n\n        if norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)\n            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)\n            self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            if not stride == 1:\n                self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n        \n        elif norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(planes//4)\n            self.norm2 = nn.BatchNorm2d(planes//4)\n            self.norm3 = nn.BatchNorm2d(planes)\n            if not stride == 1:\n                self.norm4 = nn.BatchNorm2d(planes)\n        \n        elif norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(planes//4)\n            self.norm2 = nn.InstanceNorm2d(planes//4)\n            self.norm3 = nn.InstanceNorm2d(planes)\n            if not stride == 1:\n                self.norm4 = nn.InstanceNorm2d(planes)\n\n        elif norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n            self.norm2 = nn.Sequential()\n            self.norm3 = nn.Sequential()\n            if not stride == 1:\n                self.norm4 = nn.Sequential()\n\n        if stride == 1:\n            self.downsample = None\n        \n        else:    \n            self.downsample = nn.Sequential(\n                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)\n\n\n    def forward(self, x):\n        y = x\n        y = self.relu(self.norm1(self.conv1(y)))\n        y = self.relu(self.norm2(self.conv2(y)))\n        y = self.relu(self.norm3(self.conv3(y)))\n\n        if self.downsample is not None:\n            x = self.downsample(x)\n\n        return self.relu(x+y)\n\nclass BasicEncoder(nn.Module):\n    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):\n        super(BasicEncoder, self).__init__()\n        self.norm_fn = norm_fn\n\n        if self.norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)\n            \n        elif self.norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(64)\n\n        elif self.norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(64)\n\n        elif self.norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)\n        self.relu1 = nn.ReLU(inplace=True)\n\n        self.in_planes = 64\n        self.layer1 = self._make_layer(64,  stride=1)\n        self.layer2 = self._make_layer(96, stride=2)\n        self.layer3 = self._make_layer(128, stride=2)\n\n        # output convolution\n        self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)\n\n        self.dropout = None\n        if dropout > 0:\n            self.dropout = nn.Dropout2d(p=dropout)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, dim, stride=1):\n        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)\n        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)\n        layers = (layer1, layer2)\n        \n        self.in_planes = dim\n        return nn.Sequential(*layers)\n\n\n    def forward(self, x):\n\n        # if input is list, combine batch dimension\n        is_list = isinstance(x, tuple) or isinstance(x, list)\n        if is_list:\n            batch_dim = x[0].shape[0]\n            x = torch.cat(x, dim=0)\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu1(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n\n        x = self.conv2(x)\n\n        if self.training and self.dropout is not None:\n            x = self.dropout(x)\n\n        if is_list:\n            x = torch.split(x, [batch_dim, batch_dim], dim=0)\n\n        return x\n\n\nclass SmallEncoder(nn.Module):\n    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):\n        super(SmallEncoder, self).__init__()\n        self.norm_fn = norm_fn\n\n        if self.norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)\n            \n        elif self.norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(32)\n\n        elif self.norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(32)\n\n        elif self.norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n\n        self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)\n        self.relu1 = nn.ReLU(inplace=True)\n\n        self.in_planes = 32\n        self.layer1 = self._make_layer(32,  stride=1)\n        self.layer2 = self._make_layer(64, stride=2)\n        self.layer3 = self._make_layer(96, stride=2)\n\n        self.dropout = None\n        if dropout > 0:\n            self.dropout = nn.Dropout2d(p=dropout)\n        \n        self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, dim, stride=1):\n        layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)\n        layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)\n        layers = (layer1, layer2)\n    \n        self.in_planes = dim\n        return nn.Sequential(*layers)\n\n\n    def forward(self, x):\n\n        # if input is list, combine batch dimension\n        is_list = isinstance(x, tuple) or isinstance(x, list)\n        if is_list:\n            batch_dim = x[0].shape[0]\n            x = torch.cat(x, dim=0)\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu1(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.conv2(x)\n\n        if self.training and self.dropout is not None:\n            x = self.dropout(x)\n\n        if is_list:\n            x = torch.split(x, [batch_dim, batch_dim], dim=0)\n\n        return x\n"
  },
  {
    "path": "third_party/RAFT/core/raft.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .update import BasicUpdateBlock, SmallUpdateBlock\nfrom .extractor import BasicEncoder, SmallEncoder\nfrom .corr import CorrBlock, AlternateCorrBlock\nfrom .utils.utils import bilinear_sampler, coords_grid, upflow8\n\ntry:\n    autocast = torch.cuda.amp.autocast\nexcept:\n    # dummy autocast for PyTorch < 1.6\n    class autocast:\n        def __init__(self, enabled):\n            pass\n        def __enter__(self):\n            pass\n        def __exit__(self, *args):\n            pass\n\n\nclass RAFT(nn.Module):\n    def __init__(self, args):\n        super(RAFT, self).__init__()\n        self.args = args\n\n        if args.small:\n            self.hidden_dim = hdim = 96\n            self.context_dim = cdim = 64\n            args.corr_levels = 4\n            args.corr_radius = 3\n        \n        else:\n            self.hidden_dim = hdim = 128\n            self.context_dim = cdim = 128\n            args.corr_levels = 4\n            args.corr_radius = 4\n\n        # if 'dropout' not in self.args:\n        self.args.dropout = 0\n\n        # if 'alternate_corr' not in self.args:\n        self.args.alternate_corr = False\n\n        # feature network, context network, and update block\n        if args.small:\n            self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)        \n            self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)\n            self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)\n\n        else:\n            self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)        \n            self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)\n            self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)\n\n    def freeze_bn(self):\n        for m in self.modules():\n            if isinstance(m, nn.BatchNorm2d):\n                m.eval()\n\n    def initialize_flow(self, img):\n        \"\"\" Flow is represented as difference between two coordinate grids flow = coords1 - coords0\"\"\"\n        N, C, H, W = img.shape\n        coords0 = coords_grid(N, H//8, W//8, device=img.device)\n        coords1 = coords_grid(N, H//8, W//8, device=img.device)\n\n        # optical flow computed as difference: flow = coords1 - coords0\n        return coords0, coords1\n\n    def upsample_flow(self, flow, mask):\n        \"\"\" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination \"\"\"\n        N, _, H, W = flow.shape\n        mask = mask.view(N, 1, 9, 8, 8, H, W)\n        mask = torch.softmax(mask, dim=2)\n\n        up_flow = F.unfold(8 * flow, [3,3], padding=1)\n        up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)\n\n        up_flow = torch.sum(mask * up_flow, dim=2)\n        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)\n        return up_flow.reshape(N, 2, 8*H, 8*W)\n\n\n    def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):\n        \"\"\" Estimate optical flow between pair of frames \"\"\"\n\n        image1 = 2 * (image1 / 255.0) - 1.0\n        image2 = 2 * (image2 / 255.0) - 1.0\n\n        image1 = image1.contiguous()\n        image2 = image2.contiguous()\n\n        hdim = self.hidden_dim\n        cdim = self.context_dim\n\n        # run the feature network\n        with autocast(enabled=self.args.mixed_precision):\n            fmap1, fmap2 = self.fnet([image1, image2])        \n        \n        fmap1 = fmap1.float()\n        fmap2 = fmap2.float()\n        if self.args.alternate_corr:\n            corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)\n        else:\n            corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)\n\n        # run the context network\n        with autocast(enabled=self.args.mixed_precision):\n            cnet = self.cnet(image1)\n            net, inp = torch.split(cnet, [hdim, cdim], dim=1)\n            net = torch.tanh(net)\n            inp = torch.relu(inp)\n\n        coords0, coords1 = self.initialize_flow(image1)\n\n        if flow_init is not None:\n            coords1 = coords1 + flow_init\n\n        flow_predictions = []\n        for itr in range(iters):\n            coords1 = coords1.detach()\n            corr = corr_fn(coords1) # index correlation volume\n\n            flow = coords1 - coords0\n            with autocast(enabled=self.args.mixed_precision):\n                net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)\n\n            # F(t+1) = F(t) + \\Delta(t)\n            coords1 = coords1 + delta_flow\n\n            # upsample predictions\n            if up_mask is None:\n                flow_up = upflow8(coords1 - coords0)\n            else:\n                flow_up = self.upsample_flow(coords1 - coords0, up_mask)\n            \n            flow_predictions.append(flow_up)\n\n        if test_mode:\n            return coords1 - coords0, flow_up\n            \n        return flow_predictions\n"
  },
  {
    "path": "third_party/RAFT/core/update.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass FlowHead(nn.Module):\n    def __init__(self, input_dim=128, hidden_dim=256):\n        super(FlowHead, self).__init__()\n        self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)\n        self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        return self.conv2(self.relu(self.conv1(x)))\n\nclass ConvGRU(nn.Module):\n    def __init__(self, hidden_dim=128, input_dim=192+128):\n        super(ConvGRU, self).__init__()\n        self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)\n        self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)\n        self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)\n\n    def forward(self, h, x):\n        hx = torch.cat([h, x], dim=1)\n\n        z = torch.sigmoid(self.convz(hx))\n        r = torch.sigmoid(self.convr(hx))\n        q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))\n\n        h = (1-z) * h + z * q\n        return h\n\nclass SepConvGRU(nn.Module):\n    def __init__(self, hidden_dim=128, input_dim=192+128):\n        super(SepConvGRU, self).__init__()\n        self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))\n        self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))\n        self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))\n\n        self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))\n        self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))\n        self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))\n\n\n    def forward(self, h, x):\n        # horizontal\n        hx = torch.cat([h, x], dim=1)\n        z = torch.sigmoid(self.convz1(hx))\n        r = torch.sigmoid(self.convr1(hx))\n        q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))        \n        h = (1-z) * h + z * q\n\n        # vertical\n        hx = torch.cat([h, x], dim=1)\n        z = torch.sigmoid(self.convz2(hx))\n        r = torch.sigmoid(self.convr2(hx))\n        q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))       \n        h = (1-z) * h + z * q\n\n        return h\n\nclass SmallMotionEncoder(nn.Module):\n    def __init__(self, args):\n        super(SmallMotionEncoder, self).__init__()\n        cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2\n        self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)\n        self.convf1 = nn.Conv2d(2, 64, 7, padding=3)\n        self.convf2 = nn.Conv2d(64, 32, 3, padding=1)\n        self.conv = nn.Conv2d(128, 80, 3, padding=1)\n\n    def forward(self, flow, corr):\n        cor = F.relu(self.convc1(corr))\n        flo = F.relu(self.convf1(flow))\n        flo = F.relu(self.convf2(flo))\n        cor_flo = torch.cat([cor, flo], dim=1)\n        out = F.relu(self.conv(cor_flo))\n        return torch.cat([out, flow], dim=1)\n\nclass BasicMotionEncoder(nn.Module):\n    def __init__(self, args):\n        super(BasicMotionEncoder, self).__init__()\n        cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2\n        self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)\n        self.convc2 = nn.Conv2d(256, 192, 3, padding=1)\n        self.convf1 = nn.Conv2d(2, 128, 7, padding=3)\n        self.convf2 = nn.Conv2d(128, 64, 3, padding=1)\n        self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)\n\n    def forward(self, flow, corr):\n        cor = F.relu(self.convc1(corr))\n        cor = F.relu(self.convc2(cor))\n        flo = F.relu(self.convf1(flow))\n        flo = F.relu(self.convf2(flo))\n\n        cor_flo = torch.cat([cor, flo], dim=1)\n        out = F.relu(self.conv(cor_flo))\n        return torch.cat([out, flow], dim=1)\n\nclass SmallUpdateBlock(nn.Module):\n    def __init__(self, args, hidden_dim=96):\n        super(SmallUpdateBlock, self).__init__()\n        self.encoder = SmallMotionEncoder(args)\n        self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)\n        self.flow_head = FlowHead(hidden_dim, hidden_dim=128)\n\n    def forward(self, net, inp, corr, flow):\n        motion_features = self.encoder(flow, corr)\n        inp = torch.cat([inp, motion_features], dim=1)\n        net = self.gru(net, inp)\n        delta_flow = self.flow_head(net)\n\n        return net, None, delta_flow\n\nclass BasicUpdateBlock(nn.Module):\n    def __init__(self, args, hidden_dim=128, input_dim=128):\n        super(BasicUpdateBlock, self).__init__()\n        self.args = args\n        self.encoder = BasicMotionEncoder(args)\n        self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)\n        self.flow_head = FlowHead(hidden_dim, hidden_dim=256)\n\n        self.mask = nn.Sequential(\n            nn.Conv2d(128, 256, 3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(256, 64*9, 1, padding=0))\n\n    def forward(self, net, inp, corr, flow, upsample=True):\n        motion_features = self.encoder(flow, corr)\n        inp = torch.cat([inp, motion_features], dim=1)\n\n        net = self.gru(net, inp)\n        delta_flow = self.flow_head(net)\n\n        # scale mask to balence gradients\n        mask = .25 * self.mask(net)\n        return net, mask, delta_flow\n\n\n\n"
  },
  {
    "path": "third_party/RAFT/core/utils/__init__.py",
    "content": ""
  },
  {
    "path": "third_party/RAFT/core/utils/augmentor.py",
    "content": "import numpy as np\nimport random\nimport math\nfrom PIL import Image\n\nimport cv2\ncv2.setNumThreads(0)\ncv2.ocl.setUseOpenCL(False)\n\nimport torch\nfrom torchvision.transforms import ColorJitter\nimport torch.nn.functional as F\n\n\nclass FlowAugmentor:\n    def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):\n        \n        # spatial augmentation params\n        self.crop_size = crop_size\n        self.min_scale = min_scale\n        self.max_scale = max_scale\n        self.spatial_aug_prob = 0.8\n        self.stretch_prob = 0.8\n        self.max_stretch = 0.2\n\n        # flip augmentation params\n        self.do_flip = do_flip\n        self.h_flip_prob = 0.5\n        self.v_flip_prob = 0.1\n\n        # photometric augmentation params\n        self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)\n        self.asymmetric_color_aug_prob = 0.2\n        self.eraser_aug_prob = 0.5\n\n    def color_transform(self, img1, img2):\n        \"\"\" Photometric augmentation \"\"\"\n\n        # asymmetric\n        if np.random.rand() < self.asymmetric_color_aug_prob:\n            img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)\n            img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)\n\n        # symmetric\n        else:\n            image_stack = np.concatenate([img1, img2], axis=0)\n            image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)\n            img1, img2 = np.split(image_stack, 2, axis=0)\n\n        return img1, img2\n\n    def eraser_transform(self, img1, img2, bounds=[50, 100]):\n        \"\"\" Occlusion augmentation \"\"\"\n\n        ht, wd = img1.shape[:2]\n        if np.random.rand() < self.eraser_aug_prob:\n            mean_color = np.mean(img2.reshape(-1, 3), axis=0)\n            for _ in range(np.random.randint(1, 3)):\n                x0 = np.random.randint(0, wd)\n                y0 = np.random.randint(0, ht)\n                dx = np.random.randint(bounds[0], bounds[1])\n                dy = np.random.randint(bounds[0], bounds[1])\n                img2[y0:y0+dy, x0:x0+dx, :] = mean_color\n\n        return img1, img2\n\n    def spatial_transform(self, img1, img2, flow):\n        # randomly sample scale\n        ht, wd = img1.shape[:2]\n        min_scale = np.maximum(\n            (self.crop_size[0] + 8) / float(ht), \n            (self.crop_size[1] + 8) / float(wd))\n\n        scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)\n        scale_x = scale\n        scale_y = scale\n        if np.random.rand() < self.stretch_prob:\n            scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)\n            scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)\n        \n        scale_x = np.clip(scale_x, min_scale, None)\n        scale_y = np.clip(scale_y, min_scale, None)\n\n        if np.random.rand() < self.spatial_aug_prob:\n            # rescale the images\n            img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)\n            img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)\n            flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)\n            flow = flow * [scale_x, scale_y]\n\n        if self.do_flip:\n            if np.random.rand() < self.h_flip_prob: # h-flip\n                img1 = img1[:, ::-1]\n                img2 = img2[:, ::-1]\n                flow = flow[:, ::-1] * [-1.0, 1.0]\n\n            if np.random.rand() < self.v_flip_prob: # v-flip\n                img1 = img1[::-1, :]\n                img2 = img2[::-1, :]\n                flow = flow[::-1, :] * [1.0, -1.0]\n\n        y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])\n        x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])\n        \n        img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n\n        return img1, img2, flow\n\n    def __call__(self, img1, img2, flow):\n        img1, img2 = self.color_transform(img1, img2)\n        img1, img2 = self.eraser_transform(img1, img2)\n        img1, img2, flow = self.spatial_transform(img1, img2, flow)\n\n        img1 = np.ascontiguousarray(img1)\n        img2 = np.ascontiguousarray(img2)\n        flow = np.ascontiguousarray(flow)\n\n        return img1, img2, flow\n\nclass SparseFlowAugmentor:\n    def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):\n        # spatial augmentation params\n        self.crop_size = crop_size\n        self.min_scale = min_scale\n        self.max_scale = max_scale\n        self.spatial_aug_prob = 0.8\n        self.stretch_prob = 0.8\n        self.max_stretch = 0.2\n\n        # flip augmentation params\n        self.do_flip = do_flip\n        self.h_flip_prob = 0.5\n        self.v_flip_prob = 0.1\n\n        # photometric augmentation params\n        self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)\n        self.asymmetric_color_aug_prob = 0.2\n        self.eraser_aug_prob = 0.5\n        \n    def color_transform(self, img1, img2):\n        image_stack = np.concatenate([img1, img2], axis=0)\n        image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)\n        img1, img2 = np.split(image_stack, 2, axis=0)\n        return img1, img2\n\n    def eraser_transform(self, img1, img2):\n        ht, wd = img1.shape[:2]\n        if np.random.rand() < self.eraser_aug_prob:\n            mean_color = np.mean(img2.reshape(-1, 3), axis=0)\n            for _ in range(np.random.randint(1, 3)):\n                x0 = np.random.randint(0, wd)\n                y0 = np.random.randint(0, ht)\n                dx = np.random.randint(50, 100)\n                dy = np.random.randint(50, 100)\n                img2[y0:y0+dy, x0:x0+dx, :] = mean_color\n\n        return img1, img2\n\n    def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):\n        ht, wd = flow.shape[:2]\n        coords = np.meshgrid(np.arange(wd), np.arange(ht))\n        coords = np.stack(coords, axis=-1)\n\n        coords = coords.reshape(-1, 2).astype(np.float32)\n        flow = flow.reshape(-1, 2).astype(np.float32)\n        valid = valid.reshape(-1).astype(np.float32)\n\n        coords0 = coords[valid>=1]\n        flow0 = flow[valid>=1]\n\n        ht1 = int(round(ht * fy))\n        wd1 = int(round(wd * fx))\n\n        coords1 = coords0 * [fx, fy]\n        flow1 = flow0 * [fx, fy]\n\n        xx = np.round(coords1[:,0]).astype(np.int32)\n        yy = np.round(coords1[:,1]).astype(np.int32)\n\n        v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)\n        xx = xx[v]\n        yy = yy[v]\n        flow1 = flow1[v]\n\n        flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)\n        valid_img = np.zeros([ht1, wd1], dtype=np.int32)\n\n        flow_img[yy, xx] = flow1\n        valid_img[yy, xx] = 1\n\n        return flow_img, valid_img\n\n    def spatial_transform(self, img1, img2, flow, valid):\n        # randomly sample scale\n\n        ht, wd = img1.shape[:2]\n        min_scale = np.maximum(\n            (self.crop_size[0] + 1) / float(ht), \n            (self.crop_size[1] + 1) / float(wd))\n\n        scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)\n        scale_x = np.clip(scale, min_scale, None)\n        scale_y = np.clip(scale, min_scale, None)\n\n        if np.random.rand() < self.spatial_aug_prob:\n            # rescale the images\n            img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)\n            img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)\n            flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)\n\n        if self.do_flip:\n            if np.random.rand() < 0.5: # h-flip\n                img1 = img1[:, ::-1]\n                img2 = img2[:, ::-1]\n                flow = flow[:, ::-1] * [-1.0, 1.0]\n                valid = valid[:, ::-1]\n\n        margin_y = 20\n        margin_x = 50\n\n        y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)\n        x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)\n\n        y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])\n        x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])\n\n        img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        return img1, img2, flow, valid\n\n\n    def __call__(self, img1, img2, flow, valid):\n        img1, img2 = self.color_transform(img1, img2)\n        img1, img2 = self.eraser_transform(img1, img2)\n        img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)\n\n        img1 = np.ascontiguousarray(img1)\n        img2 = np.ascontiguousarray(img2)\n        flow = np.ascontiguousarray(flow)\n        valid = np.ascontiguousarray(valid)\n\n        return img1, img2, flow, valid\n"
  },
  {
    "path": "third_party/RAFT/core/utils/flow_viz.py",
    "content": "# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization\n\n\n# MIT License\n#\n# Copyright (c) 2018 Tom Runia\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to conditions.\n#\n# Author: Tom Runia\n# Date Created: 2018-08-03\n\nimport numpy as np\n\ndef make_colorwheel():\n    \"\"\"\n    Generates a color wheel for optical flow visualization as presented in:\n        Baker et al. \"A Database and Evaluation Methodology for Optical Flow\" (ICCV, 2007)\n        URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf\n\n    Code follows the original C++ source code of Daniel Scharstein.\n    Code follows the the Matlab source code of Deqing Sun.\n\n    Returns:\n        np.ndarray: Color wheel\n    \"\"\"\n\n    RY = 15\n    YG = 6\n    GC = 4\n    CB = 11\n    BM = 13\n    MR = 6\n\n    ncols = RY + YG + GC + CB + BM + MR\n    colorwheel = np.zeros((ncols, 3))\n    col = 0\n\n    # RY\n    colorwheel[0:RY, 0] = 255\n    colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)\n    col = col+RY\n    # YG\n    colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)\n    colorwheel[col:col+YG, 1] = 255\n    col = col+YG\n    # GC\n    colorwheel[col:col+GC, 1] = 255\n    colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)\n    col = col+GC\n    # CB\n    colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)\n    colorwheel[col:col+CB, 2] = 255\n    col = col+CB\n    # BM\n    colorwheel[col:col+BM, 2] = 255\n    colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)\n    col = col+BM\n    # MR\n    colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)\n    colorwheel[col:col+MR, 0] = 255\n    return colorwheel\n\n\ndef flow_uv_to_colors(u, v, convert_to_bgr=False):\n    \"\"\"\n    Applies the flow color wheel to (possibly clipped) flow components u and v.\n\n    According to the C++ source code of Daniel Scharstein\n    According to the Matlab source code of Deqing Sun\n\n    Args:\n        u (np.ndarray): Input horizontal flow of shape [H,W]\n        v (np.ndarray): Input vertical flow of shape [H,W]\n        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.\n\n    Returns:\n        np.ndarray: Flow visualization image of shape [H,W,3]\n    \"\"\"\n    flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)\n    colorwheel = make_colorwheel()  # shape [55x3]\n    ncols = colorwheel.shape[0]\n    rad = np.sqrt(np.square(u) + np.square(v))\n    a = np.arctan2(-v, -u)/np.pi\n    fk = (a+1) / 2*(ncols-1)\n    k0 = np.floor(fk).astype(np.int32)\n    k1 = k0 + 1\n    k1[k1 == ncols] = 0\n    f = fk - k0\n    for i in range(colorwheel.shape[1]):\n        tmp = colorwheel[:,i]\n        col0 = tmp[k0] / 255.0\n        col1 = tmp[k1] / 255.0\n        col = (1-f)*col0 + f*col1\n        idx = (rad <= 1)\n        col[idx]  = 1 - rad[idx] * (1-col[idx])\n        col[~idx] = col[~idx] * 0.75   # out of range\n        # Note the 2-i => BGR instead of RGB\n        ch_idx = 2-i if convert_to_bgr else i\n        flow_image[:,:,ch_idx] = np.floor(255 * col)\n    return flow_image\n\n\ndef flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):\n    \"\"\"\n    Expects a two dimensional flow image of shape.\n\n    Args:\n        flow_uv (np.ndarray): Flow UV image of shape [H,W,2]\n        clip_flow (float, optional): Clip maximum of flow values. Defaults to None.\n        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.\n\n    Returns:\n        np.ndarray: Flow visualization image of shape [H,W,3]\n    \"\"\"\n    assert flow_uv.ndim == 3, 'input flow must have three dimensions'\n    assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'\n    if clip_flow is not None:\n        flow_uv = np.clip(flow_uv, 0, clip_flow)\n    u = flow_uv[:,:,0]\n    v = flow_uv[:,:,1]\n    rad = np.sqrt(np.square(u) + np.square(v))\n    rad_max = np.max(rad)\n    epsilon = 1e-5\n    u = u / (rad_max + epsilon)\n    v = v / (rad_max + epsilon)\n    return flow_uv_to_colors(u, v, convert_to_bgr)"
  },
  {
    "path": "third_party/RAFT/core/utils/frame_utils.py",
    "content": "import numpy as np\nfrom PIL import Image\nfrom os.path import *\nimport re\n\nimport cv2\ncv2.setNumThreads(0)\ncv2.ocl.setUseOpenCL(False)\n\nTAG_CHAR = np.array([202021.25], np.float32)\n\ndef readFlow(fn):\n    \"\"\" Read .flo file in Middlebury format\"\"\"\n    # Code adapted from:\n    # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy\n\n    # WARNING: this will work on little-endian architectures (eg Intel x86) only!\n    # print 'fn = %s'%(fn)\n    with open(fn, 'rb') as f:\n        magic = np.fromfile(f, np.float32, count=1)\n        if 202021.25 != magic:\n            print('Magic number incorrect. Invalid .flo file')\n            return None\n        else:\n            w = np.fromfile(f, np.int32, count=1)\n            h = np.fromfile(f, np.int32, count=1)\n            # print 'Reading %d x %d flo file\\n' % (w, h)\n            data = np.fromfile(f, np.float32, count=2*int(w)*int(h))\n            # Reshape data into 3D array (columns, rows, bands)\n            # The reshape here is for visualization, the original code is (w,h,2)\n            return np.resize(data, (int(h), int(w), 2))\n\ndef readPFM(file):\n    file = open(file, 'rb')\n\n    color = None\n    width = None\n    height = None\n    scale = None\n    endian = None\n\n    header = file.readline().rstrip()\n    if header == b'PF':\n        color = True\n    elif header == b'Pf':\n        color = False\n    else:\n        raise Exception('Not a PFM file.')\n\n    dim_match = re.match(rb'^(\\d+)\\s(\\d+)\\s$', file.readline())\n    if dim_match:\n        width, height = map(int, dim_match.groups())\n    else:\n        raise Exception('Malformed PFM header.')\n\n    scale = float(file.readline().rstrip())\n    if scale < 0: # little-endian\n        endian = '<'\n        scale = -scale\n    else:\n        endian = '>' # big-endian\n\n    data = np.fromfile(file, endian + 'f')\n    shape = (height, width, 3) if color else (height, width)\n\n    data = np.reshape(data, shape)\n    data = np.flipud(data)\n    return data\n\ndef writeFlow(filename,uv,v=None):\n    \"\"\" Write optical flow to file.\n    \n    If v is None, uv is assumed to contain both u and v channels,\n    stacked in depth.\n    Original code by Deqing Sun, adapted from Daniel Scharstein.\n    \"\"\"\n    nBands = 2\n\n    if v is None:\n        assert(uv.ndim == 3)\n        assert(uv.shape[2] == 2)\n        u = uv[:,:,0]\n        v = uv[:,:,1]\n    else:\n        u = uv\n\n    assert(u.shape == v.shape)\n    height,width = u.shape\n    f = open(filename,'wb')\n    # write the header\n    f.write(TAG_CHAR)\n    np.array(width).astype(np.int32).tofile(f)\n    np.array(height).astype(np.int32).tofile(f)\n    # arrange into matrix form\n    tmp = np.zeros((height, width*nBands))\n    tmp[:,np.arange(width)*2] = u\n    tmp[:,np.arange(width)*2 + 1] = v\n    tmp.astype(np.float32).tofile(f)\n    f.close()\n\n\ndef readFlowKITTI(filename):\n    flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)\n    flow = flow[:,:,::-1].astype(np.float32)\n    flow, valid = flow[:, :, :2], flow[:, :, 2]\n    flow = (flow - 2**15) / 64.0\n    return flow, valid\n\ndef readDispKITTI(filename):\n    disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0\n    valid = disp > 0.0\n    flow = np.stack([-disp, np.zeros_like(disp)], -1)\n    return flow, valid\n\n\ndef writeFlowKITTI(filename, uv):\n    uv = 64.0 * uv + 2**15\n    valid = np.ones([uv.shape[0], uv.shape[1], 1])\n    uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)\n    cv2.imwrite(filename, uv[..., ::-1])\n    \n\ndef read_gen(file_name, pil=False):\n    ext = splitext(file_name)[-1]\n    if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':\n        return Image.open(file_name)\n    elif ext == '.bin' or ext == '.raw':\n        return np.load(file_name)\n    elif ext == '.flo':\n        return readFlow(file_name).astype(np.float32)\n    elif ext == '.pfm':\n        flow = readPFM(file_name).astype(np.float32)\n        if len(flow.shape) == 2:\n            return flow\n        else:\n            return flow[:, :, :-1]\n    return []"
  },
  {
    "path": "third_party/RAFT/core/utils/utils.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom scipy import interpolate\n\n\nclass InputPadder:\n    \"\"\" Pads images such that dimensions are divisible by 8 \"\"\"\n    def __init__(self, dims, mode='sintel'):\n        self.ht, self.wd = dims[-2:]\n        pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8\n        pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8\n        if mode == 'sintel':\n            self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]\n        else:\n            self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]\n\n    def pad(self, *inputs):\n        return [F.pad(x, self._pad, mode='replicate') for x in inputs]\n\n    def unpad(self,x):\n        ht, wd = x.shape[-2:]\n        c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]\n        return x[..., c[0]:c[1], c[2]:c[3]]\n\ndef forward_interpolate(flow):\n    flow = flow.detach().cpu().numpy()\n    dx, dy = flow[0], flow[1]\n\n    ht, wd = dx.shape\n    x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))\n\n    x1 = x0 + dx\n    y1 = y0 + dy\n    \n    x1 = x1.reshape(-1)\n    y1 = y1.reshape(-1)\n    dx = dx.reshape(-1)\n    dy = dy.reshape(-1)\n\n    valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)\n    x1 = x1[valid]\n    y1 = y1[valid]\n    dx = dx[valid]\n    dy = dy[valid]\n\n    flow_x = interpolate.griddata(\n        (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)\n\n    flow_y = interpolate.griddata(\n        (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)\n\n    flow = np.stack([flow_x, flow_y], axis=0)\n    return torch.from_numpy(flow).float()\n\n\ndef bilinear_sampler(img, coords, mode='bilinear', mask=False):\n    \"\"\" Wrapper for grid_sample, uses pixel coordinates \"\"\"\n    H, W = img.shape[-2:]\n    xgrid, ygrid = coords.split([1,1], dim=-1)\n    xgrid = 2*xgrid/(W-1) - 1\n    ygrid = 2*ygrid/(H-1) - 1\n\n    grid = torch.cat([xgrid, ygrid], dim=-1)\n    img = F.grid_sample(img, grid, align_corners=True)\n\n    if mask:\n        mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)\n        return img, mask.float()\n\n    return img\n\n\ndef coords_grid(batch, ht, wd, device):\n    coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))\n    coords = torch.stack(coords[::-1], dim=0).float()\n    return coords[None].repeat(batch, 1, 1, 1)\n\n\ndef upflow8(flow, mode='bilinear'):\n    new_size = (8 * flow.shape[2], 8 * flow.shape[3])\n    return  8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)\n"
  },
  {
    "path": "third_party/RAFT/demo.py",
    "content": "import sys\nsys.path.append('core')\n\nimport argparse\nimport os\nimport cv2\nimport glob\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nfrom raft import RAFT\nfrom utils import flow_viz\nfrom utils.utils import InputPadder\n\n\n\nDEVICE = 'cuda'\n\ndef load_image(imfile):\n    img = np.array(Image.open(imfile)).astype(np.uint8)\n    img = torch.from_numpy(img).permute(2, 0, 1).float()\n    return img[None].to(DEVICE)\n\n\ndef viz(img, flo):\n    img = img[0].permute(1,2,0).cpu().numpy()\n    flo = flo[0].permute(1,2,0).cpu().numpy()\n    \n    # map flow to rgb image\n    flo = flow_viz.flow_to_image(flo)\n    img_flo = np.concatenate([img, flo], axis=0)\n\n    # import matplotlib.pyplot as plt\n    # plt.imshow(img_flo / 255.0)\n    # plt.show()\n\n    cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)\n    cv2.waitKey()\n\n\ndef demo(args):\n    model = torch.nn.DataParallel(RAFT(args))\n    model.load_state_dict(torch.load(args.model))\n\n    model = model.module\n    model.to(DEVICE)\n    model.eval()\n\n    with torch.no_grad():\n        images = glob.glob(os.path.join(args.path, '*.png')) + \\\n                 glob.glob(os.path.join(args.path, '*.jpg'))\n        \n        images = sorted(images)\n        for imfile1, imfile2 in zip(images[:-1], images[1:]):\n            image1 = load_image(imfile1)\n            image2 = load_image(imfile2)\n\n            padder = InputPadder(image1.shape)\n            image1, image2 = padder.pad(image1, image2)\n\n            flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)\n            viz(image1, flow_up)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--model', help=\"restore checkpoint\")\n    parser.add_argument('--path', help=\"dataset for evaluation\")\n    parser.add_argument('--small', action='store_true', help='use small model')\n    parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')\n    parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')\n    args = parser.parse_args()\n\n    demo(args)\n"
  },
  {
    "path": "third_party/RAFT/download_models.sh",
    "content": "#!/bin/bash\nwget https://dl.dropboxusercontent.com/s/4j4z58wuv8o0mfz/models.zip\nunzip models.zip\n"
  },
  {
    "path": "third_party/RAFT/evaluate.py",
    "content": "import sys\nsys.path.append('core')\n\nfrom PIL import Image\nimport argparse\nimport os\nimport time\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport matplotlib.pyplot as plt\n\nimport datasets\nfrom utils import flow_viz\nfrom utils import frame_utils\n\nfrom raft import RAFT\nfrom utils.utils import InputPadder, forward_interpolate\n\n\n@torch.no_grad()\ndef create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):\n    \"\"\" Create submission for the Sintel leaderboard \"\"\"\n    model.eval()\n    for dstype in ['clean', 'final']:\n        test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)\n        \n        flow_prev, sequence_prev = None, None\n        for test_id in range(len(test_dataset)):\n            image1, image2, (sequence, frame) = test_dataset[test_id]\n            if sequence != sequence_prev:\n                flow_prev = None\n            \n            padder = InputPadder(image1.shape)\n            image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())\n\n            flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)\n            flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()\n\n            if warm_start:\n                flow_prev = forward_interpolate(flow_low[0])[None].cuda()\n            \n            output_dir = os.path.join(output_path, dstype, sequence)\n            output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))\n\n            if not os.path.exists(output_dir):\n                os.makedirs(output_dir)\n\n            frame_utils.writeFlow(output_file, flow)\n            sequence_prev = sequence\n\n\n@torch.no_grad()\ndef create_kitti_submission(model, iters=24, output_path='kitti_submission'):\n    \"\"\" Create submission for the Sintel leaderboard \"\"\"\n    model.eval()\n    test_dataset = datasets.KITTI(split='testing', aug_params=None)\n\n    if not os.path.exists(output_path):\n        os.makedirs(output_path)\n\n    for test_id in range(len(test_dataset)):\n        image1, image2, (frame_id, ) = test_dataset[test_id]\n        padder = InputPadder(image1.shape, mode='kitti')\n        image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())\n\n        _, flow_pr = model(image1, image2, iters=iters, test_mode=True)\n        flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()\n\n        output_filename = os.path.join(output_path, frame_id)\n        frame_utils.writeFlowKITTI(output_filename, flow)\n\n\n@torch.no_grad()\ndef validate_chairs(model, iters=24):\n    \"\"\" Perform evaluation on the FlyingChairs (test) split \"\"\"\n    model.eval()\n    epe_list = []\n\n    val_dataset = datasets.FlyingChairs(split='validation')\n    for val_id in range(len(val_dataset)):\n        image1, image2, flow_gt, _ = val_dataset[val_id]\n        image1 = image1[None].cuda()\n        image2 = image2[None].cuda()\n\n        _, flow_pr = model(image1, image2, iters=iters, test_mode=True)\n        epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()\n        epe_list.append(epe.view(-1).numpy())\n\n    epe = np.mean(np.concatenate(epe_list))\n    print(\"Validation Chairs EPE: %f\" % epe)\n    return {'chairs': epe}\n\n\n@torch.no_grad()\ndef validate_sintel(model, iters=32):\n    \"\"\" Peform validation using the Sintel (train) split \"\"\"\n    model.eval()\n    results = {}\n    for dstype in ['clean', 'final']:\n        val_dataset = datasets.MpiSintel(split='training', dstype=dstype)\n        epe_list = []\n\n        for val_id in range(len(val_dataset)):\n            image1, image2, flow_gt, _ = val_dataset[val_id]\n            image1 = image1[None].cuda()\n            image2 = image2[None].cuda()\n\n            padder = InputPadder(image1.shape)\n            image1, image2 = padder.pad(image1, image2)\n\n            flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)\n            flow = padder.unpad(flow_pr[0]).cpu()\n\n            epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()\n            epe_list.append(epe.view(-1).numpy())\n\n        epe_all = np.concatenate(epe_list)\n        epe = np.mean(epe_all)\n        px1 = np.mean(epe_all<1)\n        px3 = np.mean(epe_all<3)\n        px5 = np.mean(epe_all<5)\n\n        print(\"Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f\" % (dstype, epe, px1, px3, px5))\n        results[dstype] = np.mean(epe_list)\n\n    return results\n\n\n@torch.no_grad()\ndef validate_kitti(model, iters=24):\n    \"\"\" Peform validation using the KITTI-2015 (train) split \"\"\"\n    model.eval()\n    val_dataset = datasets.KITTI(split='training')\n\n    out_list, epe_list = [], []\n    for val_id in range(len(val_dataset)):\n        image1, image2, flow_gt, valid_gt = val_dataset[val_id]\n        image1 = image1[None].cuda()\n        image2 = image2[None].cuda()\n\n        padder = InputPadder(image1.shape, mode='kitti')\n        image1, image2 = padder.pad(image1, image2)\n\n        flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)\n        flow = padder.unpad(flow_pr[0]).cpu()\n\n        epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()\n        mag = torch.sum(flow_gt**2, dim=0).sqrt()\n\n        epe = epe.view(-1)\n        mag = mag.view(-1)\n        val = valid_gt.view(-1) >= 0.5\n\n        out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()\n        epe_list.append(epe[val].mean().item())\n        out_list.append(out[val].cpu().numpy())\n\n    epe_list = np.array(epe_list)\n    out_list = np.concatenate(out_list)\n\n    epe = np.mean(epe_list)\n    f1 = 100 * np.mean(out_list)\n\n    print(\"Validation KITTI: %f, %f\" % (epe, f1))\n    return {'kitti-epe': epe, 'kitti-f1': f1}\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--model', help=\"restore checkpoint\")\n    parser.add_argument('--dataset', help=\"dataset for evaluation\")\n    parser.add_argument('--small', action='store_true', help='use small model')\n    parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')\n    parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')\n    args = parser.parse_args()\n\n    model = torch.nn.DataParallel(RAFT(args))\n    model.load_state_dict(torch.load(args.model))\n\n    model.cuda()\n    model.eval()\n\n    # create_sintel_submission(model.module, warm_start=True)\n    # create_kitti_submission(model.module)\n\n    with torch.no_grad():\n        if args.dataset == 'chairs':\n            validate_chairs(model.module)\n\n        elif args.dataset == 'sintel':\n            validate_sintel(model.module)\n\n        elif args.dataset == 'kitti':\n            validate_kitti(model.module)\n\n\n"
  },
  {
    "path": "third_party/RAFT/train.py",
    "content": "from __future__ import print_function, division\nimport sys\nsys.path.append('core')\n\nimport argparse\nimport os\nimport cv2\nimport time\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\n\nfrom torch.utils.data import DataLoader\nfrom raft import RAFT\nimport evaluate\nimport datasets\n\nfrom torch.utils.tensorboard import SummaryWriter\n\ntry:\n    from torch.cuda.amp import GradScaler\nexcept:\n    # dummy GradScaler for PyTorch < 1.6\n    class GradScaler:\n        def __init__(self):\n            pass\n        def scale(self, loss):\n            return loss\n        def unscale_(self, optimizer):\n            pass\n        def step(self, optimizer):\n            optimizer.step()\n        def update(self):\n            pass\n\n\n# exclude extremly large displacements\nMAX_FLOW = 400\nSUM_FREQ = 100\nVAL_FREQ = 5000\n\n\ndef sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):\n    \"\"\" Loss function defined over sequence of flow predictions \"\"\"\n\n    n_predictions = len(flow_preds)    \n    flow_loss = 0.0\n\n    # exlude invalid pixels and extremely large diplacements\n    mag = torch.sum(flow_gt**2, dim=1).sqrt()\n    valid = (valid >= 0.5) & (mag < max_flow)\n\n    for i in range(n_predictions):\n        i_weight = gamma**(n_predictions - i - 1)\n        i_loss = (flow_preds[i] - flow_gt).abs()\n        flow_loss += i_weight * (valid[:, None] * i_loss).mean()\n\n    epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()\n    epe = epe.view(-1)[valid.view(-1)]\n\n    metrics = {\n        'epe': epe.mean().item(),\n        '1px': (epe < 1).float().mean().item(),\n        '3px': (epe < 3).float().mean().item(),\n        '5px': (epe < 5).float().mean().item(),\n    }\n\n    return flow_loss, metrics\n\n\ndef count_parameters(model):\n    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n\n\ndef fetch_optimizer(args, model):\n    \"\"\" Create the optimizer and learning rate scheduler \"\"\"\n    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)\n\n    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,\n        pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')\n\n    return optimizer, scheduler\n    \n\nclass Logger:\n    def __init__(self, model, scheduler):\n        self.model = model\n        self.scheduler = scheduler\n        self.total_steps = 0\n        self.running_loss = {}\n        self.writer = None\n\n    def _print_training_status(self):\n        metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]\n        training_str = \"[{:6d}, {:10.7f}] \".format(self.total_steps+1, self.scheduler.get_last_lr()[0])\n        metrics_str = (\"{:10.4f}, \"*len(metrics_data)).format(*metrics_data)\n        \n        # print the training status\n        print(training_str + metrics_str)\n\n        if self.writer is None:\n            self.writer = SummaryWriter()\n\n        for k in self.running_loss:\n            self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps)\n            self.running_loss[k] = 0.0\n\n    def push(self, metrics):\n        self.total_steps += 1\n\n        for key in metrics:\n            if key not in self.running_loss:\n                self.running_loss[key] = 0.0\n\n            self.running_loss[key] += metrics[key]\n\n        if self.total_steps % SUM_FREQ == SUM_FREQ-1:\n            self._print_training_status()\n            self.running_loss = {}\n\n    def write_dict(self, results):\n        if self.writer is None:\n            self.writer = SummaryWriter()\n\n        for key in results:\n            self.writer.add_scalar(key, results[key], self.total_steps)\n\n    def close(self):\n        self.writer.close()\n\n\ndef train(args):\n\n    model = nn.DataParallel(RAFT(args), device_ids=args.gpus)\n    print(\"Parameter Count: %d\" % count_parameters(model))\n\n    if args.restore_ckpt is not None:\n        model.load_state_dict(torch.load(args.restore_ckpt), strict=False)\n\n    model.cuda()\n    model.train()\n\n    if args.stage != 'chairs':\n        model.module.freeze_bn()\n\n    train_loader = datasets.fetch_dataloader(args)\n    optimizer, scheduler = fetch_optimizer(args, model)\n\n    total_steps = 0\n    scaler = GradScaler(enabled=args.mixed_precision)\n    logger = Logger(model, scheduler)\n\n    VAL_FREQ = 5000\n    add_noise = True\n\n    should_keep_training = True\n    while should_keep_training:\n\n        for i_batch, data_blob in enumerate(train_loader):\n            optimizer.zero_grad()\n            image1, image2, flow, valid = [x.cuda() for x in data_blob]\n\n            if args.add_noise:\n                stdv = np.random.uniform(0.0, 5.0)\n                image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)\n                image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)\n\n            flow_predictions = model(image1, image2, iters=args.iters)            \n\n            loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)\n            scaler.scale(loss).backward()\n            scaler.unscale_(optimizer)                \n            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)\n            \n            scaler.step(optimizer)\n            scheduler.step()\n            scaler.update()\n\n            logger.push(metrics)\n\n            if total_steps % VAL_FREQ == VAL_FREQ - 1:\n                PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)\n                torch.save(model.state_dict(), PATH)\n\n                results = {}\n                for val_dataset in args.validation:\n                    if val_dataset == 'chairs':\n                        results.update(evaluate.validate_chairs(model.module))\n                    elif val_dataset == 'sintel':\n                        results.update(evaluate.validate_sintel(model.module))\n                    elif val_dataset == 'kitti':\n                        results.update(evaluate.validate_kitti(model.module))\n\n                logger.write_dict(results)\n                \n                model.train()\n                if args.stage != 'chairs':\n                    model.module.freeze_bn()\n            \n            total_steps += 1\n\n            if total_steps > args.num_steps:\n                should_keep_training = False\n                break\n\n    logger.close()\n    PATH = 'checkpoints/%s.pth' % args.name\n    torch.save(model.state_dict(), PATH)\n\n    return PATH\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--name', default='raft', help=\"name your experiment\")\n    parser.add_argument('--stage', help=\"determines which dataset to use for training\") \n    parser.add_argument('--restore_ckpt', help=\"restore checkpoint\")\n    parser.add_argument('--small', action='store_true', help='use small model')\n    parser.add_argument('--validation', type=str, nargs='+')\n\n    parser.add_argument('--lr', type=float, default=0.00002)\n    parser.add_argument('--num_steps', type=int, default=100000)\n    parser.add_argument('--batch_size', type=int, default=6)\n    parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])\n    parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])\n    parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')\n\n    parser.add_argument('--iters', type=int, default=12)\n    parser.add_argument('--wdecay', type=float, default=.00005)\n    parser.add_argument('--epsilon', type=float, default=1e-8)\n    parser.add_argument('--clip', type=float, default=1.0)\n    parser.add_argument('--dropout', type=float, default=0.0)\n    parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')\n    parser.add_argument('--add_noise', action='store_true')\n    args = parser.parse_args()\n\n    torch.manual_seed(1234)\n    np.random.seed(1234)\n\n    if not os.path.isdir('checkpoints'):\n        os.mkdir('checkpoints')\n\n    train(args)"
  },
  {
    "path": "third_party/RAFT/train_mixed.sh",
    "content": "#!/bin/bash\nmkdir -p checkpoints\npython -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision \npython -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision\npython -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision\npython -u train.py --name raft-kitti  --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision\n"
  },
  {
    "path": "third_party/RAFT/train_standard.sh",
    "content": "#!/bin/bash\nmkdir -p checkpoints\npython -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001\npython -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001\npython -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85\npython -u train.py --name raft-kitti  --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85\n"
  },
  {
    "path": "train_stereoanyvideo.py",
    "content": "import argparse\nimport logging\nfrom pathlib import Path\nfrom tqdm import tqdm\nimport os\nimport cv2\nimport numpy as np\nimport torch\nimport torch.optim as optim\nimport torch.nn.functional as F\nfrom munch import DefaultMunch\nimport json\nfrom pytorch_lightning.lite import LightningLite\nfrom torch.cuda.amp import GradScaler\n\nfrom stereoanyvideo.train_utils.utils import (\n    run_test_eval,\n    save_ims_to_tb,\n    count_parameters,\n)\nfrom stereoanyvideo.train_utils.logger import Logger\n\nfrom stereoanyvideo.evaluation.core.evaluator import Evaluator\nfrom stereoanyvideo.train_utils.losses import sequence_loss, temporal_loss\nimport stereoanyvideo.datasets.video_datasets as datasets\nfrom stereoanyvideo.models.core.stereoanyvideo import StereoAnyVideo\n\n\ndef fetch_optimizer(args, model):\n    \"\"\"Create the optimizer and learning rate scheduler\"\"\"\n    for name, param in model.named_parameters():\n        if any([key in name for key in ['depthanything']]):\n            param.requires_grad_(False)\n    optimizer = optim.AdamW(\n        model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8\n    )\n    scheduler = optim.lr_scheduler.OneCycleLR(\n        optimizer,\n        args.lr,\n        args.num_steps + 100,\n        pct_start=0.01,\n        cycle_momentum=False,\n        anneal_strategy=\"linear\",\n    )\n\n    return optimizer, scheduler\n\n\ndef forward_batch(batch, model, args):\n    output = {}\n    disparities = model(\n        batch[\"img\"][:, :, 0],\n        batch[\"img\"][:, :, 1],\n        iters=args.train_iters,\n        test_mode=False,\n    )\n    num_traj = len(batch[\"disp\"][0])\n    for i in range(num_traj):\n        seq_loss, metrics = sequence_loss(\n            disparities[:, i], -batch[\"disp\"][:, i, 0], batch[\"valid_disp\"][:, i, 0])\n        output[f\"disp_{i}\"] = {\"loss\": seq_loss / num_traj, \"metrics\": metrics}\n\n    output[\"disparity\"] = {\n        \"predictions\": torch.cat(\n            [disparities[-1, i] for i in range(num_traj)], dim=1).detach(),\n    }\n    return output\n\nclass Lite(LightningLite):\n    def run(self, args):\n        self.seed_everything(0)\n\n        evaluator = Evaluator()\n\n        eval_vis_cfg = {\n            \"visualize_interval\": 0,  # Use 0 for no visualization\n            \"exp_dir\": args.ckpt_path,\n        }\n        eval_vis_cfg = DefaultMunch.fromDict(eval_vis_cfg, object())\n        evaluator.setup_visualization(eval_vis_cfg)\n\n        model = StereoAnyVideo()\n        model.cuda()\n\n        with open(args.ckpt_path + \"/meta.json\", \"w\") as file:\n            json.dump(vars(args), file, sort_keys=True, indent=4)\n\n        train_loader = datasets.fetch_dataloader(args)\n        train_loader = self.setup_dataloaders(train_loader, move_to_device=False)\n\n        logging.info(f\"Train loader size:  {len(train_loader)}\")\n\n        optimizer, scheduler = fetch_optimizer(args, model)\n        print(\"Parameter Count:\", {count_parameters(model)})\n        logging.info(f\"Parameter Count:  {count_parameters(model)}\")\n        total_steps = 0\n        logger = Logger(model, scheduler, args.ckpt_path)\n\n        folder_ckpts = [\n            f\n            for f in os.listdir(args.ckpt_path)\n            if not os.path.isdir(f) and f.endswith(\".pth\") and not \"final\" in f\n        ]\n        if len(folder_ckpts) > 0:\n            ckpt_path = sorted(folder_ckpts)[-1]\n            ckpt = self.load(os.path.join(args.ckpt_path, ckpt_path))\n            logging.info(f\"Loading checkpoint {ckpt_path}\")\n            if \"model\" in ckpt:\n                model.load_state_dict(ckpt[\"model\"])\n            else:\n                model.load_state_dict(ckpt)\n            if \"optimizer\" in ckpt:\n                logging.info(\"Load optimizer\")\n                optimizer.load_state_dict(ckpt[\"optimizer\"])\n            if \"scheduler\" in ckpt:\n                logging.info(\"Load scheduler\")\n                scheduler.load_state_dict(ckpt[\"scheduler\"])\n            if \"total_steps\" in ckpt:\n                total_steps = ckpt[\"total_steps\"]\n                logging.info(f\"Load total_steps {total_steps}\")\n\n        elif args.restore_ckpt is not None:\n            assert args.restore_ckpt.endswith(\".pth\") or args.restore_ckpt.endswith(\n                \".pt\"\n            )\n            logging.info(\"Loading checkpoint...\")\n            strict = True\n\n            state_dict = self.load(args.restore_ckpt)\n            if \"model\" in state_dict:\n                state_dict = state_dict[\"model\"]\n            if list(state_dict.keys())[0].startswith(\"module.\"):\n                state_dict = {\n                    k.replace(\"module.\", \"\"): v for k, v in state_dict.items()\n                }\n            model.load_state_dict(state_dict, strict=strict)\n\n            logging.info(f\"Done loading checkpoint\")\n        model, optimizer = self.setup(model, optimizer, move_to_device=False)\n        model.cuda()\n        model.train()\n        model.module.module.freeze_bn()  # We keep BatchNorm frozen\n\n        scaler = GradScaler(enabled=args.mixed_precision)\n\n        should_keep_training = True\n        global_batch_num = 0\n        epoch = -1\n        while should_keep_training:\n            epoch += 1\n\n            for i_batch, batch in enumerate(tqdm(train_loader)):\n                optimizer.zero_grad()\n                if batch is None:\n                    print(\"batch is None\")\n                    continue\n                for k, v in batch.items():\n                    batch[k] = v.cuda()\n\n                assert model.training\n\n                output = forward_batch(batch, model, args)\n\n                loss = 0\n                logger.update()\n                for k, v in output.items():\n                    if \"loss\" in v:\n                        loss += v[\"loss\"]\n                        logger.writer.add_scalar(\n                            f\"live_{k}_loss\", v[\"loss\"].item(), total_steps\n                        )\n                    if \"metrics\" in v:\n                        logger.push(v[\"metrics\"], k)\n\n                if self.global_rank == 0:\n                    if len(output) > 1:\n                        logger.writer.add_scalar(\n                            f\"live_total_loss\", loss.item(), total_steps\n                        )\n                    logger.writer.add_scalar(\n                        f\"learning_rate\", optimizer.param_groups[0][\"lr\"], total_steps\n                    )\n                    global_batch_num += 1\n                self.barrier()\n                self.backward(scaler.scale(loss))\n                scaler.unscale_(optimizer)\n                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n\n                scaler.step(optimizer)\n                if total_steps < args.num_steps:\n                    scheduler.step()\n                scaler.update()\n                total_steps += 1\n\n                if self.global_rank == 0:\n\n                    if (total_steps % args.save_steps == 0) or (total_steps == 1 and args.validate_at_start):\n                        ckpt_iter = \"0\" * (6 - len(str(total_steps))) + str(total_steps)\n                        save_path = Path(\n                            f\"{args.ckpt_path}/model_{args.name}_{ckpt_iter}.pth\"\n                        )\n\n                        save_dict = {\n                            \"model\": model.module.module.state_dict(),\n                            \"optimizer\": optimizer.state_dict(),\n                            \"scheduler\": scheduler.state_dict(),\n                            \"total_steps\": total_steps,\n                        }\n\n                        logging.info(f\"Saving file {save_path}\")\n                        self.save(save_dict, save_path)\n\n                self.barrier()\n\n                if total_steps > args.num_steps:\n                    should_keep_training = False\n                    break\n\n        logger.close()\n        PATH = f\"{args.ckpt_path}/{args.name}_final.pth\"\n        torch.save(model.module.module.state_dict(), PATH)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--name\", default=\"StereoAnyVideo\", help=\"name your experiment\")\n    parser.add_argument(\"--restore_ckpt\", help=\"restore checkpoint\")\n    parser.add_argument(\"--ckpt_path\", help=\"path to save checkpoints\")\n    parser.add_argument(\n        \"--mixed_precision\", action=\"store_true\", help=\"use mixed precision\"\n    )\n\n    # Training parameters\n    parser.add_argument(\n        \"--batch_size\", type=int, default=8, help=\"batch size used during training.\"\n    )\n    parser.add_argument(\n        \"--train_datasets\",\n        nargs=\"+\",\n        default=[\"things\", \"monkaa\", \"driving\"],\n        help=\"training datasets.\",\n    )\n    parser.add_argument(\"--lr\", type=float, default=0.0001, help=\"max learning rate.\")\n\n    parser.add_argument(\n        \"--num_steps\", type=int, default=80000, help=\"length of training schedule.\"\n    )\n    parser.add_argument(\n        \"--save_steps\", type=int, default=3000, help=\"length of training schedule.\"\n    )\n    parser.add_argument(\n        \"--image_size\",\n        type=int,\n        nargs=\"+\",\n        default=[320, 720],\n        help=\"size of the random image crops used during training.\",\n    )\n    parser.add_argument(\n        \"--train_iters\",\n        type=int,\n        default=12,\n        help=\"number of updates to the disparity field in each forward pass.\",\n    )\n    parser.add_argument(\n        \"--wdecay\", type=float, default=0.00001, help=\"Weight decay in optimizer.\"\n    )\n\n    parser.add_argument(\n        \"--sample_len\", type=int, default=5, help=\"length of training video samples\"\n    )\n    parser.add_argument(\n        \"--validate_at_start\", action=\"store_true\", help=\"validate the model at start\"\n    )\n    parser.add_argument(\n        \"--evaluate_every_n_epoch\",\n        type=int,\n        default=1,\n        help=\"evaluate every n epoch\",\n    )\n\n    parser.add_argument(\n        \"--num_workers\", type=int, default=6, help=\"number of dataloader workers.\"\n    )\n    # Validation parameters\n    parser.add_argument(\n        \"--valid_iters\",\n        type=int,\n        default=32,\n        help=\"number of updates to the disparity field in each forward pass during validation.\",\n    )\n    # Data augmentation\n    parser.add_argument(\n        \"--img_gamma\", type=float, nargs=\"+\", default=None, help=\"gamma range\"\n    )\n    parser.add_argument(\n        \"--saturation_range\",\n        type=float,\n        nargs=\"+\",\n        default=None,\n        help=\"color saturation\",\n    )\n    parser.add_argument(\n        \"--do_flip\",\n        default=False,\n        choices=[\"h\", \"v\"],\n        help=\"flip the images horizontally or vertically\",\n    )\n    parser.add_argument(\n        \"--spatial_scale\",\n        type=float,\n        nargs=\"+\",\n        default=[0, 0],\n        help=\"re-scale the images randomly\",\n    )\n    parser.add_argument(\n        \"--noyjitter\",\n        action=\"store_true\",\n        help=\"don't simulate imperfect rectification\",\n    )\n    args = parser.parse_args()\n\n    Path(args.ckpt_path).mkdir(exist_ok=True, parents=True)\n\n    logging.basicConfig(\n        level=logging.INFO,\n        filename=args.ckpt_path + '/' + args.name + '.log',\n        filemode='a',\n        format=\"%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s\",\n    )\n\n    from pytorch_lightning.strategies import DDPStrategy\n\n    Lite(\n        strategy=DDPStrategy(find_unused_parameters=True),\n        devices=\"auto\",\n        accelerator=\"gpu\",\n        precision=32,\n    ).run(args)\n"
  },
  {
    "path": "train_stereoanyvideo.sh",
    "content": "#!/bin/bash\n\nexport PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH\n\npython train_stereoanyvideo.py --batch_size 1 \\\n --spatial_scale -0.2 0.4 --image_size 256 512 --saturation_range 0 1.4 --num_steps 80000  \\\n --ckpt_path logging/StereoAnyVideo_SF \\\n --sample_len 5 --train_iters 10 --lr 0.0001 \\\n --num_workers 8 --save_steps 3000 --train_datasets things monkaa driving\n"
  },
  {
    "path": "train_utils/logger.py",
    "content": "import logging\nimport os\n\nfrom torch.utils.tensorboard import SummaryWriter\n\n\nclass Logger:\n\n    SUM_FREQ = 100\n\n    def __init__(self, model, scheduler, ckpt_path):\n        self.model = model\n        self.scheduler = scheduler\n        self.total_steps = 0\n        self.running_loss = {}\n        self.ckpt_path = ckpt_path\n        self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, \"runs\"))\n\n        logging.info(\n            f\"Training Metrics: 1px_disp_0...5, 3px_disp_0...5, 5px_disp_0...5, epe_disp_0...5\"\n        )\n    def _print_training_status(self):\n        metrics_data = [\n            self.running_loss[k] / Logger.SUM_FREQ\n            for k in sorted(self.running_loss.keys())\n        ]\n        training_str = \"[{:6d}] \".format(self.total_steps + 1)\n        metrics_str = (\"{:10.4f}, \" * len(metrics_data)).format(*metrics_data)\n\n        # print the training status\n        logging.info(\n            f\"Training Metrics ({self.total_steps}): {training_str + metrics_str}\"\n        )\n\n        if self.writer is None:\n            self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, \"runs\"))\n        for k in self.running_loss:\n            self.writer.add_scalar(\n                k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps\n            )\n            self.running_loss[k] = 0.0\n\n    def push(self, metrics, task):\n        for key in metrics:\n            task_key = str(key) + \"_\" + task\n            if task_key not in self.running_loss:\n                self.running_loss[task_key] = 0.0\n            self.running_loss[task_key] += metrics[key]\n\n    def update(self):\n        self.total_steps += 1\n        if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1:\n            print(self.running_loss)\n            self._print_training_status()\n            self.running_loss = {}\n\n    def write_dict(self, results):\n        if self.writer is None:\n            self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, \"runs\"))\n\n        for key in results:\n            self.writer.add_scalar(key, results[key], self.total_steps)\n\n    def close(self):\n        self.writer.close()\n"
  },
  {
    "path": "train_utils/losses.py",
    "content": "import torch\nfrom einops import rearrange\nimport torch.nn.functional as F\n\n\ndef sequence_loss(flow_preds, flow_gt, valid, loss_gamma=0.9, max_flow=700):\n    \"\"\"Loss function defined over sequence of flow predictions\"\"\"\n    n_predictions = len(flow_preds)\n    assert n_predictions >= 1\n    flow_loss = 0.0\n    # exlude invalid pixels and extremely large diplacements\n    mag = torch.sum(flow_gt ** 2, dim=1).sqrt().unsqueeze(1)\n\n    if len(valid.shape) != len(flow_gt.shape):\n        valid = valid.unsqueeze(1)\n\n    valid = (valid >= 0.5) & (mag < max_flow)\n\n    if valid.shape != flow_gt.shape:\n        valid = torch.cat([valid, valid], dim=1)\n    assert valid.shape == flow_gt.shape, [valid.shape, flow_gt.shape]\n    assert not torch.isinf(flow_gt[valid.bool()]).any()\n\n    for i in range(n_predictions):\n        assert (\n            not torch.isnan(flow_preds[i]).any()\n            and not torch.isinf(flow_preds[i]).any()\n        )\n        if n_predictions == 1:\n            i_weight = 1\n        else:\n            # We adjust the loss_gamma so it is consistent for any number of iterations\n            adjusted_loss_gamma = loss_gamma ** (15 / (n_predictions - 1))\n            i_weight = adjusted_loss_gamma ** (n_predictions - i - 1)\n\n        flow_pred = flow_preds[i].clone()\n        if valid.shape[1] == 1 and flow_preds[i].shape[1] == 2:\n            flow_pred = flow_pred[:, :1]\n\n        i_loss = (flow_pred - flow_gt).abs()\n\n        assert i_loss.shape == valid.shape, [\n            i_loss.shape,\n            valid.shape,\n            flow_gt.shape,\n            flow_pred.shape,\n        ]\n        flow_loss += i_weight * i_loss[valid.bool()].mean()\n\n    epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt()\n\n    valid = valid[:, 0]\n    epe = epe.view(-1)\n    epe = epe[valid.reshape(epe.shape)]\n\n    metrics = {\n        \"epe\": epe.mean().item(),\n        \"1px\": (epe < 1).float().mean().item(),\n        \"3px\": (epe < 3).float().mean().item(),\n        \"5px\": (epe < 5).float().mean().item(),\n    }\n    return flow_loss, metrics\n\n\ndef temporal_loss(flow_preds, flow_preds2, flow_gt, flow_gt2, valid, loss_gamma=0.9, max_flow=700):\n    assert len(flow_preds) == len(flow_preds2)\n    n_predictions = len(flow_preds)\n    assert n_predictions >= 1\n    flow_loss = 0.0\n    # exlude invalid pixels and extremely large diplacements\n    mag = torch.sum(flow_gt ** 2, dim=1).sqrt().unsqueeze(1)\n\n    if len(valid.shape) != len(flow_gt.shape):\n        valid = valid.unsqueeze(1)\n\n    valid = (valid >= 0.5) & (mag < max_flow)\n    if valid.shape != flow_gt.shape:\n        valid = torch.cat([valid, valid], dim=1)\n    assert valid.shape == flow_gt.shape, [valid.shape, flow_gt.shape]\n    assert not torch.isinf(flow_gt[valid.bool()]).any()\n\n    for i in range(n_predictions):\n        assert (\n            not torch.isnan(flow_preds[i]).any()\n            and not torch.isinf(flow_preds[i]).any()\n        )\n        assert (\n                not torch.isnan(flow_preds2[i]).any()\n                and not torch.isinf(flow_preds2[i]).any()\n        )\n        if n_predictions == 1:\n            i_weight = 1\n        else:\n            # We adjust the loss_gamma so it is consistent for any number of iterations\n            adjusted_loss_gamma = loss_gamma ** (15 / (n_predictions - 1))\n            i_weight = adjusted_loss_gamma ** (n_predictions - i - 1)\n\n        flow_pred = flow_preds[i].clone()\n        flow_pred2 = flow_preds2[i].clone()\n        if valid.shape[1] == 1 and flow_preds[i].shape[1] == 2:\n            flow_pred = flow_pred[:, :1]\n            flow_pred2 = flow_pred2[:, :1]\n        i_loss = ((flow_pred2 - flow_pred).abs() - (flow_gt2 - flow_gt).abs()).abs()\n\n        assert i_loss.shape == valid.shape, [\n            i_loss.shape,\n            valid.shape,\n            flow_gt.shape,\n            flow_pred.shape,\n        ]\n        flow_loss += i_weight * i_loss[valid.bool()].mean()\n\n    tepe = torch.sum(((flow_preds2[-1] - flow_preds[-1]) -  - (flow_gt2 - flow_gt)) ** 2, dim=1).sqrt()\n\n    mask = (flow_gt2 - flow_gt) < 5\n    valid = mask * valid\n    valid = valid[:, 0]\n    tepe = tepe.view(-1)\n    tepe = tepe[valid.reshape(tepe.shape)]\n\n    metrics = {\n        \"tepe\": tepe.mean().item(),\n    }\n    return flow_loss, metrics\n\ndef compute_flow(Flow_Model, seq):\n    n, t, c, h, w = seq.size()\n    flows_forward = []\n    flows_backward = []\n    for i in range(t-1):\n        # i-th flow_backward denotes seq[i+1] towards seq[i]\n        flow_backward = Flow_Model.forward_fullres(seq[:,i], seq[:,i+1])\n        # i-th flow_forward denotes seq[i] towards seq[i+1]\n        flow_forward = Flow_Model.forward_fullres(seq[:,i+1], seq[:,i])\n        flows_backward.append(flow_backward)\n        flows_forward.append(flow_forward)\n    flows_forward = torch.stack(flows_forward, dim=1)\n    flows_backward = torch.stack(flows_backward, dim=1)\n\n    return flows_forward, flows_backward\n\n\ndef flow_warp(x, flow):\n    if flow.size(3) != 2:  # [B, H, W, 2]\n        flow = flow.permute(0, 2, 3, 1)\n    if x.size()[-2:] != flow.size()[1:3]:\n        raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '\n                         f'flow ({flow.size()[1:3]}) are not the same.')\n    _, _, h, w = x.size()\n    # create mesh grid\n    grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))\n    grid = torch.stack((grid_x, grid_y), 2).type_as(x)  # (h, w, 2)\n    grid.requires_grad = False\n\n    grid_flow = grid + flow\n    # scale grid_flow to [-1,1]\n    grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0\n    grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0\n    grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)\n    output = F.grid_sample(\n        x,\n        grid_flow,\n        mode='bilinear',\n        padding_mode='zeros',\n        align_corners=True)\n    return output\n\n\ndef bidirectional_alignment(seq, flows_backward, flows_forward):\n    b, T, *_ = seq.shape\n\n    # seq_backward = seq[:, 1:, ...]\n    # seq_forward = seq[:, :T - 1, ...]\n    # seq_backward = rearrange(seq_backward, \"b t c h w -> (b t) c h w\")\n    # seq_forward = rearrange(seq_forward, \"b t c h w -> (b t) c h w\")\n    # flows_forward = rearrange(flows_forward, \"b t c h w -> (b t) c h w\")\n    # flows_backward = rearrange(flows_backward, \"b t c h w -> (b t) c h w\")\n    # seq_backward = flow_warp(seq_backward, flows_backward)\n    # seq_forward = flow_warp(seq_forward, flows_forward)\n    # seq_backward = rearrange(seq_backward, \"(b t) c h w -> b t c h w\", b=b, t=T - 1)\n    # seq_forward = rearrange(seq_forward, \"(b t) c h w -> b t c h w\", b=b, t=T - 1)\n    # output_backward = torch.cat((seq_backward, seq[:, -1:]), dim=1)\n    # output_forward = torch.cat((seq[:, :1], seq_forward), dim=1)\n\n    output_backward = []\n    for i in range(1, T):\n        feat_prop = flow_warp(seq[:, i], flows_backward[:, i-1])\n        output_backward.append(feat_prop)\n    output_backward.append(seq[:, T - 1])\n    output_backward = torch.stack(output_backward, dim=1)\n\n    # forward-time process\n    output_forward = [seq[:, 0]]\n    for i in range(T - 1):\n        feat_prop = flow_warp(seq[:, i], flows_forward[:, i])\n        output_forward.append(feat_prop)\n    output_forward = torch.stack(output_forward, dim=1)\n\n    return output_backward, output_forward\n\n\ndef consistency_loss(seq, disparities, Flow_Model, alpha=50):\n    b, T, *_ = seq.shape\n    # compute optical flow\n    flows_forward, flows_backward = compute_flow(Flow_Model, seq)\n\n    seq_backward, seq_forward = bidirectional_alignment(seq, flows_backward, flows_forward)\n    disparities_backward, disparities_forward = bidirectional_alignment(disparities, flows_backward, flows_forward)\n\n    diff_disparities_back = torch.abs(disparities_backward - disparities)\n    diff_disparities_for = torch.abs(disparities_forward - disparities)\n    diff_seq_back = (seq_backward - seq) ** 2\n    diff_seq_for = (seq_forward - seq) ** 2\n\n    mask_seq_back = torch.exp(-(alpha * diff_seq_back))\n    mask_seq_for = torch.exp(-(alpha * diff_seq_for))\n    mask_seq_back = torch.sum(mask_seq_back, dim=2, keepdim=True)\n    mask_seq_for = torch.sum(mask_seq_for, dim=2, keepdim=True)\n    temporal_loss_back = torch.mul(mask_seq_back, diff_disparities_back)\n    temporal_loss_for = torch.mul(mask_seq_for, diff_disparities_for)\n    temporal_loss = torch.mean(temporal_loss_back) + torch.mean(temporal_loss_for)\n\n    return temporal_loss\n"
  },
  {
    "path": "train_utils/utils.py",
    "content": "import numpy as np\nimport os\nimport torch\n\nimport json\nimport flow_vis\nimport matplotlib.pyplot as plt\n\nimport stereoanyvideo.datasets.video_datasets as datasets\nfrom stereoanyvideo.evaluation.utils.utils import aggregate_and_print_results\n\n\ndef count_parameters(model):\n    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n\n\ndef run_test_eval(ckpt_path, eval_type, evaluator, model, dataloaders, writer, step):\n    for ds_name, dataloader in dataloaders:\n        # evaluator.visualize_interval = 1 if not \"sintel\" in ds_name else 0\n\n        evaluate_result = evaluator.evaluate_sequence(\n            model=model.module.module,\n            test_dataloader=dataloader,\n            writer=writer if not \"sintel\" in ds_name else None,\n            step=step,\n            train_mode=True,\n        )\n\n        aggregate_result = aggregate_and_print_results(\n            evaluate_result,\n        )\n\n        save_metrics = [\n            \"flow_mean_accuracy_5px\",\n            \"flow_mean_accuracy_3px\",\n            \"flow_mean_accuracy_1px\",\n            \"flow_epe_traj_mean\",\n        ]\n        for epe_name in (\"epe\", \"temp_epe\", \"temp_epe_r\"):\n            for m in [\n                f\"disp_{epe_name}_bad_0.5px\",\n                f\"disp_{epe_name}_bad_1px\",\n                f\"disp_{epe_name}_bad_2px\",\n                f\"disp_{epe_name}_bad_3px\",\n                f\"disp_{epe_name}_mean\",\n            ]:\n                save_metrics.append(m)\n\n        for k, v in aggregate_result.items():\n            if k in save_metrics:\n                writer.add_scalars(\n                    f\"{ds_name}_{k.rsplit('_', 1)[0]}\",\n                    {f\"{ds_name}_{k}\": v},\n                    step,\n                )\n\n        result_file = os.path.join(\n            ckpt_path,\n            f\"result_{ds_name}_{eval_type}_{step}_mimo.json\",\n        )\n        print(f\"Dumping {eval_type} results to {result_file}.\")\n        with open(result_file, \"w\") as f:\n            json.dump(aggregate_result, f)\n\n\ndef fig2data(fig):\n    \"\"\"\n    fig = plt.figure()\n    image = fig2data(fig)\n    @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it\n    @param fig a matplotlib figure\n    @return a numpy 3D array of RGBA values\n    \"\"\"\n    import PIL.Image as Image\n\n    # draw the renderer\n    fig.canvas.draw()\n\n    # Get the RGBA buffer from the figure\n    w, h = fig.canvas.get_width_height()\n    buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n    buf.shape = (w, h, 3)\n\n    image = Image.frombytes(\"RGB\", (w, h), buf.tobytes())\n    image = np.asarray(image)\n    return image\n\n\ndef save_ims_to_tb(writer, batch, output, total_steps):\n    writer.add_image(\n        \"train_im\",\n        torch.cat([torch.cat([im[0], im[1]], dim=-1) for im in batch[\"img\"][0]], dim=-2)\n        / 255.0,\n        total_steps,\n        dataformats=\"CHW\",\n    )\n    if \"disp\" in batch and len(batch[\"disp\"]) > 0:\n        disp_im = [\n            (torch.cat([im[0], im[1]], dim=-1) * torch.cat([val[0], val[1]], dim=-1))\n            for im, val in zip(batch[\"disp\"][0], batch[\"valid_disp\"][0])\n        ]\n\n        disp_im = torch.cat(disp_im, dim=1)\n\n        figure = plt.figure()\n        plt.imshow(disp_im.cpu()[0])\n        disp_im = fig2data(figure).copy()\n\n        writer.add_image(\n            \"train_disp\",\n            disp_im,\n            total_steps,\n            dataformats=\"HWC\",\n        )\n\n    for k, v in output.items():\n        if \"predictions\" in v:\n            pred = v[\"predictions\"]\n            if k == \"disparity\":\n                figure = plt.figure()\n                plt.imshow(pred.cpu()[0])\n                pred = fig2data(figure).copy()\n                dataformat = \"HWC\"\n            else:\n                pred = torch.tensor(\n                    flow_vis.flow_to_color(\n                        pred.permute(1, 2, 0).cpu().numpy(), convert_to_bgr=False\n                    )\n                    / 255.0\n                )\n                dataformat = \"HWC\"\n            writer.add_image(\n                f\"pred_{k}\",\n                pred,\n                total_steps,\n                dataformats=dataformat,\n            )\n        if \"gt\" in v:\n            gt = v[\"gt\"]\n            gt = torch.tensor(\n                flow_vis.flow_to_color(\n                    gt.permute(1, 2, 0).cpu().numpy(), convert_to_bgr=False\n                )\n                / 255.0\n            )\n            dataformat = \"HWC\"\n            writer.add_image(\n                f\"gt_{k}\",\n                gt,\n                total_steps,\n                dataformats=dataformat,\n            )\n"
  }
]