Full Code of TomTomTommi/stereoanyvideo for AI

main 3a8beddc470c cached
83 files
513.1 KB
168.4k tokens
534 symbols
1 requests
Download .txt
Showing preview only (541K chars total). Download the full file or copy to clipboard to get everything.
Repository: TomTomTommi/stereoanyvideo
Branch: main
Commit: 3a8beddc470c
Files: 83
Total size: 513.1 KB

Directory structure:
gitextract_751vu01n/

├── LICENSE
├── README.md
├── assets/
│   └── 1
├── checkpoints/
│   └── checkpoints here.txt
├── data/
│   └── datasets/
│       └── dataset here.txt
├── datasets/
│   ├── augmentor.py
│   ├── frame_utils.py
│   └── video_datasets.py
├── demo.py
├── demo.sh
├── evaluate_stereoanyvideo.sh
├── evaluation/
│   ├── configs/
│   │   ├── eval_dynamic_replica.yaml
│   │   ├── eval_infinigensv.yaml
│   │   ├── eval_kittidepth.yaml
│   │   ├── eval_sintel_clean.yaml
│   │   ├── eval_sintel_final.yaml
│   │   ├── eval_southkensington.yaml
│   │   └── eval_vkitti2.yaml
│   ├── core/
│   │   └── evaluator.py
│   ├── evaluate.py
│   └── utils/
│       ├── eval_utils.py
│       ├── ssim.py
│       └── utils.py
├── models/
│   ├── Video-Depth-Anything/
│   │   ├── app.py
│   │   ├── get_weights.sh
│   │   ├── run.py
│   │   ├── utils/
│   │   │   ├── dc_utils.py
│   │   │   └── util.py
│   │   └── video_depth_anything/
│   │       ├── dinov2.py
│   │       ├── dinov2_layers/
│   │       │   ├── __init__.py
│   │       │   ├── attention.py
│   │       │   ├── block.py
│   │       │   ├── drop_path.py
│   │       │   ├── layer_scale.py
│   │       │   ├── mlp.py
│   │       │   ├── patch_embed.py
│   │       │   └── swiglu_ffn.py
│   │       ├── dpt.py
│   │       ├── dpt_temporal.py
│   │       ├── motion_module/
│   │       │   ├── attention.py
│   │       │   └── motion_module.py
│   │       ├── util/
│   │       │   ├── blocks.py
│   │       │   └── transform.py
│   │       └── video_depth.py
│   ├── core/
│   │   ├── attention.py
│   │   ├── corr.py
│   │   ├── extractor.py
│   │   ├── model_zoo.py
│   │   ├── stereoanyvideo.py
│   │   ├── update.py
│   │   └── utils/
│   │       ├── config.py
│   │       └── utils.py
│   ├── raft_model.py
│   └── stereoanyvideo_model.py
├── requirements.txt
├── third_party/
│   └── RAFT/
│       ├── LICENSE
│       ├── README.md
│       ├── alt_cuda_corr/
│       │   ├── correlation.cpp
│       │   ├── correlation_kernel.cu
│       │   └── setup.py
│       ├── chairs_split.txt
│       ├── core/
│       │   ├── __init__.py
│       │   ├── corr.py
│       │   ├── datasets.py
│       │   ├── extractor.py
│       │   ├── raft.py
│       │   ├── update.py
│       │   └── utils/
│       │       ├── __init__.py
│       │       ├── augmentor.py
│       │       ├── flow_viz.py
│       │       ├── frame_utils.py
│       │       └── utils.py
│       ├── demo.py
│       ├── download_models.sh
│       ├── evaluate.py
│       ├── train.py
│       ├── train_mixed.sh
│       └── train_standard.sh
├── train_stereoanyvideo.py
├── train_stereoanyvideo.sh
└── train_utils/
    ├── logger.py
    ├── losses.py
    └── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
<h1 align='center' style="text-align:center; font-weight:bold; font-size:2.0em;letter-spacing:2.0px;">
Stereo Any Video: <br> Temporally Consistent Stereo Matching<h1>      

<div align="center">
  <a href="https://arxiv.org/abs/2503.05549" target="_blank" rel="external nofollow noopener">
  <img src="https://img.shields.io/badge/Paper-arXiv-deepgreen" alt="Paper arXiv"></a>
  <a href="https://tomtomtommi.github.io/StereoAnyVideo/" target="_blank" rel="external nofollow noopener">
  <img src="https://img.shields.io/badge/Project-Page-9cf" alt="Project Page"></a>
</div>
</p>

![Demo](./assets/stereoanyvideo.gif)

## Installation

Installation with cuda 12.2

<details>
  <summary>Setup the root for all source files</summary>
  <pre><code>
    git clone https://github.com/tomtomtommi/stereoanyvideo
    cd stereoanyvideo
    export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH
  </code></pre>
</details>

<details>
  <summary>Create a conda env</summary>
  <pre><code>
    conda create -n sav python=3.10
    conda activate sav
  </code></pre>
</details>

<details>
  <summary>Install requirements</summary>
  <pre><code>
    conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=12.1 -c pytorch -c nvidia
    pip install pip==24.0
    pip install pytorch_lightning==1.6.0
    pip install iopath
    conda install -c bottler nvidiacub
    pip install scikit-image matplotlib imageio plotly opencv-python
    conda install -c fvcore -c conda-forge fvcore
    pip install black usort flake8 flake8-bugbear flake8-comprehensions
    conda install pytorch3d -c pytorch3d
    pip install -r requirements.txt
    pip install timm
  </code></pre>
</details>

<details>
  <summary>Download VDA checkpoints</summary>
  <pre><code>
    cd models/Video-Depth-Anything
    sh get_weights.sh
  </code></pre>
</details>

## Inference a stereo video

```
sh demo.sh
```
Before running, download the checkpoints on [google drive](https://drive.google.com/drive/folders/1c7L065dcBWhCYYjWYo2edGOG605PnpXv?usp=sharing) . 
Copy the checkpoints to `./checkpoints/`

In default, left and right camera videos are supposed to be structured like this:
```none
./demo_video/
        ├── left
            ├── left000000.png
            ├── left000001.png
            ├── left000002.png
            ...
        ├── right
            ├── right000000.png
            ├── right000001.png
            ├── right000002.png
            ...
```

A simple way to run the demo is using SouthKensingtonSV.

To test on your own data, modify `--path ./demo_video/`. More arguments can be found and modified in ` demo.py`

## Dataset

Download the following datasets and put in `./data/datasets/`:
 - [SceneFlow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
 - [Sintel](http://sintel.is.tue.mpg.de/stereo)
 - [Dynamic_Replica](https://dynamic-stereo.github.io/)
 - [KITTI Depth](https://www.cvlibs.net/datasets/kitti/eval_depth_all.php)
 - [Infinigen SV](https://tomtomtommi.github.io/BiDAVideo/)
 - [Virtual KITTI2](https://europe.naverlabs.com/proxy-virtual-worlds-vkitti-2/)
 - [SouthKensington SV](https://tomtomtommi.github.io/BiDAVideo/)


## Evaluation
```
sh evaluate_stereoanyvideo.sh
```

## Training
```
sh train_stereoanyvideo.sh
```

## Citation 
If you use our method in your research, please consider citing:
```
@inproceedings{jing2025stereo,
  title={Stereo any video: Temporally consistent stereo matching},
  author={Jing, Junpeng and Luo, Weixun and Mao, Ye and Mikolajczyk, Krystian},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages={20836--20846},
  year={2025}
}
```


================================================
FILE: assets/1
================================================



================================================
FILE: checkpoints/checkpoints here.txt
================================================


================================================
FILE: data/datasets/dataset here.txt
================================================


================================================
FILE: datasets/augmentor.py
================================================
import numpy as np
import random
from PIL import Image

import cv2

cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)

from torchvision.transforms import ColorJitter, functional, Compose


class AdjustGamma(object):
    def __init__(self, gamma_min, gamma_max, gain_min=1.0, gain_max=1.0):
        self.gamma_min, self.gamma_max, self.gain_min, self.gain_max = (
            gamma_min,
            gamma_max,
            gain_min,
            gain_max,
        )

    def __call__(self, sample):
        gain = random.uniform(self.gain_min, self.gain_max)
        gamma = random.uniform(self.gamma_min, self.gamma_max)
        return functional.adjust_gamma(sample, gamma, gain)

    def __repr__(self):
        return f"Adjust Gamma {self.gamma_min}, ({self.gamma_max}) and Gain ({self.gain_min}, {self.gain_max})"


class SequenceDispFlowAugmentor:
    def __init__(
        self,
        crop_size,
        min_scale=-0.2,
        max_scale=0.5,
        do_flip=True,
        yjitter=False,
        saturation_range=[0.6, 1.4],
        gamma=[1, 1, 1, 1],
    ):
        # spatial augmentation params
        self.crop_size = crop_size
        self.min_scale = min_scale
        self.max_scale = max_scale
        self.spatial_aug_prob = 1.0
        self.stretch_prob = 0.8
        self.max_stretch = 0.2

        # flip augmentation params
        self.yjitter = yjitter
        self.do_flip = do_flip
        self.h_flip_prob = 0.5
        self.v_flip_prob = 0.1

        # photometric augmentation params
        self.photo_aug = Compose(
            [
                ColorJitter(
                    brightness=0.4,
                    contrast=0.4,
                    saturation=saturation_range,
                    hue=0.5 / 3.14,
                ),
                AdjustGamma(*gamma),
            ]
        )
        self.asymmetric_color_aug_prob = 0.2
        self.eraser_aug_prob = 0.5

    def color_transform(self, seq):
        """Photometric augmentation"""

        # asymmetric
        if np.random.rand() < self.asymmetric_color_aug_prob:
            for i in range(len(seq)):
                for cam in (0, 1):
                    seq[i][cam] = np.array(
                        self.photo_aug(Image.fromarray(seq[i][cam])), dtype=np.uint8
                    )
        # symmetric
        else:
            image_stack = np.concatenate(
                [seq[i][cam] for i in range(len(seq)) for cam in (0, 1)], axis=0
            )
            image_stack = np.array(
                self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8
            )
            split = np.split(image_stack, len(seq) * 2, axis=0)
            for i in range(len(seq)):
                seq[i][0] = split[2 * i]
                seq[i][1] = split[2 * i + 1]
        return seq

    def eraser_transform(self, seq, bounds=[50, 100]):
        """Occlusion augmentation"""
        ht, wd = seq[0][0].shape[:2]
        for i in range(len(seq)):
            for cam in (0, 1):
                if np.random.rand() < self.eraser_aug_prob:
                    mean_color = np.mean(seq[0][0].reshape(-1, 3), axis=0)
                    for _ in range(np.random.randint(1, 3)):
                        x0 = np.random.randint(0, wd)
                        y0 = np.random.randint(0, ht)
                        dx = np.random.randint(bounds[0], bounds[1])
                        dy = np.random.randint(bounds[0], bounds[1])
                        seq[i][cam][y0 : y0 + dy, x0 : x0 + dx, :] = mean_color

        return seq

    def spatial_transform(self, img, disp):
        # randomly sample scale
        ht, wd = img[0][0].shape[:2]
        min_scale = np.maximum(
            (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd)
        )

        scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
        scale_x = scale
        scale_y = scale
        if np.random.rand() < self.stretch_prob:
            scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
            scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)

        scale_x = np.clip(scale_x, min_scale, None)
        scale_y = np.clip(scale_y, min_scale, None)

        if np.random.rand() < self.spatial_aug_prob:
            # rescale the images
            for i in range(len(img)):
                for cam in (0, 1):
                    img[i][cam] = cv2.resize(
                        img[i][cam],
                        None,
                        fx=scale_x,
                        fy=scale_y,
                        interpolation=cv2.INTER_LINEAR,
                    )
                    if len(disp[i]) > 0:
                        disp[i][cam] = cv2.resize(
                            disp[i][cam],
                            None,
                            fx=scale_x,
                            fy=scale_y,
                            interpolation=cv2.INTER_LINEAR,
                        )
                        disp[i][cam] = disp[i][cam] * [scale_x, scale_y]

        if self.yjitter:
            y0 = np.random.randint(2, img[0][0].shape[0] - self.crop_size[0] - 2)
            x0 = np.random.randint(2, img[0][0].shape[1] - self.crop_size[1] - 2)

            for i in range(len(img)):
                y1 = y0 + np.random.randint(-2, 2 + 1)
                img[i][0] = img[i][0][
                    y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
                ]
                img[i][1] = img[i][1][
                    y1 : y1 + self.crop_size[0], x0 : x0 + self.crop_size[1]
                ]
                if len(disp[i]) > 0:
                    disp[i][0] = disp[i][0][
                        y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
                    ]
                    disp[i][1] = disp[i][1][
                        y1 : y1 + self.crop_size[0], x0 : x0 + self.crop_size[1]
                    ]
        else:
            y0 = np.random.randint(0, img[0][0].shape[0] - self.crop_size[0])
            x0 = np.random.randint(0, img[0][0].shape[1] - self.crop_size[1])
            for i in range(len(img)):
                for cam in (0, 1):
                    img[i][cam] = img[i][cam][
                        y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
                    ]
                    if len(disp[i]) > 0:
                        disp[i][cam] = disp[i][cam][
                            y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
                        ]

        return img, disp

    def __call__(self, img, disp):
        img = self.color_transform(img)
        img = self.eraser_transform(img)
        img, disp = self.spatial_transform(img, disp)

        for i in range(len(img)):
            for cam in (0, 1):
                img[i][cam] = np.ascontiguousarray(img[i][cam])
                if len(disp[i]) > 0:
                    disp[i][cam] = np.ascontiguousarray(disp[i][cam])

        return img, disp


class SequenceDispSparseFlowAugmentor:
    def __init__(
        self,
        crop_size,
        min_scale=-0.2,
        max_scale=0.5,
        do_flip=True,
        yjitter=False,
        saturation_range=[0.6, 1.4],
        gamma=[1, 1, 1, 1],
    ):
        # spatial augmentation params
        self.crop_size = crop_size
        self.min_scale = min_scale
        self.max_scale = max_scale
        self.spatial_aug_prob = 1.0
        self.stretch_prob = 0.8
        self.max_stretch = 0.2

        # flip augmentation params
        self.yjitter = yjitter
        self.do_flip = do_flip
        self.h_flip_prob = 0.5
        self.v_flip_prob = 0.1

        # photometric augmentation params
        self.photo_aug = Compose(
            [
                ColorJitter(
                    brightness=0.4,
                    contrast=0.4,
                    saturation=saturation_range,
                    hue=0.5 / 3.14,
                ),
                AdjustGamma(*gamma),
            ]
        )
        self.asymmetric_color_aug_prob = 0.2
        self.eraser_aug_prob = 0.5

    def color_transform(self, seq):
        """Photometric augmentation"""
        # symmetric
        image_stack = np.concatenate(
            [seq[i][cam] for i in range(len(seq)) for cam in (0, 1)], axis=0
        )
        image_stack = np.array(
            self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8
        )
        split = np.split(image_stack, len(seq) * 2, axis=0)
        for i in range(len(seq)):
            seq[i][0] = split[2 * i]
            seq[i][1] = split[2 * i + 1]
        return seq

    def eraser_transform(self, seq, bounds=[50, 100]):
        """Occlusion augmentation"""
        ht, wd = seq[0][0].shape[:2]
        for i in range(len(seq)):
            for cam in (0, 1):
                if np.random.rand() < self.eraser_aug_prob:
                    mean_color = np.mean(seq[0][0].reshape(-1, 3), axis=0)
                    for _ in range(np.random.randint(1, 3)):
                        x0 = np.random.randint(0, wd)
                        y0 = np.random.randint(0, ht)
                        dx = np.random.randint(bounds[0], bounds[1])
                        dy = np.random.randint(bounds[0], bounds[1])
                        seq[i][cam][y0 : y0 + dy, x0 : x0 + dx, :] = mean_color

        return seq

    def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
        ht, wd = flow.shape[:2]
        coords = np.meshgrid(np.arange(wd), np.arange(ht))
        coords = np.stack(coords, axis=-1)

        coords = coords.reshape(-1, 2).astype(np.float32)
        flow = flow.reshape(-1, 2).astype(np.float32)
        valid = valid.reshape(-1).astype(np.float32)

        coords0 = coords[valid>=1]
        flow0 = flow[valid>=1]

        ht1 = int(round(ht * fy))
        wd1 = int(round(wd * fx))

        coords1 = coords0 * [fx, fy]
        flow1 = flow0 * [fx, fy]

        xx = np.round(coords1[:,0]).astype(np.int32)
        yy = np.round(coords1[:,1]).astype(np.int32)

        v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
        xx = xx[v]
        yy = yy[v]
        flow1 = flow1[v]

        flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
        valid_img = np.zeros([ht1, wd1], dtype=np.int32)

        flow_img[yy, xx] = flow1
        valid_img[yy, xx] = 1

        return flow_img, valid_img

    def spatial_transform(self, img, disp, valid):
        # randomly sample scale
        ht, wd = img[0][0].shape[:2]
        min_scale = np.maximum(
            (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd)
        )

        scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
        scale_x = scale
        scale_y = scale
        if np.random.rand() < self.stretch_prob:
            scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
            scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)

        scale_x = np.clip(scale_x, min_scale, None)
        scale_y = np.clip(scale_y, min_scale, None)

        if np.random.rand() < self.spatial_aug_prob:
            # rescale the images
            for i in range(len(img)):
                for cam in (0, 1):
                    img[i][cam] = cv2.resize(
                        img[i][cam],
                        None,
                        fx=scale_x,
                        fy=scale_y,
                        interpolation=cv2.INTER_LINEAR,
                    )
                    if len(disp[i]) > 0:
                        disp[i][cam], valid[i][cam] = self.resize_sparse_flow_map(disp[i][cam], valid[i][cam], fx=scale_x, fy=scale_y)

        margin_y = 20
        margin_x = 50

        y0 = np.random.randint(0, img[0][0].shape[0] - self.crop_size[0])
        x0 = np.random.randint(0, img[0][0].shape[1] - self.crop_size[1])
        for i in range(len(img)):
            for cam in (0, 1):
                img[i][cam] = img[i][cam][
                    y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
                ]
                if len(disp[i]) > 0:
                    disp[i][cam] = disp[i][cam][
                        y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
                    ]
                    valid[i][cam] = valid[i][cam][
                        y0: y0 + self.crop_size[0], x0: x0 + self.crop_size[1]
                    ]
        return img, disp, valid

    def __call__(self, img, disp, valid):
        img = self.color_transform(img)
        img = self.eraser_transform(img)
        img, disp, valid = self.spatial_transform(img, disp, valid)

        for i in range(len(img)):
            for cam in (0, 1):
                img[i][cam] = np.ascontiguousarray(img[i][cam])
                if len(disp[i]) > 0:
                    disp[i][cam] = np.ascontiguousarray(disp[i][cam])
                    valid[i][cam] = np.ascontiguousarray(valid[i][cam])

        return img, disp, valid


================================================
FILE: datasets/frame_utils.py
================================================
import numpy as np
from PIL import Image
from os.path import *
import re
import imageio
import cv2

cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)

TAG_CHAR = np.array([202021.25], np.float32)


def readFlow(fn):
    """Read .flo file in Middlebury format"""
    # Code adapted from:
    # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy

    # WARNING: this will work on little-endian architectures (eg Intel x86) only!
    # print 'fn = %s'%(fn)
    with open(fn, "rb") as f:
        magic = np.fromfile(f, np.float32, count=1)
        if 202021.25 != magic:
            print("Magic number incorrect. Invalid .flo file")
            return None
        else:
            w = np.fromfile(f, np.int32, count=1)
            h = np.fromfile(f, np.int32, count=1)
            # print 'Reading %d x %d flo file\n' % (w, h)
            data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
            # Reshape data into 3D array (columns, rows, bands)
            # The reshape here is for visualization, the original code is (w,h,2)
            return np.resize(data, (int(h), int(w), 2))


def readPFM(file):
    file = open(file, "rb")

    color = None
    width = None
    height = None
    scale = None
    endian = None

    header = file.readline().rstrip()
    if header == b"PF":
        color = True
    elif header == b"Pf":
        color = False
    else:
        raise Exception("Not a PFM file.")

    dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline())
    if dim_match:
        width, height = map(int, dim_match.groups())
    else:
        raise Exception("Malformed PFM header.")

    scale = float(file.readline().rstrip())
    if scale < 0:  # little-endian
        endian = "<"
        scale = -scale
    else:
        endian = ">"  # big-endian

    data = np.fromfile(file, endian + "f")
    shape = (height, width, 3) if color else (height, width)

    data = np.reshape(data, shape)
    data = np.flipud(data)
    return data


def readDispSintelStereo(file_name):
    """Return disparity read from filename."""
    f_in = np.array(Image.open(file_name))
    d_r = f_in[:, :, 0].astype("float64")
    d_g = f_in[:, :, 1].astype("float64")
    d_b = f_in[:, :, 2].astype("float64")

    disp = d_r * 4 + d_g / (2 ** 6) + d_b / (2 ** 14)
    mask = np.array(Image.open(file_name.replace("disparities", "occlusions")))
    valid = (mask == 0) & (disp > 0)
    return disp, valid


def readDispMiddlebury(file_name):
    assert basename(file_name) == "disp0GT.pfm"
    disp = readPFM(file_name).astype(np.float32)
    assert len(disp.shape) == 2
    nocc_pix = file_name.replace("disp0GT.pfm", "mask0nocc.png")
    assert exists(nocc_pix)
    nocc_pix = imageio.imread(nocc_pix) == 255
    assert np.any(nocc_pix)
    return disp, nocc_pix


def read_gen(file_name, pil=False):
    ext = splitext(file_name)[-1]
    if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg":
        return Image.open(file_name)
    elif ext == ".bin" or ext == ".raw":
        return np.load(file_name)
    elif ext == ".flo":
        return readFlow(file_name).astype(np.float32)
    elif ext == ".pfm":
        flow = readPFM(file_name).astype(np.float32)
        if len(flow.shape) == 2:
            return flow
        else:
            return flow[:, :, :-1]
    return []


================================================
FILE: datasets/video_datasets.py
================================================
import os
import copy
import gzip
import logging
import torch
import numpy as np
import torch.utils.data as data
import torch.nn.functional as F
import os.path as osp
from glob import glob
import cv2
import re
from scipy.spatial.transform import Rotation as R

from collections import defaultdict
from PIL import Image
from dataclasses import dataclass
from typing import List, Optional
from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.implicitron.dataset.types import (
    FrameAnnotation as ImplicitronFrameAnnotation,
    load_dataclass,
)

from stereoanyvideo.datasets import frame_utils
from stereoanyvideo.evaluation.utils.eval_utils import depth2disparity_scale
from stereoanyvideo.datasets.augmentor import SequenceDispFlowAugmentor, SequenceDispSparseFlowAugmentor


@dataclass
class DynamicReplicaFrameAnnotation(ImplicitronFrameAnnotation):
    """A dataclass used to load annotations from json."""

    camera_name: Optional[str] = None


class StereoSequenceDataset(data.Dataset):
    def __init__(self, aug_params=None, sparse=False, reader=None):
        self.augmentor = None
        self.sparse = sparse
        self.img_pad = (
            aug_params.pop("img_pad", None) if aug_params is not None else None
        )
        if aug_params is not None and "crop_size" in aug_params:
            if sparse:
                self.augmentor = SequenceDispSparseFlowAugmentor(**aug_params)
            else:
                self.augmentor = SequenceDispFlowAugmentor(**aug_params)

        if reader is None:
            self.disparity_reader = frame_utils.read_gen
        else:
            self.disparity_reader = reader
        self.depth_reader = self._load_depth
        self.is_test = False
        self.sample_list = []
        self.extra_info = []
        self.depth_eps = 1e-5

    def _load_depth(self, depth_path):
        if depth_path[-3:] == "npy":
            return self._load_npy_depth(depth_path)
        elif depth_path[-3:] == "png":
            if "kitti_depth" in depth_path:
                return self._load_kitti_depth(depth_path)
            elif "vkitti2" in depth_path:
                return self._load_vkitti2(depth_path)
            else:
                return self._load_16big_png_depth(depth_path)
        else:
            raise ValueError("Other format depth is not implemented")

    def _load_npy_depth(self, depth_npy):
        depth = np.load(depth_npy)
        return depth

    def _load_vkitti2(self, depth_png):
        depth_image = cv2.imread(depth_png, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
        depth_in_meters = depth_image.astype(np.float32) / 100.0
        depth_in_meters[depth_image == 0] = -1.

        return depth_in_meters

    def _load_kitti_depth(self, depth_png):
        # depth_image = cv2.imread(depth_png, cv2.IMREAD_UNCHANGED)
        # depth_in_meters = depth_image.astype(np.float32) / 256.0
        depth_image = np.array(Image.open(depth_png), dtype=int)
        # make sure we have a proper 16bit depth map here.. not 8bit!
        assert (np.max(depth_image) > 255)

        depth_in_meters = depth_image.astype(np.float32) / 256.
        depth_in_meters[depth_image == 0] = -1.

        return depth_in_meters

    def _load_16big_png_depth(self, depth_png):
        with Image.open(depth_png) as depth_pil:
            # the image is stored with 16-bit depth but PIL reads it as I (32 bit).
            # we cast it to uint16, then reinterpret as float16, then cast to float32
            depth = (
                np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
                .astype(np.float32)
                .reshape((depth_pil.size[1], depth_pil.size[0]))
            )
        return depth

    def load_tartanair_pose(self, filepath, index=0):
        poses = np.loadtxt(filepath)
        tx, ty, tz, qx, qy, qz, qw = poses[index]

        # Quaternion to rotation matrix
        r = R.from_quat([qx, qy, qz, qw])
        R_mat = r.as_matrix()

        # Assemble 4x4 pose matrix
        T = np.eye(4)
        T[:3, :3] = R_mat
        T[:3, 3] = [tx, ty, tz]

        return T

    def parse_txt_file(self, file_path):
        with open(file_path, 'r') as file:
            data = file.read()

        # Regex patterns
        intrinsic_pattern = re.compile(r"Intrinsic:\s*\[\[([^\]]+)\]\s*\[\s*([^\]]+)\]\s*\[\s*([^\]]+)\]\]")
        frame_pattern = re.compile(r"Frame (\d+): Pose: ([\w\d]+)\n([\s\S]+?)(?=Frame|\Z)")

        # Extract intrinsic matrix (K)
        intrinsic_match = intrinsic_pattern.search(data)
        if intrinsic_match:
            K = np.array([list(map(float, row.split())) for row in intrinsic_match.groups()])
        else:
            raise ValueError("Intrinsic matrix not found in the file")

        # Extract frames and compute R and T
        frames = []
        for frame_match in frame_pattern.finditer(data):
            frame_number = int(frame_match.group(1))
            pose_id = frame_match.group(2)
            pose_matrix = np.array([list(map(float, row.split())) for row in frame_match.group(3).strip().split('\n')])

            # Decompose pose matrix into R and T
            R = pose_matrix[:3, :3]  # The upper-left 3x3 part is the rotation matrix
            T = pose_matrix[:3, 3]  # The first three elements of the fourth column is the translation vector

            frames.append({
                'frame_number': frame_number,
                'pose_id': pose_id,
                'pose_matrix': pose_matrix,
                'R': R,
                'T': T
            })

        return K, frames

    def _get_pytorch3d_camera(
        self, entry_viewpoint, image_size, scale: float
    ) -> PerspectiveCameras:
        assert entry_viewpoint is not None
        # principal point and focal length
        principal_point = torch.tensor(
            entry_viewpoint.principal_point, dtype=torch.float
        )
        focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
        half_image_size_wh_orig = (
            torch.tensor(list(reversed(image_size)), dtype=torch.float) / 2.0
        )

        # first, we convert from the dataset's NDC convention to pixels
        format = entry_viewpoint.intrinsics_format
        if format.lower() == "ndc_norm_image_bounds":
            # this is e.g. currently used in CO3D for storing intrinsics
            rescale = half_image_size_wh_orig
        elif format.lower() == "ndc_isotropic":
            rescale = half_image_size_wh_orig.min()
        else:
            raise ValueError(f"Unknown intrinsics format: {format}")

        # principal point and focal length in pixels
        principal_point_px = half_image_size_wh_orig - principal_point * rescale
        focal_length_px = focal_length * rescale

        # now, convert from pixels to PyTorch3D v0.5+ NDC convention
        # if self.image_height is None or self.image_width is None:
        out_size = list(reversed(image_size))

        half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
        half_min_image_size_output = half_image_size_output.min()

        # rescaled principal point and focal length in ndc
        principal_point = (
            half_image_size_output - principal_point_px * scale
        ) / half_min_image_size_output
        focal_length = focal_length_px * scale / half_min_image_size_output
        return PerspectiveCameras(
            focal_length=focal_length[None],
            principal_point=principal_point[None],
            R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
            T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
        )

    def _get_pytorch3d_camera_from_blender(self, R, T, K, image_size, scale: float) -> PerspectiveCameras:
        assert R is not None and T is not None and K is not None
        assert R.shape == (3, 3), f"Expected R to be 3x3, but got {R.shape}"
        assert T.shape == (3,), f"Expected T to be a 3-element vector, but got {T.shape}"
        assert K.shape == (3, 3), f"Expected K to be 3x3, but got {K.shape}"

        # Extract principal point and focal length from K
        fx = K[0, 0]
        fy = K[1, 1]
        cx = K[0, 2]
        cy = K[1, 2]

        principal_point = torch.tensor([cx, cy], dtype=torch.float)
        focal_length = torch.tensor([fx, fy], dtype=torch.float)

        half_image_size_wh_orig = (
                torch.tensor(list(reversed(image_size)), dtype=torch.float) / 2.0
        )

        # Adjust principal point and focal length in pixels
        principal_point_px = principal_point * scale
        focal_length_px = focal_length * scale

        # Convert from pixels to PyTorch3D NDC convention
        principal_point = (principal_point_px - half_image_size_wh_orig) / half_image_size_wh_orig
        half_min_image_size_output = half_image_size_wh_orig.min()
        focal_length = focal_length_px / half_min_image_size_output

        R = R.T @ np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]], dtype=np.float64)
        T = T @ np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]], dtype=np.float64)

        # Convert R and T to PyTorch tensors
        R_tensor = torch.tensor(R, dtype=torch.float).unsqueeze(0)  # Add batch dimension
        T_tensor = torch.tensor(T, dtype=torch.float).view(1, 3)  # Ensure T is a 1x3 tensor

        # Return PerspectiveCameras object
        return PerspectiveCameras(
            focal_length=focal_length.unsqueeze(0),  # Add batch dimension
            principal_point=principal_point.unsqueeze(0),  # Add batch dimension
            R=R_tensor,
            T=T_tensor,
        )

    def _get_output_tensor(self, sample):
        output_tensor = defaultdict(list)
        sample_size = len(sample["image"]["left"])
        output_tensor_keys = ["img", "disp", "valid_disp", "mask"]
        add_keys = ["viewpoint", "metadata"]
        for add_key in add_keys:
            if add_key in sample:
                output_tensor_keys.append(add_key)

        for key in output_tensor_keys:
            output_tensor[key] = [[] for _ in range(sample_size)]

        if "viewpoint" in sample:
            viewpoint_left = self._get_pytorch3d_camera(
                sample["viewpoint"]["left"][0],
                sample["metadata"]["left"][0][1],
                scale=1.0,
            )
            viewpoint_right = self._get_pytorch3d_camera(
                sample["viewpoint"]["right"][0],
                sample["metadata"]["right"][0][1],
                scale=1.0,
            )
            depth2disp_scale = depth2disparity_scale(
                viewpoint_left,
                viewpoint_right,
                torch.Tensor(sample["metadata"]["left"][0][1])[None],
            )
            output_tensor["depth2disp_scale"] = [depth2disp_scale]

        if "camera" in sample:
            output_tensor["viewpoint"] = [[] for _ in range(sample_size)]
            # InfinigenSV
            if sample["camera"]["left"][0][-3:] == "npz":
                # Note that the K, R, T is based on Blender world Matrix
                camera_left = np.load(sample["camera"]["left"][0])
                camera_right = np.load(sample["camera"]["right"][0])
                camera_left_RT = camera_left['T']
                camera_right_RT = camera_right['T']
                camera_left_K = camera_left['K']
                camera_right_K = camera_right['K']
                camera_left_T = camera_left['T'][:3, 3]
                camera_left_R = camera_left['T'][:3, :3]
                fix_baseline = np.linalg.norm(camera_left_RT[:3, 3] - camera_right_RT[:3, 3])
                focal_length_px = camera_left_K[0][0]
                depth2disp_scale = focal_length_px * fix_baseline

            # Sintel
            elif sample["camera"]["left"][0][-3:] == "cam":
                TAG_FLOAT = 202021.25
                f = open(sample["camera"]["left"][0], 'rb')
                check = np.fromfile(f, dtype=np.float32, count=1)[0]
                assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(
                    TAG_FLOAT, check)
                camera_left_K = np.fromfile(f, dtype='float64', count=9).reshape((3, 3))
                camera_left_RT = np.fromfile(f, dtype='float64', count=12).reshape((3, 4))
                fix_baseline = 0.1 # From the MPI Sintel dataset website, the baseline of the cameras = 10cm = 0.1m
                focal_length_px = camera_left_K[0][0]
                depth2disp_scale = focal_length_px * fix_baseline
                camera_left_T = camera_left_RT[:3, 3]
                camera_left_R = camera_left_RT[:3, :3]

            # Spring
            elif any(filename in path for path in sample["camera"]["left"] for filename in ["focaldistance.txt", "extrinsics.txt", "intrinsics.txt"]):
                for path in sample["camera"]["left"]:
                    if "intrinsics.txt" in path:
                        intrinsics_path = path
                    elif "extrinsics.txt" in path:
                        extrinsics_path = path

                fx, fy, cx, cy = np.loadtxt(intrinsics_path)[0]
                # Build the 3x3 intrinsic matrix
                camera_left_K = np.array([
                    [fx, 0, cx],
                    [0, fy, cy],
                    [0, 0, 1]
                ])
                focal_length_px = camera_left_K[0][0]
                fix_baseline = 0.065  # From the dataset website, the baseline of the cameras = 6.5cm = 0.065m
                depth2disp_scale = focal_length_px * fix_baseline
                camera_left_RT = np.loadtxt(extrinsics_path).reshape(-1, 4, 4)[0]
                camera_left_T = camera_left_RT[:3, 3]
                camera_left_R = camera_left_RT[:3, :3]

            # TartanAir
            elif sample["camera"]["left"][0][-13:] == "pose_left.txt":
                fx, fy, cx, cy = 320.0, 320.0, 320.0, 240.0
                # Build the 3x3 intrinsic matrix
                camera_left_K = np.array([
                    [fx, 0, cx],
                    [0, fy, cy],
                    [0, 0, 1]
                ])
                focal_length_px = camera_left_K[0][0]
                fix_baseline = 0.25
                depth2disp_scale = focal_length_px * fix_baseline
                camera_left_RT = self.load_tartanair_pose(sample["camera"]["left"][0], index=0)
                camera_left_T = camera_left_RT[:3, 3]
                camera_left_R = camera_left_RT[:3, :3]

            # KITTI Depth
            elif sample["camera"]["left"][0][-20:] == "calib_cam_to_cam.txt":
                calib_data = {}
                with open(sample["camera"]["left"][0], 'r') as f:
                    for line in f:
                        key, value = line.split(':', 1)
                        calib_data[key.strip()] = value.strip()

                P_key = 'P_rect_02'
                if P_key in calib_data:
                    P_values = np.array([float(x) for x in calib_data[P_key].split()])
                    P_matrix = P_values.reshape(3, 4)
                else:
                    raise KeyError(f"Projection matrix for camera not found in calibration data")
                focal_length_px = P_matrix[0, 0]

                T_key1 = 'T_02'
                T_key2 = 'T_03'
                if T_key1 in calib_data and T_key2 in calib_data:
                    T1 = np.array([float(x) for x in calib_data[T_key1].split()])
                    T2 = np.array([float(x) for x in calib_data[T_key2].split()])
                    baseline = np.linalg.norm(T1 - T2)
                else:
                    raise KeyError(f"Translation vectors for cameras not found in calibration data")

                R_key1 = 'R_rect_02'
                R_key2 = 'R_rect_03'
                if R_key1 in calib_data and R_key2 in calib_data:
                    R1 = np.array([float(x) for x in calib_data[R_key1].split()]).reshape(3, 3)
                    R2 = np.array([float(x) for x in calib_data[R_key2].split()]).reshape(3, 3)
                else:
                    raise KeyError(f"Rotation vectors for cameras not found in calibration data")

                depth2disp_scale = focal_length_px * baseline
                camera_left_K = P_matrix[:, :3]
                camera_left_T = T1
                camera_left_R = R1

            # VKITTI2
            elif sample["camera"]["left"][0][-13:] == "intrinsic.txt":
                baseline = 0.532725
                with open(sample["camera"]["left"][0], 'r') as f:
                    line = f.readlines()[1]
                    values = line.strip().split()
                    frame = int(values[0])
                    camera_id = int(values[1])
                    fx = float(values[2])
                    fy = float(values[3])
                    cx = float(values[4])
                    cy = float(values[5])
                    # Construct the intrinsic matrix
                    camera_left_K = torch.tensor([[fx, 0, cx],
                                      [0, fy, cy],
                                      [0, 0, 1]], dtype=torch.float32)
                depth2disp_scale = camera_left_K[0, 0] * baseline

                with open(sample["camera"]["left"][0].replace("intrinsic.txt", "extrinsic.txt"), 'r') as f:
                    line = f.readlines()[1]
                    values = line.strip().split()
                    frame = int(values[0])
                    camera_id = int(values[1])
                    # Extract rotation (3x3) and translation (3x1)
                    camera_left_R = np.array([
                        [float(values[2]), float(values[3]), float(values[4])],
                        [float(values[6]), float(values[7]), float(values[8])],
                        [float(values[10]), float(values[11]), float(values[12])]
                    ], dtype=np.float32)

                    camera_left_T = np.array([
                        float(values[5]),
                        float(values[9]),
                        float(values[13])
                    ], dtype=np.float32)

            # SouthKensington
            elif sample["camera"]["left"][0][-3:] == "txt":
                camera_left_K, frames = self.parse_txt_file(sample["camera"]["left"][0])
                fix_baseline = 0.12
                camera_left_R = frames[0]['R']
                camera_left_T = frames[0]['T']
                focal_length_px = camera_left_K[0][0]
                depth2disp_scale = focal_length_px * fix_baseline
            else:
                raise ValueError("Other format camera is not implemented")

            output_tensor["depth2disp_scale"] = [depth2disp_scale]
            output_tensor["RTK"] = [camera_left_R, camera_left_T, camera_left_K]

        for i in range(sample_size):
            for cam in ["left", "right"]:
                if "mask" in sample and cam in sample["mask"]:
                    mask = frame_utils.read_gen(sample["mask"][cam][i])
                    mask = np.array(mask) / 255.0
                    output_tensor["mask"][i].append(mask)

                if "viewpoint" in sample and cam in sample["viewpoint"]:
                    viewpoint = self._get_pytorch3d_camera(
                        sample["viewpoint"][cam][i],
                        sample["metadata"][cam][i][1],
                        scale=1.0,
                    )
                    output_tensor["viewpoint"][i].append(viewpoint)
                if "camera" in sample:
                    # InfinigenSV
                    if sample["camera"]["left"][0][-3:] == "npz" and cam in sample["camera"]:
                        # Note that the K, R, T is based on Blender world Matrix
                        camera = np.load(sample["camera"][cam][i])
                        camera_K = camera['K']
                        camera_T = camera['T'][:3, 3]
                        camera_R = camera['T'][:3, :3]
                        viewpoint = self._get_pytorch3d_camera_from_blender(
                            camera_R, camera_T, camera_K,
                            sample["metadata"][cam][i][1],
                            scale=1.0,
                        )
                        output_tensor["viewpoint"][i].append(viewpoint)

                    # Sintel
                    elif sample["camera"]["left"][0][-3:] == "cam" and cam in sample["camera"]:
                        f = open(sample["camera"]["left"][0], 'rb')
                        check = np.fromfile(f, dtype=np.float32, count=1)[0]
                        assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(
                            TAG_FLOAT, check)
                        camera_K = np.fromfile(f, dtype='float64', count=9).reshape((3, 3))
                        camera_RT = np.fromfile(f, dtype='float64', count=12).reshape((3, 4))
                        camera_T = camera_RT[:3, 3]
                        camera_R = camera_RT[:3, :3]
                        viewpoint = self._get_pytorch3d_camera_from_blender(
                            camera_R, camera_T, camera_K,
                            sample["metadata"][cam][i][1],
                            scale=1.0,
                        )
                        output_tensor["viewpoint"][i].append(viewpoint)

                    # TartanAir
                    elif sample["camera"]["left"][0][-13:] == "pose_left.txt":
                        fx, fy, cx, cy = 320.0, 320.0, 320.0, 240.0
                        # Build the 3x3 intrinsic matrix
                        camera_left_K = np.array([
                            [fx, 0, cx],
                            [0, fy, cy],
                            [0, 0, 1]
                        ])
                        focal_length_px = camera_left_K[0][0]
                        fix_baseline = 0.25
                        depth2disp_scale = focal_length_px * fix_baseline
                        camera_left_RT = self.load_tartanair_pose(sample["camera"]["left"][0], index=i)
                        camera_left_T = camera_left_RT[:3, 3]
                        camera_left_R = camera_left_RT[:3, :3]

                    # Spring
                    elif any(filename in path for path in sample["camera"]["left"] for filename
                             in ["focaldistance.txt", "extrinsics.txt", "intrinsics.txt"]) and cam in sample["camera"]:

                        for path in sample["camera"]["left"]:
                            if "intrinsics.txt" in path:
                                intrinsics_path = path
                            elif "extrinsics.txt" in path:
                                extrinsics_path = path

                        fx, fy, cx, cy = np.loadtxt(intrinsics_path)[0]
                        # Build the 3x3 intrinsic matrix
                        camera_K = np.array([
                            [fx, 0, cx],
                            [0, fy, cy],
                            [0, 0, 1]
                        ])
                        focal_length_px = camera_left_K[0][0]
                        fix_baseline = 0.065  # From the dataset website, the baseline of the cameras = 6.5cm = 0.065m
                        depth2disp_scale = focal_length_px * fix_baseline
                        camera_RT = np.loadtxt(extrinsics_path).reshape(-1, 4, 4)[i]
                        camera_T = camera_RT[:3, 3]
                        camera_R = camera_RT[:3, :3]
                        viewpoint = self._get_pytorch3d_camera_from_blender(
                            camera_R, camera_T, camera_K,
                            sample["metadata"][cam][i][1],
                            scale=1.0,
                        )
                        output_tensor["viewpoint"][i].append(viewpoint)

                    # KITTI Depth
                    elif sample["camera"]["left"][0][-20:] == "calib_cam_to_cam.txt":
                        calib_data = {}
                        with open(sample["camera"]["left"][0], 'r') as f:
                            for line in f:
                                key, value = line.split(':', 1)
                                calib_data[key.strip()] = value.strip()

                        P_key = 'P_rect_02'
                        if P_key in calib_data:
                            P_values = np.array([float(x) for x in calib_data[P_key].split()])
                            P_matrix = P_values.reshape(3, 4)
                        else:
                            raise KeyError(f"Projection matrix for camera not found in calibration data")
                        focal_length_px = P_matrix[0, 0]

                        T_key1 = 'T_02'
                        T_key2 = 'T_03'
                        if T_key1 in calib_data and T_key2 in calib_data:
                            T1 = np.array([float(x) for x in calib_data[T_key1].split()])
                            T2 = np.array([float(x) for x in calib_data[T_key2].split()])
                            baseline = np.linalg.norm(T1 - T2)
                        else:
                            raise KeyError(f"Translation vectors for cameras not found in calibration data")

                        R_key1 = 'R_rect_02'
                        R_key2 = 'R_rect_03'
                        if R_key1 in calib_data and R_key2 in calib_data:
                            R1 = np.array([float(x) for x in calib_data[R_key1].split()]).reshape(3, 3)
                            R2 = np.array([float(x) for x in calib_data[R_key2].split()]).reshape(3, 3)
                        else:
                            raise KeyError(f"Rotation vectors for cameras not found in calibration data")

                        depth2disp_scale = focal_length_px * baseline
                        camera_K = P_matrix[:, :3]
                        camera_T = T1
                        camera_R = R1
                        viewpoint = self._get_pytorch3d_camera_from_blender(
                            camera_R, camera_T, camera_K,
                            sample["metadata"][cam][i][1],
                            scale=1.0,
                        )
                        output_tensor["viewpoint"][i].append(viewpoint)

                    # VKITTI2
                    elif sample["camera"]["left"][0][-13:] == "intrinsic.txt":
                        with open(sample["camera"]["left"][0], 'r') as f:
                            line = f.readlines()[1+i]
                            values = line.strip().split()
                            frame = int(values[0])
                            camera_id = int(values[1])
                            fx = float(values[2])
                            fy = float(values[3])
                            cx = float(values[4])
                            cy = float(values[5])
                            # Construct the intrinsic matrix
                            camera_K = torch.tensor([[fx, 0, cx],
                                                          [0, fy, cy],
                                                          [0, 0, 1]], dtype=torch.float32)
                        with open(sample["camera"]["left"][0].replace("intrinsic.txt", "extrinsic.txt"), 'r') as f:
                            line = f.readlines()[1+i]
                            values = line.strip().split()
                            frame = int(values[0])
                            camera_id = int(values[1])
                            # Extract rotation (3x3) and translation (3x1)
                            camera_R = np.array([
                                [float(values[2]), float(values[3]), float(values[4])],
                                [float(values[6]), float(values[7]), float(values[8])],
                                [float(values[10]), float(values[11]), float(values[12])]
                            ], dtype=np.float32)

                            camera_T = np.array([
                                float(values[5]),
                                float(values[9]),
                                float(values[13])
                            ], dtype=np.float32)
                        viewpoint = self._get_pytorch3d_camera_from_blender(
                            camera_R, camera_T, camera_K,
                            sample["metadata"][cam][i][1],
                            scale=1.0,
                        )
                        output_tensor["viewpoint"][i].append(viewpoint)

                    # SouthKensington
                    elif sample["camera"]["left"][0][-3:] == "txt" and cam in sample["camera"]:
                        camera_left_K, frames = self.parse_txt_file(sample["camera"]["left"][0])

                        camera_K = camera_left_K
                        camera_R = frames[i]['R']
                        camera_T = frames[i]['T']
                        viewpoint = self._get_pytorch3d_camera_from_blender(
                            camera_R, camera_T, camera_K,
                            sample["metadata"][cam][i][1],
                            scale=1.0,
                        )
                        output_tensor["viewpoint"][i].append(viewpoint)

                if "metadata" in sample and cam in sample["metadata"]:
                    metadata = sample["metadata"][cam][i]
                    output_tensor["metadata"][i].append(metadata)

                if cam in sample["image"]:
                    img = frame_utils.read_gen(sample["image"][cam][i])
                    img = np.array(img).astype(np.uint8)

                    # grayscale images
                    if len(img.shape) == 2:
                        img = np.tile(img[..., None], (1, 1, 3))
                    else:
                        img = img[..., :3]
                    output_tensor["img"][i].append(img)

                if cam in sample["disparity"]:
                    disp = self.disparity_reader(sample["disparity"][cam][i])
                    if isinstance(disp, tuple):
                        disp, valid_disp = disp
                    else:
                        valid_disp = disp < 512
                    disp = np.array(disp).astype(np.float32)
                    disp = np.stack([-disp, np.zeros_like(disp)], axis=-1)
                    # disp = np.stack([disp, np.zeros_like(disp)], axis=-1)

                    output_tensor["disp"][i].append(disp)
                    output_tensor["valid_disp"][i].append(valid_disp)

                elif "depth" in sample and cam in sample["depth"]:
                    depth = self.depth_reader(sample["depth"][cam][i])
                    depth_mask = depth < self.depth_eps
                    depth[depth_mask] = self.depth_eps
                    disp = depth2disp_scale / depth
                    disp[depth_mask] = 0
                    valid_disp = (disp < 512) * (1 - depth_mask)
                    disp = np.array(disp).astype(np.float32)
                    disp = np.stack([-disp, np.zeros_like(disp)], axis=-1)
                    output_tensor["disp"][i].append(disp)
                    output_tensor["valid_disp"][i].append(valid_disp)

        return output_tensor

    def __getitem__(self, index):
        im_tensor = {"img"}
        sample = self.sample_list[index]
        if self.is_test:
            sample_size = len(sample["image"]["left"])
            im_tensor["img"] = [[] for _ in range(sample_size)]
            for i in range(sample_size):
                for cam in ["left", "right"]:
                    img = frame_utils.read_gen(sample["image"][cam][i])
                    img = np.array(img).astype(np.uint8)[..., :3]
                    img = torch.from_numpy(img).permute(2, 0, 1).float()
                    im_tensor["img"][i].append(img)
            im_tensor["img"] = torch.stack(im_tensor["img"])
            return im_tensor, self.extra_info[index]

        index = index % len(self.sample_list)
        try:
            output_tensor = self._get_output_tensor(sample)
        except:
            logging.warning(f"Exception in loading sample {index}!")
            index = np.random.randint(len(self.sample_list))
            logging.info(f"New index is {index}")
            sample = self.sample_list[index]
            output_tensor = self._get_output_tensor(sample)

        sample_size = len(sample["image"]["left"])
        if self.augmentor is not None:
            if self.sparse:
                output_tensor["img"], output_tensor["disp"], output_tensor["valid_disp"] = self.augmentor(
                    output_tensor["img"], output_tensor["disp"], output_tensor["valid_disp"]
                )
            else:
                output_tensor["img"], output_tensor["disp"] = self.augmentor(
                    output_tensor["img"], output_tensor["disp"]
                )
        for i in range(sample_size):
            for cam in (0, 1):
                if cam < len(output_tensor["img"][i]):
                    img = (
                        torch.from_numpy(output_tensor["img"][i][cam])
                        .permute(2, 0, 1)
                        .float()
                    )
                    if self.img_pad is not None:
                        padH, padW = self.img_pad
                        img = F.pad(img, [padW] * 2 + [padH] * 2)
                    output_tensor["img"][i][cam] = img

                if cam < len(output_tensor["disp"][i]):
                    disp = (
                        torch.from_numpy(output_tensor["disp"][i][cam])
                        .permute(2, 0, 1)
                        .float()
                    )

                    if self.sparse:
                        valid_disp = torch.from_numpy(
                            output_tensor["valid_disp"][i][cam]
                        )
                    else:
                        valid_disp = (
                            (disp[0].abs() < 512)
                            & (disp[1].abs() < 512)
                            & (disp[0].abs() != 0)
                        )
                    disp = disp[:1]

                    output_tensor["disp"][i][cam] = disp
                    output_tensor["valid_disp"][i][cam] = valid_disp.float()

                if "mask" in output_tensor and cam < len(output_tensor["mask"][i]):
                    mask = torch.from_numpy(output_tensor["mask"][i][cam]).float()
                    output_tensor["mask"][i][cam] = mask

                if "viewpoint" in output_tensor and cam < len(
                    output_tensor["viewpoint"][i]
                ):

                    viewpoint = output_tensor["viewpoint"][i][cam]
                    output_tensor["viewpoint"][i][cam] = viewpoint

        res = {}
        if "viewpoint" in output_tensor and self.split != "train":
            res["viewpoint"] = output_tensor["viewpoint"]
        if "metadata" in output_tensor and self.split != "train":
            res["metadata"] = output_tensor["metadata"]
        if "depth2disp_scale" in output_tensor and self.split != "train":
            res["depth2disp_scale"] = output_tensor["depth2disp_scale"]
        if "RTK" in output_tensor and self.split != "train":
            res["RTK"] = output_tensor["RTK"]

        for k, v in output_tensor.items():
            if k != "viewpoint" and k != "metadata" and k != "depth2disp_scale" and k != "RTK":
                for i in range(len(v)):
                    if len(v[i]) > 0:
                        v[i] = torch.stack(v[i])
                if len(v) > 0 and (len(v[0]) > 0):
                    res[k] = torch.stack(v)
        return res

    def __mul__(self, v):
        copy_of_self = copy.deepcopy(self)
        copy_of_self.sample_list = v * copy_of_self.sample_list
        copy_of_self.extra_info = v * copy_of_self.extra_info
        return copy_of_self

    def __len__(self):
        return len(self.sample_list)


class DynamicReplicaDataset(StereoSequenceDataset):
    def __init__(
        self,
        aug_params=None,
        root="./data/datasets/dynamic_replica_data",
        split="train",
        sample_len=-1,
        only_first_n_samples=-1,
    ):
        super(DynamicReplicaDataset, self).__init__(aug_params)
        self.root = root
        self.sample_len = sample_len
        self.split = split

        frame_annotations_file = f"frame_annotations_{split}.jgz"

        with gzip.open(
            osp.join(root, split, frame_annotations_file), "rt", encoding="utf8"
        ) as zipfile:
            frame_annots_list = load_dataclass(
                zipfile, List[DynamicReplicaFrameAnnotation]
            )
        seq_annot = defaultdict(lambda: defaultdict(list))
        for frame_annot in frame_annots_list:
            seq_annot[frame_annot.sequence_name][frame_annot.camera_name].append(
                frame_annot
            )
        for seq_name in seq_annot.keys():
            try:
                filenames = defaultdict(lambda: defaultdict(list))
                for cam in ["left", "right"]:
                    for framedata in seq_annot[seq_name][cam]:
                        im_path = osp.join(root, split, framedata.image.path)
                        depth_path = osp.join(root, split, framedata.depth.path)
                        mask_path = osp.join(root, split, framedata.mask.path)

                        assert os.path.isfile(im_path), im_path
                        if self.split == 'train':
                            assert os.path.isfile(depth_path), depth_path
                        assert os.path.isfile(mask_path), mask_path

                        filenames["image"][cam].append(im_path)
                        if os.path.isfile(depth_path):
                            filenames["depth"][cam].append(depth_path)
                        filenames["mask"][cam].append(mask_path)

                        filenames["viewpoint"][cam].append(framedata.viewpoint)
                        filenames["metadata"][cam].append(
                            [framedata.sequence_name, framedata.image.size]
                        )

                        for k in filenames.keys():
                            assert (
                                len(filenames[k][cam])
                                == len(filenames["image"][cam])
                                > 0
                            ), framedata.sequence_name

                seq_len = len(filenames["image"][cam])

                print("seq_len", seq_name, seq_len)
                if split == "train":
                    for ref_idx in range(0, seq_len, 3):
                        step = 1 if self.sample_len == 1 else np.random.randint(1, 6)
                        if ref_idx + step * self.sample_len < seq_len:
                            sample = defaultdict(lambda: defaultdict(list))
                            for cam in ["left", "right"]:
                                for idx in range(
                                    ref_idx, ref_idx + step * self.sample_len, step
                                ):
                                    for k in filenames.keys():
                                        if "mask" not in k:
                                            sample[k][cam].append(
                                                filenames[k][cam][idx]
                                            )

                            self.sample_list.append(sample)
                else:
                    step = self.sample_len if self.sample_len > 0 else seq_len
                    counter = 0
                    for ref_idx in range(0, seq_len, step):
                        sample = defaultdict(lambda: defaultdict(list))
                        for cam in ["left", "right"]:
                            for idx in range(ref_idx, ref_idx + step):
                                for k in filenames.keys():
                                    sample[k][cam].append(filenames[k][cam][idx])

                        self.sample_list.append(sample)
                        counter += 1
                        if only_first_n_samples > 0 and counter >= only_first_n_samples:
                            break
            except Exception as e:
                print(e)
                print("Skipping sequence", seq_name)

        assert len(self.sample_list) > 0, "No samples found"
        print(f"Added {len(self.sample_list)} from Dynamic Replica {split}")
        logging.info(f"Added {len(self.sample_list)} from Dynamic Replica {split}")


class InfinigenStereoVideoDataset(StereoSequenceDataset):
    def __init__(
        self,
        aug_params=None,
        root="./data/datasets/InfinigenStereo",
        split="train",
        sample_len=-1,
        only_first_n_samples=-1,
    ):
        super(InfinigenStereoVideoDataset, self).__init__(aug_params)
        self.root = root
        self.sample_len = sample_len
        self.split = split

        sequence = sorted(
            glob(osp.join(root, self.split, "*"))
        )
        for i in range(len(sequence)):
            sequence_name = os.path.basename(sequence[i])
            try:
                filenames = defaultdict(lambda: defaultdict(list))
                for cam in ["left", "right"]:
                    suffix = "camera_0/" if cam == "left" else "camera_1/"
                    im_path_list = sorted(glob(osp.join(sequence[i], "frames/Image/", suffix, "*.png")))
                    depth_path_list = sorted(glob(osp.join(sequence[i], "frames/Depth/", suffix, "*.npy")))
                    camera_list = sorted(glob(osp.join(sequence[i], "frames/camview/", suffix, "*.npz")))
                    for j in range(len(im_path_list)):
                        im_path = im_path_list[j]
                        depth_path = depth_path_list[j]
                        camera_path = camera_list[j]
                        assert os.path.isfile(im_path), im_path
                        assert os.path.isfile(depth_path), depth_path
                        filenames["image"][cam].append(im_path)
                        filenames["depth"][cam].append(depth_path)
                        filenames["camera"][cam].append(camera_path)
                        filenames["metadata"][cam].append([sequence_name , (720,1280)])

                        for k in filenames.keys():
                            assert (
                                    len(filenames[k][cam])
                                    == len(filenames["image"][cam])
                                    > 0
                            ), sequence_name
                seq_len = len(filenames["image"][cam])

                print("seq_len", sequence_name, seq_len)
                if self.split == "train":
                    for ref_idx in range(0, seq_len, 3):
                        step = 1 if self.sample_len == 1 else np.random.randint(1, 6)
                        if ref_idx + step * self.sample_len < seq_len:
                            sample = defaultdict(lambda: defaultdict(list))
                            for cam in ["left", "right"]:
                                for idx in range(
                                    ref_idx, ref_idx + step * self.sample_len, step
                                ):
                                    for k in filenames.keys():
                                        if "mask" not in k:
                                            sample[k][cam].append(
                                                filenames[k][cam][idx]
                                            )

                            self.sample_list.append(sample)
                else:
                    step = self.sample_len if (self.sample_len > 0) and (self.sample_len < seq_len) else seq_len
                    print("sample_step", step)
                    counter = 0
                    for ref_idx in range(0, seq_len, step):
                        sample = defaultdict(lambda: defaultdict(list))
                        for cam in ["left", "right"]:
                            for idx in range(ref_idx, ref_idx + step):
                                for k in filenames.keys():
                                    sample[k][cam].append(filenames[k][cam][idx])

                        self.sample_list.append(sample)
                        counter += 1
                        if only_first_n_samples > 0 and counter >= only_first_n_samples:
                            break
            except Exception as e:
                print(e)
                print("Skipping sequence", sequence_name)
        assert len(self.sample_list) > 0, "No samples found"
        print(f"Added {len(self.sample_list)} from Infinigen Stereo Video {split}")
        logging.info(f"Added {len(self.sample_list)} from Infinigen Stereo Video {split}")


class SouthKensingtonStereoVideoDataset(StereoSequenceDataset):
    def __init__(
        self,
        aug_params=None,
        root="./data/datasets/SouthKensington/data/",
        split="test",
        subroot="",
        sample_len=-1,
        only_first_n_samples=-1,
    ):
        super(SouthKensingtonStereoVideoDataset, self).__init__(aug_params)
        self.root = root
        self.split = split
        self.sample_len = sample_len

        sequence = sorted(
            glob(osp.join(root, "*"))
        )
        for i in range(len(sequence)):
            sequence_name = os.path.basename(sequence[i])
            try:
                filenames = defaultdict(lambda: defaultdict(list))
                for cam in ["left", "right"]:
                    im_path_list = sorted(glob(osp.join(sequence[i], "images", cam, "*.png")))
                    camera_path = glob(osp.join(sequence[i], "*.txt"))[0]

                    for j in range(len(im_path_list)):
                        im_path = im_path_list[j]
                        assert os.path.isfile(im_path), im_path
                        filenames["image"][cam].append(im_path)
                        filenames["camera"][cam].append(camera_path)
                        filenames["metadata"][cam].append([sequence_name , (720,1280)])

                        for k in filenames.keys():
                            assert (
                                    len(filenames[k][cam])
                                    == len(filenames["image"][cam])
                                    > 0
                            ), sequence_name
                seq_len = len(filenames["image"][cam])
                print("seq_len", sequence_name, seq_len)

                step = self.sample_len if (self.sample_len > 0) and (self.sample_len < seq_len) else seq_len
                print("sample_step", step)
                counter = 0
                for ref_idx in range(0, seq_len, step):
                    sample = defaultdict(lambda: defaultdict(list))
                    for cam in ["left", "right"]:
                        for idx in range(ref_idx, ref_idx + step):
                            for k in filenames.keys():
                                sample[k][cam].append(filenames[k][cam][idx])

                    self.sample_list.append(sample)
                    counter += 1
                    if only_first_n_samples > 0 and counter >= only_first_n_samples:
                        break
            except Exception as e:
                print(e)
                print("Skipping sequence", sequence_name)
        assert len(self.sample_list) > 0, "No samples found"
        print(f"Added {len(self.sample_list)} from SouthKensington Stereo Video")
        logging.info(f"Added {len(self.sample_list)} from SouthKensington Stereo Video")


class KITTIDepthDataset(StereoSequenceDataset):
    def __init__(
        self,
        aug_params=None,
        root="./data/datasets/",
        split="train",
        sample_len=-1,
        only_first_n_samples=-1,
    ):
        super().__init__(aug_params, sparse=True)
        # super(KITTIDepthDataset, self).__init__(aug_params)
        image_root = osp.join(root, "kitti_depth", "input")
        gt_root = osp.join(root, "kitti_depth", "gt_depth")
        self.sample_len = sample_len
        self.split = split
        # Following CODD: https://github.com/facebookresearch/CODD
        val_split = ['2011_10_03_drive_0042_sync']  # 1 scene
        test_split = ['2011_09_26_drive_0002_sync', '2011_09_26_drive_0005_sync',
                      '2011_09_26_drive_0013_sync', '2011_09_26_drive_0020_sync',
                      '2011_09_26_drive_0023_sync', '2011_09_26_drive_0036_sync',
                      '2011_09_26_drive_0079_sync', '2011_09_26_drive_0095_sync',
                      '2011_09_26_drive_0113_sync', '2011_09_28_drive_0037_sync',
                      '2011_09_29_drive_0026_sync', '2011_09_30_drive_0016_sync',
                      '2011_10_03_drive_0047_sync']  # 13 scenes

        sequence_root = sorted(glob(osp.join(gt_root, "*")))
        train_list = []
        val_list = []
        test_list = []
        for i in range(len(sequence_root)):
            sequence_name = os.path.basename(os.path.normpath(sequence_root[i]))
            if sequence_name in test_split:
                test_list.append(sequence_root[i])
            elif sequence_name in val_split:
                val_list.append(sequence_root[i])
            else:
                train_list.append(sequence_root[i])

        if self.split == "train":
            sequence_split = train_list
        elif self.split == "val":
            sequence_split = val_list
        elif self.split == "test":
            sequence_split = test_list
        else:
            raise ValueError("Wrong Split: ", self.split)

        for i in range(len(sequence_split)):
            sequence_name = os.path.basename(os.path.normpath(sequence_split[i]))
            sequence_camera = osp.join(image_root, sequence_name[:10], "calib_cam_to_cam.txt")
            try:
                filenames = defaultdict(lambda: defaultdict(list))
                for cam in ["left", "right"]:
                    suffix = "image_02/" if cam == "left" else "image_03/"
                    depth_path_list = sorted(
                        glob(osp.join(gt_root, sequence_name, "proj_depth", "groundtruth", suffix, "*.png")))
                    for j in range(len(depth_path_list)):
                        depth_path = depth_path_list[j]
                        assert os.path.isfile(depth_path), depth_path
                        filenames["depth"][cam].append(depth_path)

                        # find the corresponding images
                        im_name = os.path.basename(os.path.normpath(depth_path))
                        im_path = osp.join(image_root, sequence_name[:10], sequence_name, suffix, "data", im_name)
                        assert os.path.isfile(im_path), im_path
                        filenames["image"][cam].append(im_path)
                        filenames["camera"][cam].append(sequence_camera)
                        filenames["metadata"][cam].append([sequence_name, (370,1224)])
                        for k in filenames.keys():
                            assert (
                                    len(filenames[k][cam])
                                    == len(filenames["depth"][cam])
                                    > 0
                            ), sequence_name
                seq_len = len(filenames["image"][cam])
                print("seq_len", sequence_name, seq_len)
                if self.split == "train":
                    for ref_idx in range(0, seq_len, 3):
                        step = 1 if self.sample_len == 1 else np.random.randint(1, 6)
                        if ref_idx + step * self.sample_len < seq_len:
                            sample = defaultdict(lambda: defaultdict(list))
                            for cam in ["left", "right"]:
                                for idx in range(
                                        ref_idx, ref_idx + step * self.sample_len, step
                                ):
                                    for k in filenames.keys():
                                        if "mask" not in k:
                                            sample[k][cam].append(
                                                filenames[k][cam][idx]
                                            )
                            self.sample_list.append(sample)
                else:
                    step = self.sample_len if (self.sample_len > 0) and (self.sample_len < seq_len) else seq_len
                    print("sample_step", step)
                    counter = 0
                    for ref_idx in range(0, seq_len, step):
                        sample = defaultdict(lambda: defaultdict(list))
                        for cam in ["left", "right"]:
                            for idx in range(ref_idx, ref_idx + step):
                                for k in filenames.keys():
                                    sample[k][cam].append(filenames[k][cam][idx])

                        self.sample_list.append(sample)
                        counter += 1
                        if only_first_n_samples > 0 and counter >= only_first_n_samples:
                            break
            except Exception as e:
                print(e)
                print("Skipping sequence", sequence_name)
        assert len(self.sample_list) > 0, "No samples found"
        print(f"Added {len(self.sample_list)} from  KITTI Depth {split}")
        logging.info(f"Added {len(self.sample_list)} from KITTI Depth {split}")


def split_train_valid(path_list, valid_keywords):
    path_list_init = path_list
    for kw in valid_keywords:
        path_list = list(filter(lambda s: kw not in s, path_list))
    train_path_list = sorted(path_list)
    valid_path_list = sorted(list(set(path_list_init) - set(train_path_list)))
    return train_path_list, valid_path_list


class TartanAirDataset(StereoSequenceDataset):
    def __init__(
        self,
        aug_params=None,
        root="./data/datasets/TartanAir/",
        split="train",
        sample_len=-1,
        only_first_n_samples=-1,
    ):
        super().__init__(aug_params, sparse=False)
        self.sample_len = sample_len
        self.split = split

        # Each entry is (scene, motion, part)
        test_entries = [
            ("abandonedfactory", "Easy", "P002"),
            ("abandonedfactory", "Hard", "P002"),
            ("amusement", "Easy", "P007"),
            ("amusement", "Hard", "P007"),
            ("carwelding", "Hard", "P003"),
            ("endofworld", "Easy", "P006"),
            ("endofworld", "Hard", "P006"),
            ("gascola", "Easy", "P001"),
            ("gascola", "Hard", "P001"),
            ("hospital", "Hard", "P042"),
            ("office", "Easy", "P006"),
            ("office", "Hard", "P006"),
            ("office2", "Easy", "P004"),
            ("office2", "Hard", "P004"),
            ("oldtown", "Hard", "P006"),
            ("soulcity", "Easy", "P008"),
            ("soulcity", "Hard", "P008"),
        ]

        scene_root = sorted(glob(osp.join(root, "*")))

        sequence_root_list = []
        test_set = []
        train_set = []
        for i in range(len(scene_root)):
            sequence_root_list += sorted(glob(osp.join(scene_root[i], "Easy", "*"))) + sorted(glob(osp.join(scene_root[i], "Hard", "*")))

        for path in sequence_root_list:
            parts = path.split("/")
            if len(parts) < 5:
                continue  # skip malformed paths

            scene, motion, part = parts[-3], parts[-2], parts[-1]
            if (scene, motion, part) in test_entries:
                test_set.append(path)
            else:
                train_set.append(path)

        if self.split == "train":
            sequence_root_list = train_set
        elif self.split == "test":
            sequence_root_list = test_set
        else:
            raise KeyError(f"Wrong Split!")

        for i in range(len(sequence_root_list)):
            filenames = defaultdict(lambda: defaultdict(list))
            sequence_root = sequence_root_list[i]
            parts = os.path.normpath(sequence_root).split(os.sep)
            sequence_name = "_".join(parts[-3:])
            try:
                for cam in ['left', 'right']:
                    depth_path_list = sorted(glob(osp.join(sequence_root, "depth_left/", "*.npy")))
                    im_path_list = sorted(glob(osp.join(sequence_root, f"image_{cam}/", "*.png")))
                    pose_path = os.path.join(sequence_root, f"pose_{cam}.txt")
                    assert len(depth_path_list) == len(im_path_list), [len(depth_path_list), len(im_path_list)]
                    for j in range(len(depth_path_list)):
                        depth_path = depth_path_list[j]
                        assert os.path.isfile(depth_path), depth_path
                        filenames["depth"][cam].append(depth_path)
                        im_path = im_path_list[j]
                        assert os.path.isfile(im_path), im_path
                        filenames["image"][cam].append(im_path)
                        filenames["camera"][cam].append(pose_path)
                        filenames["metadata"][cam].append([sequence_name, (480,640)])
                        for k in filenames.keys():
                            assert (
                                    len(filenames[k][cam])
                                    == len(filenames["depth"][cam])
                                    > 0
                            ), sequence_name
                seq_len = len(filenames["image"][cam])
                print("seq_len", sequence_name, seq_len)

                if self.split == "train":
                    for ref_idx in range(0, seq_len, 3):
                        step = 1 if self.sample_len == 1 else np.random.randint(1, 6)
                        if ref_idx + step * self.sample_len < seq_len:
                            sample = defaultdict(lambda: defaultdict(list))
                            for cam in ["left", "right"]:
                                for idx in range(
                                        ref_idx, ref_idx + step * self.sample_len, step
                                ):
                                    for k in filenames.keys():
                                        if "mask" not in k:
                                            sample[k][cam].append(
                                                filenames[k][cam][idx]
                                            )
                            self.sample_list.append(sample)
                else:
                    step = self.sample_len if (self.sample_len > 0) and (self.sample_len < seq_len) else seq_len
                    print("sample_step", step)
                    counter = 0
                    for ref_idx in range(0, seq_len, step):
                        sample = defaultdict(lambda: defaultdict(list))
                        for cam in ["left", "right"]:
                            for idx in range(ref_idx, ref_idx + step):
                                for k in filenames.keys():
                                    sample[k][cam].append(filenames[k][cam][idx])

                        self.sample_list.append(sample)
                        counter += 1
                        if only_first_n_samples > 0 and counter >= only_first_n_samples:
                            break
            except Exception as e:
                print(e)
                print("Skipping sequence", sequence_name)
        assert len(self.sample_list) > 0, "No samples found"
        print(f"Added {len(self.sample_list)} from  TarTanAir  {split}")
        logging.info(f"Added {len(self.sample_list)} from TarTanAir {split}")


class VKITTI2Dataset(StereoSequenceDataset):
    def __init__(
        self,
        aug_params=None,
        root="./data/datasets/vkitti2/",
        split="train",
        sample_len=-1,
        only_first_n_samples=-1,
    ):
        super().__init__(aug_params, sparse=False)
        self.sample_len = sample_len
        self.split = split
        if self.split == "train":
            sequence_name_list = []
            scenes = ['Scene01', 'Scene02', 'Scene06', 'Scene18', 'Scene20']
            variations = ['15-deg-left', '15-deg-right', '30-deg-left', '30-deg-right',
                          'clone', 'fog', 'morning', 'overcast', 'rain', 'sunset']
            for scene in scenes:
                for variation in variations:
                    sequence_name = f"{scene}_{variation}"
                    sequence_name_list.append(sequence_name)
        elif self.split == "test":
            sequence_name_list = ["Scene01_15-deg-left", "Scene02_30-deg-right", "Scene06_fog", "Scene18_morning", "Scene20_rain"]
        else:
            raise KeyError(f"Wrong Split!")

        for i in range(len(sequence_name_list)):
            filenames = defaultdict(lambda: defaultdict(list))
            sequence_name = sequence_name_list[i]
            scene, variation = sequence_name.split("_")
            try:
                for cam in [('left', 0), ('right', 1)]:
                    depth_path_list = sorted(glob(osp.join(root, f"{scene}/{variation}/frames/depth/Camera_{cam[1]}/", "*.png")))
                    im_path_list = sorted(glob(osp.join(root, f"{scene}/{variation}/frames/rgb/Camera_{cam[1]}/", "*.jpg")))
                    intrinsic_path = os.path.join(root, f"{scene}/{variation}/intrinsic.txt")
                    assert len(depth_path_list) == len(im_path_list), [len(depth_path_list), len(im_path_list)]
                    for j in range(len(depth_path_list)):
                        depth_path = depth_path_list[j]
                        assert os.path.isfile(depth_path), depth_path
                        filenames["depth"][cam[0]].append(depth_path)
                        im_path = im_path_list[j]
                        assert os.path.isfile(im_path), im_path
                        filenames["image"][cam[0]].append(im_path)
                        filenames["camera"][cam[0]].append(intrinsic_path)
                        filenames["metadata"][cam[0]].append([sequence_name, (375,1242)])
                        for k in filenames.keys():
                            assert (
                                    len(filenames[k][cam[0]])
                                    == len(filenames["depth"][cam[0]])
                                    > 0
                            ), sequence_name
                seq_len = len(filenames["image"][cam[0]])
                print("seq_len", sequence_name, seq_len)

                if self.split == "train":
                    for ref_idx in range(0, seq_len, 3):
                        step = 1 if self.sample_len == 1 else np.random.randint(1, 6)
                        if ref_idx + step * self.sample_len < seq_len:
                            sample = defaultdict(lambda: defaultdict(list))
                            for cam in ["left", "right"]:
                                for idx in range(
                                        ref_idx, ref_idx + step * self.sample_len, step
                                ):
                                    for k in filenames.keys():
                                        if "mask" not in k:
                                            sample[k][cam].append(
                                                filenames[k][cam][idx]
                                            )
                            self.sample_list.append(sample)
                else:
                    step = self.sample_len if (self.sample_len > 0) and (self.sample_len < seq_len) else seq_len
                    print("sample_step", step)
                    counter = 0
                    for ref_idx in range(0, seq_len, step):
                        sample = defaultdict(lambda: defaultdict(list))
                        for cam in ["left", "right"]:
                            for idx in range(ref_idx, ref_idx + step):
                                for k in filenames.keys():
                                    sample[k][cam].append(filenames[k][cam][idx])

                        self.sample_list.append(sample)
                        counter += 1
                        if only_first_n_samples > 0 and counter >= only_first_n_samples:
                            break
            except Exception as e:
                print(e)
                print("Skipping sequence", sequence_name)
        assert len(self.sample_list) > 0, "No samples found"
        print(f"Added {len(self.sample_list)} from  VKITTI2  {split}")
        logging.info(f"Added {len(self.sample_list)} from VKITTI2 {split}")


class SequenceSpringDataset(StereoSequenceDataset):
    def __init__(
        self,
        aug_params=None,
        sample_len=-1,
        root="./data/datasets/Spring",
    ):
        super(SequenceSpringDataset, self).__init__(aug_params)
        self.split = "test"
        self.sample_len = sample_len
        original_length = len(self.sample_list)
        image_paths = defaultdict(list)
        disparity_paths = defaultdict(list)
        camera_paths = defaultdict(list)

        for cam in ["left", "right"]:
            image_paths[cam] = sorted(
                glob(osp.join(root, f"train_frame_{cam}/*"))
            )

        cam = "left"
        disparity_paths[cam] = sorted(
                glob(osp.join(root, f"train_disp1_{cam}/*"))
            )

        camera_paths[cam] = sorted(
                glob(osp.join(root, "train_cam_data/*"))
            )

        num_seq = len(image_paths["left"])
        # for each sequence
        for seq_idx in range(num_seq):
            sequence_name = os.path.basename(image_paths[cam][seq_idx])
            sample = defaultdict(lambda: defaultdict(list))
            for cam in ["left", "right"]:
                sample["image"][cam] = sorted(
                    glob(osp.join(image_paths[cam][seq_idx], f"frame_{cam}", "*.png"))
                )[:sample_len]
                # for _ in range(len(sample["image"][cam])):
                for _ in range(sample_len):
                    sample["metadata"][cam].append([sequence_name, (1080, 1920)])

            cam = "left"
            sample["disparity"][cam] = sorted(
                glob(osp.join(disparity_paths[cam][seq_idx], f"disp1_{cam}", "*.dsp5"))
            )[:sample_len]
            sample["camera"][cam] = sorted(
                glob(osp.join(camera_paths[cam][seq_idx], "cam_data", "*.txt"))
            )
            self.sample_list.append(sample)
            seq_len = len(sample["image"][cam])
            print("seq_len", sequence_name, seq_len)
        logging.info(
            f"Added {len(self.sample_list) - original_length} from Spring Dataset"
        )


class SequenceSceneFlowDataset(StereoSequenceDataset):
    def __init__(
        self,
        aug_params=None,
        root="./data/datasets",
        dstype="frames_cleanpass",
        sample_len=1,
        things_test=False,
        add_things=True,
        add_monkaa=True,
        add_driving=True,
    ):
        super(SequenceSceneFlowDataset, self).__init__(aug_params)
        self.root = root
        self.dstype = dstype
        self.sample_len = sample_len
        if things_test:
            self._add_things("TEST")
        else:
            if add_things:
                self._add_things("TRAIN")
            if add_monkaa:
                self._add_monkaa()
            if add_driving:
                self._add_driving()

    def _add_things(self, split="TRAIN"):
        """Add FlyingThings3D data"""

        original_length = len(self.sample_list)
        root = osp.join(self.root, "FlyingThings3D")
        image_paths = defaultdict(list)
        disparity_paths = defaultdict(list)

        for cam in ["left", "right"]:
            image_paths[cam] = sorted(
                glob(osp.join(root, self.dstype, split, f"*/*/{cam}/"))
            )
            disparity_paths[cam] = [
                path.replace(self.dstype, "disparity") for path in image_paths[cam]
            ]
        # Choose a random subset of 400 images for validation
        # state = np.random.get_state()
        # np.random.seed(1000)
        # val_idxs = set(np.random.permutation(len(image_paths["left"]))[:40])
        # np.random.set_state(state)
        # np.random.seed(0)
        num_seq = len(image_paths["left"])
        num = 0
        for seq_idx in range(num_seq):
            # if (split == "TEST" and seq_idx in val_idxs) or (
            #     split == "TRAIN" and not seq_idx in val_idxs
            # ):
            images, disparities = defaultdict(list), defaultdict(list)
            for cam in ["left", "right"]:
                images[cam] = sorted(
                    glob(osp.join(image_paths[cam][seq_idx], "*.png"))
                )
                disparities[cam] = sorted(
                    glob(osp.join(disparity_paths[cam][seq_idx], "*.pfm"))
                )
            num = num + len(images["left"])
            self._append_sample(images, disparities)
        print(num)
        assert len(self.sample_list) > 0, "No samples found"
        print(
            f"Added {len(self.sample_list) - original_length} from FlyingThings {self.dstype}"
        )
        logging.info(
            f"Added {len(self.sample_list) - original_length} from FlyingThings {self.dstype}"
        )

    def _add_monkaa(self):
        """Add FlyingThings3D data"""

        original_length = len(self.sample_list)
        root = osp.join(self.root, "Monkaa")
        image_paths = defaultdict(list)
        disparity_paths = defaultdict(list)

        for cam in ["left", "right"]:
            image_paths[cam] = sorted(glob(osp.join(root, self.dstype, f"*/{cam}/")))
            disparity_paths[cam] = [
                path.replace(self.dstype, "disparity") for path in image_paths[cam]
            ]

        num_seq = len(image_paths["left"])

        for seq_idx in range(num_seq):
            images, disparities = defaultdict(list), defaultdict(list)
            for cam in ["left", "right"]:
                images[cam] = sorted(glob(osp.join(image_paths[cam][seq_idx], "*.png")))
                disparities[cam] = sorted(
                    glob(osp.join(disparity_paths[cam][seq_idx], "*.pfm"))
                )

            self._append_sample(images, disparities)

        assert len(self.sample_list) > 0, "No samples found"
        print(
            f"Added {len(self.sample_list) - original_length} from Monkaa {self.dstype}"
        )
        logging.info(
            f"Added {len(self.sample_list) - original_length} from Monkaa {self.dstype}"
        )

    def _add_driving(self):
        """Add FlyingThings3D data"""

        original_length = len(self.sample_list)
        root = osp.join(self.root, "Driving")
        image_paths = defaultdict(list)
        disparity_paths = defaultdict(list)

        for cam in ["left", "right"]:
            image_paths[cam] = sorted(
                glob(osp.join(root, self.dstype, f"*/*/*/{cam}/"))
            )
            disparity_paths[cam] = [
                path.replace(self.dstype, "disparity") for path in image_paths[cam]
            ]

        num_seq = len(image_paths["left"])
        for seq_idx in range(num_seq):
            images, disparities = defaultdict(list), defaultdict(list)
            for cam in ["left", "right"]:
                images[cam] = sorted(glob(osp.join(image_paths[cam][seq_idx], "*.png")))
                disparities[cam] = sorted(
                    glob(osp.join(disparity_paths[cam][seq_idx], "*.pfm"))
                )

            self._append_sample(images, disparities)

        assert len(self.sample_list) > 0, "No samples found"
        print(
            f"Added {len(self.sample_list) - original_length} from Driving {self.dstype}"
        )
        logging.info(
            f"Added {len(self.sample_list) - original_length} from Driving {self.dstype}"
        )

    def _append_sample(self, images, disparities):
        seq_len = len(images["left"])
        for ref_idx in range(0, seq_len - self.sample_len):
            sample = defaultdict(lambda: defaultdict(list))
            for cam in ["left", "right"]:
                for idx in range(ref_idx, ref_idx + self.sample_len):
                    sample["image"][cam].append(images[cam][idx])
                    sample["disparity"][cam].append(disparities[cam][idx])
            self.sample_list.append(sample)

            sample = defaultdict(lambda: defaultdict(list))
            for cam in ["left", "right"]:
                for idx in range(ref_idx, ref_idx + self.sample_len):
                    sample["image"][cam].append(images[cam][seq_len - idx - 1])
                    sample["disparity"][cam].append(disparities[cam][seq_len - idx - 1])
            self.sample_list.append(sample)


class SequenceSintelStereo(StereoSequenceDataset):
    def __init__(
        self,
        dstype="clean",
        aug_params=None,
        root="./data/datasets",
    ):
        super().__init__(
            aug_params, sparse=True, reader=frame_utils.readDispSintelStereo
        )
        self.dstype = dstype
        self.split = "test"
        original_length = len(self.sample_list)
        image_root = osp.join(root, "sintel_stereo", "training")
        image_paths = defaultdict(list)
        disparity_paths = defaultdict(list)
        camera_paths = defaultdict(list)

        for cam in ["left", "right"]:
            image_paths[cam] = sorted(
                glob(osp.join(image_root, f"{self.dstype}_{cam}/*"))
            )

        cam = "left"
        disparity_paths[cam] = [
            path.replace(f"{self.dstype}_{cam}", "disparities")
            for path in image_paths[cam]
        ]
        camera_paths[cam] = [
            path.replace(f"{self.dstype}_{cam}", "camdata_left")
            for path in image_paths[cam]
        ]

        num_seq = len(image_paths["left"])
        # for each sequence
        for seq_idx in range(num_seq):
            sequence_name = os.path.basename(image_paths[cam][seq_idx])
            sample = defaultdict(lambda: defaultdict(list))
            for cam in ["left", "right"]:
                sample["image"][cam] = sorted(
                    glob(osp.join(image_paths[cam][seq_idx], "*.png"))
                )
                for _ in range(len(sample["image"][cam])):
                    sample["metadata"][cam].append([sequence_name, (436, 1024)])

            cam = "left"
            sample["disparity"][cam] = sorted(
                glob(osp.join(disparity_paths[cam][seq_idx], "*.png"))
            )
            sample["camera"][cam] = sorted(
                glob(osp.join(camera_paths[cam][seq_idx], "*.cam"))
            )

            for im1, disp, camera in zip(sample["image"][cam], sample["disparity"][cam], sample["camera"][cam]):
                assert (
                    im1.split("/")[-1].split(".")[0]
                    == disp.split("/")[-1].split(".")[0]
                    == camera.split("/")[-1].split(".")[0]
                ), (im1.split("/")[-1].split(".")[0], disp.split("/")[-1].split(".")[0], camera.split("/")[-1].split(".")[0])
            self.sample_list.append(sample)

        logging.info(
            f"Added {len(self.sample_list) - original_length} from SintelStereo {self.dstype}"
        )


def fetch_dataloader(args):
    """Create the data loader for the corresponding training set"""

    aug_params = {
        "crop_size": args.image_size,
        "min_scale": args.spatial_scale[0],
        "max_scale": args.spatial_scale[1],
        "do_flip": False,
        "yjitter": not args.noyjitter,
    }
    if hasattr(args, "saturation_range") and args.saturation_range is not None:
        aug_params["saturation_range"] = args.saturation_range
    if hasattr(args, "img_gamma") and args.img_gamma is not None:
        aug_params["gamma"] = args.img_gamma
    if hasattr(args, "do_flip") and args.do_flip is not None:
        aug_params["do_flip"] = args.do_flip

    train_dataset = None

    add_monkaa = "monkaa" in args.train_datasets
    add_driving = "driving" in args.train_datasets
    add_things = "things" in args.train_datasets
    add_dynamic_replica = "dynamic_replica" in args.train_datasets
    add_infinigensv = "infinigen_stereovideo" in args.train_datasets
    add_kittidepth = "kitti_depth" in args.train_datasets
    add_vkitti2 = "vkitti2" in args.train_datasets
    add_tartanair = "tartanair" in args.train_datasets
    new_dataset = None

    if add_monkaa or add_driving or add_things:
        # clean_dataset = SequenceSceneFlowDataset(
        #     aug_params,
        #     dstype="frames_cleanpass",
        #     sample_len=args.sample_len,
        #     add_monkaa=add_monkaa,
        #     add_driving=add_driving,
        #     add_things=add_things,
        # )

        final_dataset = SequenceSceneFlowDataset(
            aug_params,
            dstype="frames_finalpass",
            sample_len=args.sample_len,
            add_monkaa=add_monkaa,
            add_driving=add_driving,
            add_things=add_things,
        )
        # new_dataset = clean_dataset + final_dataset

        new_dataset = final_dataset

    if add_dynamic_replica:
        dr_dataset = DynamicReplicaDataset(
            aug_params, split="train", sample_len=args.sample_len
        )
        if new_dataset is None:
            new_dataset = dr_dataset
        else:
            new_dataset = new_dataset + dr_dataset

    if add_infinigensv:
        infinigensv_dataset = InfinigenStereoVideoDataset(
            aug_params, split="train", sample_len=args.sample_len
        )
        if new_dataset is None:
            new_dataset = infinigensv_dataset
        else:
            new_dataset = new_dataset + infinigensv_dataset + infinigensv_dataset + infinigensv_dataset

    if add_kittidepth:
        kittidepth_dataset = KITTIDepthDataset(
            aug_params, split="train", sample_len=args.sample_len
        )
        if new_dataset is None:
            new_dataset = kittidepth_dataset
        else:
            new_dataset = new_dataset + kittidepth_dataset

    if add_vkitti2:
        vkitti2_dataset = VKITTI2Dataset(
            aug_params, split="train", sample_len=args.sample_len
        )
        if new_dataset is None:
            new_dataset = vkitti2_dataset
        else:
            new_dataset = new_dataset + vkitti2_dataset

    if add_tartanair:
        tartanair_dataset = TartanAirDataset(
            aug_params, split="train", sample_len=args.sample_len
        )
        if new_dataset is None:
            new_dataset = tartanair_dataset
        else:
            new_dataset = new_dataset + tartanair_dataset

    logging.info(f"Adding {len(new_dataset)} samples in total")
    train_dataset = (
        new_dataset if train_dataset is None else train_dataset + new_dataset
    )

    train_loader = data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        pin_memory=True,
        shuffle=True,
        num_workers=args.num_workers,
        drop_last=True,
    )

    logging.info("Training with %d image pairs" % len(train_dataset))
    return train_loader


================================================
FILE: demo.py
================================================
import sys

import argparse
import os
import cv2
import glob
import numpy as np
import torch
import torch.nn.functional as F
from collections import defaultdict

from PIL import Image
from matplotlib import pyplot as plt
from pathlib import Path

DEVICE = 'cuda'


def load_image(imfile):
    img = np.array(Image.open(imfile).convert('RGB')).astype(np.uint8)
    img = torch.from_numpy(img).permute(2, 0, 1).float()
    return img.to(DEVICE)


def viz(img, flo):
    img = img[0].permute(1, 2, 0).cpu().numpy()
    flo = flo[0].permute(1, 2, 0).cpu().numpy()

    # map flow to rgb image
    flo = flow_viz.flow_to_image(flo)
    img_flo = np.concatenate([img, flo], axis=0)

    cv2.imshow('image', img_flo[:, :, [2, 1, 0]] / 255.0)
    cv2.waitKey()


def demo(args):
    from stereoanyvideo.models.stereoanyvideo_model import StereoAnyVideoModel
    model = StereoAnyVideoModel()

    if args.ckpt is not None:
        assert args.ckpt.endswith(".pth") or args.ckpt.endswith(
            ".pt"
        )
        strict = True
        state_dict = torch.load(args.ckpt)
        if "model" in state_dict:
            state_dict = state_dict["model"]
        if list(state_dict.keys())[0].startswith("module."):
            state_dict = {
                k.replace("module.", ""): v for k, v in state_dict.items()
            }
        model.model.load_state_dict(state_dict, strict=strict)
        print("Done loading model checkpoint", args.ckpt)

    model.to(DEVICE)
    model.eval()

    output_directory = args.output_path
    parent_directory = os.path.dirname(output_directory)
    if not os.path.exists(parent_directory):
        os.makedirs(parent_directory)
    if not os.path.isdir(output_directory):
        os.mkdir(output_directory)

    with torch.no_grad():
        images_left = sorted(glob.glob(os.path.join(args.path, 'left/*.png')) + glob.glob(os.path.join(args.path, 'left/*.jpg')))
        images_right = sorted(glob.glob(os.path.join(args.path, 'right/*.png')) + glob.glob(os.path.join(args.path, 'right/*.jpg')))
        assert len(images_left) == len(images_right), [len(images_left), len(images_right)]
        assert len(images_left) > 0, args.path
        print(f"Found {len(images_left)} frames. Saving files to {args.output_path}")

        num_frames = len(images_left)
        frame_size = args.frame_size

        disparities_ori_all = []

        for start_idx in range(0, num_frames, frame_size):
            end_idx = min(start_idx + frame_size, num_frames)

            image_left_list = []
            image_right_list = []

            for imfile1, imfile2 in zip(images_left[start_idx:end_idx], images_right[start_idx:end_idx]):
                image_left = load_image(imfile1)
                image_right = load_image(imfile2)
                image_left = F.interpolate(image_left[None], size=args.resize, mode="bilinear", align_corners=True)
                image_right = F.interpolate(image_right[None], size=args.resize, mode="bilinear", align_corners=True)
                image_left_list.append(image_left[0])
                image_right_list.append(image_right[0])

            video_left = torch.stack(image_left_list, dim=0)
            video_right = torch.stack(image_right_list, dim=0)

            batch_dict = defaultdict(list)
            batch_dict["stereo_video"] = torch.stack([video_left, video_right], dim=1)

            predictions = model(batch_dict)

            assert "disparity" in predictions
            disparities = predictions["disparity"][:, :1].clone().data.cpu().abs().numpy()
            disparities_ori = disparities.astype(np.uint8)
            disparities_ori_all.extend(disparities_ori)

        disparities_ori_all = np.array(disparities_ori_all)

        epsilon = 1e-5  # Smallest allowable disparity
        disparities_ori_all[disparities_ori_all < epsilon] = epsilon

        disparities_all = ((disparities_ori_all - disparities_ori_all.min()) / (disparities_ori_all.max() - disparities_ori_all.min()) * 255).astype(np.uint8)

        video_ori_disparity = cv2.VideoWriter(
            os.path.join(args.output_path, "disparity.mp4"),
            cv2.VideoWriter_fourcc(*"mp4v"),
            fps=args.fps,
            frameSize=(disparities_all.shape[3], disparities_all.shape[2]),
            isColor=True,
        )
        video_disparity = cv2.VideoWriter(
            os.path.join(args.output_path, "disparity_norm.mp4"),
            cv2.VideoWriter_fourcc(*"mp4v"),
            fps=args.fps,
            frameSize=(disparities_all.shape[3], disparities_all.shape[2]),
            isColor=True,
        )

        for i in range(num_frames):
            imfile1 = images_left[i]

            disparity_norm = disparities_all[i]
            disparity_norm = disparity_norm.transpose(1, 2, 0)
            disparity_norm_vis = cv2.applyColorMap(disparity_norm, cv2.COLORMAP_INFERNO)
            video_disparity.write(disparity_norm_vis)

            disparity_ori = disparities_ori_all[i]
            disparity_ori = disparity_ori.transpose(1, 2, 0)
            disparity_ori_vis = cv2.applyColorMap(disparity_ori, cv2.COLORMAP_INFERNO)
            video_ori_disparity.write(disparity_ori_vis)

            if args.save_png:
                filename_temp = args.output_path + '/disparity_norm_' + str(i).zfill(3) + '.png'
                cv2.imwrite(filename_temp, disparity_norm_vis)
                filename_temp = args.output_path + '/disparity_ori_' + str(i).zfill(3) + '.png'
                cv2.imwrite(filename_temp, disparity_ori_vis)

        video_ori_disparity.release()
        video_disparity.release()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', default="stereoanyvideo", help="name to specify model")
    parser.add_argument('--ckpt', default=None, help="checkpoint of stereo model")
    parser.add_argument('--resize', default=(720, 1280), help="image size input to the model")
    parser.add_argument("--fps", type=int, default=30, help="frame rate for video visualization")
    parser.add_argument('--path', help="dataset for evaluation")
    parser.add_argument("--save_png", action="store_true")
    parser.add_argument("--frame_size", type=int, default=150, help="number of updates in each forward pass.")
    parser.add_argument("--iters",type=int, default=20, help="number of updates in each forward pass.")
    parser.add_argument("--kernel_size", type=int, default=20, help="number of frames in each forward pass.")
    parser.add_argument('--output_path', help="directory to save output", default="demo_output")
    args = parser.parse_args()

    demo(args)


================================================
FILE: demo.sh
================================================
#!/bin/bash

export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH

python demo.py --ckpt ./checkpoints/StereoAnyVideo_MIX.pth --path ./demo_video/ --output_path ./demo_output/ --save_png

================================================
FILE: evaluate_stereoanyvideo.sh
================================================
#!/bin/bash

export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH

# evaluate on [sintel, dynamicreplica, infinigensv, vkitti2] using sceneflow checkpoint

python ./evaluation/evaluate.py --config-name eval_sintel_final \
MODEL.model_name=StereoAnyVideoModel \
MODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_SF.pth

python ./evaluation/evaluate.py --config-name eval_dynamic_replica \
MODEL.model_name=StereoAnyVideoModel \
MODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_SF.pth

python ./evaluation/evaluate.py --config-name eval_infinigensv \
MODEL.model_name=StereoAnyVideoModel \
MODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_SF.pth

python ./evaluation/evaluate.py --config-name eval_vkitti2 \
MODEL.model_name=StereoAnyVideoModel \
MODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_SF.pth



# evaluate on [sintel, kittidepth, southkensingtonSV] using mixed checkpoint

python ./evaluation/evaluate.py --config-name eval_sintel_final \
MODEL.model_name=StereoAnyVideoModel \
MODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_MIX.pth

python ./evaluation/evaluate.py --config-name eval_kittidepth \
MODEL.model_name=StereoAnyVideoModel \
MODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_MIX.pth

python ./evaluation/evaluate.py --config-name eval_southkensington \
MODEL.model_name=StereoAnyVideoModel \
MODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_SF.pth

================================================
FILE: evaluation/configs/eval_dynamic_replica.yaml
================================================
defaults:
  - default_config_eval
visualize_interval: -1
exp_dir: ./outputs/stereoanyvideo_DynamicReplica
sample_len: 150
MODEL:
    model_name: StereoAnyVideoModel


================================================
FILE: evaluation/configs/eval_infinigensv.yaml
================================================
defaults:
  - default_config_eval
visualize_interval: -1
render_bin_size: 0
exp_dir: ./outputs/stereoanyvideo_InfinigenSV
sample_len: 150
dataset_name: infinigensv
MODEL:
    model_name: StereoAnyVideoModel


================================================
FILE: evaluation/configs/eval_kittidepth.yaml
================================================
defaults:
  - default_config_eval
visualize_interval: -1
render_bin_size: 0
exp_dir: ./outputs/stereoanyvideo_KITTIDepth
sample_len: 300
dataset_name: kitti_depth
MODEL:
    model_name: StereoAnyVideoModel


================================================
FILE: evaluation/configs/eval_sintel_clean.yaml
================================================
defaults:
  - default_config_eval
visualize_interval: -1
render_bin_size: 0
exp_dir: ./outputs/stereoanyvideo_sintel_clean
dataset_name: sintel
dstype: clean
MODEL:
    model_name: StereoAnyVideoModel


================================================
FILE: evaluation/configs/eval_sintel_final.yaml
================================================
defaults:
  - default_config_eval
visualize_interval: -1
render_bin_size: 0
exp_dir: ./outputs/stereoanyvideo_sintel_final
dataset_name: sintel
dstype: final
MODEL:
    model_name: StereoAnyVideoModel

================================================
FILE: evaluation/configs/eval_southkensington.yaml
================================================
defaults:
  - default_config_eval
visualize_interval: 1
exp_dir: ./outputs/stereoanyvideo_SouthKensingtonIndoor
sample_len: 300
dataset_name: southkensingtonsv
MODEL:
    model_name: StereoAnyVideoModel


================================================
FILE: evaluation/configs/eval_vkitti2.yaml
================================================
defaults:
  - default_config_eval
visualize_interval: -1
render_bin_size: 0
exp_dir: ./outputs/stereoanyvideo_VKITTI2
sample_len: 300
dataset_name: vkitti2
MODEL:
    model_name: StereoAnyVideoModel


================================================
FILE: evaluation/core/evaluator.py
================================================
import os
import numpy as np
import cv2
from collections import defaultdict
import torch.nn.functional as F
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import Configurable
from stereoanyvideo.evaluation.utils.eval_utils import depth2disparity_scale, eval_batch
from stereoanyvideo.evaluation.utils.utils import (
    PerceptionPrediction,
    pretty_print_perception_metrics,
    visualize_batch,
)


def depth_to_colormap(depth, colormap='jet', eps=1e-3, scale_vmin=1.0):
    valid = (depth > eps) & (depth < 1e4)
    vmin = depth[valid].min() * scale_vmin
    vmax = depth[valid].max()
    if colormap=='jet':
        cmap = plt.cm.jet
    else:
        cmap = plt.cm.inferno
    norm = plt.Normalize(vmin=vmin, vmax=vmax)
    depth = cmap(norm(depth))
    depth[~valid] = 1
    return np.ascontiguousarray(depth[...,:3] * 255, dtype=np.uint8)


class Evaluator(Configurable):
    """
    A class defining the DynamicStereo evaluator.

    Args:
        eps: Threshold for converting disparity to depth.
    """

    eps = 1e-5

    def setup_visualization(self, cfg: DictConfig) -> None:
        # Visualization
        self.visualize_interval = cfg.visualize_interval
        self.render_bin_size = cfg.render_bin_size
        self.exp_dir = cfg.exp_dir
        if self.visualize_interval > 0:
            self.visualize_dir = os.path.join(cfg.exp_dir, "visualisations")

    @torch.no_grad()
    def evaluate_sequence(
        self,
        model,
        model_stabilizer,
        test_dataloader: torch.utils.data.DataLoader,
        is_real_data: bool = False,
        step=None,
        writer=None,
        train_mode=False,
        interp_shape=None,
        exp_dir=None,
    ):
        model.eval()
        per_batch_eval_results = []

        if self.visualize_interval > 0:
            os.makedirs(self.visualize_dir, exist_ok=True)

        for batch_idx, sequence in enumerate(tqdm(test_dataloader)):
            batch_dict = defaultdict(list)
            batch_dict["stereo_video"] = sequence["img"]
            if not is_real_data:
                batch_dict["disparity"] = sequence["disp"][:, 0].abs()
                batch_dict["disparity_mask"] = sequence["valid_disp"][:, :1]

                if "mask" in sequence:
                    batch_dict["fg_mask"] = sequence["mask"][:, :1]
                else:
                    batch_dict["fg_mask"] = torch.ones_like(
                        batch_dict["disparity_mask"]
                    )

            elif interp_shape is not None:
                left_video = batch_dict["stereo_video"][:, 0]
                left_video = F.interpolate(
                    left_video, tuple(interp_shape), mode="bilinear"
                )
                right_video = batch_dict["stereo_video"][:, 1]
                right_video = F.interpolate(
                    right_video, tuple(interp_shape), mode="bilinear"
                )
                batch_dict["stereo_video"] = torch.stack([left_video, right_video], 1)

            if model_stabilizer is not None:
                predictions = model.forward_stabilizer(batch_dict, model_stabilizer)
            elif train_mode:
                predictions = model.forward_batch_test(batch_dict)
            else:
                predictions = model(batch_dict)

            assert "disparity" in predictions
            predictions["disparity"] = predictions["disparity"][:, :1].clone().cpu()
            if not is_real_data:
                predictions["disparity"] = predictions["disparity"] * (
                    batch_dict["disparity_mask"].round()
                )

                batch_eval_result, seq_length = eval_batch(batch_dict, predictions, sequence["depth2disp_scale"][0])
                per_batch_eval_results.append((batch_eval_result, seq_length))
                pretty_print_perception_metrics(batch_eval_result)

            if (self.visualize_interval > 0) and (
                batch_idx % self.visualize_interval == 0
            ):
                perception_prediction = PerceptionPrediction()

                pred_disp = predictions["disparity"]
                pred_disp[pred_disp < self.eps] = self.eps

                scale = sequence["depth2disp_scale"][0]
                perception_prediction.depth_map = (scale / pred_disp).cuda()

                perspective_cameras = []
                if "viewpoint" in sequence:
                    for cam in sequence["viewpoint"]:
                        perspective_cameras.append(cam[0])
                        perception_prediction.perspective_cameras = perspective_cameras

                if "stereo_original_video" in batch_dict:
                    batch_dict["stereo_video"] = batch_dict[
                        "stereo_original_video"
                    ].clone()

                for k, v in batch_dict.items():
                    if isinstance(v, torch.Tensor):
                        batch_dict[k] = v.cuda()

                visualize_batch(
                    batch_dict,
                    perception_prediction,
                    output_dir=self.visualize_dir,
                    sequence_name=sequence["metadata"][0][0][0],
                    step=step,
                    writer=writer,
                    render_bin_size=self.render_bin_size,
                )
                filename = os.path.join(self.visualize_dir, sequence["metadata"][0][0][0])
                if not os.path.isdir(filename):
                    os.mkdir(filename)
                disparity_list = pred_disp.data.cpu().numpy()
                depth_list = perception_prediction.depth_map.data.cpu().numpy()
                np.save(f"{filename}_depth_list.npy", depth_list)

                video_disparity = cv2.VideoWriter(
                    f"{filename}_disparity.mp4",
                    cv2.VideoWriter_fourcc(*"mp4v"),
                    fps=30,
                    frameSize=(
                        batch_dict["stereo_video"][:, 0][0].shape[2], batch_dict["stereo_video"][:, 0][0].shape[1]),
                    isColor=True,
                )

                disparity_vis = depth_to_colormap(disparity_list[:, 0], eps=self.eps, colormap='inferno')
                for i in range(disparity_list.shape[0]):
                    filename_temp = filename + '/disparity_' + str(i).zfill(3) + '.png'
                    disparity_vis[i] = cv2.cvtColor(disparity_vis[i], cv2.COLOR_RGB2BGR)
                    cv2.imwrite(filename_temp, disparity_vis[i])
                    video_disparity.write(disparity_vis[i])
                video_disparity.release()

        return per_batch_eval_results


================================================
FILE: evaluation/evaluate.py
================================================
import json
import os
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

import hydra
import numpy as np

import torch
from omegaconf import OmegaConf

from stereoanyvideo.evaluation.utils.utils import aggregate_and_print_results

import stereoanyvideo.datasets.video_datasets as datasets

from stereoanyvideo.models.core.model_zoo import (
    get_all_model_default_configs,
    model_zoo,
)
from pytorch3d.implicitron.tools.config import get_default_args_field
from stereoanyvideo.evaluation.core.evaluator import Evaluator


@dataclass(eq=False)
class DefaultConfig:
    exp_dir: str = "./outputs"
    stabilizer_ckpt: Optional[str] = None

    # one of [sintel, dynamicreplica, things, kitti_depth, infinigensv, southkensingtonsv]
    dataset_name: str = "dynamicreplica"

    sample_len: int = -1
    dstype: Optional[str] = None
    # clean, final
    MODEL: Dict[str, Any] = field(
        default_factory=lambda: get_all_model_default_configs()
    )
    EVALUATOR: Dict[str, Any] = get_default_args_field(Evaluator)

    seed: int = 42
    gpu_idx: int = 0

    visualize_interval: int = 1  # Use 0 for no visualization

    render_bin_size: Optional[int] = None

    # Override hydra's working directory to current working dir,
    # also disable storing the .hydra logs:
    hydra: dict = field(
        default_factory=lambda: {
            "run": {"dir": "."},
            "output_subdir": None,
        }
    )


def run_eval(cfg: DefaultConfig):
    """
    Evaluates new view synthesis metrics of a specified model
    on a benchmark dataset.
    """
    # make the experiment directory
    os.makedirs(cfg.exp_dir, exist_ok=True)

    # dump the exp cofig to the exp_dir
    cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
    with open(cfg_file, "w") as f:
        OmegaConf.save(config=cfg, f=f)

    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    evaluator = Evaluator(**cfg.EVALUATOR)

    model = model_zoo(**cfg.MODEL)
    model.cuda(0)
    evaluator.setup_visualization(cfg)

    if cfg.dataset_name == "dynamicreplica":
        test_dataloader = datasets.DynamicReplicaDataset(
            split="test", sample_len=cfg.sample_len, only_first_n_samples=1
        )
    elif cfg.dataset_name == "infinigensv":
        test_dataloader = datasets.InfinigenStereoVideoDataset(
            split="test", sample_len=cfg.sample_len, only_first_n_samples=1
        )
    elif cfg.dataset_name == "southkensingtonsv":
        test_dataloader = datasets.SouthKensingtonStereoVideoDataset(
            sample_len=cfg.sample_len, only_first_n_samples=1
        )
        evaluator.evaluate_sequence(
            model,
            None,
            test_dataloader,
            is_real_data=True,
            exp_dir=cfg.exp_dir
        )
        return

    elif cfg.dataset_name == "kitti_depth":
        test_dataloader = datasets.KITTIDepthDataset(
            split="test", sample_len=cfg.sample_len, only_first_n_samples=1
        )
    elif cfg.dataset_name == "vkitti2":
        test_dataloader = datasets.VKITTI2Dataset(
            split="test", sample_len=cfg.sample_len, only_first_n_samples=1
        )
    elif cfg.dataset_name == "sintel":
        test_dataloader = datasets.SequenceSintelStereo(dstype=cfg.dstype)
    elif cfg.dataset_name == "things":
        test_dataloader = datasets.SequenceSceneFlowDatasets(
            {},
            dstype=cfg.dstype,
            sample_len=cfg.sample_len,
            add_monkaa=False,
            add_driving=False,
            things_test=True,
        )

    evaluate_result = evaluator.evaluate_sequence(
        model,
        None,
        test_dataloader,
        is_real_data=False,
        exp_dir=cfg.exp_dir
    )

    aggreegate_result = aggregate_and_print_results(evaluate_result)

    result_file = os.path.join(cfg.exp_dir, f"result_eval.json")

    print(f"Dumping eval results to {result_file}.")
    with open(result_file, "w") as f:
        json.dump(aggreegate_result, f)


cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config_eval", node=DefaultConfig)


@hydra.main(config_path="./configs/", config_name="default_config_eval")
def evaluate(cfg: DefaultConfig) -> None:
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
    run_eval(cfg)


if __name__ == "__main__":
    evaluate()


================================================
FILE: evaluation/utils/eval_utils.py
================================================
from dataclasses import dataclass
from typing import Dict, Optional, Union
from stereoanyvideo.evaluation.utils.ssim import SSIM
import torch
import torch.nn.functional as F
import numpy as np
import math
import cv2
from pytorch3d.utils import opencv_from_cameras_projection
from stereoanyvideo.models.raft_model import RAFTModel


@dataclass(eq=True, frozen=True)
class PerceptionMetric:
    metric: str
    depth_scaling_norm: Optional[str] = None
    suffix: str = ""
    index: str = ""

    def __str__(self):
        return (
            self.metric
            + self.index
            + (
                ("_norm_" + self.depth_scaling_norm)
                if self.depth_scaling_norm is not None
                else ""
            )
            + self.suffix
        )


def compute_flow(seq, is_seq=True):
    raft = RAFTModel().cuda()
    raft.eval()
    if is_seq:
        t, c, h, w = seq.size()
        flows_forward = []
        for i in range(t-1):
            flow_forward = raft.forward_fullres(seq[i][None], seq[i+1][None], iters=20)
            flows_forward.append(flow_forward)
        flows_forward = torch.cat(flows_forward, dim=0)
        return flows_forward

    else:
        img1, img2 = seq
        flow_forward = raft.forward_fullres(img1, img2, iters=20)
        return flow_forward

def flow_warp(x, flow):
    if flow.size(3) != 2:  # [B, H, W, 2]
        flow = flow.permute(0, 2, 3, 1)
    if x.size()[-2:] != flow.size()[1:3]:
        raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
                         f'flow ({flow.size()[1:3]}) are not the same.')
    _, _, h, w = x.size()
    # create mesh grid
    grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
    grid = torch.stack((grid_x, grid_y), 2).type_as(x)  # (h, w, 2)
    grid.requires_grad = False

    grid_flow = grid + flow
    # scale grid_flow to [-1,1]
    grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
    grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
    grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
    output = F.grid_sample(
        x,
        grid_flow,
        mode='bilinear',
        padding_mode='zeros',
        align_corners=True)
    return output

def eval_endpoint_error_sequence(
    x: torch.Tensor,
    y: torch.Tensor,
    mask: torch.Tensor,
    crop: int = 0,
    mask_thr: float = 0.5,
    clamp_thr: float = 1e-5,
) -> Dict[str, torch.Tensor]:

    assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (
        x.shape,
        y.shape,
        mask.shape,
    )
    assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)

    # chuck out the border
    if crop > 0:
        if crop > min(y.shape[2:]) - crop:
            raise ValueError("Incorrect crop size.")
        y = y[:, :, crop:-crop, crop:-crop]
        x = x[:, :, crop:-crop, crop:-crop]
        mask = mask[:, :, crop:-crop, crop:-crop]

    y = y * (mask > mask_thr).float()
    x = x * (mask > mask_thr).float()
    y[torch.isnan(y)] = 0

    results = {}
    for epe_name in ("epe", "temp_epe"):
        if epe_name == "epe":
            endpoint_error = (mask * (x - y) ** 2).sum(dim=1).sqrt()
        elif epe_name == "temp_epe":
            delta_mask = mask[:-1] * mask[1:]
            # endpoint_error = (
            #     (delta_mask * ((x[:-1] - x[1:]) - (y[:-1] - y[1:])) ** 2)
            #     .sum(dim=1)
            #     .sqrt()
            # )
            endpoint_error = (
                (delta_mask * ((x[:-1] - x[1:]).abs() - (y[:-1] - y[1:]).abs()) ** 2)
                .sum(dim=1)
                .sqrt()
            )

        # epe_nonzero = endpoint_error != 0
        nonzero = torch.count_nonzero(endpoint_error)
        epe_mean = endpoint_error.sum() / torch.clamp(
            nonzero, clamp_thr
        )  # average error for all the sequence pixels
        epe_inv_accuracy_05px = (endpoint_error > 0.5).sum() / torch.clamp(
            nonzero, clamp_thr
        )
        epe_inv_accuracy_1px = (endpoint_error > 1).sum() / torch.clamp(
            nonzero, clamp_thr
        )
        epe_inv_accuracy_2px = (endpoint_error > 2).sum() / torch.clamp(
            nonzero, clamp_thr
        )
        epe_inv_accuracy_3px = (endpoint_error > 3).sum() / torch.clamp(
            nonzero, clamp_thr
        )

        results[f"{epe_name}_mean"] = epe_mean[None]
        results[f"{epe_name}_bad_0.5px"] = epe_inv_accuracy_05px[None] * 100
        results[f"{epe_name}_bad_1px"] = epe_inv_accuracy_1px[None] * 100
        results[f"{epe_name}_bad_2px"] = epe_inv_accuracy_2px[None] * 100
        results[f"{epe_name}_bad_3px"] = epe_inv_accuracy_3px[None] * 100
    return results


def eval_TCC_sequence(
    x: torch.Tensor,
    y: torch.Tensor,
    mask: torch.Tensor,
    crop: int = 0,
    mask_thr: float = 0.5,
) -> Dict[str, torch.Tensor]:

    assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (
        x.shape,
        y.shape,
        mask.shape,
    )
    assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)
    t, c, h, w = x.shape
    # chuck out the border
    if crop > 0:
        if crop > min(y.shape[2:]) - crop:
            raise ValueError("Incorrect crop size.")
        y = y[:, :, crop:-crop, crop:-crop]
        x = x[:, :, crop:-crop, crop:-crop]
        mask = mask[:, :, crop:-crop, crop:-crop]

    y = y * (mask > mask_thr).float()
    x = x * (mask > mask_thr).float()
    x[torch.isnan(x)] = 0
    y[torch.isnan(y)] = 0

    ssim_loss = SSIM(1.0, nonnegative_ssim=True)
    delta_mask = mask[:-1] * mask[1:]

    tcc = 0
    for i in range(t-1):
        tcc += ssim_loss((torch.abs(x[i][None] - x[i+1][None]) * delta_mask[i]).expand(-1, 3, -1, -1),
                          (torch.abs(y[i][None] - y[i+1][None]) * delta_mask[i]).expand(-1, 3, -1, -1))
    tcc = tcc / (t-1)

    return tcc

def eval_TCM_sequence(
    x: torch.Tensor,
    y: torch.Tensor,
    mask: torch.Tensor,
    crop: int = 0,
    mask_thr: float = 0.5,
) -> Dict[str, torch.Tensor]:

    assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (
        x.shape,
        y.shape,
        mask.shape,
    )
    assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)

    t, c, h, w = x.shape
    # chuck out the border
    if crop > 0:
        if crop > min(y.shape[2:]) - crop:
            raise ValueError("Incorrect crop size.")
        y = y[:, :, crop:-crop, crop:-crop]
        x = x[:, :, crop:-crop, crop:-crop]
        mask = mask[:, :, crop:-crop, crop:-crop]

    y = y * (mask > mask_thr).float()
    x = x * (mask > mask_thr).float()
    y[torch.isnan(y)] = 0

    ssim_loss = SSIM(1.0, nonnegative_ssim=True, size_average=False)
    delta_mask = mask[:-1] * mask[1:]

    tcm = 0
    for i in range(t-1):
        dmax = torch.max(y[i][None].view(1, -1), -1)[0].view(1, 1, 1, 1).expand(-1, 3, -1, -1)
        dmin = torch.min(y[i][None].view(1, -1), -1)[0].view(1, 1, 1, 1).expand(-1, 3, -1, -1)

        x_norm = (x[i][None].expand(-1, 3, -1, -1) - dmin) / (dmax - dmin) * 255.
        x2_norm = (x[i+1][None].expand(-1, 3, -1, -1) - dmin) / (dmax - dmin) * 255.
        x_flow = compute_flow([x_norm.cuda(), x2_norm.cuda()], is_seq=False).cpu()

        y_norm = (y[i][None].expand(-1, 3, -1, -1) - dmin) / (dmax - dmin) * 255.
        y2_norm = (y[i+1][None].expand(-1, 3, -1, -1) - dmin) / (dmax - dmin) * 255.
        y_flow = compute_flow([y_norm.cuda(), y2_norm.cuda()], is_seq=False).cpu()

        flow_mask = torch.sum(y_flow > 250, 1, keepdim=True) == 0

        mask = delta_mask[i][None] * flow_mask
        mask = mask.expand(-1, 3, -1, -1)
        if torch.sum(mask) > 0:
            tcm += torch.mean(ssim_loss(
                torch.cat((x_flow, torch.ones_like(x_flow[:, 0, None, ...])), 1) * mask,
                torch.cat((y_flow, torch.ones_like(x_flow[:, 0, None, ...])), 1) * mask)[:, :2])
    tcm = tcm / (t-1)

    return tcm


def eval_OPW_sequence(
    img: torch.Tensor,
    x: torch.Tensor,
    y: torch.Tensor,
    mask: torch.Tensor,
    crop: int = 0,
    mask_thr: float = 0.5,
    clamp_thr: float = 1e-5,
) -> Dict[str, torch.Tensor]:

    assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (
        x.shape,
        y.shape,
        mask.shape,
    ) # T, 1, H, W
    assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)

    t, c, h, w = img[:, 0].shape
    # chuck out the border
    if crop > 0:
        if crop > min(y.shape[2:]) - crop:
            raise ValueError("Incorrect crop size.")
        y = y[:, :, crop:-crop, crop:-crop]
        x = x[:, :, crop:-crop, crop:-crop]
        mask = mask[:, :, crop:-crop, crop:-crop]

    y = y * (mask > mask_thr).float()
    x = x * (mask > mask_thr).float()
    y[torch.isnan(y)] = 0
    delta_mask = mask[:-1] * mask[1:]
    depth_mask_30 = torch.sum(y > 30, 1, keepdim=True) == 0
    depth_mask_30 = depth_mask_30[:-1] * depth_mask_30[1:]
    depth_mask_50 = torch.sum(y > 50, 1, keepdim=True) == 0
    depth_mask_50 = depth_mask_50[:-1] * depth_mask_50[1:]
    depth_mask_100 = torch.sum(y > 100, 1, keepdim=True) == 0
    depth_mask_100 = depth_mask_100[:-1] * depth_mask_100[1:]

    flow = compute_flow(img[:, 0].cuda()).cpu()
    warped_disp = flow_warp(x[1:], flow)
    warped_img = flow_warp(img[:, 0][1:].float(), flow)

    flow_mask = torch.sum(flow > 250, 1, keepdim=True) == 0

    delta_mask = delta_mask * torch.exp(-50. * torch.sqrt(
        ((warped_img / 255. - img[:, 0][:-1].float() / 255.) ** 2).sum(dim=1, keepdim=True))) * flow_mask * (
                          warped_disp > 0) > 1e-2
    opw_err = torch.abs(warped_disp - x[:-1]) * delta_mask
    opw_err_30 = torch.abs(warped_disp - x[:-1]) * delta_mask * depth_mask_30
    opw_err_50 = torch.abs(warped_disp - x[:-1]) * delta_mask * depth_mask_50
    opw_err_100 = torch.abs(warped_disp - x[:-1]) * delta_mask * depth_mask_100

    opw = 0
    opw_30 = 0
    opw_50 = 0
    opw_100 = 0
    for i in range(t-1):
        if torch.sum(delta_mask[i]) > 0:
            opw += torch.sum(opw_err[i]) / torch.sum(delta_mask[i])
        if torch.sum(delta_mask[i] * depth_mask_30[i]) > 0:
            opw_30 += torch.sum(opw_err_30[i]) / torch.sum(delta_mask[i] * depth_mask_30[i])
        if torch.sum(delta_mask[i] * depth_mask_50[i]) > 0:
            opw_50 += torch.sum(opw_err_50[i]) / torch.sum(delta_mask[i] * depth_mask_50[i])
        if torch.sum(delta_mask[i] * depth_mask_100[i]) > 0:
            opw_100 += torch.sum(opw_err_100[i]) / torch.sum(delta_mask[i] * depth_mask_100[i])
    opw = opw / (t - 1)
    opw_30 = opw_30 / (t - 1)
    opw_50 = opw_50 / (t - 1)
    opw_100 = opw_100 / (t - 1)
    return opw, opw_30, opw_50, opw_100


def eval_RTC_sequence(
    img: torch.Tensor,
    x: torch.Tensor,
    y: torch.Tensor,
    mask: torch.Tensor,
    crop: int = 0,
    mask_thr: float = 0.5,
    clamp_thr: float = 1e-5,
) -> Dict[str, torch.Tensor]:

    assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (
        x.shape,
        y.shape,
        mask.shape,
    ) # T, 1, H, W
    assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)

    t, c, h, w = img[:, 0].shape
    # chuck out the border
    if crop > 0:
        if crop > min(y.shape[2:]) - crop:
            raise ValueError("Incorrect crop size.")
        y = y[:, :, crop:-crop, crop:-crop]
        x = x[:, :, crop:-crop, crop:-crop]
        mask = mask[:, :, crop:-crop, crop:-crop]

    y = y * (mask > mask_thr).float()
    x = x * (mask > mask_thr).float()
    y[torch.isnan(y)] = 0

    flow = compute_flow(img[:, 0].cuda()).cpu()
    delta_mask = mask[:-1] * mask[1:]

    warped_disp = flow_warp(x[1:], flow)
    warped_img = flow_warp(img[:, 0][1:], flow)

    flow_mask = torch.sum(flow > 250, 1, keepdim=True) == 0
    depth_mask = torch.sum(y > 30, 1, keepdim=True) == 0
    depth_mask = depth_mask[:-1] * depth_mask[1:]

    delta_mask = delta_mask * torch.exp(-50. * torch.sqrt(
        ((warped_img / 255. - img[:, 0][:-1] / 255.) ** 2).sum(dim=1, keepdim=True))) * flow_mask * (
                         warped_disp > 0) > 1e-2
    tau = 1.01

    x1 = x[:-1]  / warped_disp
    x2 = warped_disp / x[:-1]

    x1[torch.isinf(x1)] = -1e10
    x2[torch.isinf(x2)] = -1e10
    x = torch.max(torch.cat((x1, x2), 1), 1)[0] < tau

    rtc_err = x[:, None] * delta_mask
    rtc_err_30 = x[:, None] * delta_mask * depth_mask
    rtc = 0
    rtc_30 = 0
    for i in range(t-1):
        if torch.sum(delta_mask[i]) > 0:
            rtc += torch.sum(rtc_err[i]) / torch.sum(delta_mask[i])
        if torch.sum(delta_mask[i] * depth_mask[i]) > 0:
            rtc_30 += torch.sum(rtc_err_30[i]) / torch.sum(delta_mask[i] * depth_mask[i])
    rtc = rtc / (t-1)
    rtc_30 = rtc_30 / (t - 1)
    return rtc, rtc_30


def depth2disparity_scale(left_camera, right_camera, image_size_tensor):
    # # opencv camera matrices
    (_, T1, K1), (_, T2, _) = [
        opencv_from_cameras_projection(
            f,
            image_size_tensor,
        )
        for f in (left_camera, right_camera)
    ]
    fix_baseline = T1[0][0] - T2[0][0]
    focal_length_px = K1[0][0][0]
    # following this https://github.com/princeton-vl/RAFT-Stereo#converting-disparity-to-depth
    return focal_length_px * fix_baseline


def depth_to_pcd(
    depth_map,
    img,
    focal_length,
    cx,
    cy,
    step: int = None,
    inv_extrinsic=None,
    mask=None,
    filter=False,
):
    __, w, __ = img.shape
    if step is None:
        step = int(w / 100)
    Z = depth_map[::step, ::step]
    colors = img[::step, ::step, :]

    Pixels_Y = torch.arange(Z.shape[0]).to(Z.device) * step
    Pixels_X = torch.arange(Z.shape[1]).to(Z.device) * step

    X = (Pixels_X[None, :] - cx) * Z / focal_length
    Y = (Pixels_Y[:, None] - cy) * Z / focal_length

    inds = Z > 0

    if mask is not None:
        inds = inds * (mask[::step, ::step] > 0)

    X = X[inds].reshape(-1)
    Y = Y[inds].reshape(-1)
    Z = Z[inds].reshape(-1)
    colors = colors[inds]
    pcd = torch.stack([X, Y, Z]).T

    if inv_extrinsic is not None:
        pcd_ext = torch.vstack([pcd.T, torch.ones((1, pcd.shape[0])).to(Z.device)])
        pcd = (inv_extrinsic @ pcd_ext)[:3, :].T

    if filter:
        pcd, filt_inds = filter_outliers(pcd)
        colors = colors[filt_inds]
    return pcd, colors


def filter_outliers(pcd, sigma=3):
    mean = pcd.mean(0)
    std = pcd.std(0)
    inds = ((pcd - mean).abs() < sigma * std)[:, 2]
    pcd = pcd[inds]
    return pcd, inds


def eval_batch(batch_dict, predictions, scale) -> Dict[str, Union[float, torch.Tensor]]:
    """
    Produce performance metrics for a single batch of perception
    predictions.
    Args:
        frame_data: A PixarFrameData object containing the input to the new view
            synthesis method.
        preds: A PerceptionPrediction object with the predicted data.
    Returns:
        results: A dictionary holding evaluation metrics.
    """
    results = {}

    assert "disparity" in predictions
    mask_now = torch.ones_like(batch_dict["fg_mask"])

    mask_now = mask_now * batch_dict["disparity_mask"]

    eval_flow_traj_output = eval_endpoint_error_sequence(
        predictions["disparity"], batch_dict["disparity"], mask_now
    )
    for epe_name in ("epe", "temp_epe"):
        results[PerceptionMetric(f"disp_{epe_name}_mean")] = eval_flow_traj_output[
            f"{epe_name}_mean"
        ]

        results[PerceptionMetric(f"disp_{epe_name}_bad_3px")] = eval_flow_traj_output[
            f"{epe_name}_bad_3px"
        ]

        results[PerceptionMetric(f"disp_{epe_name}_bad_2px")] = eval_flow_traj_output[
            f"{epe_name}_bad_2px"
        ]

        results[PerceptionMetric(f"disp_{epe_name}_bad_1px")] = eval_flow_traj_output[
            f"{epe_name}_bad_1px"
        ]

        results[PerceptionMetric(f"disp_{epe_name}_bad_0.5px")] = eval_flow_traj_output[
            f"{epe_name}_bad_0.5px"
        ]
    if "endpoint_error_per_pixel" in eval_flow_traj_output:
        results["disp_endpoint_error_per_pixel"] = eval_flow_traj_output[
            "endpoint_error_per_pixel"
        ]

    # disparity to depth
    depth = scale / predictions["disparity"].clamp(min=1e-10)

    eval_TCC_output = eval_TCC_sequence(
        depth, scale / batch_dict["disparity"].clamp(min=1e-10), mask_now
    )
    results[PerceptionMetric("disp_TCC")] = eval_TCC_output[None]

    eval_TCM_output = eval_TCM_sequence(
        depth, scale / batch_dict["disparity"].clamp(min=1e-10), mask_now
    )
    results[PerceptionMetric("disp_TCM")] = eval_TCM_output[None]

    eval_OPW_output, eval_OPW_30_output, eval_OPW_50_output, eval_OPW_100_output = eval_OPW_sequence(
        batch_dict["stereo_video"], depth, scale / batch_dict["disparity"].clamp(min=1e-10), mask_now
    )
    results[PerceptionMetric("disp_OPW")] = eval_OPW_output[None]
    results[PerceptionMetric("disp_OPW_100")] = eval_OPW_100_output[None]
    results[PerceptionMetric("disp_OPW_50")] = eval_OPW_50_output[None]
    if eval_OPW_30_output > 0:
        results[PerceptionMetric("disp_OPW_30")] = eval_OPW_30_output[None]
    else:
        results[PerceptionMetric("disp_OPW_30")] = torch.tensor([0.0])

    eval_RTC_output, eval_RTC_30_output = eval_RTC_sequence(
        batch_dict["stereo_video"], depth, scale / batch_dict["disparity"].clamp(min=1e-10), mask_now
    )
    results[PerceptionMetric("disp_RTC")] = eval_RTC_output[None]
    if eval_RTC_30_output > 0:
        results[PerceptionMetric("disp_RTC_30")] = eval_RTC_30_output[None]
    else:
        results[PerceptionMetric("disp_RTC_30")] = torch.tensor([0.0])

    return (results, len(predictions["disparity"]))


================================================
FILE: evaluation/utils/ssim.py
================================================
# Copyright 2020 by Gongfan Fang, Zhejiang University.
# All rights reserved.
import warnings

import torch
import torch.nn.functional as F


def _fspecial_gauss_1d(size, sigma):
    r"""Create 1-D gauss kernel
    Args:
        size (int): the size of gauss kernel
        sigma (float): sigma of normal distribution
    Returns:
        torch.Tensor: 1D kernel (1 x 1 x size)
    """
    coords = torch.arange(size, dtype=torch.float)
    coords -= size // 2

    g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
    g /= g.sum()

    return g.unsqueeze(0).unsqueeze(0)


def gaussian_filter(input, win):
    r""" Blur input with 1-D kernel
    Args:
        input (torch.Tensor): a batch of tensors to be blurred
        window (torch.Tensor): 1-D gauss kernel
    Returns:
        torch.Tensor: blurred tensors
    """
    assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape
    if len(input.shape) == 4:
        conv = F.conv2d
    elif len(input.shape) == 5:
        conv = F.conv3d
    else:
        raise NotImplementedError(input.shape)

    C = input.shape[1]
    out = input
    for i, s in enumerate(input.shape[2:]):
        if s >= win.shape[-1]:
            out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C)
        else:
            warnings.warn(
                f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}"
            )
    return out


def _ssim(X, Y, data_range, win, size_average=True, K=(0.01, 0.03)):

    r""" Calculate ssim index for X and Y
    Args:
        X (torch.Tensor): images
        Y (torch.Tensor): images
        win (torch.Tensor): 1-D gauss kernel
        data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
    Returns:
        torch.Tensor: ssim results.
    """
    K1, K2 = K
    # batch, channel, [depth,] height, width = X.shape
    compensation = 1.0

    C1 = (K1 * data_range) ** 2
    C2 = (K2 * data_range) ** 2

    win = win.to(X.device, dtype=X.dtype)

    mu1 = gaussian_filter(X, win)
    mu2 = gaussian_filter(Y, win)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq)
    sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq)
    sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2)

    cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)  # set alpha=beta=gamma=1
    ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map

    ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1)
    cs = torch.flatten(cs_map, 2).mean(-1)
    return ssim_per_channel, cs


def ssim(
    X,
    Y,
    data_range=255,
    size_average=True,
    win_size=11,
    win_sigma=1.5,
    win=None,
    K=(0.01, 0.03),
    nonnegative_ssim=False,
):
    r""" interface of ssim
    Args:
        X (torch.Tensor): a batch of images, (N,C,H,W)
        Y (torch.Tensor): a batch of images, (N,C,H,W)
        data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
        win_size: (int, optional): the size of gauss kernel
        win_sigma: (float, optional): sigma of normal distribution
        win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
        K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
        nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu
    Returns:
        torch.Tensor: ssim results
    """
    if not X.shape == Y.shape:
        raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.")

    for d in range(len(X.shape) - 1, 1, -1):
        X = X.squeeze(dim=d)
        Y = Y.squeeze(dim=d)

    if len(X.shape) not in (4, 5):
        raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")

    if not X.type() == Y.type():
        raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")

    if win is not None:  # set win_size
        win_size = win.shape[-1]

    if not (win_size % 2 == 1):
        raise ValueError("Window size should be odd.")

    if win is None:
        win = _fspecial_gauss_1d(win_size, win_sigma)
        win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))

    ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K)
    if nonnegative_ssim:
        ssim_per_channel = torch.relu(ssim_per_channel)

    if size_average:
        return ssim_per_channel.mean()
    else:
        return ssim_per_channel #.mean(1)


def ms_ssim(
    X, Y, data_range=255, size_average=True, win_size=11, win_sigma=1.5, win=None, weights=None, K=(0.01, 0.03)
):

    r""" interface of ms-ssim
    Args:
        X (torch.Tensor): a batch of images, (N,C,[T,]H,W)
        Y (torch.Tensor): a batch of images, (N,C,[T,]H,W)
        data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
        win_size: (int, optional): the size of gauss kernel
        win_sigma: (float, optional): sigma of normal distribution
        win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
        weights (list, optional): weights for different levels
        K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
    Returns:
        torch.Tensor: ms-ssim results
    """
    if not X.shape == Y.shape:
        raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.")

    for d in range(len(X.shape) - 1, 1, -1):
        X = X.squeeze(dim=d)
        Y = Y.squeeze(dim=d)

    if not X.type() == Y.type():
        raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")

    if len(X.shape) == 4:
        avg_pool = F.avg_pool2d
    elif len(X.shape) == 5:
        avg_pool = F.avg_pool3d
    else:
        raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")

    if win is not None:  # set win_size
        win_size = win.shape[-1]

    if not (win_size % 2 == 1):
        raise ValueError("Window size should be odd.")

    smaller_side = min(X.shape[-2:])
    assert smaller_side > (win_size - 1) * (
        2 ** 4
    ), "Image size should be larger than %d due to the 4 downsamplings in ms-ssim" % ((win_size - 1) * (2 ** 4))

    if weights is None:
        weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
    weights = X.new_tensor(weights)

    if win is None:
        win = _fspecial_gauss_1d(win_size, win_sigma)
        win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))

    levels = weights.shape[0]
    mcs = []
    for i in range(levels):
        ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K)

        if i < levels - 1:
            mcs.append(torch.relu(cs))
            padding = [s % 2 for s in X.shape[2:]]
            X = avg_pool(X, kernel_size=2, padding=padding)
            Y = avg_pool(Y, kernel_size=2, padding=padding)

    ssim_per_channel = torch.relu(ssim_per_channel)  # (batch, channel)
    mcs_and_ssim = torch.stack(mcs + [ssim_per_channel], dim=0)  # (level, batch, channel)
    ms_ssim_val = torch.prod(mcs_and_ssim ** weights.view(-1, 1, 1), dim=0)

    if size_average:
        return ms_ssim_val.mean()
    else:
        return ms_ssim_val.mean(1)


class SSIM(torch.nn.Module):
    def __init__(
        self,
        data_range=255,
        size_average=True,
        win_size=11,
        win_sigma=1.5,
        channel=3,
        spatial_dims=2,
        K=(0.01, 0.03),
        nonnegative_ssim=False,
    ):
        r""" class for ssim
        Args:
            data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
            size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
            win_size: (int, optional): the size of gauss kernel
            win_sigma: (float, optional): sigma of normal distribution
            channel (int, optional): input channels (default: 3)
            K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
            nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu.
        """

        super(SSIM, self).__init__()
        self.win_size = win_size
        self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)
        self.size_average = size_average
        self.data_range = data_range
        self.K = K
        self.nonnegative_ssim = nonnegative_ssim

    def forward(self, X, Y):
        return ssim(
            X,
            Y,
            data_range=self.data_range,
            size_average=self.size_average,
            win=self.win,
            K=self.K,
            nonnegative_ssim=self.nonnegative_ssim,
        )


class MS_SSIM(torch.nn.Module):
    def __init__(
        self,
        data_range=255,
        size_average=True,
        win_size=11,
        win_sigma=1.5,
        channel=3,
        spatial_dims=2,
        weights=None,
        K=(0.01, 0.03),
    ):
        r""" class for ms-ssim
        Args:
            data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
            size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
            win_size: (int, optional): the size of gauss kernel
            win_sigma: (float, optional): sigma of normal distribution
            channel (int, optional): input channels (default: 3)
            weights (list, optional): weights for different levels
            K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
        """

        super(MS_SSIM, self).__init__()
        self.win_size = win_size
        self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)
        self.size_average = size_average
        self.data_range = data_range
        self.weights = weights
        self.K = K

    def forward(self, X, Y):
        return ms_ssim(
            X,
            Y,
            data_range=self.data_range,
            size_average=self.size_average,
            win=self.win,
            weights=self.weights,
            K=self.K,
        )


================================================
FILE: evaluation/utils/utils.py
================================================
from collections import defaultdict
import configparser
import os
import math
from typing import Optional, List
import torch
import cv2
import numpy as np
from dataclasses import dataclass
from tabulate import tabulate
import logging

from pytorch3d.structures import Pointclouds
from pytorch3d.transforms import RotateAxisAngle
from pytorch3d.utils import (
    opencv_from_cameras_projection,
)
from pytorch3d.renderer import (
    AlphaCompositor,
    PointsRasterizationSettings,
    PointsRasterizer,
    PointsRenderer,
)
from stereoanyvideo.evaluation.utils.eval_utils import depth_to_pcd


@dataclass
class PerceptionPrediction:
    """
    Holds the tensors that describe a result of any perception module.
    """

    depth_map: Optional[torch.Tensor] = None
    disparity: Optional[torch.Tensor] = None
    image_rgb: Optional[torch.Tensor] = None
    fg_probability: Optional[torch.Tensor] = None


def aggregate_eval_results(per_batch_eval_results, reduction="mean"):

    total_length = 0
    aggregate_results = defaultdict(list)
    for result in per_batch_eval_results:
        if isinstance(result, tuple):
            reduction = "sum"
            length = result[1]
            total_length += length
            result = result[0]
        for metric, val in result.items():
            if reduction == "sum":
                aggregate_results[metric].append(val * length)

    if reduction == "mean":
        return {k: torch.cat(v).mean().item() for k, v in aggregate_results.items()}
    elif reduction == "sum":
        return {
            k: torch.cat(v).sum().item() / float(total_length)
            for k, v in aggregate_results.items()
        }


def aggregate_and_print_results(
    per_batch_eval_results: List[dict],
):
    print("")
    result = aggregate_eval_results(
        per_batch_eval_results,
    )
    pretty_print_perception_metrics(result)
    result = {str(k): v for k, v in result.items()}

    print("")
    return result


def pretty_print_perception_metrics(results):

    metrics = sorted(list(results.keys()), key=lambda x: x.metric)

    print("===== Perception results =====")
    print(
        tabulate(
            [[metric, results[metric]] for metric in metrics],
        )
    )
    logging.info("===== Perception results =====")
    logging.info(tabulate(
            [[metric, results[metric]] for metric in metrics],
        ))


def read_calibration(calibration_file, resolution_string):
    # ported from https://github.com/stereolabs/zed-open-capture/
    # blob/dfa0aee51ccd2297782230a05ca59e697df496b2/examples/include/calibration.hpp#L4172

    zed_resolutions = {
        "2K": (1242, 2208),
        "FHD": (1080, 1920),
        "HD": (720, 1280),
        # "qHD": (540, 960),
        "VGA": (376, 672),
    }
    assert resolution_string in zed_resolutions.keys()
    image_height, image_width = zed_resolutions[resolution_string]

    # Open camera configuration file
    assert os.path.isfile(calibration_file)
    calib = configparser.ConfigParser()
    calib.read(calibration_file)

    # Get translations
    T = np.zeros((3, 1))
    T[0] = float(calib["STEREO"]["baseline"])
    T[1] = float(calib["STEREO"]["ty"])
    T[2] = float(calib["STEREO"]["tz"])

    baseline = T[0]

    # Get left parameters
    left_cam_cx = float(calib[f"LEFT_CAM_{resolution_string}"]["cx"])
    left_cam_cy = float(calib[f"LEFT_CAM_{resolution_string}"]["cy"])
    left_cam_fx = float(calib[f"LEFT_CAM_{resolution_string}"]["fx"])
    left_cam_fy = float(calib[f"LEFT_CAM_{resolution_string}"]["fy"])
    left_cam_k1 = float(calib[f"LEFT_CAM_{resolution_string}"]["k1"])
    left_cam_k2 = float(calib[f"LEFT_CAM_{resolution_string}"]["k2"])
    left_cam_p1 = float(calib[f"LEFT_CAM_{resolution_string}"]["p1"])
    left_cam_p2 = float(calib[f"LEFT_CAM_{resolution_string}"]["p2"])
    left_cam_k3 = float(calib[f"LEFT_CAM_{resolution_string}"]["k3"])

    # Get right parameters
    right_cam_cx = float(calib[f"RIGHT_CAM_{resolution_string}"]["cx"])
    right_cam_cy = float(calib[f"RIGHT_CAM_{resolution_string}"]["cy"])
    right_cam_fx = float(calib[f"RIGHT_CAM_{resolution_string}"]["fx"])
    right_cam_fy = float(calib[f"RIGHT_CAM_{resolution_string}"]["fy"])
    right_cam_k1 = float(calib[f"RIGHT_CAM_{resolution_string}"]["k1"])
    right_cam_k2 = float(calib[f"RIGHT_CAM_{resolution_string}"]["k2"])
    right_cam_p1 = float(calib[f"RIGHT_CAM_{resolution_string}"]["p1"])
    right_cam_p2 = float(calib[f"RIGHT_CAM_{resolution_string}"]["p2"])
    right_cam_k3 = float(calib[f"RIGHT_CAM_{resolution_string}"]["k3"])

    # Get rotations
    R_zed = np.zeros(3)
    R_zed[0] = float(calib["STEREO"][f"rx_{resolution_string.lower()}"])
    R_zed[1] = float(calib["STEREO"][f"cv_{resolution_string.lower()}"])
    R_zed[2] = float(calib["STEREO"][f"rz_{resolution_string.lower()}"])

    R = cv2.Rodrigues(R_zed)[0]

    # Left
    cameraMatrix_left = np.array(
        [[left_cam_fx, 0, left_cam_cx], [0, left_cam_fy, left_cam_cy], [0, 0, 1]]
    )
    distCoeffs_left = np.array(
        [left_cam_k1, left_cam_k2, left_cam_p1, left_cam_p2, left_cam_k3]
    )

    # Right
    cameraMatrix_right = np.array(
        [
            [right_cam_fx, 0, right_cam_cx],
            [0, right_cam_fy, right_cam_cy],
            [0, 0, 1],
        ]
    )
    distCoeffs_right = np.array(
        [right_cam_k1, right_cam_k2, right_cam_p1, right_cam_p2, right_cam_k3]
    )

    # Stereo
    R1, R2, P1, P2, Q = cv2.stereoRectify(
        cameraMatrix1=cameraMatrix_left,
        distCoeffs1=distCoeffs_left,
        cameraMatrix2=cameraMatrix_right,
        distCoeffs2=distCoeffs_right,
        imageSize=(image_width, image_height),
        R=R,
        T=T,
        flags=cv2.CALIB_ZERO_DISPARITY,
        newImageSize=(image_width, image_height),
        alpha=0,
    )[:5]

    # Precompute maps for cv::remap()
    map_left_x, map_left_y = cv2.initUndistortRectifyMap(
        cameraMatrix_left,
        distCoeffs_left,
        R1,
        P1,
        (image_width, image_height),
        cv2.CV_32FC1,
    )
    map_right_x, map_right_y = cv2.initUndistortRectifyMap(
        cameraMatrix_right,
        distCoeffs_right,
        R2,
        P2,
        (image_width, image_height),
        cv2.CV_32FC1,
    )

    zed_calib = {
        "map_left_x": map_left_x,
        "map_left_y": map_left_y,
        "map_right_x": map_right_x,
        "map_right_y": map_right_y,
        "pose_left": P1,
        "pose_right": P2,
        "baseline": baseline,
        "image_width": image_width,
        "image_height": image_height,
    }

    return zed_calib


def filter_depth_discontinuities(pcd, depth_map, threshold=5):
    """
    Removes points that belong to high-depth discontinuity regions.

    Args:
        pcd (torch.Tensor): Nx3 point cloud tensor.
        depth_map (torch.Tensor): HxW depth map.
        threshold (float): Depth change threshold.

    Returns:
        torch.Tensor: Filtered point cloud.
    """
    # Compute depth differences in x and y directions
    depth_diff_x = torch.abs(depth_map[:, 1:] - depth_map[:, :-1])  # Shape (H, W-1)
    depth_diff_y = torch.abs(depth_map[1:, :] - depth_map[:-1, :])  # Shape (H-1, W)

    # Initialize mask with all True (valid points)
    mask = torch.ones_like(depth_map, dtype=torch.bool)  # Shape (H, W)

    # Apply filtering: set False where depth difference is too large
    mask[:, :-1] &= depth_diff_x <= threshold  # X-direction filtering
    mask[:-1, :] &= depth_diff_y <= threshold  # Y-direction filtering

    # Flatten mask to match point cloud size
    mask_flat = mask.flatten()[: pcd.shape[0]]

    return pcd[mask_flat]  # Return only valid points



def visualize_batch(
    batch_dict: dict,
    preds: PerceptionPrediction,
    output_dir: str,
    ref_frame: int = 0,
    only_foreground=False,
    step=0,
    sequence_name=None,
    writer=None,
    render_bin_size=None
):
    os.makedirs(output_dir, exist_ok=True)

    outputs = {}

    if preds.depth_map is not None:
        device = preds.depth_map.device

        pcd_global_seq = []
        H, W = batch_dict["stereo_video"].shape[3:]

        for i in range(len(batch_dict["stereo_video"])):
            if hasattr(preds, 'perspective_cameras'):
                R, T, K = opencv_from_cameras_projection(
                    preds.perspective_cameras[i],
                    torch.tensor([H, W])[None].to(device),
                )  # 1x3x3, 1x3, 1x3x3
            else:
                raise KeyError(f"R T K not found!")
            extrinsic_3x4_0 = torch.cat([R[0], T[0, :, None]], dim=1)
            extr_matrix = torch.cat(
                [
                    extrinsic_3x4_0,
                    torch.Tensor([[0, 0, 0, 1]]).to(extrinsic_3x4_0.device),
                ],
                dim=0,
            )
            inv_extr_matrix = extr_matrix.inverse().to(device)
            pcd, colors = depth_to_pcd(
                preds.depth_map[i, 0],
                batch_dict["stereo_video"][i][0].permute(1, 2, 0),
                K[0][0][0],
                K[0][0][2],
                K[0][1][2],
                step=1,
                inv_extrinsic=inv_extr_matrix,
                mask=batch_dict["fg_mask"][i, 0] if only_foreground else None,
                filter=False,
            )
            R, T = inv_extr_matrix[None, :3, :3], inv_extr_matrix[None, :3, 3]
            pcd_global_seq.append((pcd, colors, (R, T, preds.perspective_cameras[i])))

        raster_settings = PointsRasterizationSettings(
            image_size=[H, W],
            radius=0.003,
            points_per_pixel=10,
        )
        R, T, cam_ = pcd_global_seq[ref_frame][2]
        median_depth = preds.depth_map.median()
        cam_.cuda()

        for mode in ["angle_15", "angle_-15", "angle_0", "changing_angle"]:
            res = []
            for t, (pcd, color, __) in enumerate(pcd_global_seq):

                if mode == "changing_angle":
                    angle = math.cos((math.pi) * (t / 60)) * 15
                elif mode == "angle_15":
                    angle = 15
                elif mode == "angle_-15":
                    angle = -15
                elif mode == "angle_0":
                    angle = 0

                delta_x = median_depth * math.sin(math.radians(angle))
                delta_z = median_depth * (1 - math.cos(math.radians(angle)))

                cam = cam_.clone()
                cam.R = torch.bmm(
                    cam.R,
                    RotateAxisAngle(angle=angle, axis="Y", device=device).get_matrix()[
                        :, :3, :3
                    ],
                )
                cam.T[0, 0] = cam.T[0, 0] - delta_x
                cam.T[0, 2] = cam.T[0, 2] - delta_z + median_depth / 2.0

                rasterizer = PointsRasterizer(
                    cameras=cam, raster_settings=raster_settings
                )
                renderer = PointsRenderer(
                    rasterizer=rasterizer,
                    compositor=AlphaCompositor(background_color=(1, 1, 1)),
                )
                pcd_copy = pcd.clone()
                point_cloud = Pointclouds(points=[pcd_copy], features=[color / 255.0])

                images = renderer(point_cloud)
                res.append(images[0, ..., :3].cpu())
            res = torch.stack(res)
            video = (res * 255).numpy().astype(np.uint8)
            save_name = f"{sequence_name}_reconstruction_{step}_mode_{mode}_"

            if writer is None:
                outputs[mode] = video
            if only_foreground:
                save_name += "fg_only"
            else:
                save_name += "full_scene"
            video_out = cv2.VideoWriter(
                os.path.join(
                    output_dir,
                    f"{save_name}.mp4",
                ),
                cv2.VideoWriter_fourcc(*"mp4v"),
                fps=30,
                frameSize=(res.shape[2], res.shape[1]),
                isColor=True,
            )
            filename = os.path.join(output_dir, sequence_name + '_img_')
            if not os.path.isdir(filename + str(mode)):
                os.mkdir(filename + str(mode))
            for i in range(len(video)):
                filename_temp = filename + str(mode) + '/' + str(i).zfill(3) + '.png'
                cv2.imwrite(filename_temp, cv2.cvtColor(video[i], cv2.COLOR_BGR2RGB))
                video_out.write(cv2.cvtColor(video[i], cv2.COLOR_BGR2RGB))
            video_out.release()

            if writer is not None:
                writer.add_video(
                    f"{sequence_name}_reconstruction_mode_{mode}",
                    (res * 255).permute(0, 3, 1, 2).to(torch.uint8)[None],
                    global_step=step,
                    fps=30,
                )

    return outputs


================================================
FILE: models/Video-Depth-Anything/app.py
================================================
# Copyright (2025) Bytedance Ltd. and/or its affiliates 

# Licensed under the Apache License, Version 2.0 (the "License"); 
# you may not use this file except in compliance with the License. 
# You may obtain a copy of the License at 

#     http://www.apache.org/licenses/LICENSE-2.0 

# Unless required by applicable law or agreed to in writing, software 
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
# See the License for the specific language governing permissions and 
# limitations under the License. 
import gradio as gr


import numpy as np
import os
import torch

from video_depth_anything.video_depth import VideoDepthAnything
from utils.dc_utils import read_video_frames, vis_sequence_depth, save_video

examples = [
    ['assets/example_videos/davis_rollercoaster.mp4'],
]

model_configs = {
    'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
    'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
}

encoder='vitl'

video_depth_anything = VideoDepthAnything(**model_configs[encoder])
video_depth_anything.load_state_dict(torch.load(f'./checkpoints/video_depth_anything_{encoder}.pth', map_location='cpu'), strict=True)
video_depth_anything = video_depth_anything.to('cuda').eval()


def infer_video_depth(
    input_video: str,
    max_len: int = -1,
    target_fps: int = -1,
    max_res: int = 1280,
    output_dir: str = './outputs',
    input_size: int = 518,
):
    frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
    depth_list, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device='cuda')
    depth_list = np.stack(depth_list, axis=0)
    vis = vis_sequence_depth(depth_list)
    video_name = os.path.basename(input_video)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_src.mp4')
    depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')
    save_video(frames, processed_video_path, fps=fps)
    save_video(vis, depth_vis_path, fps=fps)

    return [processed_video_path, depth_vis_path]


def construct_demo():
    with gr.Blocks(analytics_enabled=False) as demo:
        gr.Markdown(
            f"""
            blablabla
            """
        )

        with gr.Row(equal_height=True):
            with gr.Column(scale=1):
                input_video = gr.Video(label="Input Video")

            # with gr.Tab(label="Output"):
            with gr.Column(scale=2):
                with gr.Row(equal_height=True):
                    processed_video = gr.Video(
                        label="Preprocessed video",
                        interactive=False,
                        autoplay=True,
                        loop=True,
                        show_share_button=True,
                        scale=5,
                    )
                    depth_vis_video = gr.Video(
                        label="Generated Depth Video",
                        interactive=False,
                        autoplay=True,
                        loop=True,
                        show_share_button=True,
                        scale=5,
                    )

        with gr.Row(equal_height=True):
            with gr.Column(scale=1):
                with gr.Row(equal_height=False):
                    with gr.Accordion("Advanced Settings", open=False):
                        max_len = gr.Slider(
                            label="max process length",
                            minimum=-1,
                            maximum=1000,
                            value=-1,
                            step=1,
                        )
                        target_fps = gr.Slider(
                            label="target FPS",
                            minimum=-1,
                            maximum=30,
                            value=15,
                            step=1,
                        )
                        max_res = gr.Slider(
                            label="max side resolution",
                            minimum=480,
                            maximum=1920,
                            value=1280,
                            step=1,
                        )
                    generate_btn = gr.Button("Generate")
            with gr.Column(scale=2):
                pass

        gr.Examples(
            examples=examples,
            inputs=[
                input_video,
                max_len,
                target_fps,
                max_res
            ],
            outputs=[processed_video, depth_vis_video],
            fn=infer_video_depth,
            cache_examples="lazy",
        )

        generate_btn.click(
            fn=infer_video_depth,
            inputs=[
                input_video,
                max_len,
                target_fps,
                max_res
            ],
            outputs=[processed_video, depth_vis_video],
        )

    return demo

if __name__ == "__main__":
    demo = construct_demo()
    demo.queue()
    demo.launch(server_name="0.0.0.0")

================================================
FILE: models/Video-Depth-Anything/get_weights.sh
================================================
#!/bin/bash

mkdir checkpoints
cd checkpoints
wget https://huggingface.co/depth-anything/Video-Depth-Anything-Small/resolve/main/video_depth_anything_vits.pth
wget https://huggingface.co/depth-anything/Video-Depth-Anything-Large/resolve/main/video_depth_anything_vitl.pth

================================================
FILE: models/Video-Depth-Anything/run.py
================================================
# Copyright (2025) Bytedance Ltd. and/or its affiliates 

# Licensed under the Apache License, Version 2.0 (the "License"); 
# you may not use this file except in compliance with the License. 
# You may obtain a copy of the License at 

#     http://www.apache.org/licenses/LICENSE-2.0 

# Unless required by applicable law or agreed to in writing, software 
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
# See the License for the specific language governing permissions and 
# limitations under the License. 
import argparse
import numpy as np
import os
import torch

from video_depth_anything.video_depth import VideoDepthAnything
from utils.dc_utils import read_video_frames, vis_sequence_depth, save_video

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Video Depth Anything')
    parser.add_argument('--input_video', type=str, default='./assets/example_videos/davis_rollercoaster.mp4')
    parser.add_argument('--output_dir', type=str, default='./outputs')
    parser.add_argument('--input_size', type=int, default=518)
    parser.add_argument('--max_res', type=int, default=1280)
    parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitl'])
    parser.add_argument('--max_len', type=int, default=-1, help='maximum length of the input video, -1 means no limit')
    parser.add_argument('--target_fps', type=int, default=-1, help='target fps of the input video, -1 means the original fps')

    args = parser.parse_args()

    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

    model_configs = {
        'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
        'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
    }

    video_depth_anything = VideoDepthAnything(**model_configs[args.encoder])
    video_depth_anything.load_state_dict(torch.load(f'./checkpoints/video_depth_anything_{args.encoder}.pth', map_location='cpu'), strict=True)
    video_depth_anything = video_depth_anything.to(DEVICE).eval()

    frames, target_fps = read_video_frames(args.input_video, args.max_len, args.target_fps, args.max_res)
    depth_list, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=args.input_size, device=DEVICE)
    depth_list = np.stack(depth_list, axis=0)
    vis = vis_sequence_depth(depth_list)
    video_name = os.path.basename(args.input_video)
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    processed_video_path = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_src.mp4')
    depth_vis_path = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')
    save_video(frames, processed_video_path, fps=fps)
    save_video(vis, depth_vis_path, fps=fps)

    




================================================
FILE: models/Video-Depth-Anything/utils/dc_utils.py
================================================
# This file is originally from DepthCrafter/depthcrafter/utils.py at main · Tencent/DepthCrafter
# SPDX-License-Identifier: MIT License license
#
# This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]
# Original file is released under [ MIT License license], with the full license text available at [https://github.com/Tencent/DepthCrafter?tab=License-1-ov-file].
from typing import Union, List
import tempfile
import numpy as np
import PIL.Image
import matplotlib.cm as cm
import mediapy
import torch
try:
    from decord import VideoReader, cpu
    DECORD_AVAILABLE = True
except:
    import cv2
    DECORD_AVAILABLE = False


def read_video_frames(video_path, process_length, target_fps=-1, max_res=-1, dataset="open"):
    if DECORD_AVAILABLE:
        vid = VideoReader(video_path, ctx=cpu(0))
        original_height, original_width = vid.get_batch([0]).shape[1:3]
        height = original_height
        width = original_width
        if max_res > 0 and max(height, width) > max_res:
            scale = max_res / max(original_height, original_width)
            height = round(original_height * scale)
            width = round(original_width * scale)

        vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)

        fps = vid.get_avg_fps() if target_fps == -1 else target_fps
        stride = round(vid.get_avg_fps() / fps)
        stride = max(stride, 1)
        frames_idx = list(range(0, len(vid), stride))
        if process_length != -1 and process_length < len(frames_idx):
            frames_idx = frames_idx[:process_length]
        frames = vid.get_batch(frames_idx).asnumpy()
    else:
        cap = cv2.VideoCapture(video_path)
        original_fps = cap.get(cv2.CAP_PROP_FPS)
        original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))

        if max_res > 0 and max(original_height, original_width) > max_res:
            scale = max_res / max(original_height, original_width)
            height = round(original_height * scale)
            width = round(original_width * scale)

        fps = original_fps if target_fps < 0 else target_fps

        stride = max(round(original_fps / fps), 1)

        frames = []
        frame_count = 0
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret or (process_length > 0 and frame_count >= process_length):
                break
            if frame_count % stride == 0:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
                if max_res > 0 and max(original_height, original_width) > max_res:
                    frame = cv2.resize(frame, (width, height))  # Resize frame
                frames.append(frame)
            frame_count += 1
        cap.release()
        frames = np.stack(frames, axis=0)

    return frames, fps


def save_video(
    video_frames: Union[List[np.ndarray], List[PIL.Image.Image]],
    output_video_path: str = None,
    fps: int = 10,
    crf: int = 18,
) -> str:
    if output_video_path is None:
        output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name

    if isinstance(video_frames[0], np.ndarray):
        video_frames = [frame.astype(np.uint8) for frame in video_frames]

    elif isinstance(video_frames[0], PIL.Image.Image):
        video_frames = [np.array(frame) for frame in video_frames]
    mediapy.write_video(output_video_path, video_frames, fps=fps, crf=crf)
    return output_video_path


class ColorMapper:
    # a color mapper to map depth values to a certain colormap
    def __init__(self, colormap: str = "inferno"):
        self.colormap = torch.tensor(cm.get_cmap(colormap).colors)

    def apply(self, image: torch.Tensor, v_min=None, v_max=None):
        # assert len(image.shape) == 2
        if v_min is None:
            v_min = image.min()
        if v_max is None:
            v_max = image.max()
        image = (image - v_min) / (v_max - v_min)
        image = (image * 255).long()
        image = self.colormap[image] * 255
        return image


def vis_sequence_depth(depths: np.ndarray, v_min=None, v_max=None):
    visualizer = ColorMapper()
    if v_min is None:
        v_min = depths.min()
    if v_max is None:
        v_max = depths.max()
    res = visualizer.apply(torch.tensor(depths), v_min=v_min, v_max=v_max).numpy()
    return res


================================================
FILE: models/Video-Depth-Anything/utils/util.py
================================================
# Copyright (2025) Bytedance Ltd. and/or its affiliates 

# Licensed under the Apache License, Version 2.0 (the "License"); 
# you may not use this file except in compliance with the License. 
# You may obtain a copy of the License at 

#     http://www.apache.org/licenses/LICENSE-2.0 

# Unless required by applicable law or agreed to in writing, software 
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
# See the License for the specific language governing permissions and 
# limitations under the License. 
import numpy as np

def compute_scale_and_shift(prediction, target, mask, scale_only=False):
    if scale_only:
        return compute_scale(prediction, target, mask), 0
    else:
        return compute_scale_and_shift_full(prediction, target, mask)


def compute_scale(prediction, target, mask):
    # system matrix: A = [[a_00, a_01], [a_10, a_11]]
    prediction = prediction.astype(np.float32)
    target = target.astype(np.float32)
    mask = mask.astype(np.float32)

    a_00 = np.sum(mask * prediction * prediction)
    a_01 = np.sum(mask * prediction)
    a_11 = np.sum(mask)

    # right hand side: b = [b_0, b_1]
    b_0 = np.sum(mask * prediction * target)

    x_0 = b_0 / (a_00 + 1e-6)

    return x_0

def compute_scale_and_shift_full(prediction, target, mask):
    # system matrix: A = [[a_00, a_01], [a_10, a_11]]
    prediction = prediction.astype(np.float32)
    target = target.astype(np.float32)
    mask = mask.astype(np.float32)

    a_00 = np.sum(mask * prediction * prediction)
    a_01 = np.sum(mask * prediction)
    a_11 = np.sum(mask)

    b_0 = np.sum(mask * prediction * target)
    b_1 = np.sum(mask * target)

    x_0 = 1
    x_1 = 0

    det = a_00 * a_11 - a_01 * a_01

    if det != 0:
        x_0 = (a_11 * b_0 - a_01 * b_1) / det
        x_1 = (-a_01 * b_0 + a_00 * b_1) / det

    return x_0, x_1


def get_interpolate_frames(frame_list_pre, frame_list_post):
    assert len(frame_list_pre) == len(frame_list_post)
    min_w = 0.0
    max_w = 1.0
    step = (max_w - min_w) / (len(frame_list_pre)-1)
    post_w_list = [min_w] + [i * step for i in range(1,len(frame_list_pre)-1)] + [max_w]
    interpolated_frames = []
    for i in range(len(frame_list_pre)):
        interpolated_frames.append(frame_list_pre[i] * (1-post_w_list[i]) + frame_list_post[i] * post_w_list[i])
    return interpolated_frames

================================================
FILE: models/Video-Depth-Anything/video_depth_anything/dinov2.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

# References:
#   https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py

from functools import partial
import math
import logging
from typing import Sequence, Tuple, Union, Callable

import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.nn.init import trunc_normal_

from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block


logger = logging.getLogger("dinov2")


def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
    if not depth_first and include_root:
        fn(module=module, name=name)
    for child_name, child_module in module.named_children():
        child_name = ".".join((name, child_name)) if name else child_name
        named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
    if depth_first and include_root:
        fn(module=module, name=name)
    return module


class BlockChunk(nn.ModuleList):
    def forward(self, x):
        for b in self:
            x = b(x)
        return x


class DinoVisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        ffn_bias=True,
        proj_bias=True,
        drop_path_rate=0.0,
        drop_path_uniform=False,
        init_values=None,  # for layerscale: None or 0 => no layerscale
        embed_layer=PatchEmbed,
        act_layer=nn.GELU,
        block_fn=Block,
        ffn_layer="mlp",
        block_chunks=1,
        num_register_tokens=0,
        interpolate_antialias=False,
        interpolate_offset=0.1,
    ):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_chans (int): number of input channels
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            proj_bias (bool): enable bias for proj in attn if True
            ffn_bias (bool): enable bias for ffn if True
            drop_path_rate (float): stochastic depth rate
            drop_path_uniform (bool): apply uniform drop rate across blocks
            weight_init (str): weight init scheme
            init_values (float): layer-scale init values
            embed_layer (nn.Module): patch embedding layer
            act_layer (nn.Module): MLP activation layer
            block_fn (nn.Module): transformer block class
            ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
            block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
            num_register_tokens: (int) number of extra cls tokens (so-called "registers")
            interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
            interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
        """
        super().__init__()
        norm_layer = partial(nn.LayerNorm, eps=1e-6)

        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 1
        self.n_blocks = depth
        self.num_heads = num_heads
        self.patch_size = patch_size
        self.num_register_tokens = num_register_tokens
        self.interpolate_antialias = interpolate_antialias
        self.interpolate_offset = interpolate_offset

        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        assert num_register_tokens >= 0
        self.register_tokens = (
            nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
        )

        if drop_path_uniform is True:
            dpr = [drop_path_rate] * depth
        else:
            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule

        if ffn_layer == "mlp":
            logger.info("using MLP layer as FFN")
            ffn_layer = Mlp
        elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
            logger.info("using SwiGLU layer as FFN")
            ffn_layer = SwiGLUFFNFused
        elif ffn_layer == "identity":
            logger.info("using Identity layer as FFN")

            def f(*args, **kwargs):
                return nn.Identity()

            ffn_layer = f
        else:
            raise NotImplementedError

        blocks_list = [
            block_fn(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                proj_bias=proj_bias,
                ffn_bias=ffn_bias,
                drop_path=dpr[i],
                norm_layer=norm_layer,
                act_layer=act_layer,
                ffn_layer=ffn_layer,
                init_values=init_values,
            )
            for i in range(depth)
        ]
        if block_chunks > 0:
            self.chunked_blocks = True
            chunked_blocks = []
            chunksize = depth // block_chunks
            for i in range(0, depth, chunksize):
                # this is to keep the block index consistent if we chunk the block list
                chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
            self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
        else:
            self.chunked_blocks = False
            self.blocks = nn.ModuleList(blocks_list)

        self.norm = norm_layer(embed_dim)
        self.head = nn.Identity()

        self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))

        self.init_weights()

    def init_weights(self):
        trunc_normal_(self.pos_embed, std=0.02)
        nn.init.normal_(self.cls_token, std=1e-6)
        if self.register_tokens is not None:
            nn.init.normal_(self.register_tokens, std=1e-6)
        named_apply(init_weights_vit_timm, self)

    def interpolate_pos_encoding(self, x, w, h):
        previous_dtype = x.dtype
        npatch = x.shape[1] - 1
        N = self.pos_embed.shape[1] - 1
        if npatch == N and w == h:
            return self.pos_embed
        pos_embed = self.pos_embed.float()
        class_pos_embed = pos_embed[:, 0]
        patch_pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        w0 = w // self.patch_size
        h0 = h // self.patch_size
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
        w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
        # w0, h0 = w0 + 0.1, h0 + 0.1
        
        s
Download .txt
gitextract_751vu01n/

├── LICENSE
├── README.md
├── assets/
│   └── 1
├── checkpoints/
│   └── checkpoints here.txt
├── data/
│   └── datasets/
│       └── dataset here.txt
├── datasets/
│   ├── augmentor.py
│   ├── frame_utils.py
│   └── video_datasets.py
├── demo.py
├── demo.sh
├── evaluate_stereoanyvideo.sh
├── evaluation/
│   ├── configs/
│   │   ├── eval_dynamic_replica.yaml
│   │   ├── eval_infinigensv.yaml
│   │   ├── eval_kittidepth.yaml
│   │   ├── eval_sintel_clean.yaml
│   │   ├── eval_sintel_final.yaml
│   │   ├── eval_southkensington.yaml
│   │   └── eval_vkitti2.yaml
│   ├── core/
│   │   └── evaluator.py
│   ├── evaluate.py
│   └── utils/
│       ├── eval_utils.py
│       ├── ssim.py
│       └── utils.py
├── models/
│   ├── Video-Depth-Anything/
│   │   ├── app.py
│   │   ├── get_weights.sh
│   │   ├── run.py
│   │   ├── utils/
│   │   │   ├── dc_utils.py
│   │   │   └── util.py
│   │   └── video_depth_anything/
│   │       ├── dinov2.py
│   │       ├── dinov2_layers/
│   │       │   ├── __init__.py
│   │       │   ├── attention.py
│   │       │   ├── block.py
│   │       │   ├── drop_path.py
│   │       │   ├── layer_scale.py
│   │       │   ├── mlp.py
│   │       │   ├── patch_embed.py
│   │       │   └── swiglu_ffn.py
│   │       ├── dpt.py
│   │       ├── dpt_temporal.py
│   │       ├── motion_module/
│   │       │   ├── attention.py
│   │       │   └── motion_module.py
│   │       ├── util/
│   │       │   ├── blocks.py
│   │       │   └── transform.py
│   │       └── video_depth.py
│   ├── core/
│   │   ├── attention.py
│   │   ├── corr.py
│   │   ├── extractor.py
│   │   ├── model_zoo.py
│   │   ├── stereoanyvideo.py
│   │   ├── update.py
│   │   └── utils/
│   │       ├── config.py
│   │       └── utils.py
│   ├── raft_model.py
│   └── stereoanyvideo_model.py
├── requirements.txt
├── third_party/
│   └── RAFT/
│       ├── LICENSE
│       ├── README.md
│       ├── alt_cuda_corr/
│       │   ├── correlation.cpp
│       │   ├── correlation_kernel.cu
│       │   └── setup.py
│       ├── chairs_split.txt
│       ├── core/
│       │   ├── __init__.py
│       │   ├── corr.py
│       │   ├── datasets.py
│       │   ├── extractor.py
│       │   ├── raft.py
│       │   ├── update.py
│       │   └── utils/
│       │       ├── __init__.py
│       │       ├── augmentor.py
│       │       ├── flow_viz.py
│       │       ├── frame_utils.py
│       │       └── utils.py
│       ├── demo.py
│       ├── download_models.sh
│       ├── evaluate.py
│       ├── train.py
│       ├── train_mixed.sh
│       └── train_standard.sh
├── train_stereoanyvideo.py
├── train_stereoanyvideo.sh
└── train_utils/
    ├── logger.py
    ├── losses.py
    └── utils.py
Download .txt
SYMBOL INDEX (534 symbols across 54 files)

FILE: datasets/augmentor.py
  class AdjustGamma (line 13) | class AdjustGamma(object):
    method __init__ (line 14) | def __init__(self, gamma_min, gamma_max, gain_min=1.0, gain_max=1.0):
    method __call__ (line 22) | def __call__(self, sample):
    method __repr__ (line 27) | def __repr__(self):
  class SequenceDispFlowAugmentor (line 31) | class SequenceDispFlowAugmentor:
    method __init__ (line 32) | def __init__(
    method color_transform (line 71) | def color_transform(self, seq):
    method eraser_transform (line 95) | def eraser_transform(self, seq, bounds=[50, 100]):
    method spatial_transform (line 111) | def spatial_transform(self, img, disp):
    method __call__ (line 183) | def __call__(self, img, disp):
  class SequenceDispSparseFlowAugmentor (line 197) | class SequenceDispSparseFlowAugmentor:
    method __init__ (line 198) | def __init__(
    method color_transform (line 237) | def color_transform(self, seq):
    method eraser_transform (line 252) | def eraser_transform(self, seq, bounds=[50, 100]):
    method resize_sparse_flow_map (line 268) | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
    method spatial_transform (line 302) | def spatial_transform(self, img, disp, valid):
    method __call__ (line 352) | def __call__(self, img, disp, valid):

FILE: datasets/frame_utils.py
  function readFlow (line 14) | def readFlow(fn):
  function readPFM (line 36) | def readPFM(file):
  function readDispSintelStereo (line 74) | def readDispSintelStereo(file_name):
  function readDispMiddlebury (line 87) | def readDispMiddlebury(file_name):
  function read_gen (line 98) | def read_gen(file_name, pil=False):

FILE: datasets/video_datasets.py
  class DynamicReplicaFrameAnnotation (line 31) | class DynamicReplicaFrameAnnotation(ImplicitronFrameAnnotation):
  class StereoSequenceDataset (line 37) | class StereoSequenceDataset(data.Dataset):
    method __init__ (line 38) | def __init__(self, aug_params=None, sparse=False, reader=None):
    method _load_depth (line 60) | def _load_depth(self, depth_path):
    method _load_npy_depth (line 73) | def _load_npy_depth(self, depth_npy):
    method _load_vkitti2 (line 77) | def _load_vkitti2(self, depth_png):
    method _load_kitti_depth (line 84) | def _load_kitti_depth(self, depth_png):
    method _load_16big_png_depth (line 96) | def _load_16big_png_depth(self, depth_png):
    method load_tartanair_pose (line 107) | def load_tartanair_pose(self, filepath, index=0):
    method parse_txt_file (line 122) | def parse_txt_file(self, file_path):
    method _get_pytorch3d_camera (line 158) | def _get_pytorch3d_camera(
    method _get_pytorch3d_camera_from_blender (line 204) | def _get_pytorch3d_camera_from_blender(self, R, T, K, image_size, scal...
    method _get_output_tensor (line 247) | def _get_output_tensor(self, sample):
    method __getitem__ (line 666) | def __getitem__(self, index):
    method __mul__ (line 766) | def __mul__(self, v):
    method __len__ (line 772) | def __len__(self):
  class DynamicReplicaDataset (line 776) | class DynamicReplicaDataset(StereoSequenceDataset):
    method __init__ (line 777) | def __init__(
  class InfinigenStereoVideoDataset (line 876) | class InfinigenStereoVideoDataset(StereoSequenceDataset):
    method __init__ (line 877) | def __init__(
  class SouthKensingtonStereoVideoDataset (line 961) | class SouthKensingtonStereoVideoDataset(StereoSequenceDataset):
    method __init__ (line 962) | def __init__(
  class KITTIDepthDataset (line 1025) | class KITTIDepthDataset(StereoSequenceDataset):
    method __init__ (line 1026) | def __init__(
  function split_train_valid (line 1139) | def split_train_valid(path_list, valid_keywords):
  class TartanAirDataset (line 1148) | class TartanAirDataset(StereoSequenceDataset):
    method __init__ (line 1149) | def __init__(
  class VKITTI2Dataset (line 1275) | class VKITTI2Dataset(StereoSequenceDataset):
    method __init__ (line 1276) | def __init__(
  class SequenceSpringDataset (line 1367) | class SequenceSpringDataset(StereoSequenceDataset):
    method __init__ (line 1368) | def __init__(
  class SequenceSceneFlowDataset (line 1424) | class SequenceSceneFlowDataset(StereoSequenceDataset):
    method __init__ (line 1425) | def __init__(
    method _add_things (line 1450) | def _add_things(self, split="TRAIN"):
    method _add_monkaa (line 1496) | def _add_monkaa(self):
    method _add_driving (line 1530) | def _add_driving(self):
    method _append_sample (line 1565) | def _append_sample(self, images, disparities):
  class SequenceSintelStereo (line 1583) | class SequenceSintelStereo(StereoSequenceDataset):
    method __init__ (line 1584) | def __init__(
  function fetch_dataloader (line 1649) | def fetch_dataloader(args):

FILE: demo.py
  function load_image (line 19) | def load_image(imfile):
  function viz (line 25) | def viz(img, flo):
  function demo (line 37) | def demo(args):

FILE: evaluation/core/evaluator.py
  function depth_to_colormap (line 19) | def depth_to_colormap(depth, colormap='jet', eps=1e-3, scale_vmin=1.0):
  class Evaluator (line 33) | class Evaluator(Configurable):
    method setup_visualization (line 43) | def setup_visualization(self, cfg: DictConfig) -> None:
    method evaluate_sequence (line 52) | def evaluate_sequence(

FILE: evaluation/evaluate.py
  class DefaultConfig (line 25) | class DefaultConfig:
  function run_eval (line 57) | def run_eval(cfg: DefaultConfig):
  function evaluate (line 141) | def evaluate(cfg: DefaultConfig) -> None:

FILE: evaluation/utils/eval_utils.py
  class PerceptionMetric (line 14) | class PerceptionMetric:
    method __str__ (line 20) | def __str__(self):
  function compute_flow (line 33) | def compute_flow(seq, is_seq=True):
  function flow_warp (line 50) | def flow_warp(x, flow):
  function eval_endpoint_error_sequence (line 75) | def eval_endpoint_error_sequence(
  function eval_TCC_sequence (line 146) | def eval_TCC_sequence(
  function eval_TCM_sequence (line 185) | def eval_TCM_sequence(
  function eval_OPW_sequence (line 242) | def eval_OPW_sequence(
  function eval_RTC_sequence (line 313) | def eval_RTC_sequence(
  function depth2disparity_scale (line 379) | def depth2disparity_scale(left_camera, right_camera, image_size_tensor):
  function depth_to_pcd (line 394) | def depth_to_pcd(
  function filter_outliers (line 438) | def filter_outliers(pcd, sigma=3):
  function eval_batch (line 446) | def eval_batch(batch_dict, predictions, scale) -> Dict[str, Union[float,...

FILE: evaluation/utils/ssim.py
  function _fspecial_gauss_1d (line 9) | def _fspecial_gauss_1d(size, sigma):
  function gaussian_filter (line 26) | def gaussian_filter(input, win):
  function _ssim (line 54) | def _ssim(X, Y, data_range, win, size_average=True, K=(0.01, 0.03)):
  function ssim (line 94) | def ssim(
  function ms_ssim (line 152) | def ms_ssim(
  class SSIM (line 227) | class SSIM(torch.nn.Module):
    method __init__ (line 228) | def __init__(
    method forward (line 258) | def forward(self, X, Y):
  class MS_SSIM (line 270) | class MS_SSIM(torch.nn.Module):
    method __init__ (line 271) | def __init__(
    method forward (line 301) | def forward(self, X, Y):

FILE: evaluation/utils/utils.py
  class PerceptionPrediction (line 28) | class PerceptionPrediction:
  function aggregate_eval_results (line 39) | def aggregate_eval_results(per_batch_eval_results, reduction="mean"):
  function aggregate_and_print_results (line 62) | def aggregate_and_print_results(
  function pretty_print_perception_metrics (line 76) | def pretty_print_perception_metrics(results):
  function read_calibration (line 92) | def read_calibration(calibration_file, resolution_string):
  function filter_depth_discontinuities (line 216) | def filter_depth_discontinuities(pcd, depth_map, threshold=5):
  function visualize_batch (line 246) | def visualize_batch(

FILE: models/Video-Depth-Anything/app.py
  function infer_video_depth (line 40) | def infer_video_depth(
  function construct_demo (line 64) | def construct_demo():

FILE: models/Video-Depth-Anything/utils/dc_utils.py
  function read_video_frames (line 21) | def read_video_frames(video_path, process_length, target_fps=-1, max_res...
  function save_video (line 74) | def save_video(
  class ColorMapper (line 92) | class ColorMapper:
    method __init__ (line 94) | def __init__(self, colormap: str = "inferno"):
    method apply (line 97) | def apply(self, image: torch.Tensor, v_min=None, v_max=None):
  function vis_sequence_depth (line 109) | def vis_sequence_depth(depths: np.ndarray, v_min=None, v_max=None):

FILE: models/Video-Depth-Anything/utils/util.py
  function compute_scale_and_shift (line 16) | def compute_scale_and_shift(prediction, target, mask, scale_only=False):
  function compute_scale (line 23) | def compute_scale(prediction, target, mask):
  function compute_scale_and_shift_full (line 40) | def compute_scale_and_shift_full(prediction, target, mask):
  function get_interpolate_frames (line 65) | def get_interpolate_frames(frame_list_pre, frame_list_post):

FILE: models/Video-Depth-Anything/video_depth_anything/dinov2.py
  function named_apply (line 26) | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=Tr...
  class BlockChunk (line 37) | class BlockChunk(nn.ModuleList):
    method forward (line 38) | def forward(self, x):
  class DinoVisionTransformer (line 44) | class DinoVisionTransformer(nn.Module):
    method __init__ (line 45) | def __init__(
    method init_weights (line 172) | def init_weights(self):
    method interpolate_pos_encoding (line 179) | def interpolate_pos_encoding(self, x, w, h):
    method prepare_tokens_with_masks (line 212) | def prepare_tokens_with_masks(self, x, masks=None):
    method forward_features_list (line 233) | def forward_features_list(self, x_list, masks_list):
    method forward_features (line 253) | def forward_features(self, x, masks=None):
    method _get_intermediate_layers_not_chunked (line 271) | def _get_intermediate_layers_not_chunked(self, x, n=1):
    method _get_intermediate_layers_chunked (line 283) | def _get_intermediate_layers_chunked(self, x, n=1):
    method get_intermediate_layers (line 297) | def get_intermediate_layers(
    method forward (line 323) | def forward(self, *args, is_training=False, **kwargs):
  function init_weights_vit_timm (line 331) | def init_weights_vit_timm(module: nn.Module, name: str = ""):
  function vit_small (line 339) | def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
  function vit_base (line 353) | def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
  function vit_large (line 367) | def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
  function vit_giant2 (line 381) | def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
  function DINOv2 (line 398) | def DINOv2(model_name):

FILE: models/Video-Depth-Anything/video_depth_anything/dinov2_layers/attention.py
  class Attention (line 29) | class Attention(nn.Module):
    method __init__ (line 30) | def __init__(
    method forward (line 49) | def forward(self, x: Tensor) -> Tensor:
  class MemEffAttention (line 65) | class MemEffAttention(Attention):
    method forward (line 66) | def forward(self, x: Tensor, attn_bias=None) -> Tensor:

FILE: models/Video-Depth-Anything/video_depth_anything/dinov2_layers/block.py
  class Block (line 36) | class Block(nn.Module):
    method __init__ (line 37) | def __init__(
    method forward (line 82) | def forward(self, x: Tensor) -> Tensor:
  function drop_add_residual_stochastic_depth (line 110) | def drop_add_residual_stochastic_depth(
  function get_branges_scales (line 134) | def get_branges_scales(x, sample_drop_ratio=0.0):
  function add_residual (line 142) | def add_residual(x, brange, residual, residual_scale_factor, scaling_vec...
  function get_attn_bias_and_cat (line 157) | def get_attn_bias_and_cat(x_list, branges=None):
  function drop_add_residual_stochastic_depth_list (line 181) | def drop_add_residual_stochastic_depth_list(
  class NestedTensorBlock (line 204) | class NestedTensorBlock(Block):
    method forward_nested (line 205) | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
    method forward (line 245) | def forward(self, x_or_x_list):

FILE: models/Video-Depth-Anything/video_depth_anything/dinov2_layers/drop_path.py
  function drop_path (line 15) | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
  class DropPath (line 27) | class DropPath(nn.Module):
    method __init__ (line 30) | def __init__(self, drop_prob=None):
    method forward (line 34) | def forward(self, x):

FILE: models/Video-Depth-Anything/video_depth_anything/dinov2_layers/layer_scale.py
  class LayerScale (line 16) | class LayerScale(nn.Module):
    method __init__ (line 17) | def __init__(
    method forward (line 27) | def forward(self, x: Tensor) -> Tensor:

FILE: models/Video-Depth-Anything/video_depth_anything/dinov2_layers/mlp.py
  class Mlp (line 17) | class Mlp(nn.Module):
    method __init__ (line 18) | def __init__(
    method forward (line 35) | def forward(self, x: Tensor) -> Tensor:

FILE: models/Video-Depth-Anything/video_depth_anything/dinov2_layers/patch_embed.py
  function make_2tuple (line 17) | def make_2tuple(x):
  class PatchEmbed (line 26) | class PatchEmbed(nn.Module):
    method __init__ (line 38) | def __init__(
    method forward (line 69) | def forward(self, x: Tensor) -> Tensor:
    method flops (line 84) | def flops(self) -> float:

FILE: models/Video-Depth-Anything/video_depth_anything/dinov2_layers/swiglu_ffn.py
  class SwiGLUFFN (line 13) | class SwiGLUFFN(nn.Module):
    method __init__ (line 14) | def __init__(
    method forward (line 29) | def forward(self, x: Tensor) -> Tensor:
  class SwiGLUFFNFused (line 45) | class SwiGLUFFNFused(SwiGLU):
    method __init__ (line 46) | def __init__(

FILE: models/Video-Depth-Anything/video_depth_anything/dpt.py
  function _make_fusion_block (line 21) | def _make_fusion_block(features, use_bn, size=None):
  class ConvBlock (line 33) | class ConvBlock(nn.Module):
    method __init__ (line 34) | def __init__(self, in_feature, out_feature):
    method forward (line 43) | def forward(self, x):
  class DPTHead (line 47) | class DPTHead(nn.Module):
    method __init__ (line 48) | def __init__(
    method forward (line 126) | def forward(self, out_features, patch_h, patch_w):

FILE: models/Video-Depth-Anything/video_depth_anything/dpt_temporal.py
  class DPTHeadTemporal (line 22) | class DPTHeadTemporal(DPTHead):
    method __init__ (line 23) | def __init__(self,
    method forward (line 53) | def forward(self, out_features, patch_h, patch_w, frame_length):

FILE: models/Video-Depth-Anything/video_depth_anything/motion_module/attention.py
  class CrossAttention (line 30) | class CrossAttention(nn.Module):
    method __init__ (line 45) | def __init__(
    method reshape_heads_to_batch_dim (line 93) | def reshape_heads_to_batch_dim(self, tensor):
    method reshape_heads_to_4d (line 100) | def reshape_heads_to_4d(self, tensor):
    method reshape_batch_dim_to_heads (line 106) | def reshape_batch_dim_to_heads(self, tensor):
    method reshape_4d_to_heads (line 113) | def reshape_4d_to_heads(self, tensor):
    method set_attention_slice (line 119) | def set_attention_slice(self, slice_size):
    method forward (line 125) | def forward(self, hidden_states, encoder_hidden_states=None, attention...
    method _attention (line 182) | def _attention(self, query, key, value, attention_mask=None):
    method _sliced_attention (line 213) | def _sliced_attention(self, query, key, value, sequence_length, dim, a...
    method _memory_efficient_attention_xformers (line 256) | def _memory_efficient_attention_xformers(self, query, key, value, atte...
    method _memory_efficient_attention_split (line 275) | def _memory_efficient_attention_split(self, query, key, value, attenti...
  class FeedForward (line 296) | class FeedForward(nn.Module):
    method __init__ (line 308) | def __init__(
    method forward (line 335) | def forward(self, hidden_states):
  class GELU (line 341) | class GELU(nn.Module):
    method __init__ (line 346) | def __init__(self, dim_in: int, dim_out: int):
    method gelu (line 350) | def gelu(self, gate):
    method forward (line 356) | def forward(self, hidden_states):
  class GEGLU (line 363) | class GEGLU(nn.Module):
    method __init__ (line 372) | def __init__(self, dim_in: int, dim_out: int):
    method gelu (line 376) | def gelu(self, gate):
    method forward (line 382) | def forward(self, hidden_states):
  class ApproximateGELU (line 387) | class ApproximateGELU(nn.Module):
    method __init__ (line 394) | def __init__(self, dim_in: int, dim_out: int):
    method forward (line 398) | def forward(self, x):
  function precompute_freqs_cis (line 403) | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
  function reshape_for_broadcast (line 411) | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
  function apply_rotary_emb (line 419) | def apply_rotary_emb(

FILE: models/Video-Depth-Anything/video_depth_anything/motion_module/motion_module.py
  function zero_module (line 25) | def zero_module(module):
  class TemporalModule (line 32) | class TemporalModule(nn.Module):
    method __init__ (line 33) | def __init__(
    method forward (line 60) | def forward(self, input_tensor, encoder_hidden_states, attention_mask=...
  class TemporalTransformer3DModel (line 68) | class TemporalTransformer3DModel(nn.Module):
    method __init__ (line 69) | def __init__(
    method forward (line 102) | def forward(self, hidden_states, encoder_hidden_states=None, attention...
  class TemporalTransformerBlock (line 129) | class TemporalTransformerBlock(nn.Module):
    method __init__ (line 130) | def __init__(
    method forward (line 164) | def forward(self, hidden_states, encoder_hidden_states=None, attention...
  class PositionalEncoding (line 180) | class PositionalEncoding(nn.Module):
    method __init__ (line 181) | def __init__(
    method forward (line 196) | def forward(self, x):
  class TemporalAttention (line 200) | class TemporalAttention(CrossAttention):
    method __init__ (line 201) | def __init__(
    method forward (line 230) | def forward(self, hidden_states, encoder_hidden_states=None, attention...

FILE: models/Video-Depth-Anything/video_depth_anything/util/blocks.py
  function _make_scratch (line 4) | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
  class ResidualConvUnit (line 37) | class ResidualConvUnit(nn.Module):
    method __init__ (line 40) | def __init__(self, features, activation, bn):
    method forward (line 68) | def forward(self, x):
  class FeatureFusionBlock (line 94) | class FeatureFusionBlock(nn.Module):
    method __init__ (line 97) | def __init__(
    method forward (line 135) | def forward(self, *xs, size=None):

FILE: models/Video-Depth-Anything/video_depth_anything/util/transform.py
  class Resize (line 5) | class Resize(object):
    method __init__ (line 9) | def __init__(
    method constrain_to_multiple_of (line 51) | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
    method get_size (line 62) | def get_size(self, width, height):
    method __call__ (line 109) | def __call__(self, sample):
  class NormalizeImage (line 125) | class NormalizeImage(object):
    method __init__ (line 129) | def __init__(self, mean, std):
    method __call__ (line 133) | def __call__(self, sample):
  class PrepareForNet (line 139) | class PrepareForNet(object):
    method __init__ (line 143) | def __init__(self):
    method __call__ (line 146) | def __call__(self, sample):

FILE: models/Video-Depth-Anything/video_depth_anything/video_depth.py
  class VideoDepthAnything (line 35) | class VideoDepthAnything(nn.Module):
    method __init__ (line 36) | def __init__(
    method forward (line 57) | def forward(self, x):
    method infer_video_depth (line 66) | def infer_video_depth(self, frames, target_fps, input_size=518, device...

FILE: models/core/attention.py
  function elu_feature_map (line 13) | def elu_feature_map(x):
  class PositionEncodingSine (line 17) | class PositionEncodingSine(nn.Module):
    method __init__ (line 22) | def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True):
    method forward (line 53) | def forward(self, x):
  class LinearAttention (line 61) | class LinearAttention(Module):
    method __init__ (line 62) | def __init__(self, eps=1e-6):
    method forward (line 67) | def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
  class FullAttention (line 97) | class FullAttention(Module):
    method __init__ (line 98) | def __init__(self, use_dropout=False, attention_dropout=0.1):
    method forward (line 103) | def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
  class LoFTREncoderLayer (line 134) | class LoFTREncoderLayer(nn.Module):
    method __init__ (line 135) | def __init__(self, d_model, nhead, attention="linear"):
    method forward (line 159) | def forward(self, x, source, x_mask=None, source_mask=None):
  class LocalFeatureTransformer (line 187) | class LocalFeatureTransformer(nn.Module):
    method __init__ (line 190) | def __init__(self, d_model, nhead, layer_names, attention):
    method _reset_parameters (line 202) | def _reset_parameters(self):
    method forward (line 207) | def forward(self, feat0, feat1, mask0=None, mask1=None):

FILE: models/core/corr.py
  function bilinear_sampler (line 6) | def bilinear_sampler(img, coords, mode="bilinear", mask=False):
  function coords_grid (line 24) | def coords_grid(batch, ht, wd, device):
  class AAPC (line 32) | class AAPC:
    method __init__ (line 36) | def __init__(self, fmap1, fmap2, att=None):
    method __call__ (line 43) | def __call__(self, flow, extra_offset, small_patch=False):
    method correlation (line 49) | def correlation(self, left_feature, right_feature, flow, small_patch):
    method get_correlation (line 73) | def get_correlation(self, left_feature, right_feature, psize=(3, 3), d...

FILE: models/core/extractor.py
  class ResidualBlock (line 12) | class ResidualBlock(nn.Module):
    method __init__ (line 13) | def __init__(self, in_planes, planes, norm_fn="group", stride=1):
    method forward (line 48) | def forward(self, x):
  class BasicEncoder (line 58) | class BasicEncoder(nn.Module):
    method __init__ (line 59) | def __init__(self, input_dim=3, output_dim=128, norm_fn="batch", dropo...
    method _make_layer (line 99) | def _make_layer(self, dim, stride=1):
    method forward (line 107) | def forward(self, x):
  class MultiBasicEncoder (line 134) | class MultiBasicEncoder(nn.Module):
    method __init__ (line 135) | def __init__(self, output_dim=[128], norm_fn='batch', dropout=0.0, dow...
    method _make_layer (line 201) | def _make_layer(self, dim, stride=1):
    method forward (line 209) | def forward(self, x, dual_inp=False, num_layers=3):
  class DepthExtractor (line 238) | class DepthExtractor(nn.Module):
    method __init__ (line 239) | def __init__(self):
    method forward (line 258) | def forward(self, x):

FILE: models/core/model_zoo.py
  function model_zoo (line 18) | def model_zoo(model_name: str, **kwargs):
  function get_all_model_default_configs (line 37) | def get_all_model_default_configs():

FILE: models/core/stereoanyvideo.py
  function _ntuple (line 19) | def _ntuple(n):
  function exists (line 28) | def exists(val):
  function default (line 32) | def default(val, d):
  class Mlp (line 38) | class Mlp(nn.Module):
    method __init__ (line 39) | def __init__(
    method forward (line 66) | def forward(self, x):
  class StereoAnyVideo (line 75) | class StereoAnyVideo(nn.Module):
    method __init__ (line 76) | def __init__(self, mixed_precision=False):
    method no_weight_decay (line 93) | def no_weight_decay(self):
    method freeze_bn (line 96) | def freeze_bn(self):
    method convex_upsample (line 101) | def convex_upsample(self, flow, mask, rate=4):
    method convex_upsample_3D (line 114) | def convex_upsample_3D(self, flow, mask, b, T, rate=4):
    method zero_init (line 144) | def zero_init(self, fmap):
    method forward_batch_test (line 151) | def forward_batch_test(
    method forward (line 205) | def forward(self, image1, image2, flow_init=None, iters=12, test_mode=...

FILE: models/core/update.py
  function pool2x (line 9) | def pool2x(x):
  function pool4x (line 12) | def pool4x(x):
  function interp (line 15) | def interp(x, dest):
  class FlowHead (line 20) | class FlowHead(nn.Module):
    method __init__ (line 21) | def __init__(self, input_dim=128, hidden_dim=256, output_dim=2):
    method forward (line 27) | def forward(self, x):
  class FlowHead3D (line 31) | class FlowHead3D(nn.Module):
    method __init__ (line 32) | def __init__(self, input_dim=128, hidden_dim=256):
    method forward (line 38) | def forward(self, x):
  class ConvGRU (line 42) | class ConvGRU(nn.Module):
    method __init__ (line 43) | def __init__(self, hidden_dim, input_dim, kernel_size=3):
    method forward (line 49) | def forward(self, h, cz, cr, cq, *x_list):
  class SepConvGRU (line 61) | class SepConvGRU(nn.Module):
    method __init__ (line 62) | def __init__(self, hidden_dim=128, input_dim=192+128):
    method forward (line 73) | def forward(self, h, *x):
  class BasicMotionEncoder (line 92) | class BasicMotionEncoder(nn.Module):
    method __init__ (line 93) | def __init__(self, cor_planes):
    method forward (line 102) | def forward(self, flow, corr):
  class BasicMotionEncoder3D (line 113) | class BasicMotionEncoder3D(nn.Module):
    method __init__ (line 114) | def __init__(self, cor_planes):
    method forward (line 123) | def forward(self, flow, corr):
  class SepConvGRU3D (line 134) | class SepConvGRU3D(nn.Module):
    method __init__ (line 135) | def __init__(self, hidden_dim=128, input_dim=192 + 128):
    method forward (line 167) | def forward(self, h, x):
  class SKSepConvGRU3D (line 191) | class SKSepConvGRU3D(nn.Module):
    method __init__ (line 192) | def __init__(self, hidden_dim=128, input_dim=192 + 128):
    method forward (line 228) | def forward(self, h, x):
  class BasicUpdateBlock (line 252) | class BasicUpdateBlock(nn.Module):
    method __init__ (line 253) | def __init__(self, hidden_dim, cor_planes, mask_size=8, attention_type...
    method forward (line 273) | def forward(self, net, inp, corr, flow, upsample=True, t=1):
  class Attention (line 289) | class Attention(nn.Module):
    method __init__ (line 290) | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None):
    method forward (line 298) | def forward(self, x):
  class TimeAttnBlock (line 312) | class TimeAttnBlock(nn.Module):
    method __init__ (line 313) | def __init__(self, dim=256, num_heads=8):
    method forward (line 322) | def forward(self, x, T=1):
  class SpaceAttnBlock (line 340) | class SpaceAttnBlock(nn.Module):
    method __init__ (line 341) | def __init__(self, dim=256, num_heads=8):
    method forward (line 345) | def forward(self, x, T=1):
  class SequenceUpdateBlock3D (line 353) | class SequenceUpdateBlock3D(nn.Module):
    method __init__ (line 354) | def __init__(self, hidden_dim, cor_planes, mask_size=8):
    method forward (line 368) | def forward(self, net, inp, corrs, flows, t):

FILE: models/core/utils/config.py
  class ReplaceableBase (line 174) | class ReplaceableBase:
    method __new__ (line 180) | def __new__(cls, *args, **kwargs):
  class Configurable (line 192) | class Configurable:
    method __new__ (line 200) | def __new__(cls, *args, **kwargs):
  class _Registry (line 215) | class _Registry:
    method __init__ (line 222) | def __init__(self) -> None:
    method register (line 227) | def register(self, some_class: Type[_X]) -> Type[_X]:
    method _register (line 235) | def _register(
    method get (line 260) | def get(
    method get_all (line 292) | def get_all(
    method _is_base_class (line 320) | def _is_base_class(some_class: Type[ReplaceableBase]) -> bool:
    method _base_class_from_class (line 328) | def _base_class_from_class(
  class _ProcessType (line 344) | class _ProcessType(Enum):
  function _default_create (line 355) | def _default_create(
  function run_auto_creation (line 413) | def run_auto_creation(self: Any) -> None:
  function _is_configurable_class (line 421) | def _is_configurable_class(C) -> bool:
  function get_default_args (line 425) | def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> Di...
  function _dataclass_name_for_function (line 487) | def _dataclass_name_for_function(C: Any) -> str:
  function enable_get_default_args (line 496) | def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
  function _params_iter (line 553) | def _params_iter(C):
  function _is_immutable_type (line 563) | def _is_immutable_type(type_: Type, val: Any) -> bool:
  function _resolve_optional (line 575) | def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
  function _is_actually_dataclass (line 587) | def _is_actually_dataclass(some_class) -> bool:
  function expand_args_fields (line 597) | def expand_args_fields(
  function get_default_args_field (line 776) | def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
  function _get_type_to_process (line 793) | def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
  function _process_member (line 825) | def _process_member(
  function remove_unused_components (line 917) | def remove_unused_components(dict_: DictConfig) -> None:

FILE: models/core/utils/utils.py
  function interp (line 7) | def interp(tensor, size):
  class InputPadder (line 16) | class InputPadder:
    method __init__ (line 19) | def __init__(self, dims, mode="sintel", divis_by=8):
    method pad (line 33) | def pad(self, *inputs):
    method unpad (line 37) | def unpad(self, x):
  function coords_grid (line 44) | def coords_grid(batch, ht, wd):
  function upflow8 (line 50) | def upflow8(flow, mode='bilinear'):

FILE: models/raft_model.py
  class RAFTModel (line 13) | class RAFTModel(Configurable, torch.nn.Module):
    method __post_init__ (line 16) | def __post_init__(self):
    method forward (line 45) | def forward(self, image1, image2, iters=10):
    method forward_fullres (line 59) | def forward_fullres(self, image1, image2, iters=20):

FILE: models/stereoanyvideo_model.py
  class StereoAnyVideoModel (line 9) | class StereoAnyVideoModel(Configurable, torch.nn.Module):
    method __post_init__ (line 14) | def __post_init__(self):
    method forward (line 32) | def forward(self, batch_dict, iters=20):

FILE: third_party/RAFT/alt_cuda_corr/correlation.cpp
  function corr_forward (line 23) | std::vector<torch::Tensor> corr_forward(
  function corr_backward (line 36) | std::vector<torch::Tensor> corr_backward(
  function PYBIND11_MODULE (line 51) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: third_party/RAFT/core/corr.py
  class CorrBlock (line 12) | class CorrBlock:
    method __init__ (line 13) | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
    method __call__ (line 29) | def __call__(self, coords):
    method corr (line 53) | def corr(fmap1, fmap2):
  class AlternateCorrBlock (line 63) | class AlternateCorrBlock:
    method __init__ (line 64) | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
    method __call__ (line 74) | def __call__(self, coords):

FILE: third_party/RAFT/core/datasets.py
  class FlowDataset (line 18) | class FlowDataset(data.Dataset):
    method __init__ (line 19) | def __init__(self, aug_params=None, sparse=False):
    method __getitem__ (line 34) | def __getitem__(self, index):
    method __rmul__ (line 93) | def __rmul__(self, v):
    method __len__ (line 98) | def __len__(self):
  class MpiSintel (line 102) | class MpiSintel(FlowDataset):
    method __init__ (line 103) | def __init__(self, aug_params=None, split='training', root='datasets/S...
  class FlyingChairs (line 121) | class FlyingChairs(FlowDataset):
    method __init__ (line 122) | def __init__(self, aug_params=None, split='train', root='datasets/Flyi...
  class FlyingThings3D (line 137) | class FlyingThings3D(FlowDataset):
    method __init__ (line 138) | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', ds...
  class KITTI (line 161) | class KITTI(FlowDataset):
    method __init__ (line 162) | def __init__(self, aug_params=None, split='training', root='datasets/K...
  class HD1K (line 180) | class HD1K(FlowDataset):
    method __init__ (line 181) | def __init__(self, aug_params=None, root='datasets/HD1k'):
  function fetch_dataloader (line 199) | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):

FILE: third_party/RAFT/core/extractor.py
  class ResidualBlock (line 6) | class ResidualBlock(nn.Module):
    method __init__ (line 7) | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
    method forward (line 48) | def forward(self, x):
  class BottleneckBlock (line 60) | class BottleneckBlock(nn.Module):
    method __init__ (line 61) | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
    method forward (line 107) | def forward(self, x):
  class BasicEncoder (line 118) | class BasicEncoder(nn.Module):
    method __init__ (line 119) | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
    method _make_layer (line 159) | def _make_layer(self, dim, stride=1):
    method forward (line 168) | def forward(self, x):
  class SmallEncoder (line 195) | class SmallEncoder(nn.Module):
    method __init__ (line 196) | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
    method _make_layer (line 235) | def _make_layer(self, dim, stride=1):
    method forward (line 244) | def forward(self, x):

FILE: third_party/RAFT/core/raft.py
  class autocast (line 15) | class autocast:
    method __init__ (line 16) | def __init__(self, enabled):
    method __enter__ (line 18) | def __enter__(self):
    method __exit__ (line 20) | def __exit__(self, *args):
  class RAFT (line 24) | class RAFT(nn.Module):
    method __init__ (line 25) | def __init__(self, args):
    method freeze_bn (line 58) | def freeze_bn(self):
    method initialize_flow (line 63) | def initialize_flow(self, img):
    method upsample_flow (line 72) | def upsample_flow(self, flow, mask):
    method forward (line 86) | def forward(self, image1, image2, iters=12, flow_init=None, upsample=T...

FILE: third_party/RAFT/core/update.py
  class FlowHead (line 6) | class FlowHead(nn.Module):
    method __init__ (line 7) | def __init__(self, input_dim=128, hidden_dim=256):
    method forward (line 13) | def forward(self, x):
  class ConvGRU (line 16) | class ConvGRU(nn.Module):
    method __init__ (line 17) | def __init__(self, hidden_dim=128, input_dim=192+128):
    method forward (line 23) | def forward(self, h, x):
  class SepConvGRU (line 33) | class SepConvGRU(nn.Module):
    method __init__ (line 34) | def __init__(self, hidden_dim=128, input_dim=192+128):
    method forward (line 45) | def forward(self, h, x):
  class SmallMotionEncoder (line 62) | class SmallMotionEncoder(nn.Module):
    method __init__ (line 63) | def __init__(self, args):
    method forward (line 71) | def forward(self, flow, corr):
  class BasicMotionEncoder (line 79) | class BasicMotionEncoder(nn.Module):
    method __init__ (line 80) | def __init__(self, args):
    method forward (line 89) | def forward(self, flow, corr):
  class SmallUpdateBlock (line 99) | class SmallUpdateBlock(nn.Module):
    method __init__ (line 100) | def __init__(self, args, hidden_dim=96):
    method forward (line 106) | def forward(self, net, inp, corr, flow):
  class BasicUpdateBlock (line 114) | class BasicUpdateBlock(nn.Module):
    method __init__ (line 115) | def __init__(self, args, hidden_dim=128, input_dim=128):
    method forward (line 127) | def forward(self, net, inp, corr, flow, upsample=True):

FILE: third_party/RAFT/core/utils/augmentor.py
  class FlowAugmentor (line 15) | class FlowAugmentor:
    method __init__ (line 16) | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=T...
    method color_transform (line 36) | def color_transform(self, img1, img2):
    method eraser_transform (line 52) | def eraser_transform(self, img1, img2, bounds=[50, 100]):
    method spatial_transform (line 67) | def spatial_transform(self, img1, img2, flow):
    method __call__ (line 111) | def __call__(self, img1, img2, flow):
  class SparseFlowAugmentor (line 122) | class SparseFlowAugmentor:
    method __init__ (line 123) | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=F...
    method color_transform (line 142) | def color_transform(self, img1, img2):
    method eraser_transform (line 148) | def eraser_transform(self, img1, img2):
    method resize_sparse_flow_map (line 161) | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
    method spatial_transform (line 195) | def spatial_transform(self, img1, img2, flow, valid):
    method __call__ (line 236) | def __call__(self, img1, img2, flow, valid):

FILE: third_party/RAFT/core/utils/flow_viz.py
  function make_colorwheel (line 20) | def make_colorwheel():
  function flow_uv_to_colors (line 70) | def flow_uv_to_colors(u, v, convert_to_bgr=False):
  function flow_to_image (line 109) | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):

FILE: third_party/RAFT/core/utils/frame_utils.py
  function readFlow (line 12) | def readFlow(fn):
  function readPFM (line 33) | def readPFM(file):
  function writeFlow (line 70) | def writeFlow(filename,uv,v=None):
  function readFlowKITTI (line 102) | def readFlowKITTI(filename):
  function readDispKITTI (line 109) | def readDispKITTI(filename):
  function writeFlowKITTI (line 116) | def writeFlowKITTI(filename, uv):
  function read_gen (line 123) | def read_gen(file_name, pil=False):

FILE: third_party/RAFT/core/utils/utils.py
  class InputPadder (line 7) | class InputPadder:
    method __init__ (line 9) | def __init__(self, dims, mode='sintel'):
    method pad (line 18) | def pad(self, *inputs):
    method unpad (line 21) | def unpad(self,x):
  function forward_interpolate (line 26) | def forward_interpolate(flow):
  function bilinear_sampler (line 57) | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
  function coords_grid (line 74) | def coords_grid(batch, ht, wd, device):
  function upflow8 (line 80) | def upflow8(flow, mode='bilinear'):

FILE: third_party/RAFT/demo.py
  function load_image (line 20) | def load_image(imfile):
  function viz (line 26) | def viz(img, flo):
  function demo (line 42) | def demo(args):

FILE: third_party/RAFT/evaluate.py
  function create_sintel_submission (line 22) | def create_sintel_submission(model, iters=32, warm_start=False, output_p...
  function create_kitti_submission (line 54) | def create_kitti_submission(model, iters=24, output_path='kitti_submissi...
  function validate_chairs (line 75) | def validate_chairs(model, iters=24):
  function validate_sintel (line 96) | def validate_sintel(model, iters=32):
  function validate_kitti (line 131) | def validate_kitti(model, iters=24):

FILE: third_party/RAFT/train.py
  class GradScaler (line 28) | class GradScaler:
    method __init__ (line 29) | def __init__(self):
    method scale (line 31) | def scale(self, loss):
    method unscale_ (line 33) | def unscale_(self, optimizer):
    method step (line 35) | def step(self, optimizer):
    method update (line 37) | def update(self):
  function sequence_loss (line 47) | def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FL...
  function count_parameters (line 75) | def count_parameters(model):
  function fetch_optimizer (line 79) | def fetch_optimizer(args, model):
  class Logger (line 89) | class Logger:
    method __init__ (line 90) | def __init__(self, model, scheduler):
    method _print_training_status (line 97) | def _print_training_status(self):
    method push (line 112) | def push(self, metrics):
    method write_dict (line 125) | def write_dict(self, results):
    method close (line 132) | def close(self):
  function train (line 136) | def train(args):

FILE: train_stereoanyvideo.py
  function fetch_optimizer (line 29) | def fetch_optimizer(args, model):
  function forward_batch (line 49) | def forward_batch(batch, model, args):
  class Lite (line 69) | class Lite(LightningLite):
    method run (line 70) | def run(self, args):

FILE: train_utils/logger.py
  class Logger (line 7) | class Logger:
    method __init__ (line 11) | def __init__(self, model, scheduler, ckpt_path):
    method _print_training_status (line 22) | def _print_training_status(self):
    method push (line 43) | def push(self, metrics, task):
    method update (line 50) | def update(self):
    method write_dict (line 57) | def write_dict(self, results):
    method close (line 64) | def close(self):

FILE: train_utils/losses.py
  function sequence_loss (line 6) | def sequence_loss(flow_preds, flow_gt, valid, loss_gamma=0.9, max_flow=7...
  function temporal_loss (line 65) | def temporal_loss(flow_preds, flow_preds2, flow_gt, flow_gt2, valid, los...
  function compute_flow (line 126) | def compute_flow(Flow_Model, seq):
  function flow_warp (line 143) | def flow_warp(x, flow):
  function bidirectional_alignment (line 169) | def bidirectional_alignment(seq, flows_backward, flows_forward):
  function consistency_loss (line 202) | def consistency_loss(seq, disparities, Flow_Model, alpha=50):

FILE: train_utils/utils.py
  function count_parameters (line 13) | def count_parameters(model):
  function run_test_eval (line 17) | def run_test_eval(ckpt_path, eval_type, evaluator, model, dataloaders, w...
  function fig2data (line 66) | def fig2data(fig):
  function save_ims_to_tb (line 89) | def save_ims_to_tb(writer, batch, output, total_steps):
Condensed preview — 83 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (571K chars).
[
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 3671,
    "preview": "<h1 align='center' style=\"text-align:center; font-weight:bold; font-size:2.0em;letter-spacing:2.0px;\">\nStereo Any Video:"
  },
  {
    "path": "assets/1",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "checkpoints/checkpoints here.txt",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "data/datasets/dataset here.txt",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "datasets/augmentor.py",
    "chars": 12953,
    "preview": "import numpy as np\nimport random\nfrom PIL import Image\n\nimport cv2\n\ncv2.setNumThreads(0)\ncv2.ocl.setUseOpenCL(False)\n\nfr"
  },
  {
    "path": "datasets/frame_utils.py",
    "chars": 3366,
    "preview": "import numpy as np\nfrom PIL import Image\nfrom os.path import *\nimport re\nimport imageio\nimport cv2\n\ncv2.setNumThreads(0)"
  },
  {
    "path": "datasets/video_datasets.py",
    "chars": 79294,
    "preview": "import os\nimport copy\nimport gzip\nimport logging\nimport torch\nimport numpy as np\nimport torch.utils.data as data\nimport "
  },
  {
    "path": "demo.py",
    "chars": 6638,
    "preview": "import sys\n\nimport argparse\nimport os\nimport cv2\nimport glob\nimport numpy as np\nimport torch\nimport torch.nn.functional "
  },
  {
    "path": "demo.sh",
    "chars": 187,
    "preview": "#!/bin/bash\n\nexport PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH\n\npython demo.py --ckpt ./checkpoints/StereoAnyVideo_M"
  },
  {
    "path": "evaluate_stereoanyvideo.sh",
    "chars": 1509,
    "preview": "#!/bin/bash\n\nexport PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH\n\n# evaluate on [sintel, dynamicreplica, infinigensv, "
  },
  {
    "path": "evaluation/configs/eval_dynamic_replica.yaml",
    "chars": 165,
    "preview": "defaults:\n  - default_config_eval\nvisualize_interval: -1\nexp_dir: ./outputs/stereoanyvideo_DynamicReplica\nsample_len: 15"
  },
  {
    "path": "evaluation/configs/eval_infinigensv.yaml",
    "chars": 207,
    "preview": "defaults:\n  - default_config_eval\nvisualize_interval: -1\nrender_bin_size: 0\nexp_dir: ./outputs/stereoanyvideo_InfinigenS"
  },
  {
    "path": "evaluation/configs/eval_kittidepth.yaml",
    "chars": 206,
    "preview": "defaults:\n  - default_config_eval\nvisualize_interval: -1\nrender_bin_size: 0\nexp_dir: ./outputs/stereoanyvideo_KITTIDepth"
  },
  {
    "path": "evaluation/configs/eval_sintel_clean.yaml",
    "chars": 201,
    "preview": "defaults:\n  - default_config_eval\nvisualize_interval: -1\nrender_bin_size: 0\nexp_dir: ./outputs/stereoanyvideo_sintel_cle"
  },
  {
    "path": "evaluation/configs/eval_sintel_final.yaml",
    "chars": 200,
    "preview": "defaults:\n  - default_config_eval\nvisualize_interval: -1\nrender_bin_size: 0\nexp_dir: ./outputs/stereoanyvideo_sintel_fin"
  },
  {
    "path": "evaluation/configs/eval_southkensington.yaml",
    "chars": 203,
    "preview": "defaults:\n  - default_config_eval\nvisualize_interval: 1\nexp_dir: ./outputs/stereoanyvideo_SouthKensingtonIndoor\nsample_l"
  },
  {
    "path": "evaluation/configs/eval_vkitti2.yaml",
    "chars": 199,
    "preview": "defaults:\n  - default_config_eval\nvisualize_interval: -1\nrender_bin_size: 0\nexp_dir: ./outputs/stereoanyvideo_VKITTI2\nsa"
  },
  {
    "path": "evaluation/core/evaluator.py",
    "chars": 6702,
    "preview": "import os\nimport numpy as np\nimport cv2\nfrom collections import defaultdict\nimport torch.nn.functional as F\nimport torch"
  },
  {
    "path": "evaluation/evaluate.py",
    "chars": 4416,
    "preview": "import json\nimport os\nfrom dataclasses import dataclass, field\nfrom typing import Any, Dict, Optional\n\nimport hydra\nimpo"
  },
  {
    "path": "evaluation/utils/eval_utils.py",
    "chars": 17789,
    "preview": "from dataclasses import dataclass\nfrom typing import Dict, Optional, Union\nfrom stereoanyvideo.evaluation.utils.ssim imp"
  },
  {
    "path": "evaluation/utils/ssim.py",
    "chars": 11023,
    "preview": "# Copyright 2020 by Gongfan Fang, Zhejiang University.\n# All rights reserved.\nimport warnings\n\nimport torch\nimport torch"
  },
  {
    "path": "evaluation/utils/utils.py",
    "chars": 12844,
    "preview": "from collections import defaultdict\nimport configparser\nimport os\nimport math\nfrom typing import Optional, List\nimport t"
  },
  {
    "path": "models/Video-Depth-Anything/app.py",
    "chars": 5259,
    "preview": "# Copyright (2025) Bytedance Ltd. and/or its affiliates \n\n# Licensed under the Apache License, Version 2.0 (the \"License"
  },
  {
    "path": "models/Video-Depth-Anything/get_weights.sh",
    "chars": 271,
    "preview": "#!/bin/bash\n\nmkdir checkpoints\ncd checkpoints\nwget https://huggingface.co/depth-anything/Video-Depth-Anything-Small/reso"
  },
  {
    "path": "models/Video-Depth-Anything/run.py",
    "chars": 2899,
    "preview": "# Copyright (2025) Bytedance Ltd. and/or its affiliates \n\n# Licensed under the Apache License, Version 2.0 (the \"License"
  },
  {
    "path": "models/Video-Depth-Anything/utils/dc_utils.py",
    "chars": 4416,
    "preview": "# This file is originally from DepthCrafter/depthcrafter/utils.py at main · Tencent/DepthCrafter\n# SPDX-License-Identifi"
  },
  {
    "path": "models/Video-Depth-Anything/utils/util.py",
    "chars": 2449,
    "preview": "# Copyright (2025) Bytedance Ltd. and/or its affiliates \n\n# Licensed under the Apache License, Version 2.0 (the \"License"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2.py",
    "chars": 15178,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2_layers/__init__.py",
    "chars": 382,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2_layers/attention.py",
    "chars": 2343,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2_layers/block.py",
    "chars": 9332,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2_layers/drop_path.py",
    "chars": 1160,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2_layers/layer_scale.py",
    "chars": 823,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2_layers/mlp.py",
    "chars": 1272,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2_layers/patch_embed.py",
    "chars": 2832,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dinov2_layers/swiglu_ffn.py",
    "chars": 1859,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dpt.py",
    "chars": 5467,
    "preview": "# Copyright (2025) Bytedance Ltd. and/or its affiliates \n\n# Licensed under the Apache License, Version 2.0 (the \"License"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/dpt_temporal.py",
    "chars": 4231,
    "preview": "# Copyright (2025) Bytedance Ltd. and/or its affiliates \n\n# Licensed under the Apache License, Version 2.0 (the \"License"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/motion_module/attention.py",
    "chars": 16719,
    "preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/motion_module/motion_module.py",
    "chars": 11162,
    "preview": "# This file is originally from AnimateDiff/animatediff/models/motion_module.py at main · guoyww/AnimateDiff\n# SPDX-Licen"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/util/blocks.py",
    "chars": 4064,
    "preview": "import torch.nn as nn\n\n\ndef _make_scratch(in_shape, out_shape, groups=1, expand=False):\n    scratch = nn.Module()\n\n    o"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/util/transform.py",
    "chars": 6075,
    "preview": "import numpy as np\nimport cv2\n\n\nclass Resize(object):\n    \"\"\"Resize sample to given size (width, height).\n    \"\"\"\n\n    d"
  },
  {
    "path": "models/Video-Depth-Anything/video_depth_anything/video_depth.py",
    "chars": 6363,
    "preview": "# Copyright (2025) Bytedance Ltd. and/or its affiliates \n\n# Licensed under the Apache License, Version 2.0 (the \"License"
  },
  {
    "path": "models/core/attention.py",
    "chars": 8421,
    "preview": "import math\nimport copy\nimport torch\nimport torch.nn as nn\nfrom torch.nn import Module, Dropout\n\n\"\"\"\nLinear Transformer "
  },
  {
    "path": "models/core/corr.py",
    "chars": 3401,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom einops import rearrange\n\n\ndef bilinear_sampler(img, coords, mode=\"bili"
  },
  {
    "path": "models/core/extractor.py",
    "chars": 9725,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport os\nimport sys\nimport import"
  },
  {
    "path": "models/core/model_zoo.py",
    "chars": 1184,
    "preview": "import copy\nfrom pytorch3d.implicitron.tools.config import get_default_args\nfrom stereoanyvideo.models.stereoanyvideo_mo"
  },
  {
    "path": "models/core/stereoanyvideo.py",
    "chars": 12842,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom typing import Dict, List\nfrom einops import rea"
  },
  {
    "path": "models/core/update.py",
    "chars": 14208,
    "preview": "from einops import rearrange\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom opt_einsum import c"
  },
  {
    "path": "models/core/utils/config.py",
    "chars": 33854,
    "preview": "import dataclasses\nimport inspect\nimport itertools\nimport sys\nimport warnings\nfrom collections import Counter, defaultdi"
  },
  {
    "path": "models/core/utils/utils.py",
    "chars": 1597,
    "preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom scipy import interpolate\n\n\ndef interp(tensor, size)"
  },
  {
    "path": "models/raft_model.py",
    "chars": 2592,
    "preview": "from types import SimpleNamespace\nfrom typing import ClassVar\nimport torch.nn.functional as F\n\nfrom pytorch3d.implicitro"
  },
  {
    "path": "models/stereoanyvideo_model.py",
    "chars": 1124,
    "preview": "from typing import ClassVar\n\nimport torch\nimport torch.nn.functional as F\nfrom pytorch3d.implicitron.tools.config import"
  },
  {
    "path": "requirements.txt",
    "chars": 123,
    "preview": "hydra-core==1.1\nnumpy==1.23.5\nmunch==2.5.0\nomegaconf==2.1.0\nflow_vis==0.1\neinops==0.4.1\nopt_einsum==3.3.0\nrequests\nmovie"
  },
  {
    "path": "third_party/RAFT/LICENSE",
    "chars": 1512,
    "preview": "BSD 3-Clause License\n\nCopyright (c) 2020, princeton-vl\nAll rights reserved.\n\nRedistribution and use in source and binary"
  },
  {
    "path": "third_party/RAFT/README.md",
    "chars": 2725,
    "preview": "# RAFT\nThis repository contains the source code for our paper:\n\n[RAFT: Recurrent All Pairs Field Transforms for Optical "
  },
  {
    "path": "third_party/RAFT/alt_cuda_corr/correlation.cpp",
    "chars": 1368,
    "preview": "#include <torch/extension.h>\n#include <vector>\n\n// CUDA forward declarations\nstd::vector<torch::Tensor> corr_cuda_forwar"
  },
  {
    "path": "third_party/RAFT/alt_cuda_corr/correlation_kernel.cu",
    "chars": 10249,
    "preview": "#include <torch/extension.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <vector>\n\n\n#define BLOCK_H 4\n#define B"
  },
  {
    "path": "third_party/RAFT/alt_cuda_corr/setup.py",
    "chars": 381,
    "preview": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n\nsetup(\n    name='corr"
  },
  {
    "path": "third_party/RAFT/chairs_split.txt",
    "chars": 45743,
    "preview": "1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n1\n2\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n1\n2\n1\n"
  },
  {
    "path": "third_party/RAFT/core/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "third_party/RAFT/core/corr.py",
    "chars": 3086,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom .utils.utils import bilinear_sampler, coords_grid\n\ntry:\n    import alt"
  },
  {
    "path": "third_party/RAFT/core/datasets.py",
    "chars": 9245,
    "preview": "# Data loading based on https://github.com/NVIDIA/flownet2-pytorch\n\nimport numpy as np\nimport torch\nimport torch.utils.d"
  },
  {
    "path": "third_party/RAFT/core/extractor.py",
    "chars": 8847,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(se"
  },
  {
    "path": "third_party/RAFT/core/raft.py",
    "chars": 4924,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .update import BasicUpdateBl"
  },
  {
    "path": "third_party/RAFT/core/update.py",
    "chars": 5227,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass FlowHead(nn.Module):\n    def __init__(self, i"
  },
  {
    "path": "third_party/RAFT/core/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "third_party/RAFT/core/utils/augmentor.py",
    "chars": 9108,
    "preview": "import numpy as np\nimport random\nimport math\nfrom PIL import Image\n\nimport cv2\ncv2.setNumThreads(0)\ncv2.ocl.setUseOpenCL"
  },
  {
    "path": "third_party/RAFT/core/utils/flow_viz.py",
    "chars": 4318,
    "preview": "# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization\n\n\n# MIT License\n#\n# Copyright "
  },
  {
    "path": "third_party/RAFT/core/utils/frame_utils.py",
    "chars": 4024,
    "preview": "import numpy as np\nfrom PIL import Image\nfrom os.path import *\nimport re\n\nimport cv2\ncv2.setNumThreads(0)\ncv2.ocl.setUse"
  },
  {
    "path": "third_party/RAFT/core/utils/utils.py",
    "chars": 2489,
    "preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom scipy import interpolate\n\n\nclass InputPadder:\n    \""
  },
  {
    "path": "third_party/RAFT/demo.py",
    "chars": 2073,
    "preview": "import sys\nsys.path.append('core')\n\nimport argparse\nimport os\nimport cv2\nimport glob\nimport numpy as np\nimport torch\nfro"
  },
  {
    "path": "third_party/RAFT/download_models.sh",
    "chars": 97,
    "preview": "#!/bin/bash\nwget https://dl.dropboxusercontent.com/s/4j4z58wuv8o0mfz/models.zip\nunzip models.zip\n"
  },
  {
    "path": "third_party/RAFT/evaluate.py",
    "chars": 6618,
    "preview": "import sys\nsys.path.append('core')\n\nfrom PIL import Image\nimport argparse\nimport os\nimport time\nimport numpy as np\nimpor"
  },
  {
    "path": "third_party/RAFT/train.py",
    "chars": 7987,
    "preview": "from __future__ import print_function, division\nimport sys\nsys.path.append('core')\n\nimport argparse\nimport os\nimport cv2"
  },
  {
    "path": "third_party/RAFT/train_mixed.sh",
    "chars": 921,
    "preview": "#!/bin/bash\nmkdir -p checkpoints\npython -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num"
  },
  {
    "path": "third_party/RAFT/train_standard.sh",
    "chars": 860,
    "preview": "#!/bin/bash\nmkdir -p checkpoints\npython -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --n"
  },
  {
    "path": "train_stereoanyvideo.py",
    "chars": 11369,
    "preview": "import argparse\nimport logging\nfrom pathlib import Path\nfrom tqdm import tqdm\nimport os\nimport cv2\nimport numpy as np\nim"
  },
  {
    "path": "train_stereoanyvideo.sh",
    "chars": 371,
    "preview": "#!/bin/bash\n\nexport PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH\n\npython train_stereoanyvideo.py --batch_size 1 \\\n --s"
  },
  {
    "path": "train_utils/logger.py",
    "chars": 2133,
    "preview": "import logging\nimport os\n\nfrom torch.utils.tensorboard import SummaryWriter\n\n\nclass Logger:\n\n    SUM_FREQ = 100\n\n    def"
  },
  {
    "path": "train_utils/losses.py",
    "chars": 8395,
    "preview": "import torch\nfrom einops import rearrange\nimport torch.nn.functional as F\n\n\ndef sequence_loss(flow_preds, flow_gt, valid"
  },
  {
    "path": "train_utils/utils.py",
    "chars": 4586,
    "preview": "import numpy as np\nimport os\nimport torch\n\nimport json\nimport flow_vis\nimport matplotlib.pyplot as plt\n\nimport stereoany"
  }
]

About this extraction

This page contains the full source code of the TomTomTommi/stereoanyvideo GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 83 files (513.1 KB), approximately 168.4k tokens, and a symbol index with 534 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!