Repository: TomTomTommi/stereoanyvideo
Branch: main
Commit: 3a8beddc470c
Files: 83
Total size: 513.1 KB
Directory structure:
gitextract_751vu01n/
├── LICENSE
├── README.md
├── assets/
│ └── 1
├── checkpoints/
│ └── checkpoints here.txt
├── data/
│ └── datasets/
│ └── dataset here.txt
├── datasets/
│ ├── augmentor.py
│ ├── frame_utils.py
│ └── video_datasets.py
├── demo.py
├── demo.sh
├── evaluate_stereoanyvideo.sh
├── evaluation/
│ ├── configs/
│ │ ├── eval_dynamic_replica.yaml
│ │ ├── eval_infinigensv.yaml
│ │ ├── eval_kittidepth.yaml
│ │ ├── eval_sintel_clean.yaml
│ │ ├── eval_sintel_final.yaml
│ │ ├── eval_southkensington.yaml
│ │ └── eval_vkitti2.yaml
│ ├── core/
│ │ └── evaluator.py
│ ├── evaluate.py
│ └── utils/
│ ├── eval_utils.py
│ ├── ssim.py
│ └── utils.py
├── models/
│ ├── Video-Depth-Anything/
│ │ ├── app.py
│ │ ├── get_weights.sh
│ │ ├── run.py
│ │ ├── utils/
│ │ │ ├── dc_utils.py
│ │ │ └── util.py
│ │ └── video_depth_anything/
│ │ ├── dinov2.py
│ │ ├── dinov2_layers/
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── block.py
│ │ │ ├── drop_path.py
│ │ │ ├── layer_scale.py
│ │ │ ├── mlp.py
│ │ │ ├── patch_embed.py
│ │ │ └── swiglu_ffn.py
│ │ ├── dpt.py
│ │ ├── dpt_temporal.py
│ │ ├── motion_module/
│ │ │ ├── attention.py
│ │ │ └── motion_module.py
│ │ ├── util/
│ │ │ ├── blocks.py
│ │ │ └── transform.py
│ │ └── video_depth.py
│ ├── core/
│ │ ├── attention.py
│ │ ├── corr.py
│ │ ├── extractor.py
│ │ ├── model_zoo.py
│ │ ├── stereoanyvideo.py
│ │ ├── update.py
│ │ └── utils/
│ │ ├── config.py
│ │ └── utils.py
│ ├── raft_model.py
│ └── stereoanyvideo_model.py
├── requirements.txt
├── third_party/
│ └── RAFT/
│ ├── LICENSE
│ ├── README.md
│ ├── alt_cuda_corr/
│ │ ├── correlation.cpp
│ │ ├── correlation_kernel.cu
│ │ └── setup.py
│ ├── chairs_split.txt
│ ├── core/
│ │ ├── __init__.py
│ │ ├── corr.py
│ │ ├── datasets.py
│ │ ├── extractor.py
│ │ ├── raft.py
│ │ ├── update.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── augmentor.py
│ │ ├── flow_viz.py
│ │ ├── frame_utils.py
│ │ └── utils.py
│ ├── demo.py
│ ├── download_models.sh
│ ├── evaluate.py
│ ├── train.py
│ ├── train_mixed.sh
│ └── train_standard.sh
├── train_stereoanyvideo.py
├── train_stereoanyvideo.sh
└── train_utils/
├── logger.py
├── losses.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
Stereo Any Video:
Temporally Consistent Stereo Matching

## Installation
Installation with cuda 12.2
Setup the root for all source files
git clone https://github.com/tomtomtommi/stereoanyvideo
cd stereoanyvideo
export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH
Create a conda env
conda create -n sav python=3.10
conda activate sav
Install requirements
conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=12.1 -c pytorch -c nvidia
pip install pip==24.0
pip install pytorch_lightning==1.6.0
pip install iopath
conda install -c bottler nvidiacub
pip install scikit-image matplotlib imageio plotly opencv-python
conda install -c fvcore -c conda-forge fvcore
pip install black usort flake8 flake8-bugbear flake8-comprehensions
conda install pytorch3d -c pytorch3d
pip install -r requirements.txt
pip install timm
Download VDA checkpoints
cd models/Video-Depth-Anything
sh get_weights.sh
## Inference a stereo video
```
sh demo.sh
```
Before running, download the checkpoints on [google drive](https://drive.google.com/drive/folders/1c7L065dcBWhCYYjWYo2edGOG605PnpXv?usp=sharing) .
Copy the checkpoints to `./checkpoints/`
In default, left and right camera videos are supposed to be structured like this:
```none
./demo_video/
├── left
├── left000000.png
├── left000001.png
├── left000002.png
...
├── right
├── right000000.png
├── right000001.png
├── right000002.png
...
```
A simple way to run the demo is using SouthKensingtonSV.
To test on your own data, modify `--path ./demo_video/`. More arguments can be found and modified in ` demo.py`
## Dataset
Download the following datasets and put in `./data/datasets/`:
- [SceneFlow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
- [Sintel](http://sintel.is.tue.mpg.de/stereo)
- [Dynamic_Replica](https://dynamic-stereo.github.io/)
- [KITTI Depth](https://www.cvlibs.net/datasets/kitti/eval_depth_all.php)
- [Infinigen SV](https://tomtomtommi.github.io/BiDAVideo/)
- [Virtual KITTI2](https://europe.naverlabs.com/proxy-virtual-worlds-vkitti-2/)
- [SouthKensington SV](https://tomtomtommi.github.io/BiDAVideo/)
## Evaluation
```
sh evaluate_stereoanyvideo.sh
```
## Training
```
sh train_stereoanyvideo.sh
```
## Citation
If you use our method in your research, please consider citing:
```
@inproceedings{jing2025stereo,
title={Stereo any video: Temporally consistent stereo matching},
author={Jing, Junpeng and Luo, Weixun and Mao, Ye and Mikolajczyk, Krystian},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={20836--20846},
year={2025}
}
```
================================================
FILE: assets/1
================================================
================================================
FILE: checkpoints/checkpoints here.txt
================================================
================================================
FILE: data/datasets/dataset here.txt
================================================
================================================
FILE: datasets/augmentor.py
================================================
import numpy as np
import random
from PIL import Image
import cv2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
from torchvision.transforms import ColorJitter, functional, Compose
class AdjustGamma(object):
def __init__(self, gamma_min, gamma_max, gain_min=1.0, gain_max=1.0):
self.gamma_min, self.gamma_max, self.gain_min, self.gain_max = (
gamma_min,
gamma_max,
gain_min,
gain_max,
)
def __call__(self, sample):
gain = random.uniform(self.gain_min, self.gain_max)
gamma = random.uniform(self.gamma_min, self.gamma_max)
return functional.adjust_gamma(sample, gamma, gain)
def __repr__(self):
return f"Adjust Gamma {self.gamma_min}, ({self.gamma_max}) and Gain ({self.gain_min}, {self.gain_max})"
class SequenceDispFlowAugmentor:
def __init__(
self,
crop_size,
min_scale=-0.2,
max_scale=0.5,
do_flip=True,
yjitter=False,
saturation_range=[0.6, 1.4],
gamma=[1, 1, 1, 1],
):
# spatial augmentation params
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 1.0
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.yjitter = yjitter
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = Compose(
[
ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=saturation_range,
hue=0.5 / 3.14,
),
AdjustGamma(*gamma),
]
)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, seq):
"""Photometric augmentation"""
# asymmetric
if np.random.rand() < self.asymmetric_color_aug_prob:
for i in range(len(seq)):
for cam in (0, 1):
seq[i][cam] = np.array(
self.photo_aug(Image.fromarray(seq[i][cam])), dtype=np.uint8
)
# symmetric
else:
image_stack = np.concatenate(
[seq[i][cam] for i in range(len(seq)) for cam in (0, 1)], axis=0
)
image_stack = np.array(
self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8
)
split = np.split(image_stack, len(seq) * 2, axis=0)
for i in range(len(seq)):
seq[i][0] = split[2 * i]
seq[i][1] = split[2 * i + 1]
return seq
def eraser_transform(self, seq, bounds=[50, 100]):
"""Occlusion augmentation"""
ht, wd = seq[0][0].shape[:2]
for i in range(len(seq)):
for cam in (0, 1):
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(seq[0][0].reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(bounds[0], bounds[1])
dy = np.random.randint(bounds[0], bounds[1])
seq[i][cam][y0 : y0 + dy, x0 : x0 + dx, :] = mean_color
return seq
def spatial_transform(self, img, disp):
# randomly sample scale
ht, wd = img[0][0].shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd)
)
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = scale
scale_y = scale
if np.random.rand() < self.stretch_prob:
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_x = np.clip(scale_x, min_scale, None)
scale_y = np.clip(scale_y, min_scale, None)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
for i in range(len(img)):
for cam in (0, 1):
img[i][cam] = cv2.resize(
img[i][cam],
None,
fx=scale_x,
fy=scale_y,
interpolation=cv2.INTER_LINEAR,
)
if len(disp[i]) > 0:
disp[i][cam] = cv2.resize(
disp[i][cam],
None,
fx=scale_x,
fy=scale_y,
interpolation=cv2.INTER_LINEAR,
)
disp[i][cam] = disp[i][cam] * [scale_x, scale_y]
if self.yjitter:
y0 = np.random.randint(2, img[0][0].shape[0] - self.crop_size[0] - 2)
x0 = np.random.randint(2, img[0][0].shape[1] - self.crop_size[1] - 2)
for i in range(len(img)):
y1 = y0 + np.random.randint(-2, 2 + 1)
img[i][0] = img[i][0][
y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
]
img[i][1] = img[i][1][
y1 : y1 + self.crop_size[0], x0 : x0 + self.crop_size[1]
]
if len(disp[i]) > 0:
disp[i][0] = disp[i][0][
y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
]
disp[i][1] = disp[i][1][
y1 : y1 + self.crop_size[0], x0 : x0 + self.crop_size[1]
]
else:
y0 = np.random.randint(0, img[0][0].shape[0] - self.crop_size[0])
x0 = np.random.randint(0, img[0][0].shape[1] - self.crop_size[1])
for i in range(len(img)):
for cam in (0, 1):
img[i][cam] = img[i][cam][
y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
]
if len(disp[i]) > 0:
disp[i][cam] = disp[i][cam][
y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
]
return img, disp
def __call__(self, img, disp):
img = self.color_transform(img)
img = self.eraser_transform(img)
img, disp = self.spatial_transform(img, disp)
for i in range(len(img)):
for cam in (0, 1):
img[i][cam] = np.ascontiguousarray(img[i][cam])
if len(disp[i]) > 0:
disp[i][cam] = np.ascontiguousarray(disp[i][cam])
return img, disp
class SequenceDispSparseFlowAugmentor:
def __init__(
self,
crop_size,
min_scale=-0.2,
max_scale=0.5,
do_flip=True,
yjitter=False,
saturation_range=[0.6, 1.4],
gamma=[1, 1, 1, 1],
):
# spatial augmentation params
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 1.0
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.yjitter = yjitter
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = Compose(
[
ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=saturation_range,
hue=0.5 / 3.14,
),
AdjustGamma(*gamma),
]
)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, seq):
"""Photometric augmentation"""
# symmetric
image_stack = np.concatenate(
[seq[i][cam] for i in range(len(seq)) for cam in (0, 1)], axis=0
)
image_stack = np.array(
self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8
)
split = np.split(image_stack, len(seq) * 2, axis=0)
for i in range(len(seq)):
seq[i][0] = split[2 * i]
seq[i][1] = split[2 * i + 1]
return seq
def eraser_transform(self, seq, bounds=[50, 100]):
"""Occlusion augmentation"""
ht, wd = seq[0][0].shape[:2]
for i in range(len(seq)):
for cam in (0, 1):
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(seq[0][0].reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(bounds[0], bounds[1])
dy = np.random.randint(bounds[0], bounds[1])
seq[i][cam][y0 : y0 + dy, x0 : x0 + dx, :] = mean_color
return seq
def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
ht, wd = flow.shape[:2]
coords = np.meshgrid(np.arange(wd), np.arange(ht))
coords = np.stack(coords, axis=-1)
coords = coords.reshape(-1, 2).astype(np.float32)
flow = flow.reshape(-1, 2).astype(np.float32)
valid = valid.reshape(-1).astype(np.float32)
coords0 = coords[valid>=1]
flow0 = flow[valid>=1]
ht1 = int(round(ht * fy))
wd1 = int(round(wd * fx))
coords1 = coords0 * [fx, fy]
flow1 = flow0 * [fx, fy]
xx = np.round(coords1[:,0]).astype(np.int32)
yy = np.round(coords1[:,1]).astype(np.int32)
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
xx = xx[v]
yy = yy[v]
flow1 = flow1[v]
flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
valid_img = np.zeros([ht1, wd1], dtype=np.int32)
flow_img[yy, xx] = flow1
valid_img[yy, xx] = 1
return flow_img, valid_img
def spatial_transform(self, img, disp, valid):
# randomly sample scale
ht, wd = img[0][0].shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd)
)
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = scale
scale_y = scale
if np.random.rand() < self.stretch_prob:
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_x = np.clip(scale_x, min_scale, None)
scale_y = np.clip(scale_y, min_scale, None)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
for i in range(len(img)):
for cam in (0, 1):
img[i][cam] = cv2.resize(
img[i][cam],
None,
fx=scale_x,
fy=scale_y,
interpolation=cv2.INTER_LINEAR,
)
if len(disp[i]) > 0:
disp[i][cam], valid[i][cam] = self.resize_sparse_flow_map(disp[i][cam], valid[i][cam], fx=scale_x, fy=scale_y)
margin_y = 20
margin_x = 50
y0 = np.random.randint(0, img[0][0].shape[0] - self.crop_size[0])
x0 = np.random.randint(0, img[0][0].shape[1] - self.crop_size[1])
for i in range(len(img)):
for cam in (0, 1):
img[i][cam] = img[i][cam][
y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
]
if len(disp[i]) > 0:
disp[i][cam] = disp[i][cam][
y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]
]
valid[i][cam] = valid[i][cam][
y0: y0 + self.crop_size[0], x0: x0 + self.crop_size[1]
]
return img, disp, valid
def __call__(self, img, disp, valid):
img = self.color_transform(img)
img = self.eraser_transform(img)
img, disp, valid = self.spatial_transform(img, disp, valid)
for i in range(len(img)):
for cam in (0, 1):
img[i][cam] = np.ascontiguousarray(img[i][cam])
if len(disp[i]) > 0:
disp[i][cam] = np.ascontiguousarray(disp[i][cam])
valid[i][cam] = np.ascontiguousarray(valid[i][cam])
return img, disp, valid
================================================
FILE: datasets/frame_utils.py
================================================
import numpy as np
from PIL import Image
from os.path import *
import re
import imageio
import cv2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
TAG_CHAR = np.array([202021.25], np.float32)
def readFlow(fn):
"""Read .flo file in Middlebury format"""
# Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
# print 'fn = %s'%(fn)
with open(fn, "rb") as f:
magic = np.fromfile(f, np.float32, count=1)
if 202021.25 != magic:
print("Magic number incorrect. Invalid .flo file")
return None
else:
w = np.fromfile(f, np.int32, count=1)
h = np.fromfile(f, np.int32, count=1)
# print 'Reading %d x %d flo file\n' % (w, h)
data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
# Reshape data into 3D array (columns, rows, bands)
# The reshape here is for visualization, the original code is (w,h,2)
return np.resize(data, (int(h), int(w), 2))
def readPFM(file):
file = open(file, "rb")
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header == b"PF":
color = True
elif header == b"Pf":
color = False
else:
raise Exception("Not a PFM file.")
dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline())
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception("Malformed PFM header.")
scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = "<"
scale = -scale
else:
endian = ">" # big-endian
data = np.fromfile(file, endian + "f")
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data
def readDispSintelStereo(file_name):
"""Return disparity read from filename."""
f_in = np.array(Image.open(file_name))
d_r = f_in[:, :, 0].astype("float64")
d_g = f_in[:, :, 1].astype("float64")
d_b = f_in[:, :, 2].astype("float64")
disp = d_r * 4 + d_g / (2 ** 6) + d_b / (2 ** 14)
mask = np.array(Image.open(file_name.replace("disparities", "occlusions")))
valid = (mask == 0) & (disp > 0)
return disp, valid
def readDispMiddlebury(file_name):
assert basename(file_name) == "disp0GT.pfm"
disp = readPFM(file_name).astype(np.float32)
assert len(disp.shape) == 2
nocc_pix = file_name.replace("disp0GT.pfm", "mask0nocc.png")
assert exists(nocc_pix)
nocc_pix = imageio.imread(nocc_pix) == 255
assert np.any(nocc_pix)
return disp, nocc_pix
def read_gen(file_name, pil=False):
ext = splitext(file_name)[-1]
if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg":
return Image.open(file_name)
elif ext == ".bin" or ext == ".raw":
return np.load(file_name)
elif ext == ".flo":
return readFlow(file_name).astype(np.float32)
elif ext == ".pfm":
flow = readPFM(file_name).astype(np.float32)
if len(flow.shape) == 2:
return flow
else:
return flow[:, :, :-1]
return []
================================================
FILE: datasets/video_datasets.py
================================================
import os
import copy
import gzip
import logging
import torch
import numpy as np
import torch.utils.data as data
import torch.nn.functional as F
import os.path as osp
from glob import glob
import cv2
import re
from scipy.spatial.transform import Rotation as R
from collections import defaultdict
from PIL import Image
from dataclasses import dataclass
from typing import List, Optional
from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.implicitron.dataset.types import (
FrameAnnotation as ImplicitronFrameAnnotation,
load_dataclass,
)
from stereoanyvideo.datasets import frame_utils
from stereoanyvideo.evaluation.utils.eval_utils import depth2disparity_scale
from stereoanyvideo.datasets.augmentor import SequenceDispFlowAugmentor, SequenceDispSparseFlowAugmentor
@dataclass
class DynamicReplicaFrameAnnotation(ImplicitronFrameAnnotation):
"""A dataclass used to load annotations from json."""
camera_name: Optional[str] = None
class StereoSequenceDataset(data.Dataset):
def __init__(self, aug_params=None, sparse=False, reader=None):
self.augmentor = None
self.sparse = sparse
self.img_pad = (
aug_params.pop("img_pad", None) if aug_params is not None else None
)
if aug_params is not None and "crop_size" in aug_params:
if sparse:
self.augmentor = SequenceDispSparseFlowAugmentor(**aug_params)
else:
self.augmentor = SequenceDispFlowAugmentor(**aug_params)
if reader is None:
self.disparity_reader = frame_utils.read_gen
else:
self.disparity_reader = reader
self.depth_reader = self._load_depth
self.is_test = False
self.sample_list = []
self.extra_info = []
self.depth_eps = 1e-5
def _load_depth(self, depth_path):
if depth_path[-3:] == "npy":
return self._load_npy_depth(depth_path)
elif depth_path[-3:] == "png":
if "kitti_depth" in depth_path:
return self._load_kitti_depth(depth_path)
elif "vkitti2" in depth_path:
return self._load_vkitti2(depth_path)
else:
return self._load_16big_png_depth(depth_path)
else:
raise ValueError("Other format depth is not implemented")
def _load_npy_depth(self, depth_npy):
depth = np.load(depth_npy)
return depth
def _load_vkitti2(self, depth_png):
depth_image = cv2.imread(depth_png, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
depth_in_meters = depth_image.astype(np.float32) / 100.0
depth_in_meters[depth_image == 0] = -1.
return depth_in_meters
def _load_kitti_depth(self, depth_png):
# depth_image = cv2.imread(depth_png, cv2.IMREAD_UNCHANGED)
# depth_in_meters = depth_image.astype(np.float32) / 256.0
depth_image = np.array(Image.open(depth_png), dtype=int)
# make sure we have a proper 16bit depth map here.. not 8bit!
assert (np.max(depth_image) > 255)
depth_in_meters = depth_image.astype(np.float32) / 256.
depth_in_meters[depth_image == 0] = -1.
return depth_in_meters
def _load_16big_png_depth(self, depth_png):
with Image.open(depth_png) as depth_pil:
# the image is stored with 16-bit depth but PIL reads it as I (32 bit).
# we cast it to uint16, then reinterpret as float16, then cast to float32
depth = (
np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
.astype(np.float32)
.reshape((depth_pil.size[1], depth_pil.size[0]))
)
return depth
def load_tartanair_pose(self, filepath, index=0):
poses = np.loadtxt(filepath)
tx, ty, tz, qx, qy, qz, qw = poses[index]
# Quaternion to rotation matrix
r = R.from_quat([qx, qy, qz, qw])
R_mat = r.as_matrix()
# Assemble 4x4 pose matrix
T = np.eye(4)
T[:3, :3] = R_mat
T[:3, 3] = [tx, ty, tz]
return T
def parse_txt_file(self, file_path):
with open(file_path, 'r') as file:
data = file.read()
# Regex patterns
intrinsic_pattern = re.compile(r"Intrinsic:\s*\[\[([^\]]+)\]\s*\[\s*([^\]]+)\]\s*\[\s*([^\]]+)\]\]")
frame_pattern = re.compile(r"Frame (\d+): Pose: ([\w\d]+)\n([\s\S]+?)(?=Frame|\Z)")
# Extract intrinsic matrix (K)
intrinsic_match = intrinsic_pattern.search(data)
if intrinsic_match:
K = np.array([list(map(float, row.split())) for row in intrinsic_match.groups()])
else:
raise ValueError("Intrinsic matrix not found in the file")
# Extract frames and compute R and T
frames = []
for frame_match in frame_pattern.finditer(data):
frame_number = int(frame_match.group(1))
pose_id = frame_match.group(2)
pose_matrix = np.array([list(map(float, row.split())) for row in frame_match.group(3).strip().split('\n')])
# Decompose pose matrix into R and T
R = pose_matrix[:3, :3] # The upper-left 3x3 part is the rotation matrix
T = pose_matrix[:3, 3] # The first three elements of the fourth column is the translation vector
frames.append({
'frame_number': frame_number,
'pose_id': pose_id,
'pose_matrix': pose_matrix,
'R': R,
'T': T
})
return K, frames
def _get_pytorch3d_camera(
self, entry_viewpoint, image_size, scale: float
) -> PerspectiveCameras:
assert entry_viewpoint is not None
# principal point and focal length
principal_point = torch.tensor(
entry_viewpoint.principal_point, dtype=torch.float
)
focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
half_image_size_wh_orig = (
torch.tensor(list(reversed(image_size)), dtype=torch.float) / 2.0
)
# first, we convert from the dataset's NDC convention to pixels
format = entry_viewpoint.intrinsics_format
if format.lower() == "ndc_norm_image_bounds":
# this is e.g. currently used in CO3D for storing intrinsics
rescale = half_image_size_wh_orig
elif format.lower() == "ndc_isotropic":
rescale = half_image_size_wh_orig.min()
else:
raise ValueError(f"Unknown intrinsics format: {format}")
# principal point and focal length in pixels
principal_point_px = half_image_size_wh_orig - principal_point * rescale
focal_length_px = focal_length * rescale
# now, convert from pixels to PyTorch3D v0.5+ NDC convention
# if self.image_height is None or self.image_width is None:
out_size = list(reversed(image_size))
half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
half_min_image_size_output = half_image_size_output.min()
# rescaled principal point and focal length in ndc
principal_point = (
half_image_size_output - principal_point_px * scale
) / half_min_image_size_output
focal_length = focal_length_px * scale / half_min_image_size_output
return PerspectiveCameras(
focal_length=focal_length[None],
principal_point=principal_point[None],
R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
)
def _get_pytorch3d_camera_from_blender(self, R, T, K, image_size, scale: float) -> PerspectiveCameras:
assert R is not None and T is not None and K is not None
assert R.shape == (3, 3), f"Expected R to be 3x3, but got {R.shape}"
assert T.shape == (3,), f"Expected T to be a 3-element vector, but got {T.shape}"
assert K.shape == (3, 3), f"Expected K to be 3x3, but got {K.shape}"
# Extract principal point and focal length from K
fx = K[0, 0]
fy = K[1, 1]
cx = K[0, 2]
cy = K[1, 2]
principal_point = torch.tensor([cx, cy], dtype=torch.float)
focal_length = torch.tensor([fx, fy], dtype=torch.float)
half_image_size_wh_orig = (
torch.tensor(list(reversed(image_size)), dtype=torch.float) / 2.0
)
# Adjust principal point and focal length in pixels
principal_point_px = principal_point * scale
focal_length_px = focal_length * scale
# Convert from pixels to PyTorch3D NDC convention
principal_point = (principal_point_px - half_image_size_wh_orig) / half_image_size_wh_orig
half_min_image_size_output = half_image_size_wh_orig.min()
focal_length = focal_length_px / half_min_image_size_output
R = R.T @ np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]], dtype=np.float64)
T = T @ np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]], dtype=np.float64)
# Convert R and T to PyTorch tensors
R_tensor = torch.tensor(R, dtype=torch.float).unsqueeze(0) # Add batch dimension
T_tensor = torch.tensor(T, dtype=torch.float).view(1, 3) # Ensure T is a 1x3 tensor
# Return PerspectiveCameras object
return PerspectiveCameras(
focal_length=focal_length.unsqueeze(0), # Add batch dimension
principal_point=principal_point.unsqueeze(0), # Add batch dimension
R=R_tensor,
T=T_tensor,
)
def _get_output_tensor(self, sample):
output_tensor = defaultdict(list)
sample_size = len(sample["image"]["left"])
output_tensor_keys = ["img", "disp", "valid_disp", "mask"]
add_keys = ["viewpoint", "metadata"]
for add_key in add_keys:
if add_key in sample:
output_tensor_keys.append(add_key)
for key in output_tensor_keys:
output_tensor[key] = [[] for _ in range(sample_size)]
if "viewpoint" in sample:
viewpoint_left = self._get_pytorch3d_camera(
sample["viewpoint"]["left"][0],
sample["metadata"]["left"][0][1],
scale=1.0,
)
viewpoint_right = self._get_pytorch3d_camera(
sample["viewpoint"]["right"][0],
sample["metadata"]["right"][0][1],
scale=1.0,
)
depth2disp_scale = depth2disparity_scale(
viewpoint_left,
viewpoint_right,
torch.Tensor(sample["metadata"]["left"][0][1])[None],
)
output_tensor["depth2disp_scale"] = [depth2disp_scale]
if "camera" in sample:
output_tensor["viewpoint"] = [[] for _ in range(sample_size)]
# InfinigenSV
if sample["camera"]["left"][0][-3:] == "npz":
# Note that the K, R, T is based on Blender world Matrix
camera_left = np.load(sample["camera"]["left"][0])
camera_right = np.load(sample["camera"]["right"][0])
camera_left_RT = camera_left['T']
camera_right_RT = camera_right['T']
camera_left_K = camera_left['K']
camera_right_K = camera_right['K']
camera_left_T = camera_left['T'][:3, 3]
camera_left_R = camera_left['T'][:3, :3]
fix_baseline = np.linalg.norm(camera_left_RT[:3, 3] - camera_right_RT[:3, 3])
focal_length_px = camera_left_K[0][0]
depth2disp_scale = focal_length_px * fix_baseline
# Sintel
elif sample["camera"]["left"][0][-3:] == "cam":
TAG_FLOAT = 202021.25
f = open(sample["camera"]["left"][0], 'rb')
check = np.fromfile(f, dtype=np.float32, count=1)[0]
assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(
TAG_FLOAT, check)
camera_left_K = np.fromfile(f, dtype='float64', count=9).reshape((3, 3))
camera_left_RT = np.fromfile(f, dtype='float64', count=12).reshape((3, 4))
fix_baseline = 0.1 # From the MPI Sintel dataset website, the baseline of the cameras = 10cm = 0.1m
focal_length_px = camera_left_K[0][0]
depth2disp_scale = focal_length_px * fix_baseline
camera_left_T = camera_left_RT[:3, 3]
camera_left_R = camera_left_RT[:3, :3]
# Spring
elif any(filename in path for path in sample["camera"]["left"] for filename in ["focaldistance.txt", "extrinsics.txt", "intrinsics.txt"]):
for path in sample["camera"]["left"]:
if "intrinsics.txt" in path:
intrinsics_path = path
elif "extrinsics.txt" in path:
extrinsics_path = path
fx, fy, cx, cy = np.loadtxt(intrinsics_path)[0]
# Build the 3x3 intrinsic matrix
camera_left_K = np.array([
[fx, 0, cx],
[0, fy, cy],
[0, 0, 1]
])
focal_length_px = camera_left_K[0][0]
fix_baseline = 0.065 # From the dataset website, the baseline of the cameras = 6.5cm = 0.065m
depth2disp_scale = focal_length_px * fix_baseline
camera_left_RT = np.loadtxt(extrinsics_path).reshape(-1, 4, 4)[0]
camera_left_T = camera_left_RT[:3, 3]
camera_left_R = camera_left_RT[:3, :3]
# TartanAir
elif sample["camera"]["left"][0][-13:] == "pose_left.txt":
fx, fy, cx, cy = 320.0, 320.0, 320.0, 240.0
# Build the 3x3 intrinsic matrix
camera_left_K = np.array([
[fx, 0, cx],
[0, fy, cy],
[0, 0, 1]
])
focal_length_px = camera_left_K[0][0]
fix_baseline = 0.25
depth2disp_scale = focal_length_px * fix_baseline
camera_left_RT = self.load_tartanair_pose(sample["camera"]["left"][0], index=0)
camera_left_T = camera_left_RT[:3, 3]
camera_left_R = camera_left_RT[:3, :3]
# KITTI Depth
elif sample["camera"]["left"][0][-20:] == "calib_cam_to_cam.txt":
calib_data = {}
with open(sample["camera"]["left"][0], 'r') as f:
for line in f:
key, value = line.split(':', 1)
calib_data[key.strip()] = value.strip()
P_key = 'P_rect_02'
if P_key in calib_data:
P_values = np.array([float(x) for x in calib_data[P_key].split()])
P_matrix = P_values.reshape(3, 4)
else:
raise KeyError(f"Projection matrix for camera not found in calibration data")
focal_length_px = P_matrix[0, 0]
T_key1 = 'T_02'
T_key2 = 'T_03'
if T_key1 in calib_data and T_key2 in calib_data:
T1 = np.array([float(x) for x in calib_data[T_key1].split()])
T2 = np.array([float(x) for x in calib_data[T_key2].split()])
baseline = np.linalg.norm(T1 - T2)
else:
raise KeyError(f"Translation vectors for cameras not found in calibration data")
R_key1 = 'R_rect_02'
R_key2 = 'R_rect_03'
if R_key1 in calib_data and R_key2 in calib_data:
R1 = np.array([float(x) for x in calib_data[R_key1].split()]).reshape(3, 3)
R2 = np.array([float(x) for x in calib_data[R_key2].split()]).reshape(3, 3)
else:
raise KeyError(f"Rotation vectors for cameras not found in calibration data")
depth2disp_scale = focal_length_px * baseline
camera_left_K = P_matrix[:, :3]
camera_left_T = T1
camera_left_R = R1
# VKITTI2
elif sample["camera"]["left"][0][-13:] == "intrinsic.txt":
baseline = 0.532725
with open(sample["camera"]["left"][0], 'r') as f:
line = f.readlines()[1]
values = line.strip().split()
frame = int(values[0])
camera_id = int(values[1])
fx = float(values[2])
fy = float(values[3])
cx = float(values[4])
cy = float(values[5])
# Construct the intrinsic matrix
camera_left_K = torch.tensor([[fx, 0, cx],
[0, fy, cy],
[0, 0, 1]], dtype=torch.float32)
depth2disp_scale = camera_left_K[0, 0] * baseline
with open(sample["camera"]["left"][0].replace("intrinsic.txt", "extrinsic.txt"), 'r') as f:
line = f.readlines()[1]
values = line.strip().split()
frame = int(values[0])
camera_id = int(values[1])
# Extract rotation (3x3) and translation (3x1)
camera_left_R = np.array([
[float(values[2]), float(values[3]), float(values[4])],
[float(values[6]), float(values[7]), float(values[8])],
[float(values[10]), float(values[11]), float(values[12])]
], dtype=np.float32)
camera_left_T = np.array([
float(values[5]),
float(values[9]),
float(values[13])
], dtype=np.float32)
# SouthKensington
elif sample["camera"]["left"][0][-3:] == "txt":
camera_left_K, frames = self.parse_txt_file(sample["camera"]["left"][0])
fix_baseline = 0.12
camera_left_R = frames[0]['R']
camera_left_T = frames[0]['T']
focal_length_px = camera_left_K[0][0]
depth2disp_scale = focal_length_px * fix_baseline
else:
raise ValueError("Other format camera is not implemented")
output_tensor["depth2disp_scale"] = [depth2disp_scale]
output_tensor["RTK"] = [camera_left_R, camera_left_T, camera_left_K]
for i in range(sample_size):
for cam in ["left", "right"]:
if "mask" in sample and cam in sample["mask"]:
mask = frame_utils.read_gen(sample["mask"][cam][i])
mask = np.array(mask) / 255.0
output_tensor["mask"][i].append(mask)
if "viewpoint" in sample and cam in sample["viewpoint"]:
viewpoint = self._get_pytorch3d_camera(
sample["viewpoint"][cam][i],
sample["metadata"][cam][i][1],
scale=1.0,
)
output_tensor["viewpoint"][i].append(viewpoint)
if "camera" in sample:
# InfinigenSV
if sample["camera"]["left"][0][-3:] == "npz" and cam in sample["camera"]:
# Note that the K, R, T is based on Blender world Matrix
camera = np.load(sample["camera"][cam][i])
camera_K = camera['K']
camera_T = camera['T'][:3, 3]
camera_R = camera['T'][:3, :3]
viewpoint = self._get_pytorch3d_camera_from_blender(
camera_R, camera_T, camera_K,
sample["metadata"][cam][i][1],
scale=1.0,
)
output_tensor["viewpoint"][i].append(viewpoint)
# Sintel
elif sample["camera"]["left"][0][-3:] == "cam" and cam in sample["camera"]:
f = open(sample["camera"]["left"][0], 'rb')
check = np.fromfile(f, dtype=np.float32, count=1)[0]
assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(
TAG_FLOAT, check)
camera_K = np.fromfile(f, dtype='float64', count=9).reshape((3, 3))
camera_RT = np.fromfile(f, dtype='float64', count=12).reshape((3, 4))
camera_T = camera_RT[:3, 3]
camera_R = camera_RT[:3, :3]
viewpoint = self._get_pytorch3d_camera_from_blender(
camera_R, camera_T, camera_K,
sample["metadata"][cam][i][1],
scale=1.0,
)
output_tensor["viewpoint"][i].append(viewpoint)
# TartanAir
elif sample["camera"]["left"][0][-13:] == "pose_left.txt":
fx, fy, cx, cy = 320.0, 320.0, 320.0, 240.0
# Build the 3x3 intrinsic matrix
camera_left_K = np.array([
[fx, 0, cx],
[0, fy, cy],
[0, 0, 1]
])
focal_length_px = camera_left_K[0][0]
fix_baseline = 0.25
depth2disp_scale = focal_length_px * fix_baseline
camera_left_RT = self.load_tartanair_pose(sample["camera"]["left"][0], index=i)
camera_left_T = camera_left_RT[:3, 3]
camera_left_R = camera_left_RT[:3, :3]
# Spring
elif any(filename in path for path in sample["camera"]["left"] for filename
in ["focaldistance.txt", "extrinsics.txt", "intrinsics.txt"]) and cam in sample["camera"]:
for path in sample["camera"]["left"]:
if "intrinsics.txt" in path:
intrinsics_path = path
elif "extrinsics.txt" in path:
extrinsics_path = path
fx, fy, cx, cy = np.loadtxt(intrinsics_path)[0]
# Build the 3x3 intrinsic matrix
camera_K = np.array([
[fx, 0, cx],
[0, fy, cy],
[0, 0, 1]
])
focal_length_px = camera_left_K[0][0]
fix_baseline = 0.065 # From the dataset website, the baseline of the cameras = 6.5cm = 0.065m
depth2disp_scale = focal_length_px * fix_baseline
camera_RT = np.loadtxt(extrinsics_path).reshape(-1, 4, 4)[i]
camera_T = camera_RT[:3, 3]
camera_R = camera_RT[:3, :3]
viewpoint = self._get_pytorch3d_camera_from_blender(
camera_R, camera_T, camera_K,
sample["metadata"][cam][i][1],
scale=1.0,
)
output_tensor["viewpoint"][i].append(viewpoint)
# KITTI Depth
elif sample["camera"]["left"][0][-20:] == "calib_cam_to_cam.txt":
calib_data = {}
with open(sample["camera"]["left"][0], 'r') as f:
for line in f:
key, value = line.split(':', 1)
calib_data[key.strip()] = value.strip()
P_key = 'P_rect_02'
if P_key in calib_data:
P_values = np.array([float(x) for x in calib_data[P_key].split()])
P_matrix = P_values.reshape(3, 4)
else:
raise KeyError(f"Projection matrix for camera not found in calibration data")
focal_length_px = P_matrix[0, 0]
T_key1 = 'T_02'
T_key2 = 'T_03'
if T_key1 in calib_data and T_key2 in calib_data:
T1 = np.array([float(x) for x in calib_data[T_key1].split()])
T2 = np.array([float(x) for x in calib_data[T_key2].split()])
baseline = np.linalg.norm(T1 - T2)
else:
raise KeyError(f"Translation vectors for cameras not found in calibration data")
R_key1 = 'R_rect_02'
R_key2 = 'R_rect_03'
if R_key1 in calib_data and R_key2 in calib_data:
R1 = np.array([float(x) for x in calib_data[R_key1].split()]).reshape(3, 3)
R2 = np.array([float(x) for x in calib_data[R_key2].split()]).reshape(3, 3)
else:
raise KeyError(f"Rotation vectors for cameras not found in calibration data")
depth2disp_scale = focal_length_px * baseline
camera_K = P_matrix[:, :3]
camera_T = T1
camera_R = R1
viewpoint = self._get_pytorch3d_camera_from_blender(
camera_R, camera_T, camera_K,
sample["metadata"][cam][i][1],
scale=1.0,
)
output_tensor["viewpoint"][i].append(viewpoint)
# VKITTI2
elif sample["camera"]["left"][0][-13:] == "intrinsic.txt":
with open(sample["camera"]["left"][0], 'r') as f:
line = f.readlines()[1+i]
values = line.strip().split()
frame = int(values[0])
camera_id = int(values[1])
fx = float(values[2])
fy = float(values[3])
cx = float(values[4])
cy = float(values[5])
# Construct the intrinsic matrix
camera_K = torch.tensor([[fx, 0, cx],
[0, fy, cy],
[0, 0, 1]], dtype=torch.float32)
with open(sample["camera"]["left"][0].replace("intrinsic.txt", "extrinsic.txt"), 'r') as f:
line = f.readlines()[1+i]
values = line.strip().split()
frame = int(values[0])
camera_id = int(values[1])
# Extract rotation (3x3) and translation (3x1)
camera_R = np.array([
[float(values[2]), float(values[3]), float(values[4])],
[float(values[6]), float(values[7]), float(values[8])],
[float(values[10]), float(values[11]), float(values[12])]
], dtype=np.float32)
camera_T = np.array([
float(values[5]),
float(values[9]),
float(values[13])
], dtype=np.float32)
viewpoint = self._get_pytorch3d_camera_from_blender(
camera_R, camera_T, camera_K,
sample["metadata"][cam][i][1],
scale=1.0,
)
output_tensor["viewpoint"][i].append(viewpoint)
# SouthKensington
elif sample["camera"]["left"][0][-3:] == "txt" and cam in sample["camera"]:
camera_left_K, frames = self.parse_txt_file(sample["camera"]["left"][0])
camera_K = camera_left_K
camera_R = frames[i]['R']
camera_T = frames[i]['T']
viewpoint = self._get_pytorch3d_camera_from_blender(
camera_R, camera_T, camera_K,
sample["metadata"][cam][i][1],
scale=1.0,
)
output_tensor["viewpoint"][i].append(viewpoint)
if "metadata" in sample and cam in sample["metadata"]:
metadata = sample["metadata"][cam][i]
output_tensor["metadata"][i].append(metadata)
if cam in sample["image"]:
img = frame_utils.read_gen(sample["image"][cam][i])
img = np.array(img).astype(np.uint8)
# grayscale images
if len(img.shape) == 2:
img = np.tile(img[..., None], (1, 1, 3))
else:
img = img[..., :3]
output_tensor["img"][i].append(img)
if cam in sample["disparity"]:
disp = self.disparity_reader(sample["disparity"][cam][i])
if isinstance(disp, tuple):
disp, valid_disp = disp
else:
valid_disp = disp < 512
disp = np.array(disp).astype(np.float32)
disp = np.stack([-disp, np.zeros_like(disp)], axis=-1)
# disp = np.stack([disp, np.zeros_like(disp)], axis=-1)
output_tensor["disp"][i].append(disp)
output_tensor["valid_disp"][i].append(valid_disp)
elif "depth" in sample and cam in sample["depth"]:
depth = self.depth_reader(sample["depth"][cam][i])
depth_mask = depth < self.depth_eps
depth[depth_mask] = self.depth_eps
disp = depth2disp_scale / depth
disp[depth_mask] = 0
valid_disp = (disp < 512) * (1 - depth_mask)
disp = np.array(disp).astype(np.float32)
disp = np.stack([-disp, np.zeros_like(disp)], axis=-1)
output_tensor["disp"][i].append(disp)
output_tensor["valid_disp"][i].append(valid_disp)
return output_tensor
def __getitem__(self, index):
im_tensor = {"img"}
sample = self.sample_list[index]
if self.is_test:
sample_size = len(sample["image"]["left"])
im_tensor["img"] = [[] for _ in range(sample_size)]
for i in range(sample_size):
for cam in ["left", "right"]:
img = frame_utils.read_gen(sample["image"][cam][i])
img = np.array(img).astype(np.uint8)[..., :3]
img = torch.from_numpy(img).permute(2, 0, 1).float()
im_tensor["img"][i].append(img)
im_tensor["img"] = torch.stack(im_tensor["img"])
return im_tensor, self.extra_info[index]
index = index % len(self.sample_list)
try:
output_tensor = self._get_output_tensor(sample)
except:
logging.warning(f"Exception in loading sample {index}!")
index = np.random.randint(len(self.sample_list))
logging.info(f"New index is {index}")
sample = self.sample_list[index]
output_tensor = self._get_output_tensor(sample)
sample_size = len(sample["image"]["left"])
if self.augmentor is not None:
if self.sparse:
output_tensor["img"], output_tensor["disp"], output_tensor["valid_disp"] = self.augmentor(
output_tensor["img"], output_tensor["disp"], output_tensor["valid_disp"]
)
else:
output_tensor["img"], output_tensor["disp"] = self.augmentor(
output_tensor["img"], output_tensor["disp"]
)
for i in range(sample_size):
for cam in (0, 1):
if cam < len(output_tensor["img"][i]):
img = (
torch.from_numpy(output_tensor["img"][i][cam])
.permute(2, 0, 1)
.float()
)
if self.img_pad is not None:
padH, padW = self.img_pad
img = F.pad(img, [padW] * 2 + [padH] * 2)
output_tensor["img"][i][cam] = img
if cam < len(output_tensor["disp"][i]):
disp = (
torch.from_numpy(output_tensor["disp"][i][cam])
.permute(2, 0, 1)
.float()
)
if self.sparse:
valid_disp = torch.from_numpy(
output_tensor["valid_disp"][i][cam]
)
else:
valid_disp = (
(disp[0].abs() < 512)
& (disp[1].abs() < 512)
& (disp[0].abs() != 0)
)
disp = disp[:1]
output_tensor["disp"][i][cam] = disp
output_tensor["valid_disp"][i][cam] = valid_disp.float()
if "mask" in output_tensor and cam < len(output_tensor["mask"][i]):
mask = torch.from_numpy(output_tensor["mask"][i][cam]).float()
output_tensor["mask"][i][cam] = mask
if "viewpoint" in output_tensor and cam < len(
output_tensor["viewpoint"][i]
):
viewpoint = output_tensor["viewpoint"][i][cam]
output_tensor["viewpoint"][i][cam] = viewpoint
res = {}
if "viewpoint" in output_tensor and self.split != "train":
res["viewpoint"] = output_tensor["viewpoint"]
if "metadata" in output_tensor and self.split != "train":
res["metadata"] = output_tensor["metadata"]
if "depth2disp_scale" in output_tensor and self.split != "train":
res["depth2disp_scale"] = output_tensor["depth2disp_scale"]
if "RTK" in output_tensor and self.split != "train":
res["RTK"] = output_tensor["RTK"]
for k, v in output_tensor.items():
if k != "viewpoint" and k != "metadata" and k != "depth2disp_scale" and k != "RTK":
for i in range(len(v)):
if len(v[i]) > 0:
v[i] = torch.stack(v[i])
if len(v) > 0 and (len(v[0]) > 0):
res[k] = torch.stack(v)
return res
def __mul__(self, v):
copy_of_self = copy.deepcopy(self)
copy_of_self.sample_list = v * copy_of_self.sample_list
copy_of_self.extra_info = v * copy_of_self.extra_info
return copy_of_self
def __len__(self):
return len(self.sample_list)
class DynamicReplicaDataset(StereoSequenceDataset):
def __init__(
self,
aug_params=None,
root="./data/datasets/dynamic_replica_data",
split="train",
sample_len=-1,
only_first_n_samples=-1,
):
super(DynamicReplicaDataset, self).__init__(aug_params)
self.root = root
self.sample_len = sample_len
self.split = split
frame_annotations_file = f"frame_annotations_{split}.jgz"
with gzip.open(
osp.join(root, split, frame_annotations_file), "rt", encoding="utf8"
) as zipfile:
frame_annots_list = load_dataclass(
zipfile, List[DynamicReplicaFrameAnnotation]
)
seq_annot = defaultdict(lambda: defaultdict(list))
for frame_annot in frame_annots_list:
seq_annot[frame_annot.sequence_name][frame_annot.camera_name].append(
frame_annot
)
for seq_name in seq_annot.keys():
try:
filenames = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
for framedata in seq_annot[seq_name][cam]:
im_path = osp.join(root, split, framedata.image.path)
depth_path = osp.join(root, split, framedata.depth.path)
mask_path = osp.join(root, split, framedata.mask.path)
assert os.path.isfile(im_path), im_path
if self.split == 'train':
assert os.path.isfile(depth_path), depth_path
assert os.path.isfile(mask_path), mask_path
filenames["image"][cam].append(im_path)
if os.path.isfile(depth_path):
filenames["depth"][cam].append(depth_path)
filenames["mask"][cam].append(mask_path)
filenames["viewpoint"][cam].append(framedata.viewpoint)
filenames["metadata"][cam].append(
[framedata.sequence_name, framedata.image.size]
)
for k in filenames.keys():
assert (
len(filenames[k][cam])
== len(filenames["image"][cam])
> 0
), framedata.sequence_name
seq_len = len(filenames["image"][cam])
print("seq_len", seq_name, seq_len)
if split == "train":
for ref_idx in range(0, seq_len, 3):
step = 1 if self.sample_len == 1 else np.random.randint(1, 6)
if ref_idx + step * self.sample_len < seq_len:
sample = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
for idx in range(
ref_idx, ref_idx + step * self.sample_len, step
):
for k in filenames.keys():
if "mask" not in k:
sample[k][cam].append(
filenames[k][cam][idx]
)
self.sample_list.append(sample)
else:
step = self.sample_len if self.sample_len > 0 else seq_len
counter = 0
for ref_idx in range(0, seq_len, step):
sample = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
for idx in range(ref_idx, ref_idx + step):
for k in filenames.keys():
sample[k][cam].append(filenames[k][cam][idx])
self.sample_list.append(sample)
counter += 1
if only_first_n_samples > 0 and counter >= only_first_n_samples:
break
except Exception as e:
print(e)
print("Skipping sequence", seq_name)
assert len(self.sample_list) > 0, "No samples found"
print(f"Added {len(self.sample_list)} from Dynamic Replica {split}")
logging.info(f"Added {len(self.sample_list)} from Dynamic Replica {split}")
class InfinigenStereoVideoDataset(StereoSequenceDataset):
def __init__(
self,
aug_params=None,
root="./data/datasets/InfinigenStereo",
split="train",
sample_len=-1,
only_first_n_samples=-1,
):
super(InfinigenStereoVideoDataset, self).__init__(aug_params)
self.root = root
self.sample_len = sample_len
self.split = split
sequence = sorted(
glob(osp.join(root, self.split, "*"))
)
for i in range(len(sequence)):
sequence_name = os.path.basename(sequence[i])
try:
filenames = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
suffix = "camera_0/" if cam == "left" else "camera_1/"
im_path_list = sorted(glob(osp.join(sequence[i], "frames/Image/", suffix, "*.png")))
depth_path_list = sorted(glob(osp.join(sequence[i], "frames/Depth/", suffix, "*.npy")))
camera_list = sorted(glob(osp.join(sequence[i], "frames/camview/", suffix, "*.npz")))
for j in range(len(im_path_list)):
im_path = im_path_list[j]
depth_path = depth_path_list[j]
camera_path = camera_list[j]
assert os.path.isfile(im_path), im_path
assert os.path.isfile(depth_path), depth_path
filenames["image"][cam].append(im_path)
filenames["depth"][cam].append(depth_path)
filenames["camera"][cam].append(camera_path)
filenames["metadata"][cam].append([sequence_name , (720,1280)])
for k in filenames.keys():
assert (
len(filenames[k][cam])
== len(filenames["image"][cam])
> 0
), sequence_name
seq_len = len(filenames["image"][cam])
print("seq_len", sequence_name, seq_len)
if self.split == "train":
for ref_idx in range(0, seq_len, 3):
step = 1 if self.sample_len == 1 else np.random.randint(1, 6)
if ref_idx + step * self.sample_len < seq_len:
sample = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
for idx in range(
ref_idx, ref_idx + step * self.sample_len, step
):
for k in filenames.keys():
if "mask" not in k:
sample[k][cam].append(
filenames[k][cam][idx]
)
self.sample_list.append(sample)
else:
step = self.sample_len if (self.sample_len > 0) and (self.sample_len < seq_len) else seq_len
print("sample_step", step)
counter = 0
for ref_idx in range(0, seq_len, step):
sample = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
for idx in range(ref_idx, ref_idx + step):
for k in filenames.keys():
sample[k][cam].append(filenames[k][cam][idx])
self.sample_list.append(sample)
counter += 1
if only_first_n_samples > 0 and counter >= only_first_n_samples:
break
except Exception as e:
print(e)
print("Skipping sequence", sequence_name)
assert len(self.sample_list) > 0, "No samples found"
print(f"Added {len(self.sample_list)} from Infinigen Stereo Video {split}")
logging.info(f"Added {len(self.sample_list)} from Infinigen Stereo Video {split}")
class SouthKensingtonStereoVideoDataset(StereoSequenceDataset):
def __init__(
self,
aug_params=None,
root="./data/datasets/SouthKensington/data/",
split="test",
subroot="",
sample_len=-1,
only_first_n_samples=-1,
):
super(SouthKensingtonStereoVideoDataset, self).__init__(aug_params)
self.root = root
self.split = split
self.sample_len = sample_len
sequence = sorted(
glob(osp.join(root, "*"))
)
for i in range(len(sequence)):
sequence_name = os.path.basename(sequence[i])
try:
filenames = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
im_path_list = sorted(glob(osp.join(sequence[i], "images", cam, "*.png")))
camera_path = glob(osp.join(sequence[i], "*.txt"))[0]
for j in range(len(im_path_list)):
im_path = im_path_list[j]
assert os.path.isfile(im_path), im_path
filenames["image"][cam].append(im_path)
filenames["camera"][cam].append(camera_path)
filenames["metadata"][cam].append([sequence_name , (720,1280)])
for k in filenames.keys():
assert (
len(filenames[k][cam])
== len(filenames["image"][cam])
> 0
), sequence_name
seq_len = len(filenames["image"][cam])
print("seq_len", sequence_name, seq_len)
step = self.sample_len if (self.sample_len > 0) and (self.sample_len < seq_len) else seq_len
print("sample_step", step)
counter = 0
for ref_idx in range(0, seq_len, step):
sample = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
for idx in range(ref_idx, ref_idx + step):
for k in filenames.keys():
sample[k][cam].append(filenames[k][cam][idx])
self.sample_list.append(sample)
counter += 1
if only_first_n_samples > 0 and counter >= only_first_n_samples:
break
except Exception as e:
print(e)
print("Skipping sequence", sequence_name)
assert len(self.sample_list) > 0, "No samples found"
print(f"Added {len(self.sample_list)} from SouthKensington Stereo Video")
logging.info(f"Added {len(self.sample_list)} from SouthKensington Stereo Video")
class KITTIDepthDataset(StereoSequenceDataset):
def __init__(
self,
aug_params=None,
root="./data/datasets/",
split="train",
sample_len=-1,
only_first_n_samples=-1,
):
super().__init__(aug_params, sparse=True)
# super(KITTIDepthDataset, self).__init__(aug_params)
image_root = osp.join(root, "kitti_depth", "input")
gt_root = osp.join(root, "kitti_depth", "gt_depth")
self.sample_len = sample_len
self.split = split
# Following CODD: https://github.com/facebookresearch/CODD
val_split = ['2011_10_03_drive_0042_sync'] # 1 scene
test_split = ['2011_09_26_drive_0002_sync', '2011_09_26_drive_0005_sync',
'2011_09_26_drive_0013_sync', '2011_09_26_drive_0020_sync',
'2011_09_26_drive_0023_sync', '2011_09_26_drive_0036_sync',
'2011_09_26_drive_0079_sync', '2011_09_26_drive_0095_sync',
'2011_09_26_drive_0113_sync', '2011_09_28_drive_0037_sync',
'2011_09_29_drive_0026_sync', '2011_09_30_drive_0016_sync',
'2011_10_03_drive_0047_sync'] # 13 scenes
sequence_root = sorted(glob(osp.join(gt_root, "*")))
train_list = []
val_list = []
test_list = []
for i in range(len(sequence_root)):
sequence_name = os.path.basename(os.path.normpath(sequence_root[i]))
if sequence_name in test_split:
test_list.append(sequence_root[i])
elif sequence_name in val_split:
val_list.append(sequence_root[i])
else:
train_list.append(sequence_root[i])
if self.split == "train":
sequence_split = train_list
elif self.split == "val":
sequence_split = val_list
elif self.split == "test":
sequence_split = test_list
else:
raise ValueError("Wrong Split: ", self.split)
for i in range(len(sequence_split)):
sequence_name = os.path.basename(os.path.normpath(sequence_split[i]))
sequence_camera = osp.join(image_root, sequence_name[:10], "calib_cam_to_cam.txt")
try:
filenames = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
suffix = "image_02/" if cam == "left" else "image_03/"
depth_path_list = sorted(
glob(osp.join(gt_root, sequence_name, "proj_depth", "groundtruth", suffix, "*.png")))
for j in range(len(depth_path_list)):
depth_path = depth_path_list[j]
assert os.path.isfile(depth_path), depth_path
filenames["depth"][cam].append(depth_path)
# find the corresponding images
im_name = os.path.basename(os.path.normpath(depth_path))
im_path = osp.join(image_root, sequence_name[:10], sequence_name, suffix, "data", im_name)
assert os.path.isfile(im_path), im_path
filenames["image"][cam].append(im_path)
filenames["camera"][cam].append(sequence_camera)
filenames["metadata"][cam].append([sequence_name, (370,1224)])
for k in filenames.keys():
assert (
len(filenames[k][cam])
== len(filenames["depth"][cam])
> 0
), sequence_name
seq_len = len(filenames["image"][cam])
print("seq_len", sequence_name, seq_len)
if self.split == "train":
for ref_idx in range(0, seq_len, 3):
step = 1 if self.sample_len == 1 else np.random.randint(1, 6)
if ref_idx + step * self.sample_len < seq_len:
sample = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
for idx in range(
ref_idx, ref_idx + step * self.sample_len, step
):
for k in filenames.keys():
if "mask" not in k:
sample[k][cam].append(
filenames[k][cam][idx]
)
self.sample_list.append(sample)
else:
step = self.sample_len if (self.sample_len > 0) and (self.sample_len < seq_len) else seq_len
print("sample_step", step)
counter = 0
for ref_idx in range(0, seq_len, step):
sample = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
for idx in range(ref_idx, ref_idx + step):
for k in filenames.keys():
sample[k][cam].append(filenames[k][cam][idx])
self.sample_list.append(sample)
counter += 1
if only_first_n_samples > 0 and counter >= only_first_n_samples:
break
except Exception as e:
print(e)
print("Skipping sequence", sequence_name)
assert len(self.sample_list) > 0, "No samples found"
print(f"Added {len(self.sample_list)} from KITTI Depth {split}")
logging.info(f"Added {len(self.sample_list)} from KITTI Depth {split}")
def split_train_valid(path_list, valid_keywords):
path_list_init = path_list
for kw in valid_keywords:
path_list = list(filter(lambda s: kw not in s, path_list))
train_path_list = sorted(path_list)
valid_path_list = sorted(list(set(path_list_init) - set(train_path_list)))
return train_path_list, valid_path_list
class TartanAirDataset(StereoSequenceDataset):
def __init__(
self,
aug_params=None,
root="./data/datasets/TartanAir/",
split="train",
sample_len=-1,
only_first_n_samples=-1,
):
super().__init__(aug_params, sparse=False)
self.sample_len = sample_len
self.split = split
# Each entry is (scene, motion, part)
test_entries = [
("abandonedfactory", "Easy", "P002"),
("abandonedfactory", "Hard", "P002"),
("amusement", "Easy", "P007"),
("amusement", "Hard", "P007"),
("carwelding", "Hard", "P003"),
("endofworld", "Easy", "P006"),
("endofworld", "Hard", "P006"),
("gascola", "Easy", "P001"),
("gascola", "Hard", "P001"),
("hospital", "Hard", "P042"),
("office", "Easy", "P006"),
("office", "Hard", "P006"),
("office2", "Easy", "P004"),
("office2", "Hard", "P004"),
("oldtown", "Hard", "P006"),
("soulcity", "Easy", "P008"),
("soulcity", "Hard", "P008"),
]
scene_root = sorted(glob(osp.join(root, "*")))
sequence_root_list = []
test_set = []
train_set = []
for i in range(len(scene_root)):
sequence_root_list += sorted(glob(osp.join(scene_root[i], "Easy", "*"))) + sorted(glob(osp.join(scene_root[i], "Hard", "*")))
for path in sequence_root_list:
parts = path.split("/")
if len(parts) < 5:
continue # skip malformed paths
scene, motion, part = parts[-3], parts[-2], parts[-1]
if (scene, motion, part) in test_entries:
test_set.append(path)
else:
train_set.append(path)
if self.split == "train":
sequence_root_list = train_set
elif self.split == "test":
sequence_root_list = test_set
else:
raise KeyError(f"Wrong Split!")
for i in range(len(sequence_root_list)):
filenames = defaultdict(lambda: defaultdict(list))
sequence_root = sequence_root_list[i]
parts = os.path.normpath(sequence_root).split(os.sep)
sequence_name = "_".join(parts[-3:])
try:
for cam in ['left', 'right']:
depth_path_list = sorted(glob(osp.join(sequence_root, "depth_left/", "*.npy")))
im_path_list = sorted(glob(osp.join(sequence_root, f"image_{cam}/", "*.png")))
pose_path = os.path.join(sequence_root, f"pose_{cam}.txt")
assert len(depth_path_list) == len(im_path_list), [len(depth_path_list), len(im_path_list)]
for j in range(len(depth_path_list)):
depth_path = depth_path_list[j]
assert os.path.isfile(depth_path), depth_path
filenames["depth"][cam].append(depth_path)
im_path = im_path_list[j]
assert os.path.isfile(im_path), im_path
filenames["image"][cam].append(im_path)
filenames["camera"][cam].append(pose_path)
filenames["metadata"][cam].append([sequence_name, (480,640)])
for k in filenames.keys():
assert (
len(filenames[k][cam])
== len(filenames["depth"][cam])
> 0
), sequence_name
seq_len = len(filenames["image"][cam])
print("seq_len", sequence_name, seq_len)
if self.split == "train":
for ref_idx in range(0, seq_len, 3):
step = 1 if self.sample_len == 1 else np.random.randint(1, 6)
if ref_idx + step * self.sample_len < seq_len:
sample = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
for idx in range(
ref_idx, ref_idx + step * self.sample_len, step
):
for k in filenames.keys():
if "mask" not in k:
sample[k][cam].append(
filenames[k][cam][idx]
)
self.sample_list.append(sample)
else:
step = self.sample_len if (self.sample_len > 0) and (self.sample_len < seq_len) else seq_len
print("sample_step", step)
counter = 0
for ref_idx in range(0, seq_len, step):
sample = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
for idx in range(ref_idx, ref_idx + step):
for k in filenames.keys():
sample[k][cam].append(filenames[k][cam][idx])
self.sample_list.append(sample)
counter += 1
if only_first_n_samples > 0 and counter >= only_first_n_samples:
break
except Exception as e:
print(e)
print("Skipping sequence", sequence_name)
assert len(self.sample_list) > 0, "No samples found"
print(f"Added {len(self.sample_list)} from TarTanAir {split}")
logging.info(f"Added {len(self.sample_list)} from TarTanAir {split}")
class VKITTI2Dataset(StereoSequenceDataset):
def __init__(
self,
aug_params=None,
root="./data/datasets/vkitti2/",
split="train",
sample_len=-1,
only_first_n_samples=-1,
):
super().__init__(aug_params, sparse=False)
self.sample_len = sample_len
self.split = split
if self.split == "train":
sequence_name_list = []
scenes = ['Scene01', 'Scene02', 'Scene06', 'Scene18', 'Scene20']
variations = ['15-deg-left', '15-deg-right', '30-deg-left', '30-deg-right',
'clone', 'fog', 'morning', 'overcast', 'rain', 'sunset']
for scene in scenes:
for variation in variations:
sequence_name = f"{scene}_{variation}"
sequence_name_list.append(sequence_name)
elif self.split == "test":
sequence_name_list = ["Scene01_15-deg-left", "Scene02_30-deg-right", "Scene06_fog", "Scene18_morning", "Scene20_rain"]
else:
raise KeyError(f"Wrong Split!")
for i in range(len(sequence_name_list)):
filenames = defaultdict(lambda: defaultdict(list))
sequence_name = sequence_name_list[i]
scene, variation = sequence_name.split("_")
try:
for cam in [('left', 0), ('right', 1)]:
depth_path_list = sorted(glob(osp.join(root, f"{scene}/{variation}/frames/depth/Camera_{cam[1]}/", "*.png")))
im_path_list = sorted(glob(osp.join(root, f"{scene}/{variation}/frames/rgb/Camera_{cam[1]}/", "*.jpg")))
intrinsic_path = os.path.join(root, f"{scene}/{variation}/intrinsic.txt")
assert len(depth_path_list) == len(im_path_list), [len(depth_path_list), len(im_path_list)]
for j in range(len(depth_path_list)):
depth_path = depth_path_list[j]
assert os.path.isfile(depth_path), depth_path
filenames["depth"][cam[0]].append(depth_path)
im_path = im_path_list[j]
assert os.path.isfile(im_path), im_path
filenames["image"][cam[0]].append(im_path)
filenames["camera"][cam[0]].append(intrinsic_path)
filenames["metadata"][cam[0]].append([sequence_name, (375,1242)])
for k in filenames.keys():
assert (
len(filenames[k][cam[0]])
== len(filenames["depth"][cam[0]])
> 0
), sequence_name
seq_len = len(filenames["image"][cam[0]])
print("seq_len", sequence_name, seq_len)
if self.split == "train":
for ref_idx in range(0, seq_len, 3):
step = 1 if self.sample_len == 1 else np.random.randint(1, 6)
if ref_idx + step * self.sample_len < seq_len:
sample = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
for idx in range(
ref_idx, ref_idx + step * self.sample_len, step
):
for k in filenames.keys():
if "mask" not in k:
sample[k][cam].append(
filenames[k][cam][idx]
)
self.sample_list.append(sample)
else:
step = self.sample_len if (self.sample_len > 0) and (self.sample_len < seq_len) else seq_len
print("sample_step", step)
counter = 0
for ref_idx in range(0, seq_len, step):
sample = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
for idx in range(ref_idx, ref_idx + step):
for k in filenames.keys():
sample[k][cam].append(filenames[k][cam][idx])
self.sample_list.append(sample)
counter += 1
if only_first_n_samples > 0 and counter >= only_first_n_samples:
break
except Exception as e:
print(e)
print("Skipping sequence", sequence_name)
assert len(self.sample_list) > 0, "No samples found"
print(f"Added {len(self.sample_list)} from VKITTI2 {split}")
logging.info(f"Added {len(self.sample_list)} from VKITTI2 {split}")
class SequenceSpringDataset(StereoSequenceDataset):
def __init__(
self,
aug_params=None,
sample_len=-1,
root="./data/datasets/Spring",
):
super(SequenceSpringDataset, self).__init__(aug_params)
self.split = "test"
self.sample_len = sample_len
original_length = len(self.sample_list)
image_paths = defaultdict(list)
disparity_paths = defaultdict(list)
camera_paths = defaultdict(list)
for cam in ["left", "right"]:
image_paths[cam] = sorted(
glob(osp.join(root, f"train_frame_{cam}/*"))
)
cam = "left"
disparity_paths[cam] = sorted(
glob(osp.join(root, f"train_disp1_{cam}/*"))
)
camera_paths[cam] = sorted(
glob(osp.join(root, "train_cam_data/*"))
)
num_seq = len(image_paths["left"])
# for each sequence
for seq_idx in range(num_seq):
sequence_name = os.path.basename(image_paths[cam][seq_idx])
sample = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
sample["image"][cam] = sorted(
glob(osp.join(image_paths[cam][seq_idx], f"frame_{cam}", "*.png"))
)[:sample_len]
# for _ in range(len(sample["image"][cam])):
for _ in range(sample_len):
sample["metadata"][cam].append([sequence_name, (1080, 1920)])
cam = "left"
sample["disparity"][cam] = sorted(
glob(osp.join(disparity_paths[cam][seq_idx], f"disp1_{cam}", "*.dsp5"))
)[:sample_len]
sample["camera"][cam] = sorted(
glob(osp.join(camera_paths[cam][seq_idx], "cam_data", "*.txt"))
)
self.sample_list.append(sample)
seq_len = len(sample["image"][cam])
print("seq_len", sequence_name, seq_len)
logging.info(
f"Added {len(self.sample_list) - original_length} from Spring Dataset"
)
class SequenceSceneFlowDataset(StereoSequenceDataset):
def __init__(
self,
aug_params=None,
root="./data/datasets",
dstype="frames_cleanpass",
sample_len=1,
things_test=False,
add_things=True,
add_monkaa=True,
add_driving=True,
):
super(SequenceSceneFlowDataset, self).__init__(aug_params)
self.root = root
self.dstype = dstype
self.sample_len = sample_len
if things_test:
self._add_things("TEST")
else:
if add_things:
self._add_things("TRAIN")
if add_monkaa:
self._add_monkaa()
if add_driving:
self._add_driving()
def _add_things(self, split="TRAIN"):
"""Add FlyingThings3D data"""
original_length = len(self.sample_list)
root = osp.join(self.root, "FlyingThings3D")
image_paths = defaultdict(list)
disparity_paths = defaultdict(list)
for cam in ["left", "right"]:
image_paths[cam] = sorted(
glob(osp.join(root, self.dstype, split, f"*/*/{cam}/"))
)
disparity_paths[cam] = [
path.replace(self.dstype, "disparity") for path in image_paths[cam]
]
# Choose a random subset of 400 images for validation
# state = np.random.get_state()
# np.random.seed(1000)
# val_idxs = set(np.random.permutation(len(image_paths["left"]))[:40])
# np.random.set_state(state)
# np.random.seed(0)
num_seq = len(image_paths["left"])
num = 0
for seq_idx in range(num_seq):
# if (split == "TEST" and seq_idx in val_idxs) or (
# split == "TRAIN" and not seq_idx in val_idxs
# ):
images, disparities = defaultdict(list), defaultdict(list)
for cam in ["left", "right"]:
images[cam] = sorted(
glob(osp.join(image_paths[cam][seq_idx], "*.png"))
)
disparities[cam] = sorted(
glob(osp.join(disparity_paths[cam][seq_idx], "*.pfm"))
)
num = num + len(images["left"])
self._append_sample(images, disparities)
print(num)
assert len(self.sample_list) > 0, "No samples found"
print(
f"Added {len(self.sample_list) - original_length} from FlyingThings {self.dstype}"
)
logging.info(
f"Added {len(self.sample_list) - original_length} from FlyingThings {self.dstype}"
)
def _add_monkaa(self):
"""Add FlyingThings3D data"""
original_length = len(self.sample_list)
root = osp.join(self.root, "Monkaa")
image_paths = defaultdict(list)
disparity_paths = defaultdict(list)
for cam in ["left", "right"]:
image_paths[cam] = sorted(glob(osp.join(root, self.dstype, f"*/{cam}/")))
disparity_paths[cam] = [
path.replace(self.dstype, "disparity") for path in image_paths[cam]
]
num_seq = len(image_paths["left"])
for seq_idx in range(num_seq):
images, disparities = defaultdict(list), defaultdict(list)
for cam in ["left", "right"]:
images[cam] = sorted(glob(osp.join(image_paths[cam][seq_idx], "*.png")))
disparities[cam] = sorted(
glob(osp.join(disparity_paths[cam][seq_idx], "*.pfm"))
)
self._append_sample(images, disparities)
assert len(self.sample_list) > 0, "No samples found"
print(
f"Added {len(self.sample_list) - original_length} from Monkaa {self.dstype}"
)
logging.info(
f"Added {len(self.sample_list) - original_length} from Monkaa {self.dstype}"
)
def _add_driving(self):
"""Add FlyingThings3D data"""
original_length = len(self.sample_list)
root = osp.join(self.root, "Driving")
image_paths = defaultdict(list)
disparity_paths = defaultdict(list)
for cam in ["left", "right"]:
image_paths[cam] = sorted(
glob(osp.join(root, self.dstype, f"*/*/*/{cam}/"))
)
disparity_paths[cam] = [
path.replace(self.dstype, "disparity") for path in image_paths[cam]
]
num_seq = len(image_paths["left"])
for seq_idx in range(num_seq):
images, disparities = defaultdict(list), defaultdict(list)
for cam in ["left", "right"]:
images[cam] = sorted(glob(osp.join(image_paths[cam][seq_idx], "*.png")))
disparities[cam] = sorted(
glob(osp.join(disparity_paths[cam][seq_idx], "*.pfm"))
)
self._append_sample(images, disparities)
assert len(self.sample_list) > 0, "No samples found"
print(
f"Added {len(self.sample_list) - original_length} from Driving {self.dstype}"
)
logging.info(
f"Added {len(self.sample_list) - original_length} from Driving {self.dstype}"
)
def _append_sample(self, images, disparities):
seq_len = len(images["left"])
for ref_idx in range(0, seq_len - self.sample_len):
sample = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
for idx in range(ref_idx, ref_idx + self.sample_len):
sample["image"][cam].append(images[cam][idx])
sample["disparity"][cam].append(disparities[cam][idx])
self.sample_list.append(sample)
sample = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
for idx in range(ref_idx, ref_idx + self.sample_len):
sample["image"][cam].append(images[cam][seq_len - idx - 1])
sample["disparity"][cam].append(disparities[cam][seq_len - idx - 1])
self.sample_list.append(sample)
class SequenceSintelStereo(StereoSequenceDataset):
def __init__(
self,
dstype="clean",
aug_params=None,
root="./data/datasets",
):
super().__init__(
aug_params, sparse=True, reader=frame_utils.readDispSintelStereo
)
self.dstype = dstype
self.split = "test"
original_length = len(self.sample_list)
image_root = osp.join(root, "sintel_stereo", "training")
image_paths = defaultdict(list)
disparity_paths = defaultdict(list)
camera_paths = defaultdict(list)
for cam in ["left", "right"]:
image_paths[cam] = sorted(
glob(osp.join(image_root, f"{self.dstype}_{cam}/*"))
)
cam = "left"
disparity_paths[cam] = [
path.replace(f"{self.dstype}_{cam}", "disparities")
for path in image_paths[cam]
]
camera_paths[cam] = [
path.replace(f"{self.dstype}_{cam}", "camdata_left")
for path in image_paths[cam]
]
num_seq = len(image_paths["left"])
# for each sequence
for seq_idx in range(num_seq):
sequence_name = os.path.basename(image_paths[cam][seq_idx])
sample = defaultdict(lambda: defaultdict(list))
for cam in ["left", "right"]:
sample["image"][cam] = sorted(
glob(osp.join(image_paths[cam][seq_idx], "*.png"))
)
for _ in range(len(sample["image"][cam])):
sample["metadata"][cam].append([sequence_name, (436, 1024)])
cam = "left"
sample["disparity"][cam] = sorted(
glob(osp.join(disparity_paths[cam][seq_idx], "*.png"))
)
sample["camera"][cam] = sorted(
glob(osp.join(camera_paths[cam][seq_idx], "*.cam"))
)
for im1, disp, camera in zip(sample["image"][cam], sample["disparity"][cam], sample["camera"][cam]):
assert (
im1.split("/")[-1].split(".")[0]
== disp.split("/")[-1].split(".")[0]
== camera.split("/")[-1].split(".")[0]
), (im1.split("/")[-1].split(".")[0], disp.split("/")[-1].split(".")[0], camera.split("/")[-1].split(".")[0])
self.sample_list.append(sample)
logging.info(
f"Added {len(self.sample_list) - original_length} from SintelStereo {self.dstype}"
)
def fetch_dataloader(args):
"""Create the data loader for the corresponding training set"""
aug_params = {
"crop_size": args.image_size,
"min_scale": args.spatial_scale[0],
"max_scale": args.spatial_scale[1],
"do_flip": False,
"yjitter": not args.noyjitter,
}
if hasattr(args, "saturation_range") and args.saturation_range is not None:
aug_params["saturation_range"] = args.saturation_range
if hasattr(args, "img_gamma") and args.img_gamma is not None:
aug_params["gamma"] = args.img_gamma
if hasattr(args, "do_flip") and args.do_flip is not None:
aug_params["do_flip"] = args.do_flip
train_dataset = None
add_monkaa = "monkaa" in args.train_datasets
add_driving = "driving" in args.train_datasets
add_things = "things" in args.train_datasets
add_dynamic_replica = "dynamic_replica" in args.train_datasets
add_infinigensv = "infinigen_stereovideo" in args.train_datasets
add_kittidepth = "kitti_depth" in args.train_datasets
add_vkitti2 = "vkitti2" in args.train_datasets
add_tartanair = "tartanair" in args.train_datasets
new_dataset = None
if add_monkaa or add_driving or add_things:
# clean_dataset = SequenceSceneFlowDataset(
# aug_params,
# dstype="frames_cleanpass",
# sample_len=args.sample_len,
# add_monkaa=add_monkaa,
# add_driving=add_driving,
# add_things=add_things,
# )
final_dataset = SequenceSceneFlowDataset(
aug_params,
dstype="frames_finalpass",
sample_len=args.sample_len,
add_monkaa=add_monkaa,
add_driving=add_driving,
add_things=add_things,
)
# new_dataset = clean_dataset + final_dataset
new_dataset = final_dataset
if add_dynamic_replica:
dr_dataset = DynamicReplicaDataset(
aug_params, split="train", sample_len=args.sample_len
)
if new_dataset is None:
new_dataset = dr_dataset
else:
new_dataset = new_dataset + dr_dataset
if add_infinigensv:
infinigensv_dataset = InfinigenStereoVideoDataset(
aug_params, split="train", sample_len=args.sample_len
)
if new_dataset is None:
new_dataset = infinigensv_dataset
else:
new_dataset = new_dataset + infinigensv_dataset + infinigensv_dataset + infinigensv_dataset
if add_kittidepth:
kittidepth_dataset = KITTIDepthDataset(
aug_params, split="train", sample_len=args.sample_len
)
if new_dataset is None:
new_dataset = kittidepth_dataset
else:
new_dataset = new_dataset + kittidepth_dataset
if add_vkitti2:
vkitti2_dataset = VKITTI2Dataset(
aug_params, split="train", sample_len=args.sample_len
)
if new_dataset is None:
new_dataset = vkitti2_dataset
else:
new_dataset = new_dataset + vkitti2_dataset
if add_tartanair:
tartanair_dataset = TartanAirDataset(
aug_params, split="train", sample_len=args.sample_len
)
if new_dataset is None:
new_dataset = tartanair_dataset
else:
new_dataset = new_dataset + tartanair_dataset
logging.info(f"Adding {len(new_dataset)} samples in total")
train_dataset = (
new_dataset if train_dataset is None else train_dataset + new_dataset
)
train_loader = data.DataLoader(
train_dataset,
batch_size=args.batch_size,
pin_memory=True,
shuffle=True,
num_workers=args.num_workers,
drop_last=True,
)
logging.info("Training with %d image pairs" % len(train_dataset))
return train_loader
================================================
FILE: demo.py
================================================
import sys
import argparse
import os
import cv2
import glob
import numpy as np
import torch
import torch.nn.functional as F
from collections import defaultdict
from PIL import Image
from matplotlib import pyplot as plt
from pathlib import Path
DEVICE = 'cuda'
def load_image(imfile):
img = np.array(Image.open(imfile).convert('RGB')).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img.to(DEVICE)
def viz(img, flo):
img = img[0].permute(1, 2, 0).cpu().numpy()
flo = flo[0].permute(1, 2, 0).cpu().numpy()
# map flow to rgb image
flo = flow_viz.flow_to_image(flo)
img_flo = np.concatenate([img, flo], axis=0)
cv2.imshow('image', img_flo[:, :, [2, 1, 0]] / 255.0)
cv2.waitKey()
def demo(args):
from stereoanyvideo.models.stereoanyvideo_model import StereoAnyVideoModel
model = StereoAnyVideoModel()
if args.ckpt is not None:
assert args.ckpt.endswith(".pth") or args.ckpt.endswith(
".pt"
)
strict = True
state_dict = torch.load(args.ckpt)
if "model" in state_dict:
state_dict = state_dict["model"]
if list(state_dict.keys())[0].startswith("module."):
state_dict = {
k.replace("module.", ""): v for k, v in state_dict.items()
}
model.model.load_state_dict(state_dict, strict=strict)
print("Done loading model checkpoint", args.ckpt)
model.to(DEVICE)
model.eval()
output_directory = args.output_path
parent_directory = os.path.dirname(output_directory)
if not os.path.exists(parent_directory):
os.makedirs(parent_directory)
if not os.path.isdir(output_directory):
os.mkdir(output_directory)
with torch.no_grad():
images_left = sorted(glob.glob(os.path.join(args.path, 'left/*.png')) + glob.glob(os.path.join(args.path, 'left/*.jpg')))
images_right = sorted(glob.glob(os.path.join(args.path, 'right/*.png')) + glob.glob(os.path.join(args.path, 'right/*.jpg')))
assert len(images_left) == len(images_right), [len(images_left), len(images_right)]
assert len(images_left) > 0, args.path
print(f"Found {len(images_left)} frames. Saving files to {args.output_path}")
num_frames = len(images_left)
frame_size = args.frame_size
disparities_ori_all = []
for start_idx in range(0, num_frames, frame_size):
end_idx = min(start_idx + frame_size, num_frames)
image_left_list = []
image_right_list = []
for imfile1, imfile2 in zip(images_left[start_idx:end_idx], images_right[start_idx:end_idx]):
image_left = load_image(imfile1)
image_right = load_image(imfile2)
image_left = F.interpolate(image_left[None], size=args.resize, mode="bilinear", align_corners=True)
image_right = F.interpolate(image_right[None], size=args.resize, mode="bilinear", align_corners=True)
image_left_list.append(image_left[0])
image_right_list.append(image_right[0])
video_left = torch.stack(image_left_list, dim=0)
video_right = torch.stack(image_right_list, dim=0)
batch_dict = defaultdict(list)
batch_dict["stereo_video"] = torch.stack([video_left, video_right], dim=1)
predictions = model(batch_dict)
assert "disparity" in predictions
disparities = predictions["disparity"][:, :1].clone().data.cpu().abs().numpy()
disparities_ori = disparities.astype(np.uint8)
disparities_ori_all.extend(disparities_ori)
disparities_ori_all = np.array(disparities_ori_all)
epsilon = 1e-5 # Smallest allowable disparity
disparities_ori_all[disparities_ori_all < epsilon] = epsilon
disparities_all = ((disparities_ori_all - disparities_ori_all.min()) / (disparities_ori_all.max() - disparities_ori_all.min()) * 255).astype(np.uint8)
video_ori_disparity = cv2.VideoWriter(
os.path.join(args.output_path, "disparity.mp4"),
cv2.VideoWriter_fourcc(*"mp4v"),
fps=args.fps,
frameSize=(disparities_all.shape[3], disparities_all.shape[2]),
isColor=True,
)
video_disparity = cv2.VideoWriter(
os.path.join(args.output_path, "disparity_norm.mp4"),
cv2.VideoWriter_fourcc(*"mp4v"),
fps=args.fps,
frameSize=(disparities_all.shape[3], disparities_all.shape[2]),
isColor=True,
)
for i in range(num_frames):
imfile1 = images_left[i]
disparity_norm = disparities_all[i]
disparity_norm = disparity_norm.transpose(1, 2, 0)
disparity_norm_vis = cv2.applyColorMap(disparity_norm, cv2.COLORMAP_INFERNO)
video_disparity.write(disparity_norm_vis)
disparity_ori = disparities_ori_all[i]
disparity_ori = disparity_ori.transpose(1, 2, 0)
disparity_ori_vis = cv2.applyColorMap(disparity_ori, cv2.COLORMAP_INFERNO)
video_ori_disparity.write(disparity_ori_vis)
if args.save_png:
filename_temp = args.output_path + '/disparity_norm_' + str(i).zfill(3) + '.png'
cv2.imwrite(filename_temp, disparity_norm_vis)
filename_temp = args.output_path + '/disparity_ori_' + str(i).zfill(3) + '.png'
cv2.imwrite(filename_temp, disparity_ori_vis)
video_ori_disparity.release()
video_disparity.release()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', default="stereoanyvideo", help="name to specify model")
parser.add_argument('--ckpt', default=None, help="checkpoint of stereo model")
parser.add_argument('--resize', default=(720, 1280), help="image size input to the model")
parser.add_argument("--fps", type=int, default=30, help="frame rate for video visualization")
parser.add_argument('--path', help="dataset for evaluation")
parser.add_argument("--save_png", action="store_true")
parser.add_argument("--frame_size", type=int, default=150, help="number of updates in each forward pass.")
parser.add_argument("--iters",type=int, default=20, help="number of updates in each forward pass.")
parser.add_argument("--kernel_size", type=int, default=20, help="number of frames in each forward pass.")
parser.add_argument('--output_path', help="directory to save output", default="demo_output")
args = parser.parse_args()
demo(args)
================================================
FILE: demo.sh
================================================
#!/bin/bash
export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH
python demo.py --ckpt ./checkpoints/StereoAnyVideo_MIX.pth --path ./demo_video/ --output_path ./demo_output/ --save_png
================================================
FILE: evaluate_stereoanyvideo.sh
================================================
#!/bin/bash
export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH
# evaluate on [sintel, dynamicreplica, infinigensv, vkitti2] using sceneflow checkpoint
python ./evaluation/evaluate.py --config-name eval_sintel_final \
MODEL.model_name=StereoAnyVideoModel \
MODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_SF.pth
python ./evaluation/evaluate.py --config-name eval_dynamic_replica \
MODEL.model_name=StereoAnyVideoModel \
MODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_SF.pth
python ./evaluation/evaluate.py --config-name eval_infinigensv \
MODEL.model_name=StereoAnyVideoModel \
MODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_SF.pth
python ./evaluation/evaluate.py --config-name eval_vkitti2 \
MODEL.model_name=StereoAnyVideoModel \
MODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_SF.pth
# evaluate on [sintel, kittidepth, southkensingtonSV] using mixed checkpoint
python ./evaluation/evaluate.py --config-name eval_sintel_final \
MODEL.model_name=StereoAnyVideoModel \
MODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_MIX.pth
python ./evaluation/evaluate.py --config-name eval_kittidepth \
MODEL.model_name=StereoAnyVideoModel \
MODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_MIX.pth
python ./evaluation/evaluate.py --config-name eval_southkensington \
MODEL.model_name=StereoAnyVideoModel \
MODEL.StereoAnyVideoModel.model_weights=./checkpoints/StereoAnyVideo_SF.pth
================================================
FILE: evaluation/configs/eval_dynamic_replica.yaml
================================================
defaults:
- default_config_eval
visualize_interval: -1
exp_dir: ./outputs/stereoanyvideo_DynamicReplica
sample_len: 150
MODEL:
model_name: StereoAnyVideoModel
================================================
FILE: evaluation/configs/eval_infinigensv.yaml
================================================
defaults:
- default_config_eval
visualize_interval: -1
render_bin_size: 0
exp_dir: ./outputs/stereoanyvideo_InfinigenSV
sample_len: 150
dataset_name: infinigensv
MODEL:
model_name: StereoAnyVideoModel
================================================
FILE: evaluation/configs/eval_kittidepth.yaml
================================================
defaults:
- default_config_eval
visualize_interval: -1
render_bin_size: 0
exp_dir: ./outputs/stereoanyvideo_KITTIDepth
sample_len: 300
dataset_name: kitti_depth
MODEL:
model_name: StereoAnyVideoModel
================================================
FILE: evaluation/configs/eval_sintel_clean.yaml
================================================
defaults:
- default_config_eval
visualize_interval: -1
render_bin_size: 0
exp_dir: ./outputs/stereoanyvideo_sintel_clean
dataset_name: sintel
dstype: clean
MODEL:
model_name: StereoAnyVideoModel
================================================
FILE: evaluation/configs/eval_sintel_final.yaml
================================================
defaults:
- default_config_eval
visualize_interval: -1
render_bin_size: 0
exp_dir: ./outputs/stereoanyvideo_sintel_final
dataset_name: sintel
dstype: final
MODEL:
model_name: StereoAnyVideoModel
================================================
FILE: evaluation/configs/eval_southkensington.yaml
================================================
defaults:
- default_config_eval
visualize_interval: 1
exp_dir: ./outputs/stereoanyvideo_SouthKensingtonIndoor
sample_len: 300
dataset_name: southkensingtonsv
MODEL:
model_name: StereoAnyVideoModel
================================================
FILE: evaluation/configs/eval_vkitti2.yaml
================================================
defaults:
- default_config_eval
visualize_interval: -1
render_bin_size: 0
exp_dir: ./outputs/stereoanyvideo_VKITTI2
sample_len: 300
dataset_name: vkitti2
MODEL:
model_name: StereoAnyVideoModel
================================================
FILE: evaluation/core/evaluator.py
================================================
import os
import numpy as np
import cv2
from collections import defaultdict
import torch.nn.functional as F
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import Configurable
from stereoanyvideo.evaluation.utils.eval_utils import depth2disparity_scale, eval_batch
from stereoanyvideo.evaluation.utils.utils import (
PerceptionPrediction,
pretty_print_perception_metrics,
visualize_batch,
)
def depth_to_colormap(depth, colormap='jet', eps=1e-3, scale_vmin=1.0):
valid = (depth > eps) & (depth < 1e4)
vmin = depth[valid].min() * scale_vmin
vmax = depth[valid].max()
if colormap=='jet':
cmap = plt.cm.jet
else:
cmap = plt.cm.inferno
norm = plt.Normalize(vmin=vmin, vmax=vmax)
depth = cmap(norm(depth))
depth[~valid] = 1
return np.ascontiguousarray(depth[...,:3] * 255, dtype=np.uint8)
class Evaluator(Configurable):
"""
A class defining the DynamicStereo evaluator.
Args:
eps: Threshold for converting disparity to depth.
"""
eps = 1e-5
def setup_visualization(self, cfg: DictConfig) -> None:
# Visualization
self.visualize_interval = cfg.visualize_interval
self.render_bin_size = cfg.render_bin_size
self.exp_dir = cfg.exp_dir
if self.visualize_interval > 0:
self.visualize_dir = os.path.join(cfg.exp_dir, "visualisations")
@torch.no_grad()
def evaluate_sequence(
self,
model,
model_stabilizer,
test_dataloader: torch.utils.data.DataLoader,
is_real_data: bool = False,
step=None,
writer=None,
train_mode=False,
interp_shape=None,
exp_dir=None,
):
model.eval()
per_batch_eval_results = []
if self.visualize_interval > 0:
os.makedirs(self.visualize_dir, exist_ok=True)
for batch_idx, sequence in enumerate(tqdm(test_dataloader)):
batch_dict = defaultdict(list)
batch_dict["stereo_video"] = sequence["img"]
if not is_real_data:
batch_dict["disparity"] = sequence["disp"][:, 0].abs()
batch_dict["disparity_mask"] = sequence["valid_disp"][:, :1]
if "mask" in sequence:
batch_dict["fg_mask"] = sequence["mask"][:, :1]
else:
batch_dict["fg_mask"] = torch.ones_like(
batch_dict["disparity_mask"]
)
elif interp_shape is not None:
left_video = batch_dict["stereo_video"][:, 0]
left_video = F.interpolate(
left_video, tuple(interp_shape), mode="bilinear"
)
right_video = batch_dict["stereo_video"][:, 1]
right_video = F.interpolate(
right_video, tuple(interp_shape), mode="bilinear"
)
batch_dict["stereo_video"] = torch.stack([left_video, right_video], 1)
if model_stabilizer is not None:
predictions = model.forward_stabilizer(batch_dict, model_stabilizer)
elif train_mode:
predictions = model.forward_batch_test(batch_dict)
else:
predictions = model(batch_dict)
assert "disparity" in predictions
predictions["disparity"] = predictions["disparity"][:, :1].clone().cpu()
if not is_real_data:
predictions["disparity"] = predictions["disparity"] * (
batch_dict["disparity_mask"].round()
)
batch_eval_result, seq_length = eval_batch(batch_dict, predictions, sequence["depth2disp_scale"][0])
per_batch_eval_results.append((batch_eval_result, seq_length))
pretty_print_perception_metrics(batch_eval_result)
if (self.visualize_interval > 0) and (
batch_idx % self.visualize_interval == 0
):
perception_prediction = PerceptionPrediction()
pred_disp = predictions["disparity"]
pred_disp[pred_disp < self.eps] = self.eps
scale = sequence["depth2disp_scale"][0]
perception_prediction.depth_map = (scale / pred_disp).cuda()
perspective_cameras = []
if "viewpoint" in sequence:
for cam in sequence["viewpoint"]:
perspective_cameras.append(cam[0])
perception_prediction.perspective_cameras = perspective_cameras
if "stereo_original_video" in batch_dict:
batch_dict["stereo_video"] = batch_dict[
"stereo_original_video"
].clone()
for k, v in batch_dict.items():
if isinstance(v, torch.Tensor):
batch_dict[k] = v.cuda()
visualize_batch(
batch_dict,
perception_prediction,
output_dir=self.visualize_dir,
sequence_name=sequence["metadata"][0][0][0],
step=step,
writer=writer,
render_bin_size=self.render_bin_size,
)
filename = os.path.join(self.visualize_dir, sequence["metadata"][0][0][0])
if not os.path.isdir(filename):
os.mkdir(filename)
disparity_list = pred_disp.data.cpu().numpy()
depth_list = perception_prediction.depth_map.data.cpu().numpy()
np.save(f"{filename}_depth_list.npy", depth_list)
video_disparity = cv2.VideoWriter(
f"{filename}_disparity.mp4",
cv2.VideoWriter_fourcc(*"mp4v"),
fps=30,
frameSize=(
batch_dict["stereo_video"][:, 0][0].shape[2], batch_dict["stereo_video"][:, 0][0].shape[1]),
isColor=True,
)
disparity_vis = depth_to_colormap(disparity_list[:, 0], eps=self.eps, colormap='inferno')
for i in range(disparity_list.shape[0]):
filename_temp = filename + '/disparity_' + str(i).zfill(3) + '.png'
disparity_vis[i] = cv2.cvtColor(disparity_vis[i], cv2.COLOR_RGB2BGR)
cv2.imwrite(filename_temp, disparity_vis[i])
video_disparity.write(disparity_vis[i])
video_disparity.release()
return per_batch_eval_results
================================================
FILE: evaluation/evaluate.py
================================================
import json
import os
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
import hydra
import numpy as np
import torch
from omegaconf import OmegaConf
from stereoanyvideo.evaluation.utils.utils import aggregate_and_print_results
import stereoanyvideo.datasets.video_datasets as datasets
from stereoanyvideo.models.core.model_zoo import (
get_all_model_default_configs,
model_zoo,
)
from pytorch3d.implicitron.tools.config import get_default_args_field
from stereoanyvideo.evaluation.core.evaluator import Evaluator
@dataclass(eq=False)
class DefaultConfig:
exp_dir: str = "./outputs"
stabilizer_ckpt: Optional[str] = None
# one of [sintel, dynamicreplica, things, kitti_depth, infinigensv, southkensingtonsv]
dataset_name: str = "dynamicreplica"
sample_len: int = -1
dstype: Optional[str] = None
# clean, final
MODEL: Dict[str, Any] = field(
default_factory=lambda: get_all_model_default_configs()
)
EVALUATOR: Dict[str, Any] = get_default_args_field(Evaluator)
seed: int = 42
gpu_idx: int = 0
visualize_interval: int = 1 # Use 0 for no visualization
render_bin_size: Optional[int] = None
# Override hydra's working directory to current working dir,
# also disable storing the .hydra logs:
hydra: dict = field(
default_factory=lambda: {
"run": {"dir": "."},
"output_subdir": None,
}
)
def run_eval(cfg: DefaultConfig):
"""
Evaluates new view synthesis metrics of a specified model
on a benchmark dataset.
"""
# make the experiment directory
os.makedirs(cfg.exp_dir, exist_ok=True)
# dump the exp cofig to the exp_dir
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
with open(cfg_file, "w") as f:
OmegaConf.save(config=cfg, f=f)
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
evaluator = Evaluator(**cfg.EVALUATOR)
model = model_zoo(**cfg.MODEL)
model.cuda(0)
evaluator.setup_visualization(cfg)
if cfg.dataset_name == "dynamicreplica":
test_dataloader = datasets.DynamicReplicaDataset(
split="test", sample_len=cfg.sample_len, only_first_n_samples=1
)
elif cfg.dataset_name == "infinigensv":
test_dataloader = datasets.InfinigenStereoVideoDataset(
split="test", sample_len=cfg.sample_len, only_first_n_samples=1
)
elif cfg.dataset_name == "southkensingtonsv":
test_dataloader = datasets.SouthKensingtonStereoVideoDataset(
sample_len=cfg.sample_len, only_first_n_samples=1
)
evaluator.evaluate_sequence(
model,
None,
test_dataloader,
is_real_data=True,
exp_dir=cfg.exp_dir
)
return
elif cfg.dataset_name == "kitti_depth":
test_dataloader = datasets.KITTIDepthDataset(
split="test", sample_len=cfg.sample_len, only_first_n_samples=1
)
elif cfg.dataset_name == "vkitti2":
test_dataloader = datasets.VKITTI2Dataset(
split="test", sample_len=cfg.sample_len, only_first_n_samples=1
)
elif cfg.dataset_name == "sintel":
test_dataloader = datasets.SequenceSintelStereo(dstype=cfg.dstype)
elif cfg.dataset_name == "things":
test_dataloader = datasets.SequenceSceneFlowDatasets(
{},
dstype=cfg.dstype,
sample_len=cfg.sample_len,
add_monkaa=False,
add_driving=False,
things_test=True,
)
evaluate_result = evaluator.evaluate_sequence(
model,
None,
test_dataloader,
is_real_data=False,
exp_dir=cfg.exp_dir
)
aggreegate_result = aggregate_and_print_results(evaluate_result)
result_file = os.path.join(cfg.exp_dir, f"result_eval.json")
print(f"Dumping eval results to {result_file}.")
with open(result_file, "w") as f:
json.dump(aggreegate_result, f)
cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config_eval", node=DefaultConfig)
@hydra.main(config_path="./configs/", config_name="default_config_eval")
def evaluate(cfg: DefaultConfig) -> None:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
run_eval(cfg)
if __name__ == "__main__":
evaluate()
================================================
FILE: evaluation/utils/eval_utils.py
================================================
from dataclasses import dataclass
from typing import Dict, Optional, Union
from stereoanyvideo.evaluation.utils.ssim import SSIM
import torch
import torch.nn.functional as F
import numpy as np
import math
import cv2
from pytorch3d.utils import opencv_from_cameras_projection
from stereoanyvideo.models.raft_model import RAFTModel
@dataclass(eq=True, frozen=True)
class PerceptionMetric:
metric: str
depth_scaling_norm: Optional[str] = None
suffix: str = ""
index: str = ""
def __str__(self):
return (
self.metric
+ self.index
+ (
("_norm_" + self.depth_scaling_norm)
if self.depth_scaling_norm is not None
else ""
)
+ self.suffix
)
def compute_flow(seq, is_seq=True):
raft = RAFTModel().cuda()
raft.eval()
if is_seq:
t, c, h, w = seq.size()
flows_forward = []
for i in range(t-1):
flow_forward = raft.forward_fullres(seq[i][None], seq[i+1][None], iters=20)
flows_forward.append(flow_forward)
flows_forward = torch.cat(flows_forward, dim=0)
return flows_forward
else:
img1, img2 = seq
flow_forward = raft.forward_fullres(img1, img2, iters=20)
return flow_forward
def flow_warp(x, flow):
if flow.size(3) != 2: # [B, H, W, 2]
flow = flow.permute(0, 2, 3, 1)
if x.size()[-2:] != flow.size()[1:3]:
raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
f'flow ({flow.size()[1:3]}) are not the same.')
_, _, h, w = x.size()
# create mesh grid
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (h, w, 2)
grid.requires_grad = False
grid_flow = grid + flow
# scale grid_flow to [-1,1]
grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
output = F.grid_sample(
x,
grid_flow,
mode='bilinear',
padding_mode='zeros',
align_corners=True)
return output
def eval_endpoint_error_sequence(
x: torch.Tensor,
y: torch.Tensor,
mask: torch.Tensor,
crop: int = 0,
mask_thr: float = 0.5,
clamp_thr: float = 1e-5,
) -> Dict[str, torch.Tensor]:
assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (
x.shape,
y.shape,
mask.shape,
)
assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)
# chuck out the border
if crop > 0:
if crop > min(y.shape[2:]) - crop:
raise ValueError("Incorrect crop size.")
y = y[:, :, crop:-crop, crop:-crop]
x = x[:, :, crop:-crop, crop:-crop]
mask = mask[:, :, crop:-crop, crop:-crop]
y = y * (mask > mask_thr).float()
x = x * (mask > mask_thr).float()
y[torch.isnan(y)] = 0
results = {}
for epe_name in ("epe", "temp_epe"):
if epe_name == "epe":
endpoint_error = (mask * (x - y) ** 2).sum(dim=1).sqrt()
elif epe_name == "temp_epe":
delta_mask = mask[:-1] * mask[1:]
# endpoint_error = (
# (delta_mask * ((x[:-1] - x[1:]) - (y[:-1] - y[1:])) ** 2)
# .sum(dim=1)
# .sqrt()
# )
endpoint_error = (
(delta_mask * ((x[:-1] - x[1:]).abs() - (y[:-1] - y[1:]).abs()) ** 2)
.sum(dim=1)
.sqrt()
)
# epe_nonzero = endpoint_error != 0
nonzero = torch.count_nonzero(endpoint_error)
epe_mean = endpoint_error.sum() / torch.clamp(
nonzero, clamp_thr
) # average error for all the sequence pixels
epe_inv_accuracy_05px = (endpoint_error > 0.5).sum() / torch.clamp(
nonzero, clamp_thr
)
epe_inv_accuracy_1px = (endpoint_error > 1).sum() / torch.clamp(
nonzero, clamp_thr
)
epe_inv_accuracy_2px = (endpoint_error > 2).sum() / torch.clamp(
nonzero, clamp_thr
)
epe_inv_accuracy_3px = (endpoint_error > 3).sum() / torch.clamp(
nonzero, clamp_thr
)
results[f"{epe_name}_mean"] = epe_mean[None]
results[f"{epe_name}_bad_0.5px"] = epe_inv_accuracy_05px[None] * 100
results[f"{epe_name}_bad_1px"] = epe_inv_accuracy_1px[None] * 100
results[f"{epe_name}_bad_2px"] = epe_inv_accuracy_2px[None] * 100
results[f"{epe_name}_bad_3px"] = epe_inv_accuracy_3px[None] * 100
return results
def eval_TCC_sequence(
x: torch.Tensor,
y: torch.Tensor,
mask: torch.Tensor,
crop: int = 0,
mask_thr: float = 0.5,
) -> Dict[str, torch.Tensor]:
assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (
x.shape,
y.shape,
mask.shape,
)
assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)
t, c, h, w = x.shape
# chuck out the border
if crop > 0:
if crop > min(y.shape[2:]) - crop:
raise ValueError("Incorrect crop size.")
y = y[:, :, crop:-crop, crop:-crop]
x = x[:, :, crop:-crop, crop:-crop]
mask = mask[:, :, crop:-crop, crop:-crop]
y = y * (mask > mask_thr).float()
x = x * (mask > mask_thr).float()
x[torch.isnan(x)] = 0
y[torch.isnan(y)] = 0
ssim_loss = SSIM(1.0, nonnegative_ssim=True)
delta_mask = mask[:-1] * mask[1:]
tcc = 0
for i in range(t-1):
tcc += ssim_loss((torch.abs(x[i][None] - x[i+1][None]) * delta_mask[i]).expand(-1, 3, -1, -1),
(torch.abs(y[i][None] - y[i+1][None]) * delta_mask[i]).expand(-1, 3, -1, -1))
tcc = tcc / (t-1)
return tcc
def eval_TCM_sequence(
x: torch.Tensor,
y: torch.Tensor,
mask: torch.Tensor,
crop: int = 0,
mask_thr: float = 0.5,
) -> Dict[str, torch.Tensor]:
assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (
x.shape,
y.shape,
mask.shape,
)
assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)
t, c, h, w = x.shape
# chuck out the border
if crop > 0:
if crop > min(y.shape[2:]) - crop:
raise ValueError("Incorrect crop size.")
y = y[:, :, crop:-crop, crop:-crop]
x = x[:, :, crop:-crop, crop:-crop]
mask = mask[:, :, crop:-crop, crop:-crop]
y = y * (mask > mask_thr).float()
x = x * (mask > mask_thr).float()
y[torch.isnan(y)] = 0
ssim_loss = SSIM(1.0, nonnegative_ssim=True, size_average=False)
delta_mask = mask[:-1] * mask[1:]
tcm = 0
for i in range(t-1):
dmax = torch.max(y[i][None].view(1, -1), -1)[0].view(1, 1, 1, 1).expand(-1, 3, -1, -1)
dmin = torch.min(y[i][None].view(1, -1), -1)[0].view(1, 1, 1, 1).expand(-1, 3, -1, -1)
x_norm = (x[i][None].expand(-1, 3, -1, -1) - dmin) / (dmax - dmin) * 255.
x2_norm = (x[i+1][None].expand(-1, 3, -1, -1) - dmin) / (dmax - dmin) * 255.
x_flow = compute_flow([x_norm.cuda(), x2_norm.cuda()], is_seq=False).cpu()
y_norm = (y[i][None].expand(-1, 3, -1, -1) - dmin) / (dmax - dmin) * 255.
y2_norm = (y[i+1][None].expand(-1, 3, -1, -1) - dmin) / (dmax - dmin) * 255.
y_flow = compute_flow([y_norm.cuda(), y2_norm.cuda()], is_seq=False).cpu()
flow_mask = torch.sum(y_flow > 250, 1, keepdim=True) == 0
mask = delta_mask[i][None] * flow_mask
mask = mask.expand(-1, 3, -1, -1)
if torch.sum(mask) > 0:
tcm += torch.mean(ssim_loss(
torch.cat((x_flow, torch.ones_like(x_flow[:, 0, None, ...])), 1) * mask,
torch.cat((y_flow, torch.ones_like(x_flow[:, 0, None, ...])), 1) * mask)[:, :2])
tcm = tcm / (t-1)
return tcm
def eval_OPW_sequence(
img: torch.Tensor,
x: torch.Tensor,
y: torch.Tensor,
mask: torch.Tensor,
crop: int = 0,
mask_thr: float = 0.5,
clamp_thr: float = 1e-5,
) -> Dict[str, torch.Tensor]:
assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (
x.shape,
y.shape,
mask.shape,
) # T, 1, H, W
assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)
t, c, h, w = img[:, 0].shape
# chuck out the border
if crop > 0:
if crop > min(y.shape[2:]) - crop:
raise ValueError("Incorrect crop size.")
y = y[:, :, crop:-crop, crop:-crop]
x = x[:, :, crop:-crop, crop:-crop]
mask = mask[:, :, crop:-crop, crop:-crop]
y = y * (mask > mask_thr).float()
x = x * (mask > mask_thr).float()
y[torch.isnan(y)] = 0
delta_mask = mask[:-1] * mask[1:]
depth_mask_30 = torch.sum(y > 30, 1, keepdim=True) == 0
depth_mask_30 = depth_mask_30[:-1] * depth_mask_30[1:]
depth_mask_50 = torch.sum(y > 50, 1, keepdim=True) == 0
depth_mask_50 = depth_mask_50[:-1] * depth_mask_50[1:]
depth_mask_100 = torch.sum(y > 100, 1, keepdim=True) == 0
depth_mask_100 = depth_mask_100[:-1] * depth_mask_100[1:]
flow = compute_flow(img[:, 0].cuda()).cpu()
warped_disp = flow_warp(x[1:], flow)
warped_img = flow_warp(img[:, 0][1:].float(), flow)
flow_mask = torch.sum(flow > 250, 1, keepdim=True) == 0
delta_mask = delta_mask * torch.exp(-50. * torch.sqrt(
((warped_img / 255. - img[:, 0][:-1].float() / 255.) ** 2).sum(dim=1, keepdim=True))) * flow_mask * (
warped_disp > 0) > 1e-2
opw_err = torch.abs(warped_disp - x[:-1]) * delta_mask
opw_err_30 = torch.abs(warped_disp - x[:-1]) * delta_mask * depth_mask_30
opw_err_50 = torch.abs(warped_disp - x[:-1]) * delta_mask * depth_mask_50
opw_err_100 = torch.abs(warped_disp - x[:-1]) * delta_mask * depth_mask_100
opw = 0
opw_30 = 0
opw_50 = 0
opw_100 = 0
for i in range(t-1):
if torch.sum(delta_mask[i]) > 0:
opw += torch.sum(opw_err[i]) / torch.sum(delta_mask[i])
if torch.sum(delta_mask[i] * depth_mask_30[i]) > 0:
opw_30 += torch.sum(opw_err_30[i]) / torch.sum(delta_mask[i] * depth_mask_30[i])
if torch.sum(delta_mask[i] * depth_mask_50[i]) > 0:
opw_50 += torch.sum(opw_err_50[i]) / torch.sum(delta_mask[i] * depth_mask_50[i])
if torch.sum(delta_mask[i] * depth_mask_100[i]) > 0:
opw_100 += torch.sum(opw_err_100[i]) / torch.sum(delta_mask[i] * depth_mask_100[i])
opw = opw / (t - 1)
opw_30 = opw_30 / (t - 1)
opw_50 = opw_50 / (t - 1)
opw_100 = opw_100 / (t - 1)
return opw, opw_30, opw_50, opw_100
def eval_RTC_sequence(
img: torch.Tensor,
x: torch.Tensor,
y: torch.Tensor,
mask: torch.Tensor,
crop: int = 0,
mask_thr: float = 0.5,
clamp_thr: float = 1e-5,
) -> Dict[str, torch.Tensor]:
assert len(x.shape) == len(y.shape) == len(mask.shape) == 4, (
x.shape,
y.shape,
mask.shape,
) # T, 1, H, W
assert x.shape[0] == y.shape[0] == mask.shape[0], (x.shape, y.shape, mask.shape)
t, c, h, w = img[:, 0].shape
# chuck out the border
if crop > 0:
if crop > min(y.shape[2:]) - crop:
raise ValueError("Incorrect crop size.")
y = y[:, :, crop:-crop, crop:-crop]
x = x[:, :, crop:-crop, crop:-crop]
mask = mask[:, :, crop:-crop, crop:-crop]
y = y * (mask > mask_thr).float()
x = x * (mask > mask_thr).float()
y[torch.isnan(y)] = 0
flow = compute_flow(img[:, 0].cuda()).cpu()
delta_mask = mask[:-1] * mask[1:]
warped_disp = flow_warp(x[1:], flow)
warped_img = flow_warp(img[:, 0][1:], flow)
flow_mask = torch.sum(flow > 250, 1, keepdim=True) == 0
depth_mask = torch.sum(y > 30, 1, keepdim=True) == 0
depth_mask = depth_mask[:-1] * depth_mask[1:]
delta_mask = delta_mask * torch.exp(-50. * torch.sqrt(
((warped_img / 255. - img[:, 0][:-1] / 255.) ** 2).sum(dim=1, keepdim=True))) * flow_mask * (
warped_disp > 0) > 1e-2
tau = 1.01
x1 = x[:-1] / warped_disp
x2 = warped_disp / x[:-1]
x1[torch.isinf(x1)] = -1e10
x2[torch.isinf(x2)] = -1e10
x = torch.max(torch.cat((x1, x2), 1), 1)[0] < tau
rtc_err = x[:, None] * delta_mask
rtc_err_30 = x[:, None] * delta_mask * depth_mask
rtc = 0
rtc_30 = 0
for i in range(t-1):
if torch.sum(delta_mask[i]) > 0:
rtc += torch.sum(rtc_err[i]) / torch.sum(delta_mask[i])
if torch.sum(delta_mask[i] * depth_mask[i]) > 0:
rtc_30 += torch.sum(rtc_err_30[i]) / torch.sum(delta_mask[i] * depth_mask[i])
rtc = rtc / (t-1)
rtc_30 = rtc_30 / (t - 1)
return rtc, rtc_30
def depth2disparity_scale(left_camera, right_camera, image_size_tensor):
# # opencv camera matrices
(_, T1, K1), (_, T2, _) = [
opencv_from_cameras_projection(
f,
image_size_tensor,
)
for f in (left_camera, right_camera)
]
fix_baseline = T1[0][0] - T2[0][0]
focal_length_px = K1[0][0][0]
# following this https://github.com/princeton-vl/RAFT-Stereo#converting-disparity-to-depth
return focal_length_px * fix_baseline
def depth_to_pcd(
depth_map,
img,
focal_length,
cx,
cy,
step: int = None,
inv_extrinsic=None,
mask=None,
filter=False,
):
__, w, __ = img.shape
if step is None:
step = int(w / 100)
Z = depth_map[::step, ::step]
colors = img[::step, ::step, :]
Pixels_Y = torch.arange(Z.shape[0]).to(Z.device) * step
Pixels_X = torch.arange(Z.shape[1]).to(Z.device) * step
X = (Pixels_X[None, :] - cx) * Z / focal_length
Y = (Pixels_Y[:, None] - cy) * Z / focal_length
inds = Z > 0
if mask is not None:
inds = inds * (mask[::step, ::step] > 0)
X = X[inds].reshape(-1)
Y = Y[inds].reshape(-1)
Z = Z[inds].reshape(-1)
colors = colors[inds]
pcd = torch.stack([X, Y, Z]).T
if inv_extrinsic is not None:
pcd_ext = torch.vstack([pcd.T, torch.ones((1, pcd.shape[0])).to(Z.device)])
pcd = (inv_extrinsic @ pcd_ext)[:3, :].T
if filter:
pcd, filt_inds = filter_outliers(pcd)
colors = colors[filt_inds]
return pcd, colors
def filter_outliers(pcd, sigma=3):
mean = pcd.mean(0)
std = pcd.std(0)
inds = ((pcd - mean).abs() < sigma * std)[:, 2]
pcd = pcd[inds]
return pcd, inds
def eval_batch(batch_dict, predictions, scale) -> Dict[str, Union[float, torch.Tensor]]:
"""
Produce performance metrics for a single batch of perception
predictions.
Args:
frame_data: A PixarFrameData object containing the input to the new view
synthesis method.
preds: A PerceptionPrediction object with the predicted data.
Returns:
results: A dictionary holding evaluation metrics.
"""
results = {}
assert "disparity" in predictions
mask_now = torch.ones_like(batch_dict["fg_mask"])
mask_now = mask_now * batch_dict["disparity_mask"]
eval_flow_traj_output = eval_endpoint_error_sequence(
predictions["disparity"], batch_dict["disparity"], mask_now
)
for epe_name in ("epe", "temp_epe"):
results[PerceptionMetric(f"disp_{epe_name}_mean")] = eval_flow_traj_output[
f"{epe_name}_mean"
]
results[PerceptionMetric(f"disp_{epe_name}_bad_3px")] = eval_flow_traj_output[
f"{epe_name}_bad_3px"
]
results[PerceptionMetric(f"disp_{epe_name}_bad_2px")] = eval_flow_traj_output[
f"{epe_name}_bad_2px"
]
results[PerceptionMetric(f"disp_{epe_name}_bad_1px")] = eval_flow_traj_output[
f"{epe_name}_bad_1px"
]
results[PerceptionMetric(f"disp_{epe_name}_bad_0.5px")] = eval_flow_traj_output[
f"{epe_name}_bad_0.5px"
]
if "endpoint_error_per_pixel" in eval_flow_traj_output:
results["disp_endpoint_error_per_pixel"] = eval_flow_traj_output[
"endpoint_error_per_pixel"
]
# disparity to depth
depth = scale / predictions["disparity"].clamp(min=1e-10)
eval_TCC_output = eval_TCC_sequence(
depth, scale / batch_dict["disparity"].clamp(min=1e-10), mask_now
)
results[PerceptionMetric("disp_TCC")] = eval_TCC_output[None]
eval_TCM_output = eval_TCM_sequence(
depth, scale / batch_dict["disparity"].clamp(min=1e-10), mask_now
)
results[PerceptionMetric("disp_TCM")] = eval_TCM_output[None]
eval_OPW_output, eval_OPW_30_output, eval_OPW_50_output, eval_OPW_100_output = eval_OPW_sequence(
batch_dict["stereo_video"], depth, scale / batch_dict["disparity"].clamp(min=1e-10), mask_now
)
results[PerceptionMetric("disp_OPW")] = eval_OPW_output[None]
results[PerceptionMetric("disp_OPW_100")] = eval_OPW_100_output[None]
results[PerceptionMetric("disp_OPW_50")] = eval_OPW_50_output[None]
if eval_OPW_30_output > 0:
results[PerceptionMetric("disp_OPW_30")] = eval_OPW_30_output[None]
else:
results[PerceptionMetric("disp_OPW_30")] = torch.tensor([0.0])
eval_RTC_output, eval_RTC_30_output = eval_RTC_sequence(
batch_dict["stereo_video"], depth, scale / batch_dict["disparity"].clamp(min=1e-10), mask_now
)
results[PerceptionMetric("disp_RTC")] = eval_RTC_output[None]
if eval_RTC_30_output > 0:
results[PerceptionMetric("disp_RTC_30")] = eval_RTC_30_output[None]
else:
results[PerceptionMetric("disp_RTC_30")] = torch.tensor([0.0])
return (results, len(predictions["disparity"]))
================================================
FILE: evaluation/utils/ssim.py
================================================
# Copyright 2020 by Gongfan Fang, Zhejiang University.
# All rights reserved.
import warnings
import torch
import torch.nn.functional as F
def _fspecial_gauss_1d(size, sigma):
r"""Create 1-D gauss kernel
Args:
size (int): the size of gauss kernel
sigma (float): sigma of normal distribution
Returns:
torch.Tensor: 1D kernel (1 x 1 x size)
"""
coords = torch.arange(size, dtype=torch.float)
coords -= size // 2
g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
g /= g.sum()
return g.unsqueeze(0).unsqueeze(0)
def gaussian_filter(input, win):
r""" Blur input with 1-D kernel
Args:
input (torch.Tensor): a batch of tensors to be blurred
window (torch.Tensor): 1-D gauss kernel
Returns:
torch.Tensor: blurred tensors
"""
assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape
if len(input.shape) == 4:
conv = F.conv2d
elif len(input.shape) == 5:
conv = F.conv3d
else:
raise NotImplementedError(input.shape)
C = input.shape[1]
out = input
for i, s in enumerate(input.shape[2:]):
if s >= win.shape[-1]:
out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C)
else:
warnings.warn(
f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}"
)
return out
def _ssim(X, Y, data_range, win, size_average=True, K=(0.01, 0.03)):
r""" Calculate ssim index for X and Y
Args:
X (torch.Tensor): images
Y (torch.Tensor): images
win (torch.Tensor): 1-D gauss kernel
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
Returns:
torch.Tensor: ssim results.
"""
K1, K2 = K
# batch, channel, [depth,] height, width = X.shape
compensation = 1.0
C1 = (K1 * data_range) ** 2
C2 = (K2 * data_range) ** 2
win = win.to(X.device, dtype=X.dtype)
mu1 = gaussian_filter(X, win)
mu2 = gaussian_filter(Y, win)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq)
sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq)
sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2)
cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1
ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1)
cs = torch.flatten(cs_map, 2).mean(-1)
return ssim_per_channel, cs
def ssim(
X,
Y,
data_range=255,
size_average=True,
win_size=11,
win_sigma=1.5,
win=None,
K=(0.01, 0.03),
nonnegative_ssim=False,
):
r""" interface of ssim
Args:
X (torch.Tensor): a batch of images, (N,C,H,W)
Y (torch.Tensor): a batch of images, (N,C,H,W)
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu
Returns:
torch.Tensor: ssim results
"""
if not X.shape == Y.shape:
raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.")
for d in range(len(X.shape) - 1, 1, -1):
X = X.squeeze(dim=d)
Y = Y.squeeze(dim=d)
if len(X.shape) not in (4, 5):
raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")
if not X.type() == Y.type():
raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")
if win is not None: # set win_size
win_size = win.shape[-1]
if not (win_size % 2 == 1):
raise ValueError("Window size should be odd.")
if win is None:
win = _fspecial_gauss_1d(win_size, win_sigma)
win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K)
if nonnegative_ssim:
ssim_per_channel = torch.relu(ssim_per_channel)
if size_average:
return ssim_per_channel.mean()
else:
return ssim_per_channel #.mean(1)
def ms_ssim(
X, Y, data_range=255, size_average=True, win_size=11, win_sigma=1.5, win=None, weights=None, K=(0.01, 0.03)
):
r""" interface of ms-ssim
Args:
X (torch.Tensor): a batch of images, (N,C,[T,]H,W)
Y (torch.Tensor): a batch of images, (N,C,[T,]H,W)
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
weights (list, optional): weights for different levels
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
Returns:
torch.Tensor: ms-ssim results
"""
if not X.shape == Y.shape:
raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.")
for d in range(len(X.shape) - 1, 1, -1):
X = X.squeeze(dim=d)
Y = Y.squeeze(dim=d)
if not X.type() == Y.type():
raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")
if len(X.shape) == 4:
avg_pool = F.avg_pool2d
elif len(X.shape) == 5:
avg_pool = F.avg_pool3d
else:
raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")
if win is not None: # set win_size
win_size = win.shape[-1]
if not (win_size % 2 == 1):
raise ValueError("Window size should be odd.")
smaller_side = min(X.shape[-2:])
assert smaller_side > (win_size - 1) * (
2 ** 4
), "Image size should be larger than %d due to the 4 downsamplings in ms-ssim" % ((win_size - 1) * (2 ** 4))
if weights is None:
weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
weights = X.new_tensor(weights)
if win is None:
win = _fspecial_gauss_1d(win_size, win_sigma)
win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
levels = weights.shape[0]
mcs = []
for i in range(levels):
ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K)
if i < levels - 1:
mcs.append(torch.relu(cs))
padding = [s % 2 for s in X.shape[2:]]
X = avg_pool(X, kernel_size=2, padding=padding)
Y = avg_pool(Y, kernel_size=2, padding=padding)
ssim_per_channel = torch.relu(ssim_per_channel) # (batch, channel)
mcs_and_ssim = torch.stack(mcs + [ssim_per_channel], dim=0) # (level, batch, channel)
ms_ssim_val = torch.prod(mcs_and_ssim ** weights.view(-1, 1, 1), dim=0)
if size_average:
return ms_ssim_val.mean()
else:
return ms_ssim_val.mean(1)
class SSIM(torch.nn.Module):
def __init__(
self,
data_range=255,
size_average=True,
win_size=11,
win_sigma=1.5,
channel=3,
spatial_dims=2,
K=(0.01, 0.03),
nonnegative_ssim=False,
):
r""" class for ssim
Args:
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
channel (int, optional): input channels (default: 3)
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu.
"""
super(SSIM, self).__init__()
self.win_size = win_size
self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)
self.size_average = size_average
self.data_range = data_range
self.K = K
self.nonnegative_ssim = nonnegative_ssim
def forward(self, X, Y):
return ssim(
X,
Y,
data_range=self.data_range,
size_average=self.size_average,
win=self.win,
K=self.K,
nonnegative_ssim=self.nonnegative_ssim,
)
class MS_SSIM(torch.nn.Module):
def __init__(
self,
data_range=255,
size_average=True,
win_size=11,
win_sigma=1.5,
channel=3,
spatial_dims=2,
weights=None,
K=(0.01, 0.03),
):
r""" class for ms-ssim
Args:
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
channel (int, optional): input channels (default: 3)
weights (list, optional): weights for different levels
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
"""
super(MS_SSIM, self).__init__()
self.win_size = win_size
self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)
self.size_average = size_average
self.data_range = data_range
self.weights = weights
self.K = K
def forward(self, X, Y):
return ms_ssim(
X,
Y,
data_range=self.data_range,
size_average=self.size_average,
win=self.win,
weights=self.weights,
K=self.K,
)
================================================
FILE: evaluation/utils/utils.py
================================================
from collections import defaultdict
import configparser
import os
import math
from typing import Optional, List
import torch
import cv2
import numpy as np
from dataclasses import dataclass
from tabulate import tabulate
import logging
from pytorch3d.structures import Pointclouds
from pytorch3d.transforms import RotateAxisAngle
from pytorch3d.utils import (
opencv_from_cameras_projection,
)
from pytorch3d.renderer import (
AlphaCompositor,
PointsRasterizationSettings,
PointsRasterizer,
PointsRenderer,
)
from stereoanyvideo.evaluation.utils.eval_utils import depth_to_pcd
@dataclass
class PerceptionPrediction:
"""
Holds the tensors that describe a result of any perception module.
"""
depth_map: Optional[torch.Tensor] = None
disparity: Optional[torch.Tensor] = None
image_rgb: Optional[torch.Tensor] = None
fg_probability: Optional[torch.Tensor] = None
def aggregate_eval_results(per_batch_eval_results, reduction="mean"):
total_length = 0
aggregate_results = defaultdict(list)
for result in per_batch_eval_results:
if isinstance(result, tuple):
reduction = "sum"
length = result[1]
total_length += length
result = result[0]
for metric, val in result.items():
if reduction == "sum":
aggregate_results[metric].append(val * length)
if reduction == "mean":
return {k: torch.cat(v).mean().item() for k, v in aggregate_results.items()}
elif reduction == "sum":
return {
k: torch.cat(v).sum().item() / float(total_length)
for k, v in aggregate_results.items()
}
def aggregate_and_print_results(
per_batch_eval_results: List[dict],
):
print("")
result = aggregate_eval_results(
per_batch_eval_results,
)
pretty_print_perception_metrics(result)
result = {str(k): v for k, v in result.items()}
print("")
return result
def pretty_print_perception_metrics(results):
metrics = sorted(list(results.keys()), key=lambda x: x.metric)
print("===== Perception results =====")
print(
tabulate(
[[metric, results[metric]] for metric in metrics],
)
)
logging.info("===== Perception results =====")
logging.info(tabulate(
[[metric, results[metric]] for metric in metrics],
))
def read_calibration(calibration_file, resolution_string):
# ported from https://github.com/stereolabs/zed-open-capture/
# blob/dfa0aee51ccd2297782230a05ca59e697df496b2/examples/include/calibration.hpp#L4172
zed_resolutions = {
"2K": (1242, 2208),
"FHD": (1080, 1920),
"HD": (720, 1280),
# "qHD": (540, 960),
"VGA": (376, 672),
}
assert resolution_string in zed_resolutions.keys()
image_height, image_width = zed_resolutions[resolution_string]
# Open camera configuration file
assert os.path.isfile(calibration_file)
calib = configparser.ConfigParser()
calib.read(calibration_file)
# Get translations
T = np.zeros((3, 1))
T[0] = float(calib["STEREO"]["baseline"])
T[1] = float(calib["STEREO"]["ty"])
T[2] = float(calib["STEREO"]["tz"])
baseline = T[0]
# Get left parameters
left_cam_cx = float(calib[f"LEFT_CAM_{resolution_string}"]["cx"])
left_cam_cy = float(calib[f"LEFT_CAM_{resolution_string}"]["cy"])
left_cam_fx = float(calib[f"LEFT_CAM_{resolution_string}"]["fx"])
left_cam_fy = float(calib[f"LEFT_CAM_{resolution_string}"]["fy"])
left_cam_k1 = float(calib[f"LEFT_CAM_{resolution_string}"]["k1"])
left_cam_k2 = float(calib[f"LEFT_CAM_{resolution_string}"]["k2"])
left_cam_p1 = float(calib[f"LEFT_CAM_{resolution_string}"]["p1"])
left_cam_p2 = float(calib[f"LEFT_CAM_{resolution_string}"]["p2"])
left_cam_k3 = float(calib[f"LEFT_CAM_{resolution_string}"]["k3"])
# Get right parameters
right_cam_cx = float(calib[f"RIGHT_CAM_{resolution_string}"]["cx"])
right_cam_cy = float(calib[f"RIGHT_CAM_{resolution_string}"]["cy"])
right_cam_fx = float(calib[f"RIGHT_CAM_{resolution_string}"]["fx"])
right_cam_fy = float(calib[f"RIGHT_CAM_{resolution_string}"]["fy"])
right_cam_k1 = float(calib[f"RIGHT_CAM_{resolution_string}"]["k1"])
right_cam_k2 = float(calib[f"RIGHT_CAM_{resolution_string}"]["k2"])
right_cam_p1 = float(calib[f"RIGHT_CAM_{resolution_string}"]["p1"])
right_cam_p2 = float(calib[f"RIGHT_CAM_{resolution_string}"]["p2"])
right_cam_k3 = float(calib[f"RIGHT_CAM_{resolution_string}"]["k3"])
# Get rotations
R_zed = np.zeros(3)
R_zed[0] = float(calib["STEREO"][f"rx_{resolution_string.lower()}"])
R_zed[1] = float(calib["STEREO"][f"cv_{resolution_string.lower()}"])
R_zed[2] = float(calib["STEREO"][f"rz_{resolution_string.lower()}"])
R = cv2.Rodrigues(R_zed)[0]
# Left
cameraMatrix_left = np.array(
[[left_cam_fx, 0, left_cam_cx], [0, left_cam_fy, left_cam_cy], [0, 0, 1]]
)
distCoeffs_left = np.array(
[left_cam_k1, left_cam_k2, left_cam_p1, left_cam_p2, left_cam_k3]
)
# Right
cameraMatrix_right = np.array(
[
[right_cam_fx, 0, right_cam_cx],
[0, right_cam_fy, right_cam_cy],
[0, 0, 1],
]
)
distCoeffs_right = np.array(
[right_cam_k1, right_cam_k2, right_cam_p1, right_cam_p2, right_cam_k3]
)
# Stereo
R1, R2, P1, P2, Q = cv2.stereoRectify(
cameraMatrix1=cameraMatrix_left,
distCoeffs1=distCoeffs_left,
cameraMatrix2=cameraMatrix_right,
distCoeffs2=distCoeffs_right,
imageSize=(image_width, image_height),
R=R,
T=T,
flags=cv2.CALIB_ZERO_DISPARITY,
newImageSize=(image_width, image_height),
alpha=0,
)[:5]
# Precompute maps for cv::remap()
map_left_x, map_left_y = cv2.initUndistortRectifyMap(
cameraMatrix_left,
distCoeffs_left,
R1,
P1,
(image_width, image_height),
cv2.CV_32FC1,
)
map_right_x, map_right_y = cv2.initUndistortRectifyMap(
cameraMatrix_right,
distCoeffs_right,
R2,
P2,
(image_width, image_height),
cv2.CV_32FC1,
)
zed_calib = {
"map_left_x": map_left_x,
"map_left_y": map_left_y,
"map_right_x": map_right_x,
"map_right_y": map_right_y,
"pose_left": P1,
"pose_right": P2,
"baseline": baseline,
"image_width": image_width,
"image_height": image_height,
}
return zed_calib
def filter_depth_discontinuities(pcd, depth_map, threshold=5):
"""
Removes points that belong to high-depth discontinuity regions.
Args:
pcd (torch.Tensor): Nx3 point cloud tensor.
depth_map (torch.Tensor): HxW depth map.
threshold (float): Depth change threshold.
Returns:
torch.Tensor: Filtered point cloud.
"""
# Compute depth differences in x and y directions
depth_diff_x = torch.abs(depth_map[:, 1:] - depth_map[:, :-1]) # Shape (H, W-1)
depth_diff_y = torch.abs(depth_map[1:, :] - depth_map[:-1, :]) # Shape (H-1, W)
# Initialize mask with all True (valid points)
mask = torch.ones_like(depth_map, dtype=torch.bool) # Shape (H, W)
# Apply filtering: set False where depth difference is too large
mask[:, :-1] &= depth_diff_x <= threshold # X-direction filtering
mask[:-1, :] &= depth_diff_y <= threshold # Y-direction filtering
# Flatten mask to match point cloud size
mask_flat = mask.flatten()[: pcd.shape[0]]
return pcd[mask_flat] # Return only valid points
def visualize_batch(
batch_dict: dict,
preds: PerceptionPrediction,
output_dir: str,
ref_frame: int = 0,
only_foreground=False,
step=0,
sequence_name=None,
writer=None,
render_bin_size=None
):
os.makedirs(output_dir, exist_ok=True)
outputs = {}
if preds.depth_map is not None:
device = preds.depth_map.device
pcd_global_seq = []
H, W = batch_dict["stereo_video"].shape[3:]
for i in range(len(batch_dict["stereo_video"])):
if hasattr(preds, 'perspective_cameras'):
R, T, K = opencv_from_cameras_projection(
preds.perspective_cameras[i],
torch.tensor([H, W])[None].to(device),
) # 1x3x3, 1x3, 1x3x3
else:
raise KeyError(f"R T K not found!")
extrinsic_3x4_0 = torch.cat([R[0], T[0, :, None]], dim=1)
extr_matrix = torch.cat(
[
extrinsic_3x4_0,
torch.Tensor([[0, 0, 0, 1]]).to(extrinsic_3x4_0.device),
],
dim=0,
)
inv_extr_matrix = extr_matrix.inverse().to(device)
pcd, colors = depth_to_pcd(
preds.depth_map[i, 0],
batch_dict["stereo_video"][i][0].permute(1, 2, 0),
K[0][0][0],
K[0][0][2],
K[0][1][2],
step=1,
inv_extrinsic=inv_extr_matrix,
mask=batch_dict["fg_mask"][i, 0] if only_foreground else None,
filter=False,
)
R, T = inv_extr_matrix[None, :3, :3], inv_extr_matrix[None, :3, 3]
pcd_global_seq.append((pcd, colors, (R, T, preds.perspective_cameras[i])))
raster_settings = PointsRasterizationSettings(
image_size=[H, W],
radius=0.003,
points_per_pixel=10,
)
R, T, cam_ = pcd_global_seq[ref_frame][2]
median_depth = preds.depth_map.median()
cam_.cuda()
for mode in ["angle_15", "angle_-15", "angle_0", "changing_angle"]:
res = []
for t, (pcd, color, __) in enumerate(pcd_global_seq):
if mode == "changing_angle":
angle = math.cos((math.pi) * (t / 60)) * 15
elif mode == "angle_15":
angle = 15
elif mode == "angle_-15":
angle = -15
elif mode == "angle_0":
angle = 0
delta_x = median_depth * math.sin(math.radians(angle))
delta_z = median_depth * (1 - math.cos(math.radians(angle)))
cam = cam_.clone()
cam.R = torch.bmm(
cam.R,
RotateAxisAngle(angle=angle, axis="Y", device=device).get_matrix()[
:, :3, :3
],
)
cam.T[0, 0] = cam.T[0, 0] - delta_x
cam.T[0, 2] = cam.T[0, 2] - delta_z + median_depth / 2.0
rasterizer = PointsRasterizer(
cameras=cam, raster_settings=raster_settings
)
renderer = PointsRenderer(
rasterizer=rasterizer,
compositor=AlphaCompositor(background_color=(1, 1, 1)),
)
pcd_copy = pcd.clone()
point_cloud = Pointclouds(points=[pcd_copy], features=[color / 255.0])
images = renderer(point_cloud)
res.append(images[0, ..., :3].cpu())
res = torch.stack(res)
video = (res * 255).numpy().astype(np.uint8)
save_name = f"{sequence_name}_reconstruction_{step}_mode_{mode}_"
if writer is None:
outputs[mode] = video
if only_foreground:
save_name += "fg_only"
else:
save_name += "full_scene"
video_out = cv2.VideoWriter(
os.path.join(
output_dir,
f"{save_name}.mp4",
),
cv2.VideoWriter_fourcc(*"mp4v"),
fps=30,
frameSize=(res.shape[2], res.shape[1]),
isColor=True,
)
filename = os.path.join(output_dir, sequence_name + '_img_')
if not os.path.isdir(filename + str(mode)):
os.mkdir(filename + str(mode))
for i in range(len(video)):
filename_temp = filename + str(mode) + '/' + str(i).zfill(3) + '.png'
cv2.imwrite(filename_temp, cv2.cvtColor(video[i], cv2.COLOR_BGR2RGB))
video_out.write(cv2.cvtColor(video[i], cv2.COLOR_BGR2RGB))
video_out.release()
if writer is not None:
writer.add_video(
f"{sequence_name}_reconstruction_mode_{mode}",
(res * 255).permute(0, 3, 1, 2).to(torch.uint8)[None],
global_step=step,
fps=30,
)
return outputs
================================================
FILE: models/Video-Depth-Anything/app.py
================================================
# Copyright (2025) Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gradio as gr
import numpy as np
import os
import torch
from video_depth_anything.video_depth import VideoDepthAnything
from utils.dc_utils import read_video_frames, vis_sequence_depth, save_video
examples = [
['assets/example_videos/davis_rollercoaster.mp4'],
]
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
}
encoder='vitl'
video_depth_anything = VideoDepthAnything(**model_configs[encoder])
video_depth_anything.load_state_dict(torch.load(f'./checkpoints/video_depth_anything_{encoder}.pth', map_location='cpu'), strict=True)
video_depth_anything = video_depth_anything.to('cuda').eval()
def infer_video_depth(
input_video: str,
max_len: int = -1,
target_fps: int = -1,
max_res: int = 1280,
output_dir: str = './outputs',
input_size: int = 518,
):
frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
depth_list, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device='cuda')
depth_list = np.stack(depth_list, axis=0)
vis = vis_sequence_depth(depth_list)
video_name = os.path.basename(input_video)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_src.mp4')
depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')
save_video(frames, processed_video_path, fps=fps)
save_video(vis, depth_vis_path, fps=fps)
return [processed_video_path, depth_vis_path]
def construct_demo():
with gr.Blocks(analytics_enabled=False) as demo:
gr.Markdown(
f"""
blablabla
"""
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
input_video = gr.Video(label="Input Video")
# with gr.Tab(label="Output"):
with gr.Column(scale=2):
with gr.Row(equal_height=True):
processed_video = gr.Video(
label="Preprocessed video",
interactive=False,
autoplay=True,
loop=True,
show_share_button=True,
scale=5,
)
depth_vis_video = gr.Video(
label="Generated Depth Video",
interactive=False,
autoplay=True,
loop=True,
show_share_button=True,
scale=5,
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
with gr.Row(equal_height=False):
with gr.Accordion("Advanced Settings", open=False):
max_len = gr.Slider(
label="max process length",
minimum=-1,
maximum=1000,
value=-1,
step=1,
)
target_fps = gr.Slider(
label="target FPS",
minimum=-1,
maximum=30,
value=15,
step=1,
)
max_res = gr.Slider(
label="max side resolution",
minimum=480,
maximum=1920,
value=1280,
step=1,
)
generate_btn = gr.Button("Generate")
with gr.Column(scale=2):
pass
gr.Examples(
examples=examples,
inputs=[
input_video,
max_len,
target_fps,
max_res
],
outputs=[processed_video, depth_vis_video],
fn=infer_video_depth,
cache_examples="lazy",
)
generate_btn.click(
fn=infer_video_depth,
inputs=[
input_video,
max_len,
target_fps,
max_res
],
outputs=[processed_video, depth_vis_video],
)
return demo
if __name__ == "__main__":
demo = construct_demo()
demo.queue()
demo.launch(server_name="0.0.0.0")
================================================
FILE: models/Video-Depth-Anything/get_weights.sh
================================================
#!/bin/bash
mkdir checkpoints
cd checkpoints
wget https://huggingface.co/depth-anything/Video-Depth-Anything-Small/resolve/main/video_depth_anything_vits.pth
wget https://huggingface.co/depth-anything/Video-Depth-Anything-Large/resolve/main/video_depth_anything_vitl.pth
================================================
FILE: models/Video-Depth-Anything/run.py
================================================
# Copyright (2025) Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import numpy as np
import os
import torch
from video_depth_anything.video_depth import VideoDepthAnything
from utils.dc_utils import read_video_frames, vis_sequence_depth, save_video
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Video Depth Anything')
parser.add_argument('--input_video', type=str, default='./assets/example_videos/davis_rollercoaster.mp4')
parser.add_argument('--output_dir', type=str, default='./outputs')
parser.add_argument('--input_size', type=int, default=518)
parser.add_argument('--max_res', type=int, default=1280)
parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitl'])
parser.add_argument('--max_len', type=int, default=-1, help='maximum length of the input video, -1 means no limit')
parser.add_argument('--target_fps', type=int, default=-1, help='target fps of the input video, -1 means the original fps')
args = parser.parse_args()
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
}
video_depth_anything = VideoDepthAnything(**model_configs[args.encoder])
video_depth_anything.load_state_dict(torch.load(f'./checkpoints/video_depth_anything_{args.encoder}.pth', map_location='cpu'), strict=True)
video_depth_anything = video_depth_anything.to(DEVICE).eval()
frames, target_fps = read_video_frames(args.input_video, args.max_len, args.target_fps, args.max_res)
depth_list, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=args.input_size, device=DEVICE)
depth_list = np.stack(depth_list, axis=0)
vis = vis_sequence_depth(depth_list)
video_name = os.path.basename(args.input_video)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
processed_video_path = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_src.mp4')
depth_vis_path = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')
save_video(frames, processed_video_path, fps=fps)
save_video(vis, depth_vis_path, fps=fps)
================================================
FILE: models/Video-Depth-Anything/utils/dc_utils.py
================================================
# This file is originally from DepthCrafter/depthcrafter/utils.py at main · Tencent/DepthCrafter
# SPDX-License-Identifier: MIT License license
#
# This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]
# Original file is released under [ MIT License license], with the full license text available at [https://github.com/Tencent/DepthCrafter?tab=License-1-ov-file].
from typing import Union, List
import tempfile
import numpy as np
import PIL.Image
import matplotlib.cm as cm
import mediapy
import torch
try:
from decord import VideoReader, cpu
DECORD_AVAILABLE = True
except:
import cv2
DECORD_AVAILABLE = False
def read_video_frames(video_path, process_length, target_fps=-1, max_res=-1, dataset="open"):
if DECORD_AVAILABLE:
vid = VideoReader(video_path, ctx=cpu(0))
original_height, original_width = vid.get_batch([0]).shape[1:3]
height = original_height
width = original_width
if max_res > 0 and max(height, width) > max_res:
scale = max_res / max(original_height, original_width)
height = round(original_height * scale)
width = round(original_width * scale)
vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
fps = vid.get_avg_fps() if target_fps == -1 else target_fps
stride = round(vid.get_avg_fps() / fps)
stride = max(stride, 1)
frames_idx = list(range(0, len(vid), stride))
if process_length != -1 and process_length < len(frames_idx):
frames_idx = frames_idx[:process_length]
frames = vid.get_batch(frames_idx).asnumpy()
else:
cap = cv2.VideoCapture(video_path)
original_fps = cap.get(cv2.CAP_PROP_FPS)
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
if max_res > 0 and max(original_height, original_width) > max_res:
scale = max_res / max(original_height, original_width)
height = round(original_height * scale)
width = round(original_width * scale)
fps = original_fps if target_fps < 0 else target_fps
stride = max(round(original_fps / fps), 1)
frames = []
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret or (process_length > 0 and frame_count >= process_length):
break
if frame_count % stride == 0:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
if max_res > 0 and max(original_height, original_width) > max_res:
frame = cv2.resize(frame, (width, height)) # Resize frame
frames.append(frame)
frame_count += 1
cap.release()
frames = np.stack(frames, axis=0)
return frames, fps
def save_video(
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]],
output_video_path: str = None,
fps: int = 10,
crf: int = 18,
) -> str:
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
if isinstance(video_frames[0], np.ndarray):
video_frames = [frame.astype(np.uint8) for frame in video_frames]
elif isinstance(video_frames[0], PIL.Image.Image):
video_frames = [np.array(frame) for frame in video_frames]
mediapy.write_video(output_video_path, video_frames, fps=fps, crf=crf)
return output_video_path
class ColorMapper:
# a color mapper to map depth values to a certain colormap
def __init__(self, colormap: str = "inferno"):
self.colormap = torch.tensor(cm.get_cmap(colormap).colors)
def apply(self, image: torch.Tensor, v_min=None, v_max=None):
# assert len(image.shape) == 2
if v_min is None:
v_min = image.min()
if v_max is None:
v_max = image.max()
image = (image - v_min) / (v_max - v_min)
image = (image * 255).long()
image = self.colormap[image] * 255
return image
def vis_sequence_depth(depths: np.ndarray, v_min=None, v_max=None):
visualizer = ColorMapper()
if v_min is None:
v_min = depths.min()
if v_max is None:
v_max = depths.max()
res = visualizer.apply(torch.tensor(depths), v_min=v_min, v_max=v_max).numpy()
return res
================================================
FILE: models/Video-Depth-Anything/utils/util.py
================================================
# Copyright (2025) Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def compute_scale_and_shift(prediction, target, mask, scale_only=False):
if scale_only:
return compute_scale(prediction, target, mask), 0
else:
return compute_scale_and_shift_full(prediction, target, mask)
def compute_scale(prediction, target, mask):
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
prediction = prediction.astype(np.float32)
target = target.astype(np.float32)
mask = mask.astype(np.float32)
a_00 = np.sum(mask * prediction * prediction)
a_01 = np.sum(mask * prediction)
a_11 = np.sum(mask)
# right hand side: b = [b_0, b_1]
b_0 = np.sum(mask * prediction * target)
x_0 = b_0 / (a_00 + 1e-6)
return x_0
def compute_scale_and_shift_full(prediction, target, mask):
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
prediction = prediction.astype(np.float32)
target = target.astype(np.float32)
mask = mask.astype(np.float32)
a_00 = np.sum(mask * prediction * prediction)
a_01 = np.sum(mask * prediction)
a_11 = np.sum(mask)
b_0 = np.sum(mask * prediction * target)
b_1 = np.sum(mask * target)
x_0 = 1
x_1 = 0
det = a_00 * a_11 - a_01 * a_01
if det != 0:
x_0 = (a_11 * b_0 - a_01 * b_1) / det
x_1 = (-a_01 * b_0 + a_00 * b_1) / det
return x_0, x_1
def get_interpolate_frames(frame_list_pre, frame_list_post):
assert len(frame_list_pre) == len(frame_list_post)
min_w = 0.0
max_w = 1.0
step = (max_w - min_w) / (len(frame_list_pre)-1)
post_w_list = [min_w] + [i * step for i in range(1,len(frame_list_pre)-1)] + [max_w]
interpolated_frames = []
for i in range(len(frame_list_pre)):
interpolated_frames.append(frame_list_pre[i] * (1-post_w_list[i]) + frame_list_post[i] * post_w_list[i])
return interpolated_frames
================================================
FILE: models/Video-Depth-Anything/video_depth_anything/dinov2.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
from functools import partial
import math
import logging
from typing import Sequence, Tuple, Union, Callable
import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.nn.init import trunc_normal_
from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
logger = logging.getLogger("dinov2")
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = ".".join((name, child_name)) if name else child_name
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
fn(module=module, name=name)
return module
class BlockChunk(nn.ModuleList):
def forward(self, x):
for b in self:
x = b(x)
return x
class DinoVisionTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
ffn_bias=True,
proj_bias=True,
drop_path_rate=0.0,
drop_path_uniform=False,
init_values=None, # for layerscale: None or 0 => no layerscale
embed_layer=PatchEmbed,
act_layer=nn.GELU,
block_fn=Block,
ffn_layer="mlp",
block_chunks=1,
num_register_tokens=0,
interpolate_antialias=False,
interpolate_offset=0.1,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
proj_bias (bool): enable bias for proj in attn if True
ffn_bias (bool): enable bias for ffn if True
drop_path_rate (float): stochastic depth rate
drop_path_uniform (bool): apply uniform drop rate across blocks
weight_init (str): weight init scheme
init_values (float): layer-scale init values
embed_layer (nn.Module): patch embedding layer
act_layer (nn.Module): MLP activation layer
block_fn (nn.Module): transformer block class
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
"""
super().__init__()
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1
self.n_blocks = depth
self.num_heads = num_heads
self.patch_size = patch_size
self.num_register_tokens = num_register_tokens
self.interpolate_antialias = interpolate_antialias
self.interpolate_offset = interpolate_offset
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
assert num_register_tokens >= 0
self.register_tokens = (
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
)
if drop_path_uniform is True:
dpr = [drop_path_rate] * depth
else:
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
if ffn_layer == "mlp":
logger.info("using MLP layer as FFN")
ffn_layer = Mlp
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
logger.info("using SwiGLU layer as FFN")
ffn_layer = SwiGLUFFNFused
elif ffn_layer == "identity":
logger.info("using Identity layer as FFN")
def f(*args, **kwargs):
return nn.Identity()
ffn_layer = f
else:
raise NotImplementedError
blocks_list = [
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
ffn_layer=ffn_layer,
init_values=init_values,
)
for i in range(depth)
]
if block_chunks > 0:
self.chunked_blocks = True
chunked_blocks = []
chunksize = depth // block_chunks
for i in range(0, depth, chunksize):
# this is to keep the block index consistent if we chunk the block list
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
else:
self.chunked_blocks = False
self.blocks = nn.ModuleList(blocks_list)
self.norm = norm_layer(embed_dim)
self.head = nn.Identity()
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
self.init_weights()
def init_weights(self):
trunc_normal_(self.pos_embed, std=0.02)
nn.init.normal_(self.cls_token, std=1e-6)
if self.register_tokens is not None:
nn.init.normal_(self.register_tokens, std=1e-6)
named_apply(init_weights_vit_timm, self)
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
# DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
# w0, h0 = w0 + 0.1, h0 + 0.1
sqrt_N = math.sqrt(N)
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
scale_factor=(sx, sy),
# (int(w0), int(h0)), # to solve the upsampling shape issue
mode="bicubic",
antialias=self.interpolate_antialias
)
assert int(w0) == patch_pos_embed.shape[-2]
assert int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
def prepare_tokens_with_masks(self, x, masks=None):
B, nc, w, h = x.shape
x = self.patch_embed(x)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h)
if self.register_tokens is not None:
x = torch.cat(
(
x[:, :1],
self.register_tokens.expand(x.shape[0], -1, -1),
x[:, 1:],
),
dim=1,
)
return x
def forward_features_list(self, x_list, masks_list):
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
for blk in self.blocks:
x = blk(x)
all_x = x
output = []
for x, masks in zip(all_x, masks_list):
x_norm = self.norm(x)
output.append(
{
"x_norm_clstoken": x_norm[:, 0],
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
"x_prenorm": x,
"masks": masks,
}
)
return output
def forward_features(self, x, masks=None):
if isinstance(x, list):
return self.forward_features_list(x, masks)
x = self.prepare_tokens_with_masks(x, masks)
for blk in self.blocks:
x = blk(x)
x_norm = self.norm(x)
return {
"x_norm_clstoken": x_norm[:, 0],
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
"x_prenorm": x,
"masks": masks,
}
def _get_intermediate_layers_not_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
# If n is an int, take the n last blocks. If it's a list, take them
output, total_block_len = [], len(self.blocks)
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in blocks_to_take:
output.append(x)
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def _get_intermediate_layers_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
output, i, total_block_len = [], 0, len(self.blocks[-1])
# If n is an int, take the n last blocks. If it's a list, take them
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for block_chunk in self.blocks:
for blk in block_chunk[i:]: # Passing the nn.Identity()
x = blk(x)
if i in blocks_to_take:
output.append(x)
i += 1
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1, # Layers or n last layers to take
reshape: bool = False,
return_class_token: bool = False,
norm=True
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
if self.chunked_blocks:
outputs = self._get_intermediate_layers_chunked(x, n)
else:
outputs = self._get_intermediate_layers_not_chunked(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0] for out in outputs]
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
if reshape:
B, _, w, h = x.shape
outputs = [
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_class_token:
return tuple(zip(outputs, class_tokens))
return tuple(outputs)
def forward(self, *args, is_training=False, **kwargs):
ret = self.forward_features(*args, **kwargs)
if is_training:
return ret
else:
return self.head(ret["x_norm_clstoken"])
def init_weights_vit_timm(module: nn.Module, name: str = ""):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
"""
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
"""
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1536,
depth=40,
num_heads=24,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def DINOv2(model_name):
model_zoo = {
"vits": vit_small,
"vitb": vit_base,
"vitl": vit_large,
"vitg": vit_giant2
}
return model_zoo[model_name](
img_size=518,
patch_size=14,
init_values=1.0,
ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
block_chunks=0,
num_register_tokens=0,
interpolate_antialias=False,
interpolate_offset=0.1
)
================================================
FILE: models/Video-Depth-Anything/video_depth_anything/dinov2_layers/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .mlp import Mlp
from .patch_embed import PatchEmbed
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
from .block import NestedTensorBlock
from .attention import MemEffAttention
================================================
FILE: models/Video-Depth-Anything/video_depth_anything/dinov2_layers/attention.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import logging
from torch import Tensor
from torch import nn
logger = logging.getLogger("dinov2")
try:
from xformers.ops import memory_efficient_attention, unbind, fmha
XFORMERS_AVAILABLE = True
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: Tensor) -> Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffAttention(Attention):
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
if not XFORMERS_AVAILABLE:
assert attn_bias is None, "xFormers is required for nested tensors usage"
return super().forward(x)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = unbind(qkv, 2)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
================================================
FILE: models/Video-Depth-Anything/video_depth_anything/dinov2_layers/block.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
import logging
from typing import Callable, List, Any, Tuple, Dict
import torch
from torch import nn, Tensor
from .attention import Attention, MemEffAttention
from .drop_path import DropPath
from .layer_scale import LayerScale
from .mlp import Mlp
logger = logging.getLogger("dinov2")
try:
from xformers.ops import fmha
from xformers.ops import scaled_index_add, index_select_cat
XFORMERS_AVAILABLE = True
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
) -> None:
super().__init__()
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: Tensor) -> Tensor:
def attn_residual_func(x: Tensor) -> Tensor:
return self.ls1(self.attn(self.norm1(x)))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
# the overhead is compensated only for a drop path rate larger than 0.1
x = drop_add_residual_stochastic_depth(
x,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
x = drop_add_residual_stochastic_depth(
x,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
elif self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x)
x = x + ffn_residual_func(x)
return x
def drop_add_residual_stochastic_depth(
x: Tensor,
residual_func: Callable[[Tensor], Tensor],
sample_drop_ratio: float = 0.0,
) -> Tensor:
# 1) extract subset using permutation
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
x_subset = x[brange]
# 2) apply residual_func to get residual
residual = residual_func(x_subset)
x_flat = x.flatten(1)
residual = residual.flatten(1)
residual_scale_factor = b / sample_subset_size
# 3) add the residual
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
return x_plus_residual.view_as(x)
def get_branges_scales(x, sample_drop_ratio=0.0):
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
residual_scale_factor = b / sample_subset_size
return brange, residual_scale_factor
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
if scaling_vector is None:
x_flat = x.flatten(1)
residual = residual.flatten(1)
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
else:
x_plus_residual = scaled_index_add(
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
)
return x_plus_residual
attn_bias_cache: Dict[Tuple, Any] = {}
def get_attn_bias_and_cat(x_list, branges=None):
"""
this will perform the index select, cat the tensors, and provide the attn_bias from cache
"""
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
if all_shapes not in attn_bias_cache.keys():
seqlens = []
for b, x in zip(batch_sizes, x_list):
for _ in range(b):
seqlens.append(x.shape[1])
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
attn_bias._batch_sizes = batch_sizes
attn_bias_cache[all_shapes] = attn_bias
if branges is not None:
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
else:
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
cat_tensors = torch.cat(tensors_bs1, dim=1)
return attn_bias_cache[all_shapes], cat_tensors
def drop_add_residual_stochastic_depth_list(
x_list: List[Tensor],
residual_func: Callable[[Tensor, Any], Tensor],
sample_drop_ratio: float = 0.0,
scaling_vector=None,
) -> Tensor:
# 1) generate random set of indices for dropping samples in the batch
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
branges = [s[0] for s in branges_scales]
residual_scale_factors = [s[1] for s in branges_scales]
# 2) get attention bias and index+concat the tensors
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
# 3) apply residual_func to get residual, and split the result
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
outputs = []
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
return outputs
class NestedTensorBlock(Block):
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
"""
x_list contains a list of tensors to nest together and run
"""
assert isinstance(self.attn, MemEffAttention)
if self.training and self.sample_drop_ratio > 0.0:
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.attn(self.norm1(x), attn_bias=attn_bias)
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.mlp(self.norm2(x))
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
)
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
)
return x_list
else:
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
attn_bias, x = get_attn_bias_and_cat(x_list)
x = x + attn_residual_func(x, attn_bias=attn_bias)
x = x + ffn_residual_func(x)
return attn_bias.split(x)
def forward(self, x_or_x_list):
if isinstance(x_or_x_list, Tensor):
return super().forward(x_or_x_list)
elif isinstance(x_or_x_list, list):
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
return self.forward_nested(x_or_x_list)
else:
raise AssertionError
================================================
FILE: models/Video-Depth-Anything/video_depth_anything/dinov2_layers/drop_path.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
from torch import nn
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0:
random_tensor.div_(keep_prob)
output = x * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
================================================
FILE: models/Video-Depth-Anything/video_depth_anything/dinov2_layers/layer_scale.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
from typing import Union
import torch
from torch import Tensor
from torch import nn
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, Tensor] = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
================================================
FILE: models/Video-Depth-Anything/video_depth_anything/dinov2_layers/mlp.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
from typing import Callable, Optional
from torch import Tensor, nn
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
================================================
FILE: models/Video-Depth-Anything/video_depth_anything/dinov2_layers/patch_embed.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
from typing import Callable, Optional, Tuple, Union
from torch import Tensor
import torch.nn as nn
def make_2tuple(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
assert isinstance(x, int)
return (x, x)
class PatchEmbed(nn.Module):
"""
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
Args:
img_size: Image size.
patch_size: Patch token size.
in_chans: Number of input image channels.
embed_dim: Number of linear projection output channels.
norm_layer: Normalization layer.
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (
image_HW[0] // patch_HW[0],
image_HW[1] // patch_HW[1],
)
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
patch_H, patch_W = self.patch_size
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
x = self.proj(x) # B C H W
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2) # B HW C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
return x
def flops(self) -> float:
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
================================================
FILE: models/Video-Depth-Anything/video_depth_anything/dinov2_layers/swiglu_ffn.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional
from torch import Tensor, nn
import torch.nn.functional as F
class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
def forward(self, x: Tensor) -> Tensor:
x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
return self.w3(hidden)
try:
from xformers.ops import SwiGLU
XFORMERS_AVAILABLE = True
except ImportError:
SwiGLU = SwiGLUFFN
XFORMERS_AVAILABLE = False
class SwiGLUFFNFused(SwiGLU):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
out_features = out_features or in_features
hidden_features = hidden_features or in_features
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
super().__init__(
in_features=in_features,
hidden_features=hidden_features,
out_features=out_features,
bias=bias,
)
================================================
FILE: models/Video-Depth-Anything/video_depth_anything/dpt.py
================================================
# Copyright (2025) Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from .util.blocks import FeatureFusionBlock, _make_scratch
def _make_fusion_block(features, use_bn, size=None):
return FeatureFusionBlock(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
size=size,
)
class ConvBlock(nn.Module):
def __init__(self, in_feature, out_feature):
super().__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_feature),
nn.ReLU(True)
)
def forward(self, x):
return self.conv_block(x)
class DPTHead(nn.Module):
def __init__(
self,
in_channels,
features=256,
use_bn=False,
out_channels=[256, 512, 1024, 1024],
use_clstoken=False
):
super(DPTHead, self).__init__()
self.use_clstoken = use_clstoken
self.projects = nn.ModuleList([
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=1,
stride=1,
padding=0,
) for out_channel in out_channels
])
self.resize_layers = nn.ModuleList([
nn.ConvTranspose2d(
in_channels=out_channels[0],
out_channels=out_channels[0],
kernel_size=4,
stride=4,
padding=0),
nn.ConvTranspose2d(
in_channels=out_channels[1],
out_channels=out_channels[1],
kernel_size=2,
stride=2,
padding=0),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3],
out_channels=out_channels[3],
kernel_size=3,
stride=2,
padding=1)
])
if use_clstoken:
self.readout_projects = nn.ModuleList()
for _ in range(len(self.projects)):
self.readout_projects.append(
nn.Sequential(
nn.Linear(2 * in_channels, in_channels),
nn.GELU()))
self.scratch = _make_scratch(
out_channels,
features,
groups=1,
expand=False,
)
self.scratch.stem_transpose = None
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
head_features_1 = features
head_features_2 = 32
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True),
nn.Identity(),
)
def forward(self, out_features, patch_h, patch_w):
out = []
for i, x in enumerate(out_features):
if self.use_clstoken:
x, cls_token = x[0], x[1]
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
else:
x = x[0]
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
layer_1, layer_2, layer_3, layer_4 = out
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv1(path_1)
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
out = self.scratch.output_conv2(out)
return out
================================================
FILE: models/Video-Depth-Anything/video_depth_anything/dpt_temporal.py
================================================
# Copyright (2025) Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
import torch.nn as nn
from .dpt import DPTHead
from .motion_module.motion_module import TemporalModule
from easydict import EasyDict
class DPTHeadTemporal(DPTHead):
def __init__(self,
in_channels,
features=256,
use_bn=False,
out_channels=[256, 512, 1024, 1024],
use_clstoken=False,
num_frames=32,
pe='ape'
):
super().__init__(in_channels, features, use_bn, out_channels, use_clstoken)
assert num_frames > 0
motion_module_kwargs = EasyDict(num_attention_heads = 8,
num_transformer_block = 1,
num_attention_blocks = 2,
temporal_max_len = num_frames,
zero_initialize = True,
pos_embedding_type = pe)
self.motion_modules = nn.ModuleList([
TemporalModule(in_channels=out_channels[2],
**motion_module_kwargs),
TemporalModule(in_channels=out_channels[3],
**motion_module_kwargs),
TemporalModule(in_channels=features,
**motion_module_kwargs),
TemporalModule(in_channels=features,
**motion_module_kwargs)
])
def forward(self, out_features, patch_h, patch_w, frame_length):
out = []
for i, x in enumerate(out_features):
if self.use_clstoken:
x, cls_token = x[0], x[1]
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
else:
x = x[0]
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)).contiguous()
B, T = x.shape[0] // frame_length, frame_length
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
layer_1, layer_2, layer_3, layer_4 = out
B, T = layer_1.shape[0] // frame_length, frame_length
layer_3 = self.motion_modules[0](layer_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
layer_4 = self.motion_modules[1](layer_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
path_4 = self.motion_modules[2](path_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
path_3 = self.motion_modules[3](path_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv1(path_1)
out = F.interpolate(
out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True
)
# out = self.scratch.output_conv2(out)
return out
================================================
FILE: models/Video-Depth-Anything/video_depth_anything/motion_module/attention.py
================================================
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
try:
import xformers
import xformers.ops
XFORMERS_AVAILABLE = True
except ImportError:
print("xFormers not available")
XFORMERS_AVAILABLE = False
class CrossAttention(nn.Module):
r"""
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.upcast_efficient_attention = False
self.scale = dim_head**-0.5
self.heads = heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self._slice_size = None
self._use_memory_efficient_attention_xformers = False
self.added_kv_proj_dim = added_kv_proj_dim
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
else:
self.group_norm = None
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous()
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size).contiguous()
return tensor
def reshape_heads_to_4d(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous()
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim).contiguous()
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size).contiguous()
return tensor
def reshape_4d_to_heads(self, tensor):
batch_size, seq_len, head_size, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, dim * head_size).contiguous()
return tensor
def set_attention_slice(self, slice_size):
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
self._slice_size = slice_size
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
if self.added_kv_proj_dim is not None:
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
else:
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# attention, what we cannot get enough of
if XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def _attention(self, query, key, value, attention_mask=None):
if self.upcast_attention:
query = query.float()
key = key.float()
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = attention_scores.softmax(dim=-1)
# cast back to the original dtype
attention_probs = attention_probs.to(value.dtype)
# compute attention output
hidden_states = torch.bmm(attention_probs, value)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
if self.upcast_attention:
query_slice = query_slice.float()
key_slice = key_slice.float()
attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
query_slice,
key_slice.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
if attention_mask is not None:
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
if self.upcast_softmax:
attn_slice = attn_slice.float()
attn_slice = attn_slice.softmax(dim=-1)
# cast back to the original dtype
attn_slice = attn_slice.to(value.dtype)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
if self.upcast_efficient_attention:
org_dtype = query.dtype
query = query.float()
key = key.float()
value = value.float()
if attention_mask is not None:
attention_mask = attention_mask.float()
hidden_states = self._memory_efficient_attention_split(query, key, value, attention_mask)
if self.upcast_efficient_attention:
hidden_states = hidden_states.to(org_dtype)
hidden_states = self.reshape_4d_to_heads(hidden_states)
return hidden_states
# print("Errror: no xformers")
# raise NotImplementedError
def _memory_efficient_attention_split(self, query, key, value, attention_mask):
batch_size = query.shape[0]
max_batch_size = 65535
num_batches = (batch_size + max_batch_size - 1) // max_batch_size
results = []
for i in range(num_batches):
start_idx = i * max_batch_size
end_idx = min((i + 1) * max_batch_size, batch_size)
query_batch = query[start_idx:end_idx]
key_batch = key[start_idx:end_idx]
value_batch = value[start_idx:end_idx]
if attention_mask is not None:
attention_mask_batch = attention_mask[start_idx:end_idx]
else:
attention_mask_batch = None
result = xformers.ops.memory_efficient_attention(query_batch, key_batch, value_batch, attn_bias=attention_mask_batch)
results.append(result)
full_result = torch.cat(results, dim=0)
return full_result
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out))
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class GELU(nn.Module):
r"""
GELU activation function
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
def gelu(self, gate):
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states
# feedforward
class GEGLU(nn.Module):
r"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def gelu(self, gate):
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
class ApproximateGELU(nn.Module):
"""
The approximate form of Gaussian Error Linear Unit (GELU)
For more details, see section 2: https://arxiv.org/abs/1606.08415
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
def forward(self, x):
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2).contiguous())
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2).contiguous())
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
return xq_out.type_as(xq), xk_out.type_as(xk)
================================================
FILE: models/Video-Depth-Anything/video_depth_anything/motion_module/motion_module.py
================================================
# This file is originally from AnimateDiff/animatediff/models/motion_module.py at main · guoyww/AnimateDiff
# SPDX-License-Identifier: Apache-2.0 license
#
# This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]
# Original file was released under [ Apache-2.0 license], with the full license text available at [https://github.com/guoyww/AnimateDiff?tab=Apache-2.0-1-ov-file#readme].
import torch
import torch.nn.functional as F
from torch import nn
from .attention import CrossAttention, FeedForward, apply_rotary_emb, precompute_freqs_cis
from einops import rearrange, repeat
import math
try:
import xformers
import xformers.ops
XFORMERS_AVAILABLE = True
except ImportError:
print("xFormers not available")
XFORMERS_AVAILABLE = False
def zero_module(module):
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module
class TemporalModule(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads = 8,
num_transformer_block = 2,
num_attention_blocks = 2,
norm_num_groups = 32,
temporal_max_len = 32,
zero_initialize = True,
pos_embedding_type = "ape",
):
super().__init__()
self.temporal_transformer = TemporalTransformer3DModel(
in_channels=in_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=in_channels // num_attention_heads,
num_layers=num_transformer_block,
num_attention_blocks=num_attention_blocks,
norm_num_groups=norm_num_groups,
temporal_max_len=temporal_max_len,
pos_embedding_type=pos_embedding_type,
)
if zero_initialize:
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
def forward(self, input_tensor, encoder_hidden_states, attention_mask=None):
hidden_states = input_tensor
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
output = hidden_states
return output
class TemporalTransformer3DModel(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads,
attention_head_dim,
num_layers,
num_attention_blocks = 2,
norm_num_groups = 32,
temporal_max_len = 32,
pos_embedding_type = "ape",
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_attention_blocks=num_attention_blocks,
temporal_max_len=temporal_max_len,
pos_embedding_type=pos_embedding_type,
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
batch, channel, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim).contiguous()
hidden_states = self.proj_in(hidden_states)
# Transformer Blocks
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, attention_mask=attention_mask)
# output
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
return output
class TemporalTransformerBlock(nn.Module):
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
num_attention_blocks = 2,
temporal_max_len = 32,
pos_embedding_type = "ape",
):
super().__init__()
self.attention_blocks = nn.ModuleList(
[
TemporalAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
temporal_max_len=temporal_max_len,
pos_embedding_type=pos_embedding_type,
)
for i in range(num_attention_blocks)
]
)
self.norms = nn.ModuleList(
[
nn.LayerNorm(dim)
for i in range(num_attention_blocks)
]
)
self.ff = FeedForward(dim, dropout=0.0, activation_fn="geglu")
self.ff_norm = nn.LayerNorm(dim)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
for attention_block, norm in zip(self.attention_blocks, self.norms):
norm_hidden_states = norm(hidden_states)
hidden_states = attention_block(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
video_length=video_length,
attention_mask=attention_mask,
) + hidden_states
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
output = hidden_states
return output
class PositionalEncoding(nn.Module):
def __init__(
self,
d_model,
dropout = 0.,
max_len = 32
):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)].to(x.dtype)
return self.dropout(x)
class TemporalAttention(CrossAttention):
def __init__(
self,
temporal_max_len = 32,
pos_embedding_type = "ape",
*args, **kwargs
):
super().__init__(*args, **kwargs)
self.pos_embedding_type = pos_embedding_type
self._use_memory_efficient_attention_xformers = True
self.pos_encoder = None
self.freqs_cis = None
if self.pos_embedding_type == "ape":
self.pos_encoder = PositionalEncoding(
kwargs["query_dim"],
dropout=0.,
max_len=temporal_max_len
)
elif self.pos_embedding_type == "rope":
self.freqs_cis = precompute_freqs_cis(
kwargs["query_dim"],
temporal_max_len
)
else:
raise NotImplementedError
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
if self.pos_encoder is not None:
hidden_states = self.pos_encoder(hidden_states)
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
dim = query.shape[-1]
if self.added_kv_proj_dim is not None:
raise NotImplementedError
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
if self.freqs_cis is not None:
seq_len = query.shape[1]
freqs_cis = self.freqs_cis[:seq_len].to(query.device)
query, key = apply_rotary_emb(query, key, freqs_cis)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
use_memory_efficient = XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers
if use_memory_efficient and (dim // self.heads) % 8 != 0:
# print('Warning: the dim {} cannot be divided by 8. Fall into normal attention'.format(dim // self.heads))
use_memory_efficient = False
# attention, what we cannot get enough of
if use_memory_efficient:
query = self.reshape_heads_to_4d(query)
key = self.reshape_heads_to_4d(key)
value = self.reshape_heads_to_4d(value)
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
else:
raise NotImplementedError
# hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
================================================
FILE: models/Video-Depth-Anything/video_depth_anything/util/blocks.py
================================================
import torch.nn as nn
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
if len(in_shape) >= 4:
out_shape4 = out_shape
if expand:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
if len(in_shape) >= 4:
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
if len(in_shape) >= 4:
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
class ResidualConvUnit(nn.Module):
"""Residual convolution module."""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups = 1
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
if self.bn is True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn is True:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn is True:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
class FeatureFusionBlock(nn.Module):
"""Feature fusion block."""
def __init__(
self,
features,
activation,
deconv=False,
bn=False,
expand=False,
align_corners=True,
size=None,
):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups = 1
self.expand = expand
out_features = features
if self.expand is True:
out_features = features // 2
self.out_conv = nn.Conv2d(
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1
)
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
self.size = size
def forward(self, *xs, size=None):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
output = self.resConfUnit2(output)
if (size is None) and (self.size is None):
modifier = {"scale_factor": 2}
elif size is None:
modifier = {"size": self.size}
else:
modifier = {"size": size}
output = nn.functional.interpolate(
output.contiguous(), **modifier, mode="bilinear", align_corners=self.align_corners
)
output = self.out_conv(output)
return output
================================================
FILE: models/Video-Depth-Anything/video_depth_anything/util/transform.py
================================================
import numpy as np
import cv2
class Resize(object):
"""Resize sample to given size (width, height).
"""
def __init__(
self,
width,
height,
resize_target=True,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_AREA,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self.__width = width
self.__height = height
self.__resize_target = resize_target
self.__keep_aspect_ratio = keep_aspect_ratio
self.__multiple_of = ensure_multiple_of
self.__resize_method = resize_method
self.__image_interpolation_method = image_interpolation_method
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
if y < min_val:
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
return y
def get_size(self, width, height):
# determine new height and width
scale_height = self.__height / height
scale_width = self.__width / width
if self.__keep_aspect_ratio:
if self.__resize_method == "lower_bound":
# scale such that output size is lower bound
if scale_width > scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "upper_bound":
# scale such that output size is upper bound
if scale_width < scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "minimal":
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
elif self.__resize_method == "upper_bound":
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
elif self.__resize_method == "minimal":
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
return (new_width, new_height)
def __call__(self, sample):
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
# resize sample
sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
if self.__resize_target:
if "depth" in sample:
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
if "mask" in sample:
sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
return sample
class NormalizeImage(object):
"""Normlize image by given mean and std.
"""
def __init__(self, mean, std):
self.__mean = mean
self.__std = std
def __call__(self, sample):
sample["image"] = (sample["image"] - self.__mean) / self.__std
return sample
class PrepareForNet(object):
"""Prepare sample for usage as network input.
"""
def __init__(self):
pass
def __call__(self, sample):
image = np.transpose(sample["image"], (2, 0, 1))
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
if "depth" in sample:
depth = sample["depth"].astype(np.float32)
sample["depth"] = np.ascontiguousarray(depth)
if "mask" in sample:
sample["mask"] = sample["mask"].astype(np.float32)
sample["mask"] = np.ascontiguousarray(sample["mask"])
return sample
================================================
FILE: models/Video-Depth-Anything/video_depth_anything/video_depth.py
================================================
# Copyright (2025) Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision.transforms import Compose
import cv2
from tqdm import tqdm
import numpy as np
import gc
from .dinov2 import DINOv2
from .dpt_temporal import DPTHeadTemporal
from .util.transform import Resize, NormalizeImage, PrepareForNet
from ..utils.util import compute_scale_and_shift, get_interpolate_frames
# infer settings, do not change
INFER_LEN = 32
OVERLAP = 10
KEYFRAMES = [0,12,24,25,26,27,28,29,30,31]
INTERP_LEN = 8
class VideoDepthAnything(nn.Module):
def __init__(
self,
encoder='vitl',
features=256,
out_channels=[256, 512, 1024, 1024],
use_bn=False,
use_clstoken=False,
num_frames=32,
pe='ape'
):
super(VideoDepthAnything, self).__init__()
self.intermediate_layer_idx = {
'vits': [2, 5, 8, 11],
'vitl': [4, 11, 17, 23]
}
self.encoder = encoder
self.pretrained = DINOv2(model_name=encoder)
self.head = DPTHeadTemporal(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken, num_frames=num_frames, pe=pe)
def forward(self, x):
B, T, C, H, W = x.shape
patch_h, patch_w = H // 14, W // 14
features = self.pretrained.get_intermediate_layers(x.flatten(0,1), self.intermediate_layer_idx[self.encoder], return_class_token=True)
depth = self.head(features, patch_h, patch_w, T)
# depth = F.interpolate(depth, size=(H, W), mode="bilinear", align_corners=True)
# depth = F.relu(depth)
return depth
def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda'):
frame_height, frame_width = frames[0].shape[:2]
ratio = max(frame_height, frame_width) / min(frame_height, frame_width)
if ratio > 1.78: # we recommend to process video with ratio smaller than 16:9 due to memory limitation
input_size = int(input_size * 1.777 / ratio)
input_size = round(input_size / 14) * 14
transform = Compose([
Resize(
width=input_size,
height=input_size,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method='lower_bound',
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
])
frame_list = [frames[i] for i in range(frames.shape[0])]
frame_step = INFER_LEN - OVERLAP
org_video_len = len(frame_list)
append_frame_len = (frame_step - (org_video_len % frame_step)) % frame_step + (INFER_LEN - frame_step)
frame_list = frame_list + [frame_list[-1].copy()] * append_frame_len
depth_list = []
pre_input = None
for frame_id in tqdm(range(0, org_video_len, frame_step)):
cur_list = []
for i in range(INFER_LEN):
cur_list.append(torch.from_numpy(transform({'image': frame_list[frame_id+i].astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0))
cur_input = torch.cat(cur_list, dim=1).to(device)
if pre_input is not None:
cur_input[:, :OVERLAP, ...] = pre_input[:, KEYFRAMES, ...]
with torch.no_grad():
depth = self.forward(cur_input) # depth shape: [1, T, H, W]
depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True)
depth_list += [depth[i][0].cpu().numpy() for i in range(depth.shape[0])]
pre_input = cur_input
del frame_list
gc.collect()
depth_list_aligned = []
ref_align = []
align_len = OVERLAP - INTERP_LEN
kf_align_list = KEYFRAMES[:align_len]
for frame_id in range(0, len(depth_list), INFER_LEN):
if len(depth_list_aligned) == 0:
depth_list_aligned += depth_list[:INFER_LEN]
for kf_id in kf_align_list:
ref_align.append(depth_list[frame_id+kf_id])
else:
curr_align = []
for i in range(len(kf_align_list)):
curr_align.append(depth_list[frame_id+i])
scale, shift = compute_scale_and_shift(np.concatenate(curr_align),
np.concatenate(ref_align),
np.concatenate(np.ones_like(ref_align)==1))
pre_depth_list = depth_list_aligned[-INTERP_LEN:]
post_depth_list = depth_list[frame_id+align_len:frame_id+OVERLAP]
for i in range(len(post_depth_list)):
post_depth_list[i] = post_depth_list[i] * scale + shift
post_depth_list[i][post_depth_list[i]<0] = 0
depth_list_aligned[-INTERP_LEN:] = get_interpolate_frames(pre_depth_list, post_depth_list)
for i in range(OVERLAP, INFER_LEN):
new_depth = depth_list[frame_id+i] * scale + shift
new_depth[new_depth<0] = 0
depth_list_aligned.append(new_depth)
ref_align = ref_align[:1]
for kf_id in kf_align_list[1:]:
new_depth = depth_list[frame_id+kf_id] * scale + shift
new_depth[new_depth<0] = 0
ref_align.append(new_depth)
depth_list = depth_list_aligned
return depth_list[:org_video_len], target_fps
================================================
FILE: models/core/attention.py
================================================
import math
import copy
import torch
import torch.nn as nn
from torch.nn import Module, Dropout
"""
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
"""
def elu_feature_map(x):
return torch.nn.functional.elu(x) + 1
class PositionEncodingSine(nn.Module):
"""
This is a sinusoidal position encoding that generalized to 2-dimensional images
"""
def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True):
"""
Args:
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
on the final performance. For now, we keep both impls for backward compatability.
We will remove the buggy impl after re-training all variants of our released models.
"""
super().__init__()
pe = torch.zeros((d_model, *max_shape))
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
if temp_bug_fix:
div_term = torch.exp(
torch.arange(0, d_model // 2, 2).float()
* (-math.log(10000.0) / (d_model // 2))
)
else: # a buggy implementation (for backward compatability only)
div_term = torch.exp(
torch.arange(0, d_model // 2, 2).float()
* (-math.log(10000.0) / d_model // 2)
)
div_term = div_term[:, None, None] # [C//4, 1, 1]
pe[0::4, :, :] = torch.sin(x_position * div_term)
pe[1::4, :, :] = torch.cos(x_position * div_term)
pe[2::4, :, :] = torch.sin(y_position * div_term)
pe[3::4, :, :] = torch.cos(y_position * div_term)
self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W]
def forward(self, x):
"""
Args:
x: [N, C, H, W]
"""
return x + self.pe[:, :, : x.size(2), : x.size(3)].to(x.device)
class LinearAttention(Module):
def __init__(self, eps=1e-6):
super().__init__()
self.feature_map = elu_feature_map
self.eps = eps
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
"""Multi-Head linear attention proposed in "Transformers are RNNs"
Args:
queries: [N, L, H, D]
keys: [N, S, H, D]
values: [N, S, H, D]
q_mask: [N, L]
kv_mask: [N, S]
Returns:
queried_values: (N, L, H, D)
"""
Q = self.feature_map(queries)
K = self.feature_map(keys)
# set padded position to zero
if q_mask is not None:
Q = Q * q_mask[:, :, None, None]
if kv_mask is not None:
K = K * kv_mask[:, :, None, None]
values = values * kv_mask[:, :, None, None]
v_length = values.size(1)
values = values / v_length # prevent fp16 overflow
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
return queried_values.contiguous()
class FullAttention(Module):
def __init__(self, use_dropout=False, attention_dropout=0.1):
super().__init__()
self.use_dropout = use_dropout
self.dropout = Dropout(attention_dropout)
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
"""Multi-head scaled dot-product attention, a.k.a full attention.
Args:
queries: [N, L, H, D]
keys: [N, S, H, D]
values: [N, S, H, D]
q_mask: [N, L]
kv_mask: [N, S]
Returns:
queried_values: (N, L, H, D)
"""
# Compute the unnormalized attention and apply the masks
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
if kv_mask is not None:
QK.masked_fill_(
~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float("-inf")
)
# Compute the attention and the weighted average
softmax_temp = 1.0 / queries.size(3) ** 0.5 # sqrt(D)
A = torch.softmax(softmax_temp * QK, dim=2)
if self.use_dropout:
A = self.dropout(A)
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
return queried_values.contiguous()
# Ref: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
class LoFTREncoderLayer(nn.Module):
def __init__(self, d_model, nhead, attention="linear"):
super(LoFTREncoderLayer, self).__init__()
self.dim = d_model // nhead
self.nhead = nhead
# multi-head attention
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.attention = LinearAttention() if attention == "linear" else FullAttention()
self.merge = nn.Linear(d_model, d_model, bias=False)
# feed-forward network
self.mlp = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2, bias=False),
nn.ReLU(),
nn.Linear(d_model * 2, d_model, bias=False),
)
# norm and dropout
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, source, x_mask=None, source_mask=None):
"""
Args:
x (torch.Tensor): [N, L, C]
source (torch.Tensor): [N, S, C]
x_mask (torch.Tensor): [N, L] (optional)
source_mask (torch.Tensor): [N, S] (optional)
"""
bs = x.size(0)
query, key, value = x, source, source
# multi-head attention
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
message = self.attention(
query, key, value, q_mask=x_mask, kv_mask=source_mask
) # [N, L, (H, D)]
message = self.merge(message.view(bs, -1, self.nhead * self.dim)) # [N, L, C]
message = self.norm1(message)
# feed-forward network
message = self.mlp(torch.cat([x, message], dim=2))
message = self.norm2(message)
return x + message
class LocalFeatureTransformer(nn.Module):
"""A Local Feature Transformer (LoFTR) module."""
def __init__(self, d_model, nhead, layer_names, attention):
super(LocalFeatureTransformer, self).__init__()
self.d_model = d_model
self.nhead = nhead
self.layer_names = layer_names
encoder_layer = LoFTREncoderLayer(d_model, nhead, attention)
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]
)
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, feat0, feat1, mask0=None, mask1=None):
"""
Args:
feat0 (torch.Tensor): [N, L, C]
feat1 (torch.Tensor): [N, S, C]
mask0 (torch.Tensor): [N, L] (optional)
mask1 (torch.Tensor): [N, S] (optional)
"""
assert self.d_model == feat0.size(
2
), "the feature number of src and transformer must be equal"
for layer, name in zip(self.layers, self.layer_names):
if name == "self":
feat0 = layer(feat0, feat0, mask0, mask0)
feat1 = layer(feat1, feat1, mask1, mask1)
elif name == "cross":
feat0 = layer(feat0, feat1, mask0, mask1)
feat1 = layer(feat1, feat0, mask1, mask0)
else:
raise KeyError
return feat0, feat1
================================================
FILE: models/core/corr.py
================================================
import torch
import torch.nn.functional as F
from einops import rearrange
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
"""Wrapper for grid_sample, uses pixel coordinates"""
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1, 1], dim=-1)
xgrid = 2 * xgrid / (W - 1) - 1
if H > 1:
ygrid = 2 * ygrid/(H - 1) - 1
img = img.contiguous()
grid = torch.cat([xgrid, ygrid], dim=-1).contiguous()
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd, device):
coords = torch.meshgrid(
torch.arange(ht, device=device), torch.arange(wd, device=device), indexing="ij"
)
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
class AAPC:
"""
Implementation of All-in-All-Pair Correlation.
"""
def __init__(self, fmap1, fmap2, att=None):
self.fmap1 = fmap1
self.fmap2 = fmap2
self.att = att
self.coords = coords_grid(fmap1.shape[0], fmap1.shape[2], fmap1.shape[3], fmap1.device)
def __call__(self, flow, extra_offset, small_patch=False):
corr = self.correlation(self.fmap1, self.fmap2, flow, small_patch)
return corr
def correlation(self, left_feature, right_feature, flow, small_patch):
flow[:, 1:] = 0
coords = self.coords - flow
coords = coords.permute(0, 2, 3, 1)
right_feature = bilinear_sampler(right_feature, coords)
if small_patch:
psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
else:
psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
N, C, H, W = left_feature.size()
lefts = torch.split(left_feature, [C // 4] * 4, dim=1)
rights = torch.split(right_feature, [C // 4] * 4, dim=1)
corrs = []
for i in range(len(psize_list)):
corr = self.get_correlation(lefts[i], rights[i], psize_list[i], dilate_list[i])
corrs.append(corr)
final_corr = torch.cat(corrs, dim=1)
return final_corr
def get_correlation(self, left_feature, right_feature, psize=(3, 3), dilate=(1, 1)):
N, C, H, W = left_feature.size()
di_y, di_x = dilate[0], dilate[1]
pady, padx = psize[0] // 2 * di_y, psize[1] // 2 * di_x
left_pad = F.pad(left_feature, [padx, padx, pady, pady], mode='replicate')
right_pad = F.pad(right_feature, [padx, padx, pady, pady], mode='replicate')
corr_list = []
for dy1 in range(0, pady * 2 + 1, di_y):
for dx1 in range(0, padx * 2 + 1, di_x):
left_crop = left_pad[:, :, dy1:dy1 + H, dx1:dx1 + W]
for dy2 in range(0, pady * 2 + 1, di_y):
for dx2 in range(0, padx * 2 + 1, di_x):
right_crop = right_pad[:, :, dy2:dy2 + H, dx2:dx2 + W]
assert right_crop.size() == left_crop.size()
corr = (left_crop * right_crop).sum(dim=1, keepdim=True) # Sum over channels
corr_list.append(corr)
corr_final = torch.cat(corr_list, dim=1)
return corr_final
================================================
FILE: models/core/extractor.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import sys
import importlib
import timm
from einops import rearrange
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes, planes, kernel_size=3, padding=1, stride=stride
)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes, affine=False)
self.norm2 = nn.InstanceNorm2d(planes, affine=False)
self.norm3 = nn.InstanceNorm2d(planes, affine=False)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Module):
def __init__(self, input_dim=3, output_dim=128, norm_fn="batch", dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(64, affine=False)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=1)
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, x.shape[0] // 2, dim=0)
return x
class MultiBasicEncoder(nn.Module):
def __init__(self, output_dim=[128], norm_fn='batch', dropout=0.0, downsample=3):
super(MultiBasicEncoder, self).__init__()
self.norm_fn = norm_fn
self.downsample = downsample
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=1 + (downsample > 1))
self.layer3 = self._make_layer(128, stride=1 + (downsample > 0))
self.layer4 = self._make_layer(128, stride=2)
self.layer5 = self._make_layer(128, stride=2)
output_list = []
for dim in output_dim:
conv_out = nn.Sequential(
ResidualBlock(128, 128, self.norm_fn, stride=1),
nn.Conv2d(128, dim[2], 3, padding=1))
output_list.append(conv_out)
self.outputs08 = nn.ModuleList(output_list)
output_list = []
for dim in output_dim:
conv_out = nn.Sequential(
ResidualBlock(128, 128, self.norm_fn, stride=1),
nn.Conv2d(128, dim[1], 3, padding=1))
output_list.append(conv_out)
self.outputs16 = nn.ModuleList(output_list)
output_list = []
for dim in output_dim:
conv_out = nn.Conv2d(128, dim[0], 3, padding=1)
output_list.append(conv_out)
self.outputs32 = nn.ModuleList(output_list)
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
else:
self.dropout = None
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x, dual_inp=False, num_layers=3):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
if dual_inp:
v = x
x = x[:(x.shape[0]//2)]
outputs08 = [f(x) for f in self.outputs08]
if num_layers == 1:
return (outputs08, v) if dual_inp else (outputs08,)
y = self.layer4(x)
outputs16 = [f(y) for f in self.outputs16]
if num_layers == 2:
return (outputs08, outputs16, v) if dual_inp else (outputs08, outputs16)
z = self.layer5(y)
outputs32 = [f(z) for f in self.outputs32]
return (outputs08, outputs16, outputs32, v) if dual_inp else (outputs08, outputs16, outputs32)
class DepthExtractor(nn.Module):
def __init__(self):
super(DepthExtractor, self).__init__()
thirdparty_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "./models/Video-Depth-Anything"))
sys.path.append(thirdparty_path)
videodepthanything_ppl = importlib.import_module(
"stereoanyvideo.models.Video-Depth-Anything.video_depth_anything.video_depth"
)
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
}
encoder = 'vits' # or 'vitl',
self.depthanything = videodepthanything_ppl.VideoDepthAnything(**model_configs[encoder])
self.depthanything.load_state_dict(torch.load(f'./models/Video-Depth-Anything/checkpoints/video_depth_anything_{encoder}.pth'))
self.depthanything.eval()
self.conv = nn.Conv2d(32, 32, kernel_size=4, stride=4)
def forward(self, x):
# Store original height and width
B, T, C, orig_h, orig_w = x.shape
# Calculate new height and width divisible by 14
new_h = (orig_h // 14) * 14
new_w = (orig_w // 14) * 14
# Resize input to be divisible by 14 for depthanything
resized_input = F.interpolate(
x.flatten(0, 1),
size=(new_h, new_w),
mode='bilinear',
align_corners=False
).unflatten(0, (B, T))
# Pass through depthanything
depth_features_resized = self.depthanything(resized_input).contiguous()
# Resize depth features back to the original resolution
depth_features = F.interpolate(
depth_features_resized,
size=(orig_h, orig_w),
mode='bilinear',
align_corners=False
)
# Apply convolution to the depth features
depth_features = self.conv(depth_features).unflatten(0, (B, T))
return depth_features
================================================
FILE: models/core/model_zoo.py
================================================
import copy
from pytorch3d.implicitron.tools.config import get_default_args
from stereoanyvideo.models.stereoanyvideo_model import StereoAnyVideoModel
MODELS = [StereoAnyVideoModel]
_MODEL_NAME_TO_MODEL = {model_cls.__name__: model_cls for model_cls in MODELS}
_MODEL_CONFIG_NAME_TO_DEFAULT_CONFIG = {}
for model_cls in MODELS:
_MODEL_CONFIG_NAME_TO_DEFAULT_CONFIG[
model_cls.MODEL_CONFIG_NAME
] = get_default_args(model_cls)
MODEL_NAME_NONE = "NONE"
def model_zoo(model_name: str, **kwargs):
if model_name.upper() == MODEL_NAME_NONE:
return None
model_cls = _MODEL_NAME_TO_MODEL.get(model_name)
if model_cls is None:
raise ValueError(f"No such model name: {model_name}")
model_cls_params = {}
if "model_zoo" in getattr(model_cls, "__dataclass_fields__", []):
model_cls_params["model_zoo"] = model_zoo
print(
f"{model_cls.MODEL_CONFIG_NAME} model configs:",
kwargs.get(model_cls.MODEL_CONFIG_NAME),
)
return model_cls(**model_cls_params, **kwargs.get(model_cls.MODEL_CONFIG_NAME, {}))
def get_all_model_default_configs():
return copy.deepcopy(_MODEL_CONFIG_NAME_TO_DEFAULT_CONFIG)
================================================
FILE: models/core/stereoanyvideo.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List
from einops import rearrange
import collections
from collections import defaultdict
from itertools import repeat
import unfoldNd
from stereoanyvideo.models.core.update import SequenceUpdateBlock3D
from stereoanyvideo.models.core.extractor import BasicEncoder, MultiBasicEncoder, DepthExtractor
from stereoanyvideo.models.core.corr import AAPC
from stereoanyvideo.models.core.utils.utils import InputPadder, interp
autocast = torch.cuda.amp.autocast
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
to_2tuple = _ntuple(2)
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = (
norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
)
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class StereoAnyVideo(nn.Module):
def __init__(self, mixed_precision=False):
super(StereoAnyVideo, self).__init__()
self.mixed_precision = mixed_precision
self.hidden_dim = 128
self.context_dim = 128
self.dropout = 0
# feature network and update block
self.cnet = BasicEncoder(output_dim=96, norm_fn='instance', dropout=self.dropout)
self.fnet = BasicEncoder(output_dim=96, norm_fn='instance', dropout=self.dropout)
self.depthnet = DepthExtractor()
self.corr_mlp = Mlp(in_features=4 * 9 * 9, hidden_features=256, out_features=128)
self.update_block = SequenceUpdateBlock3D(hidden_dim=self.hidden_dim, cor_planes=128, mask_size=4)
@torch.jit.ignore
def no_weight_decay(self):
return {"time_embed"}
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def convex_upsample(self, flow, mask, rate=4):
""" Upsample flow field [H/rate, W/rate, 2] -> [H, W, 2] using convex combination """
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, rate, rate, H, W)
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(rate * flow, [3, 3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, rate * H, rate * W)
def convex_upsample_3D(self, flow, mask, b, T, rate=4):
"""Upsample flow field from [T, H/rate, W/rate, 2] to [T, H, W, 2] using convex combination.
unfoldNd repo: https://github.com/f-dangel/unfoldNd
Run: pip install --user unfoldNd
Args:
flow: (N*T, C_flow, H, W)
mask: (N*T, C_mask, H, W) or (N, 1, 27, 1, rate, rate, T, H, W)
rate: int
"""
flow = rearrange(flow, "(b t) c h w -> b c t h w", b=b, t=T)
mask = rearrange(mask, "(b t) c h w -> b c t h w", b=b, t=T)
N, _, T, H, W = flow.shape
mask = mask.view(N, 1, 27, 1, rate, rate, T, H, W) # (N, 1, 27, rate, rate, rate, T, H, W) if upsample T
mask = torch.softmax(mask, dim=2)
upsample = unfoldNd.UnfoldNd([3, 3, 3], padding=1)
flow_upsampled = upsample(rate * flow)
flow_upsampled = flow_upsampled.view(N, 2, 27, 1, 1, 1, T, H, W)
flow_upsampled = torch.sum(mask * flow_upsampled, dim=2)
flow_upsampled = flow_upsampled.permute(0, 1, 5, 2, 6, 3, 7, 4)
flow_upsampled = flow_upsampled.reshape(N, 2, T, rate * H,
rate * W) # (N, 2, rate*T, rate*H, rate*W) if upsample T
up_flow = rearrange(flow_upsampled, "b c t h w -> (b t) c h w")
return up_flow
def zero_init(self, fmap):
N, C, H, W = fmap.shape
flow_u = torch.zeros([N, 1, H, W], dtype=torch.float)
flow_v = torch.zeros([N, 1, H, W], dtype=torch.float)
flow = torch.cat([flow_u, flow_v], dim=1).to(fmap.device)
return flow
def forward_batch_test(
self, batch_dict, iters = 24, flow_init=None,
):
kernel_size = 20
stride = kernel_size // 2
predictions = defaultdict(list)
disp_preds = []
video = batch_dict["stereo_video"]
num_ims = len(video)
print("video", video.shape)
for i in range(0, num_ims, stride):
left_ims = video[i : min(i + kernel_size, num_ims), 0]
padder = InputPadder(left_ims.shape, divis_by=32)
right_ims = video[i : min(i + kernel_size, num_ims), 1]
left_ims, right_ims = padder.pad(left_ims, right_ims)
if flow_init is not None:
flow_init_ims = flow_init[i: min(i + kernel_size, num_ims)]
flow_init_ims = padder.pad(flow_init_ims)[0]
with autocast(enabled=self.mixed_precision):
disparities_forw = self.forward(
left_ims[None].cuda(),
right_ims[None].cuda(),
flow_init=flow_init_ims,
iters=iters,
test_mode=True,
)
else:
with autocast(enabled=self.mixed_precision):
disparities_forw = self.forward(
left_ims[None].cuda(),
right_ims[None].cuda(),
iters=iters,
test_mode=True,
)
disparities_forw = padder.unpad(disparities_forw[:, 0])[:, None].cpu()
if len(disp_preds) > 0 and len(disparities_forw) >= stride:
if len(disparities_forw) < kernel_size:
disp_preds.append(disparities_forw[stride // 2 :])
else:
disp_preds.append(disparities_forw[stride // 2 : -stride // 2])
elif len(disp_preds) == 0:
disp_preds.append(disparities_forw[: -stride // 2])
predictions["disparity"] = (torch.cat(disp_preds).squeeze(1).abs())[:, :1]
return predictions
def forward(self, image1, image2, flow_init=None, iters=12, test_mode=False):
b, T, c, h, w = image1.shape
image1 = image1 / 255.0
image2 = image2 / 255.0
# Normalize using mean and std for ImageNet pre-trained models
mean = torch.tensor([0.485, 0.456, 0.406], device=image1.device).view(1, 1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225], device=image1.device).view(1, 1, 3, 1, 1)
image1 = (image1 - mean) / std
image2 = (image2 - mean) / std
image1 = image1.float()
image2 = image2.float()
# feature network
with autocast(enabled=self.mixed_precision):
fmap1_depth_feature = self.depthnet(image1)
fmap2_depth_feature = self.depthnet(image2)
fmap1_cnet_feature = self.cnet(image1.flatten(0, 1)).unflatten(0, (b, T))
fmap1_fnet_feature = self.fnet(image1.flatten(0, 1)).unflatten(0, (b, T))
fmap2_fnet_feature = self.fnet(image2.flatten(0, 1)).unflatten(0, (b, T))
fmap1 = torch.cat((fmap1_depth_feature, fmap1_fnet_feature), dim=2).flatten(0, 1)
fmap2 = torch.cat((fmap2_depth_feature, fmap2_fnet_feature), dim=2).flatten(0, 1)
context = torch.cat((fmap1_depth_feature, fmap1_cnet_feature), dim=2).flatten(0, 1)
with autocast(enabled=self.mixed_precision):
net = torch.tanh(context)
inp = torch.relu(context)
s_net = F.avg_pool2d(net, 2, stride=2)
s_inp = F.avg_pool2d(inp, 2, stride=2)
# 1/4 -> 1/8
# feature
s_fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
s_fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
# 1/4 -> 1/16
# feature
ss_fmap1 = F.avg_pool2d(fmap1, 4, stride=4)
ss_fmap2 = F.avg_pool2d(fmap2, 4, stride=4)
ss_net = F.avg_pool2d(net, 4, stride=4)
ss_inp = F.avg_pool2d(inp, 4, stride=4)
# Correlation
corr_fn = AAPC(fmap1, fmap2)
s_corr_fn = AAPC(s_fmap1, s_fmap2)
ss_corr_fn = AAPC(ss_fmap1, ss_fmap2)
# cascaded refinement (1/16 + 1/8 + 1/4)
flow_predictions = []
flow = None
flow_up = None
if flow_init is not None:
flow_init = flow_init.cuda()
scale = fmap1.shape[2] / flow_init.shape[2]
flow = scale * interp(flow_init, size=(fmap1.shape[2], fmap1.shape[3]))
else:
# init flow
ss_flow = self.zero_init(ss_fmap1)
# 1/16
for itr in range(iters // 2):
if itr % 2 == 0:
small_patch = False
else:
small_patch = True
ss_flow = ss_flow.detach()
out_corrs = ss_corr_fn(ss_flow, None, small_patch=small_patch) # 36 * H/16 * W/16
out_corrs = self.corr_mlp(out_corrs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
with autocast(enabled=self.mixed_precision):
ss_net, up_mask, delta_flow = self.update_block(ss_net, ss_inp, out_corrs, ss_flow, t=T)
ss_flow = ss_flow + delta_flow
flow = self.convex_upsample_3D(ss_flow, up_mask, b, T, rate=4) # 2 * H/4 * W/4
flow_up = 4 * F.interpolate(flow, size=(4 * flow.shape[2], 4 * flow.shape[3]), mode='bilinear',
align_corners=True) # 2 * H/2 * W/2
flow_predictions.append(flow_up[:, :1])
scale = s_fmap1.shape[2] / flow.shape[2]
s_flow = scale * interp(flow, size=(s_fmap1.shape[2], s_fmap1.shape[3]))
# 1/8
for itr in range(iters // 2):
if itr % 2 == 0:
small_patch = False
else:
small_patch = True
s_flow = s_flow.detach()
out_corrs = s_corr_fn(s_flow, None, small_patch=small_patch)
out_corrs = self.corr_mlp(out_corrs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
with autocast(enabled=self.mixed_precision):
s_net, up_mask, delta_flow = self.update_block(s_net, s_inp, out_corrs, s_flow, t=T)
s_flow = s_flow + delta_flow
flow = self.convex_upsample_3D(s_flow, up_mask, b, T, rate=4)
flow_up = 2 * F.interpolate(flow, size=(2 * flow.shape[2], 2 * flow.shape[3]), mode='bilinear',
align_corners=True)
flow_predictions.append(flow_up[:, :1])
scale = fmap1.shape[2] / flow.shape[2]
flow = scale * interp(flow, size=(fmap1.shape[2], fmap1.shape[3]))
# 1/4
for itr in range(iters):
if itr % 2 == 0:
small_patch = False
else:
small_patch = True
flow = flow.detach()
out_corrs = corr_fn(flow, None, small_patch=small_patch)
out_corrs = self.corr_mlp(out_corrs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
with autocast(enabled=self.mixed_precision):
net, up_mask, delta_flow = self.update_block(net, inp, out_corrs, flow, t=T)
flow = flow + delta_flow
flow_up = self.convex_upsample_3D(flow, up_mask, b, T, rate=4)
flow_predictions.append(flow_up[:, :1])
predictions = torch.stack(flow_predictions)
predictions = rearrange(predictions, "d (b t) c h w -> d t b c h w", b=b, t=T)
flow_up = predictions[-1]
if test_mode:
return flow_up
return predictions
================================================
FILE: models/core/update.py
================================================
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from opt_einsum import contract
from stereoanyvideo.models.core.attention import LoFTREncoderLayer
def pool2x(x):
return F.avg_pool2d(x, 3, stride=2, padding=1)
def pool4x(x):
return F.avg_pool2d(x, 5, stride=4, padding=1)
def interp(x, dest):
interp_args = {'mode': 'bilinear', 'align_corners': True}
return F.interpolate(x, dest.shape[2:], **interp_args)
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256, output_dim=2):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class FlowHead3D(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead3D, self).__init__()
self.conv1 = nn.Conv3d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv3d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim, input_dim, kernel_size=3):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
def forward(self, h, cz, cr, cq, *x_list):
x = torch.cat(x_list, dim=1)
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx) + cz)
r = torch.sigmoid(self.convr(hx) + cr)
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)) + cq)
h = (1-z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
def forward(self, h, *x):
# horizontal
x = torch.cat(x, dim=1)
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class BasicMotionEncoder(nn.Module):
def __init__(self, cor_planes):
super(BasicMotionEncoder, self).__init__()
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicMotionEncoder3D(nn.Module):
def __init__(self, cor_planes):
super(BasicMotionEncoder3D, self).__init__()
self.convc1 = nn.Conv3d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv3d(256, 192, 3, padding=1)
self.convf1 = nn.Conv3d(2, 128, 5, padding=2)
self.convf2 = nn.Conv3d(128, 64, 3, padding=1)
self.conv = nn.Conv3d(64 + 192, 128 - 2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class SepConvGRU3D(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(SepConvGRU3D, self).__init__()
self.convz1 = nn.Conv3d(
hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)
)
self.convr1 = nn.Conv3d(
hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)
)
self.convq1 = nn.Conv3d(
hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)
)
self.convz2 = nn.Conv3d(
hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)
)
self.convr2 = nn.Conv3d(
hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)
)
self.convq2 = nn.Conv3d(
hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)
)
self.convz3 = nn.Conv3d(
hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)
)
self.convr3 = nn.Conv3d(
hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)
)
self.convq3 = nn.Conv3d(
hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)
)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
# time
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz3(hx))
r = torch.sigmoid(self.convr3(hx))
q = torch.tanh(self.convq3(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class SKSepConvGRU3D(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(SKSepConvGRU3D, self).__init__()
self.convz1 = nn.Sequential(
nn.Conv3d(hidden_dim+input_dim, hidden_dim, (1, 1, 15), padding=(0, 0, 7)),
nn.GELU(),
nn.Conv3d(hidden_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)),
)
self.convr1 = nn.Sequential(
nn.Conv3d(hidden_dim+input_dim, hidden_dim, (1, 1, 15), padding=(0, 0, 7)),
nn.GELU(),
nn.Conv3d(hidden_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)),
)
self.convq1 = nn.Conv3d(
hidden_dim + input_dim, hidden_dim, (1, 1, 5), padding=(0, 0, 2)
)
self.convz2 = nn.Conv3d(
hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)
)
self.convr2 = nn.Conv3d(
hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)
)
self.convq2 = nn.Conv3d(
hidden_dim + input_dim, hidden_dim, (1, 5, 1), padding=(0, 2, 0)
)
self.convz3 = nn.Conv3d(
hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)
)
self.convr3 = nn.Conv3d(
hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)
)
self.convq3 = nn.Conv3d(
hidden_dim + input_dim, hidden_dim, (5, 1, 1), padding=(2, 0, 0)
)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
# time
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz3(hx))
r = torch.sigmoid(self.convr3(hx))
q = torch.tanh(self.convq3(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class BasicUpdateBlock(nn.Module):
def __init__(self, hidden_dim, cor_planes, mask_size=8, attention_type=None):
super(BasicUpdateBlock, self).__init__()
self.attention_type = attention_type
if attention_type is not None:
if "update_time" in attention_type:
self.time_attn = TimeAttnBlock(dim=256, num_heads=8)
if "update_space" in attention_type:
self.space_attn = SpaceAttnBlock(dim=256, num_heads=8)
self.encoder = BasicMotionEncoder(cor_planes)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, mask_size ** 2 * 9, 1, padding=0),
)
def forward(self, net, inp, corr, flow, upsample=True, t=1):
motion_features = self.encoder(flow, corr)
inp = torch.cat((inp, motion_features), dim=1)
if self.attention_type is not None:
if "update_time" in self.attention_type:
inp = self.time_attn(inp, T=t)
if "update_space" in self.attention_type:
inp = self.space_attn(inp, T=t)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = 0.25 * self.mask(net)
return net, mask, delta_flow
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
B, N, C = x.shape
qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q, k, v = qkv, qkv, qkv
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C).contiguous()
x = self.proj(x)
return x
class TimeAttnBlock(nn.Module):
def __init__(self, dim=256, num_heads=8):
super(TimeAttnBlock, self).__init__()
self.temporal_attn = Attention(dim, num_heads=8, qkv_bias=False, qk_scale=None)
self.temporal_fc = nn.Linear(dim, dim)
self.temporal_norm1 = nn.LayerNorm(dim)
nn.init.constant_(self.temporal_fc.weight, 0)
nn.init.constant_(self.temporal_fc.bias, 0)
def forward(self, x, T=1):
_, _, h, w = x.shape
x = rearrange(x, "(b t) m h w -> (b h w) t m", h=h, w=w, t=T)
res_temporal1 = self.temporal_attn(self.temporal_norm1(x))
res_temporal1 = rearrange(
res_temporal1, "(b h w) t m -> b (h w t) m", h=h, w=w, t=T
)
res_temporal1 = self.temporal_fc(res_temporal1)
res_temporal1 = rearrange(
res_temporal1, " b (h w t) m -> b t m h w", h=h, w=w, t=T
)
x = rearrange(x, "(b h w) t m -> b t m h w", h=h, w=w, t=T)
x = x + res_temporal1
x = rearrange(x, "b t m h w -> (b t) m h w", h=h, w=w, t=T)
return x
class SpaceAttnBlock(nn.Module):
def __init__(self, dim=256, num_heads=8):
super(SpaceAttnBlock, self).__init__()
self.encoder_layer = LoFTREncoderLayer(dim, nhead=num_heads, attention="linear")
def forward(self, x, T=1):
_, _, h, w = x.shape
x = rearrange(x, "(b t) m h w -> (b t) (h w) m", h=h, w=w, t=T)
x = self.encoder_layer(x, x)
x = rearrange(x, "(b t) (h w) m -> (b t) m h w", h=h, w=w, t=T)
return x
class SequenceUpdateBlock3D(nn.Module):
def __init__(self, hidden_dim, cor_planes, mask_size=8):
super(SequenceUpdateBlock3D, self).__init__()
self.encoder = BasicMotionEncoder(cor_planes)
self.gru = SKSepConvGRU3D(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
self.flow_head = FlowHead3D(hidden_dim, hidden_dim=256)
self.mask3d = nn.Sequential(
nn.Conv3d(hidden_dim, hidden_dim + 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv3d(hidden_dim + 128, (mask_size ** 2) * (3 * 3 * 3), 1, padding=0),
)
self.time_attn = TimeAttnBlock(dim=256, num_heads=8)
self.space_attn = SpaceAttnBlock(dim=256, num_heads=8)
def forward(self, net, inp, corrs, flows, t):
motion_features = self.encoder(flows, corrs)
inp_tensor = torch.cat([inp, motion_features], dim=1)
inp_tensor = self.time_attn(inp_tensor, T=t)
inp_tensor = self.space_attn(inp_tensor, T=t)
net = rearrange(net, "(b t) c h w -> b c t h w", t=t)
inp_tensor = rearrange(inp_tensor, "(b t) c h w -> b c t h w", t=t)
net = self.gru(net, inp_tensor)
delta_flow = self.flow_head(net)
# scale mask to balance gradients
mask = 0.25 * self.mask3d(net)
net = rearrange(net, " b c t h w -> (b t) c h w")
mask =rearrange(mask, " b c t h w -> (b t) c h w")
delta_flow = rearrange(delta_flow, " b c t h w -> (b t) c h w")
return net, mask, delta_flow
================================================
FILE: models/core/utils/config.py
================================================
import dataclasses
import inspect
import itertools
import sys
import warnings
from collections import Counter, defaultdict
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch3d.common.datatypes import get_args, get_origin
"""
This functionality allows a configurable system to be determined in a dataclass-type
way. It is a generalization of omegaconf's "structured", in the dataclass case.
Core functionality:
- Configurable -- A base class used to label a class as being one which uses this
system. Uses class members and __post_init__ like a dataclass.
- expand_args_fields -- Expands a class like `dataclasses.dataclass`. Runs automatically.
- get_default_args -- gets an omegaconf.DictConfig for initializing a given class.
- run_auto_creation -- Initialises nested members. To be called in __post_init__.
In addition, a Configurable may contain members whose type is decided at runtime.
- ReplaceableBase -- As a base instead of Configurable, labels a class to say that
any child class can be used instead.
- registry -- A global store of named child classes of ReplaceableBase classes.
Used as `@registry.register` decorator on class definition.
Additional utility functions:
- remove_unused_components -- used for simplifying a DictConfig instance.
- get_default_args_field -- default for DictConfig member of another configurable.
- enable_get_default_args -- Allows get_default_args on a function or plain class.
1. The simplest usage of this functionality is as follows. First a schema is defined
in dataclass style.
class A(Configurable):
n: int = 9
class B(Configurable):
a: A
def __post_init__(self):
run_auto_creation(self)
Then it can be used like
b_args = get_default_args(B)
b = B(**b_args)
In this case, get_default_args(B) returns an omegaconf.DictConfig with the right
members {"a_args": {"n": 9}}. It also modifies the definitions of the classes to
something like the following. (The modification itself is done by the function
`expand_args_fields`, which is called inside `get_default_args`.)
@dataclasses.dataclass
class A:
n: int = 9
@dataclasses.dataclass
class B:
a_args: DictConfig = dataclasses.field(default_factory=lambda: DictConfig({"n": 9}))
def __post_init__(self):
self.a = A(**self.a_args)
2. Pluggability. Instead of a dataclass-style member being given a concrete class,
it can be given a base class and the implementation will be looked up by name in the
global `registry` in this module. E.g.
class A(ReplaceableBase):
k: int = 1
@registry.register
class A1(A):
m: int = 3
@registry.register
class A2(A):
n: str = "2"
class B(Configurable):
a: A
a_class_type: str = "A2"
b: Optional[A]
b_class_type: Optional[str] = "A2"
def __post_init__(self):
run_auto_creation(self)
will expand to
@dataclasses.dataclass
class A:
k: int = 1
@dataclasses.dataclass
class A1(A):
m: int = 3
@dataclasses.dataclass
class A2(A):
n: str = "2"
@dataclasses.dataclass
class B:
a_class_type: str = "A2"
a_A1_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3}
)
a_A2_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "n": 2}
)
b_class_type: Optional[str] = "A2"
b_A1_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3}
)
b_A2_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "n": 2}
)
def __post_init__(self):
if self.a_class_type == "A1":
self.a = A1(**self.a_A1_args)
elif self.a_class_type == "A2":
self.a = A2(**self.a_A2_args)
else:
raise ValueError(...)
if self.b_class_type is None:
self.b = None
elif self.b_class_type == "A1":
self.b = A1(**self.b_A1_args)
elif self.b_class_type == "A2":
self.b = A2(**self.b_A2_args)
else:
raise ValueError(...)
3. Aside from these classes, the members of these classes should be things
which DictConfig is happy with: e.g. (bool, int, str, None, float) and what
can be built from them with `DictConfig`s and lists of them.
In addition, you can call `get_default_args` on a function or class to get
the `DictConfig` of its defaulted arguments, assuming those are all things
which `DictConfig` is happy with, so long as you add a call to
`enable_get_default_args` after its definition. If you want to use such a
thing as the default for a member of another configured class,
`get_default_args_field` is a helper.
"""
_unprocessed_warning: str = (
" must be processed before it can be used."
+ " This is done by calling expand_args_fields "
+ "or get_default_args on it."
)
TYPE_SUFFIX: str = "_class_type"
ARGS_SUFFIX: str = "_args"
ENABLED_SUFFIX: str = "_enabled"
class ReplaceableBase:
"""
Base class for dataclass-style classes which
can be stored in the registry.
"""
def __new__(cls, *args, **kwargs):
"""
This function only exists to raise a
warning if class construction is attempted
without processing.
"""
obj = super().__new__(cls)
if cls is not ReplaceableBase and not _is_actually_dataclass(cls):
warnings.warn(cls.__name__ + _unprocessed_warning)
return obj
class Configurable:
"""
This indicates a class which is not ReplaceableBase
but still needs to be
expanded into a dataclass with expand_args_fields.
This expansion is delayed.
"""
def __new__(cls, *args, **kwargs):
"""
This function only exists to raise a
warning if class construction is attempted
without processing.
"""
obj = super().__new__(cls)
if cls is not Configurable and not _is_actually_dataclass(cls):
warnings.warn(cls.__name__ + _unprocessed_warning)
return obj
_X = TypeVar("X", bound=ReplaceableBase)
class _Registry:
"""
Register from names to classes. In particular, we say that direct subclasses of
ReplaceableBase are "base classes" and we register subclasses of each base class
in a separate namespace.
"""
def __init__(self) -> None:
self._mapping: Dict[
Type[ReplaceableBase], Dict[str, Type[ReplaceableBase]]
] = defaultdict(dict)
def register(self, some_class: Type[_X]) -> Type[_X]:
"""
A class decorator, to register a class in self.
"""
name = some_class.__name__
self._register(some_class, name=name)
return some_class
def _register(
self,
some_class: Type[ReplaceableBase],
*,
base_class: Optional[Type[ReplaceableBase]] = None,
name: str,
) -> None:
"""
Register a new member.
Args:
cls: the new member
base_class: (optional) what the new member is a type for
name: name for the new member
"""
if base_class is None:
base_class = self._base_class_from_class(some_class)
if base_class is None:
raise ValueError(
f"Cannot register {some_class}. Cannot tell what it is."
)
if some_class is base_class:
raise ValueError(f"Attempted to register the base class {some_class}")
self._mapping[base_class][name] = some_class
def get(
self, base_class_wanted: Type[ReplaceableBase], name: str
) -> Type[ReplaceableBase]:
"""
Retrieve a class from the registry by name
Args:
base_class_wanted: parent type of type we are looking for.
It determines the namespace.
This will typically be a direct subclass of ReplaceableBase.
name: what to look for
Returns:
class type
"""
if self._is_base_class(base_class_wanted):
base_class = base_class_wanted
else:
base_class = self._base_class_from_class(base_class_wanted)
if base_class is None:
raise ValueError(
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
)
result = self._mapping[base_class].get(name)
if result is None:
raise ValueError(f"{name} has not been registered.")
if not issubclass(result, base_class_wanted):
raise ValueError(
f"{name} resolves to {result} which does not subclass {base_class_wanted}"
)
return result
def get_all(
self, base_class_wanted: Type[ReplaceableBase]
) -> List[Type[ReplaceableBase]]:
"""
Retrieve all registered implementations from the registry
Args:
base_class_wanted: parent type of type we are looking for.
It determines the namespace.
This will typically be a direct subclass of ReplaceableBase.
Returns:
list of class types
"""
if self._is_base_class(base_class_wanted):
return list(self._mapping[base_class_wanted].values())
base_class = self._base_class_from_class(base_class_wanted)
if base_class is None:
raise ValueError(
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
)
return [
class_
for class_ in self._mapping[base_class].values()
if issubclass(class_, base_class_wanted) and class_ is not base_class_wanted
]
@staticmethod
def _is_base_class(some_class: Type[ReplaceableBase]) -> bool:
"""
Return whether the given type is a direct subclass of ReplaceableBase
and so gets used as a namespace.
"""
return ReplaceableBase in some_class.__bases__
@staticmethod
def _base_class_from_class(
some_class: Type[ReplaceableBase],
) -> Optional[Type[ReplaceableBase]]:
"""
Find the parent class of some_class which inherits ReplaceableBase, or None
"""
for base in some_class.mro()[-3::-1]:
if base is not ReplaceableBase and issubclass(base, ReplaceableBase):
return base
return None
# Global instance of the registry
registry = _Registry()
class _ProcessType(Enum):
"""
Type of member which gets rewritten by expand_args_fields.
"""
CONFIGURABLE = 1
REPLACEABLE = 2
OPTIONAL_CONFIGURABLE = 3
OPTIONAL_REPLACEABLE = 4
def _default_create(
name: str, type_: Type, process_type: _ProcessType
) -> Callable[[Any], None]:
"""
Return the default creation function for a member. This is a function which
could be called in __post_init__ to initialise the member, and will be called
from run_auto_creation.
Args:
name: name of the member
type_: type of the member (with any Optional removed)
process_type: Shows whether member's declared type inherits ReplaceableBase,
in which case the actual type to be created is decided at
runtime.
Returns:
Function taking one argument, the object whose member should be
initialized.
"""
def inner(self):
expand_args_fields(type_)
args = getattr(self, name + ARGS_SUFFIX)
setattr(self, name, type_(**args))
def inner_optional(self):
expand_args_fields(type_)
enabled = getattr(self, name + ENABLED_SUFFIX)
if enabled:
args = getattr(self, name + ARGS_SUFFIX)
setattr(self, name, type_(**args))
else:
setattr(self, name, None)
def inner_pluggable(self):
type_name = getattr(self, name + TYPE_SUFFIX)
if type_name is None:
setattr(self, name, None)
return
chosen_class = registry.get(type_, type_name)
if self._known_implementations.get(type_name, chosen_class) is not chosen_class:
# If this warning is raised, it means that a new definition of
# the chosen class has been registered since our class was processed
# (i.e. expanded). A DictConfig which comes from our get_default_args
# (which might have triggered the processing) will contain the old default
# values for the members of the chosen class. Changes to those defaults which
# were made in the redefinition will not be reflected here.
warnings.warn(f"New implementation of {type_name} is being chosen.")
expand_args_fields(chosen_class)
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
setattr(self, name, chosen_class(**args))
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
return inner_optional
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
def run_auto_creation(self: Any) -> None:
"""
Run all the functions named in self._creation_functions.
"""
for create_function in self._creation_functions:
getattr(self, create_function)()
def _is_configurable_class(C) -> bool:
return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase))
def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig:
"""
Get the DictConfig corresponding to the defaults in a dataclass or
configurable. Normal use is to provide a dataclass can be provided as C.
If enable_get_default_args has been called on a function or plain class,
then that function or class can be provided as C.
If C is a subclass of Configurable or ReplaceableBase, we make sure
it has been processed with expand_args_fields.
Args:
C: the class or function to be processed
_do_not_process: (internal use) When this function is called from
expand_args_fields, we specify any class currently being
processed, to make sure we don't try to process a class
while it is already being processed.
Returns:
new DictConfig object, which is typed.
"""
if C is None:
return DictConfig({})
if _is_configurable_class(C):
if C in _do_not_process:
raise ValueError(
f"Internal recursion error. Need processed {C},"
f" but cannot get it. _do_not_process={_do_not_process}"
)
# This is safe to run multiple times. It will return
# straight away if C has already been processed.
expand_args_fields(C, _do_not_process=_do_not_process)
if dataclasses.is_dataclass(C):
# Note that if get_default_args_field is used somewhere in C,
# this call is recursive. No special care is needed,
# because in practice get_default_args_field is used for
# separate types than the outer type.
out: DictConfig = OmegaConf.structured(C)
exclude = getattr(C, "_processed_members", ())
with open_dict(out):
for field in exclude:
out.pop(field, None)
return out
if _is_configurable_class(C):
raise ValueError(f"Failed to process {C}")
if not inspect.isfunction(C) and not inspect.isclass(C):
raise ValueError(f"Unexpected {C}")
dataclass_name = _dataclass_name_for_function(C)
dataclass = getattr(sys.modules[C.__module__], dataclass_name, None)
if dataclass is None:
raise ValueError(
f"Cannot get args for {C}. Was enable_get_default_args forgotten?"
)
return OmegaConf.structured(dataclass)
def _dataclass_name_for_function(C: Any) -> str:
"""
Returns the name of the dataclass which enable_get_default_args(C)
creates.
"""
name = f"_{C.__name__}_default_args_"
return name
def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
"""
If C is a function or a plain class with an __init__ function,
and you want get_default_args(C) to work, then add
`enable_get_default_args(C)` straight after the definition of C.
This makes a dataclass corresponding to the default arguments of C
and stores it in the same module as C.
Args:
C: a function, or a class with an __init__ function. Must
have types for all its defaulted args.
overwrite: whether to allow calling this a second time on
the same function.
"""
if not inspect.isfunction(C) and not inspect.isclass(C):
raise ValueError(f"Unexpected {C}")
field_annotations = []
for pname, defval in _params_iter(C):
default = defval.default
if default == inspect.Parameter.empty:
# we do not have a default value for the parameter
continue
if defval.annotation == inspect._empty:
raise ValueError(
"All arguments of the input callable have to be typed."
+ f" Argument '{pname}' does not have a type annotation."
)
_, annotation = _resolve_optional(defval.annotation)
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
default = tuple(default)
if isinstance(default, (list, dict)):
# OmegaConf will convert to [Dict|List]Config, so it is safe to reuse the value
field_ = dataclasses.field(default_factory=lambda default=default: default)
elif not _is_immutable_type(annotation, default):
continue
else:
# we can use a simple default argument for dataclass.field
field_ = dataclasses.field(default=default)
field_annotations.append((pname, defval.annotation, field_))
name = _dataclass_name_for_function(C)
module = sys.modules[C.__module__]
if hasattr(module, name):
if overwrite:
warnings.warn(f"Overwriting {name} in {C.__module__}.")
else:
raise ValueError(f"Cannot overwrite {name} in {C.__module__}.")
dc = dataclasses.make_dataclass(name, field_annotations)
dc.__module__ = C.__module__
setattr(module, name, dc)
def _params_iter(C):
"""Returns dict of keyword args of a class or function C."""
if inspect.isclass(C):
return itertools.islice( # exclude `self`
inspect.signature(C.__init__).parameters.items(), 1, None
)
return inspect.signature(C).parameters.items()
def _is_immutable_type(type_: Type, val: Any) -> bool:
PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple)
# sometimes type can be too relaxed (e.g. Any), so we also check values
if isinstance(val, PRIMITIVE_TYPES):
return True
return type_ in PRIMITIVE_TYPES or (
inspect.isclass(type_) and issubclass(type_, Enum)
)
# copied from OmegaConf
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
if get_origin(type_) is Union:
args = get_args(type_)
if len(args) == 2 and args[1] == type(None): # noqa E721
return True, args[0]
if type_ is Any:
return True, Any
return False, type_
def _is_actually_dataclass(some_class) -> bool:
# Return whether the class some_class has been processed with
# the dataclass annotation. This is more specific than
# dataclasses.is_dataclass which returns True on anything
# deriving from a dataclass.
# Checking for __init__ would also work for our purpose.
return "__dataclass_fields__" in some_class.__dict__
def expand_args_fields(
some_class: Type[_X], *, _do_not_process: Tuple[type, ...] = ()
) -> Type[_X]:
"""
This expands a class which inherits Configurable or ReplaceableBase classes,
including dataclass processing. some_class is modified in place by this function.
For classes of type ReplaceableBase, you can add some_class to the registry before
or after calling this function. But potential inner classes need to be registered
before this function is run on the outer class.
The transformations this function makes, before the concluding
dataclasses.dataclass, are as follows. if X is a base class with registered
subclasses Y and Z, replace a class member
x: X
and optionally
x_class_type: str = "Y"
def create_x(self):...
with
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
def create_x(self):
self.x = registry.get(X, self.x_class_type)(
**self.getattr(f"x_{self.x_class_type}_args)
)
x_class_type: str = "UNDEFAULTED"
without adding the optional attributes if they are already there.
Similarly, replace
x: Optional[X]
and optionally
x_class_type: Optional[str] = "Y"
def create_x(self):...
with
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
def create_x(self):
if self.x_class_type is None:
self.x = None
return
self.x = registry.get(X, self.x_class_type)(
**self.getattr(f"x_{self.x_class_type}_args)
)
x_class_type: Optional[str] = "UNDEFAULTED"
without adding the optional attributes if they are already there.
Similarly, if X is a subclass of Configurable,
x: X
and optionally
def create_x(self):...
will be replaced with
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
def create_x(self):
self.x = X(self.x_args)
Similarly, replace,
x: Optional[X]
and optionally
def create_x(self):...
x_enabled: bool = ...
with
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
x_enabled: bool = False
def create_x(self):
if self.x_enabled:
self.x = X(self.x_args)
else:
self.x = None
Also adds the following class members, unannotated so that dataclass
ignores them.
- _creation_functions: Tuple[str] of all the create_ functions,
including those from base classes.
- _known_implementations: Dict[str, Type] containing the classes which
have been found from the registry.
(used only to raise a warning if it one has been overwritten)
- _processed_members: a Dict[str, Any] of all the members which have been
transformed, with values giving the types they were declared to have.
(E.g. {"x": X} or {"x": Optional[X]} in the cases above.)
Args:
some_class: the class to be processed
_do_not_process: Internal use for get_default_args: Because get_default_args calls
and is called by this function, we let it specify any class currently
being processed, to make sure we don't try to process a class while
it is already being processed.
Returns:
some_class itself, which has been modified in place. This
allows this function to be used as a class decorator.
"""
if _is_actually_dataclass(some_class):
return some_class
# The functions this class's run_auto_creation will run.
creation_functions: List[str] = []
# The classes which this type knows about from the registry
# We could use a weakref.WeakValueDictionary here which would mean
# that we don't warn if the class we should have expected is elsewhere
# unused.
known_implementations: Dict[str, Type] = {}
# Names of members which have been processed.
processed_members: Dict[str, Any] = {}
# For all bases except ReplaceableBase and Configurable and object,
# we need to process them before our own processing. This is
# because dataclasses expect to inherit dataclasses and not unprocessed
# dataclasses.
for base in some_class.mro()[-3:0:-1]:
if base is ReplaceableBase:
continue
if base is Configurable:
continue
if not issubclass(base, (Configurable, ReplaceableBase)):
continue
expand_args_fields(base, _do_not_process=_do_not_process)
if "_creation_functions" in base.__dict__:
creation_functions.extend(base._creation_functions)
if "_known_implementations" in base.__dict__:
known_implementations.update(base._known_implementations)
if "_processed_members" in base.__dict__:
processed_members.update(base._processed_members)
to_process: List[Tuple[str, Type, _ProcessType]] = []
if "__annotations__" in some_class.__dict__:
for name, type_ in some_class.__annotations__.items():
underlying_and_process_type = _get_type_to_process(type_)
if underlying_and_process_type is None:
continue
underlying_type, process_type = underlying_and_process_type
to_process.append((name, underlying_type, process_type))
for name, underlying_type, process_type in to_process:
processed_members[name] = some_class.__annotations__[name]
_process_member(
name=name,
type_=underlying_type,
process_type=process_type,
some_class=some_class,
creation_functions=creation_functions,
_do_not_process=_do_not_process,
known_implementations=known_implementations,
)
for key, count in Counter(creation_functions).items():
if count > 1:
warnings.warn(f"Clash with {key} in a base class.")
some_class._creation_functions = tuple(creation_functions)
some_class._processed_members = processed_members
some_class._known_implementations = known_implementations
dataclasses.dataclass(eq=False)(some_class)
return some_class
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
"""
Get a dataclass field which defaults to get_default_args(...)
Args:
As for get_default_args.
Returns:
function to return new DictConfig object
"""
def create():
return get_default_args(C, _do_not_process=_do_not_process)
return dataclasses.field(default_factory=create)
def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
"""
If a member is annotated as `type_`, and that should expanded in
expand_args_fields, return how it should be expanded.
"""
if get_origin(type_) == Union:
# We look for Optional[X] which is a Union of X with None.
args = get_args(type_)
if len(args) != 2 or all(a is not type(None) for a in args): # noqa: E721
return
underlying = args[0] if args[1] is type(None) else args[1] # noqa: E721
if (
isinstance(underlying, type)
and issubclass(underlying, ReplaceableBase)
and ReplaceableBase in underlying.__bases__
):
return underlying, _ProcessType.OPTIONAL_REPLACEABLE
if isinstance(underlying, type) and issubclass(underlying, Configurable):
return underlying, _ProcessType.OPTIONAL_CONFIGURABLE
if not isinstance(type_, type):
# e.g. any other Union or Tuple
return
if issubclass(type_, ReplaceableBase) and ReplaceableBase in type_.__bases__:
return type_, _ProcessType.REPLACEABLE
if issubclass(type_, Configurable):
return type_, _ProcessType.CONFIGURABLE
def _process_member(
*,
name: str,
type_: Type,
process_type: _ProcessType,
some_class: Type,
creation_functions: List[str],
_do_not_process: Tuple[type, ...],
known_implementations: Dict[str, Type],
) -> None:
"""
Make the modification (of expand_args_fields) to some_class for a single member.
Args:
name: member name
type_: member type (with Optional removed if needed)
process_type: whether member has dynamic type
some_class: (MODIFIED IN PLACE) the class being processed
creation_functions: (MODIFIED IN PLACE) the names of the create functions
_do_not_process: as for expand_args_fields.
known_implementations: (MODIFIED IN PLACE) known types from the registry
"""
# Because we are adding defaultable members, make
# sure they go at the end of __annotations__ in case
# there are non-defaulted standard class members.
del some_class.__annotations__[name]
if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE):
type_name = name + TYPE_SUFFIX
if type_name not in some_class.__annotations__:
if process_type == _ProcessType.OPTIONAL_REPLACEABLE:
some_class.__annotations__[type_name] = Optional[str]
else:
some_class.__annotations__[type_name] = str
setattr(some_class, type_name, "UNDEFAULTED")
for derived_type in registry.get_all(type_):
if derived_type in _do_not_process:
continue
if issubclass(derived_type, some_class):
# When derived_type is some_class we have a simple
# recursion to avoid. When it's a strict subclass the
# situation is even worse.
continue
known_implementations[derived_type.__name__] = derived_type
args_name = f"{name}_{derived_type.__name__}{ARGS_SUFFIX}"
if args_name in some_class.__annotations__:
raise ValueError(
f"Cannot generate {args_name} because it is already present."
)
some_class.__annotations__[args_name] = DictConfig
setattr(
some_class,
args_name,
get_default_args_field(
derived_type, _do_not_process=_do_not_process + (some_class,)
),
)
else:
args_name = name + ARGS_SUFFIX
if args_name in some_class.__annotations__:
raise ValueError(
f"Cannot generate {args_name} because it is already present."
)
if issubclass(type_, some_class) or type_ in _do_not_process:
raise ValueError(f"Cannot process {type_} inside {some_class}")
some_class.__annotations__[args_name] = DictConfig
setattr(
some_class,
args_name,
get_default_args_field(
type_,
_do_not_process=_do_not_process + (some_class,),
),
)
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
enabled_name = name + ENABLED_SUFFIX
if enabled_name not in some_class.__annotations__:
some_class.__annotations__[enabled_name] = bool
setattr(some_class, enabled_name, False)
creation_function_name = f"create_{name}"
if not hasattr(some_class, creation_function_name):
setattr(
some_class,
creation_function_name,
_default_create(name, type_, process_type),
)
creation_functions.append(creation_function_name)
def remove_unused_components(dict_: DictConfig) -> None:
"""
Assuming dict_ represents the state of a configurable,
modify it to remove all the portions corresponding to
pluggable parts which are not in use.
For example, if renderer_class_type is SignedDistanceFunctionRenderer,
the renderer_MultiPassEmissionAbsorptionRenderer_args will be
removed. Also, if chocolate_enabled is False, then chocolate_args will
be removed.
Args:
dict_: (MODIFIED IN PLACE) a DictConfig instance
"""
keys = [key for key in dict_ if isinstance(key, str)]
suffix_length = len(TYPE_SUFFIX)
replaceables = [key[:-suffix_length] for key in keys if key.endswith(TYPE_SUFFIX)]
args_keys = [key for key in keys if key.endswith(ARGS_SUFFIX)]
for replaceable in replaceables:
selected_type = dict_[replaceable + TYPE_SUFFIX]
if selected_type is None:
expect = ""
else:
expect = replaceable + "_" + selected_type + ARGS_SUFFIX
with open_dict(dict_):
for key in args_keys:
if key.startswith(replaceable + "_") and key != expect:
del dict_[key]
suffix_length = len(ENABLED_SUFFIX)
enableables = [key[:-suffix_length] for key in keys if key.endswith(ENABLED_SUFFIX)]
for enableable in enableables:
enabled = dict_[enableable + ENABLED_SUFFIX]
if not enabled:
with open_dict(dict_):
dict_.pop(enableable + ARGS_SUFFIX, None)
for key in dict_:
if isinstance(dict_.get(key), DictConfig):
remove_unused_components(dict_[key])
================================================
FILE: models/core/utils/utils.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
from scipy import interpolate
def interp(tensor, size):
return F.interpolate(
tensor,
size=size,
mode="bilinear",
align_corners=True,
)
class InputPadder:
"""Pads images such that dimensions are divisible by 8"""
def __init__(self, dims, mode="sintel", divis_by=8):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
if mode == "sintel":
self._pad = [
pad_wd // 2,
pad_wd - pad_wd // 2,
pad_ht // 2,
pad_ht - pad_ht // 2,
]
else:
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
def pad(self, *inputs):
assert all((x.ndim == 4) for x in inputs)
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
def unpad(self, x):
assert x.ndim == 4
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]
def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def upflow8(flow, mode='bilinear'):
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
================================================
FILE: models/raft_model.py
================================================
from types import SimpleNamespace
from typing import ClassVar
import torch.nn.functional as F
from pytorch3d.implicitron.tools.config import Configurable
import torch
import importlib
import sys
import os
autocast = torch.cuda.amp.autocast
class RAFTModel(Configurable, torch.nn.Module):
MODEL_CONFIG_NAME: ClassVar[str] = "RAFTModel"
def __post_init__(self):
super().__init__()
thirdparty_raft_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../third_party/RAFT"))
sys.path.append(thirdparty_raft_path)
raft = importlib.import_module(
"stereoanyvideo.third_party.RAFT.core.raft"
)
self.raft_utils = importlib.import_module(
"stereoanyvideo.third_party.RAFT.core.utils.utils"
)
self.model_weights: str = "./third_party/RAFT/models/raft-things.pth"
model_args = SimpleNamespace(
mixed_precision=False,
small=False,
dropout=0.0,
)
self.args = model_args
self.model = raft.RAFT(model_args).cuda()
state_dict = torch.load(self.model_weights, map_location="cpu")
weight_dict = {}
for k,v in state_dict.items():
temp_k = k.replace('module.', '') if 'module' in k else k
weight_dict[temp_k] = v
self.model.load_state_dict(weight_dict, strict=True)
def forward(self, image1, image2, iters=10):
left_image_rgb = image1.cuda()
right_image_rgb = image2.cuda()
padder = self.raft_utils.InputPadder(left_image_rgb.shape)
left_image_rgb, right_image_rgb = padder.pad(
left_image_rgb, right_image_rgb
)
with autocast(enabled=self.args.mixed_precision):
flow, flow_up = self.model(left_image_rgb, right_image_rgb, iters=iters, test_mode=True)
flow_up = padder.unpad(flow_up)
return 0.25 * F.interpolate(flow_up, size=(flow_up.shape[2] // 4, flow_up.shape[3] // 4), mode="bilinear",
align_corners=True)
def forward_fullres(self, image1, image2, iters=20):
left_image_rgb = image1.cuda()
right_image_rgb = image2.cuda()
padder = self.raft_utils.InputPadder(left_image_rgb.shape)
left_image_rgb, right_image_rgb = padder.pad(
left_image_rgb, right_image_rgb
)
with autocast(enabled=self.args.mixed_precision):
flow, flow_up = self.model(left_image_rgb.contiguous(), right_image_rgb.contiguous(), iters=iters, test_mode=True)
flow_up = padder.unpad(flow_up)
return flow_up
================================================
FILE: models/stereoanyvideo_model.py
================================================
from typing import ClassVar
import torch
import torch.nn.functional as F
from pytorch3d.implicitron.tools.config import Configurable
from stereoanyvideo.models.core.stereoanyvideo import StereoAnyVideo
class StereoAnyVideoModel(Configurable, torch.nn.Module):
MODEL_CONFIG_NAME: ClassVar[str] = "StereoAnyVideoModel"
model_weights: str = "./checkpoints/StereoAnyVideo_MIX.pth"
def __post_init__(self):
super().__init__()
self.mixed_precision = False
model = StereoAnyVideo(mixed_precision=self.mixed_precision)
state_dict = torch.load(self.model_weights, map_location="cpu")
if "model" in state_dict:
state_dict = state_dict["model"]
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
state_dict = {"module." + k: v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=True)
self.model = model
self.model.to("cuda")
self.model.eval()
def forward(self, batch_dict, iters=20):
return self.model.forward_batch_test(batch_dict, iters=iters)
================================================
FILE: requirements.txt
================================================
hydra-core==1.1
numpy==1.23.5
munch==2.5.0
omegaconf==2.1.0
flow_vis==0.1
einops==0.4.1
opt_einsum==3.3.0
requests
moviepy
================================================
FILE: third_party/RAFT/LICENSE
================================================
BSD 3-Clause License
Copyright (c) 2020, princeton-vl
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: third_party/RAFT/README.md
================================================
# RAFT
This repository contains the source code for our paper:
[RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)
ECCV 2020
Zachary Teed and Jia Deng
## Requirements
The code has been tested with PyTorch 1.6 and Cuda 10.1.
```Shell
conda create --name raft
conda activate raft
conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch
```
## Demos
Pretrained models can be downloaded by running
```Shell
./download_models.sh
```
or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing)
You can demo a trained model on a sequence of frames
```Shell
python demo.py --model=models/raft-things.pth --path=demo-frames
```
## Required Data
To evaluate/train RAFT, you will need to download the required datasets.
* [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
* [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
* [Sintel](http://sintel.is.tue.mpg.de/)
* [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
* [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional)
By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder
```Shell
├── datasets
├── Sintel
├── test
├── training
├── KITTI
├── testing
├── training
├── devkit
├── FlyingChairs_release
├── data
├── FlyingThings3D
├── frames_cleanpass
├── frames_finalpass
├── optical_flow
```
## Evaluation
You can evaluate a trained model using `evaluate.py`
```Shell
python evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision
```
## Training
We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard
```Shell
./train_standard.sh
```
If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
```Shell
./train_mixed.sh
```
## (Optional) Efficent Implementation
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
```Shell
cd alt_cuda_corr && python setup.py install && cd ..
```
and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
================================================
FILE: third_party/RAFT/alt_cuda_corr/correlation.cpp
================================================
#include
#include
// CUDA forward declarations
std::vector corr_cuda_forward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
int radius);
std::vector corr_cuda_backward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius);
// C++ interface
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector corr_forward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
int radius) {
CHECK_INPUT(fmap1);
CHECK_INPUT(fmap2);
CHECK_INPUT(coords);
return corr_cuda_forward(fmap1, fmap2, coords, radius);
}
std::vector corr_backward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius) {
CHECK_INPUT(fmap1);
CHECK_INPUT(fmap2);
CHECK_INPUT(coords);
CHECK_INPUT(corr_grad);
return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &corr_forward, "CORR forward");
m.def("backward", &corr_backward, "CORR backward");
}
================================================
FILE: third_party/RAFT/alt_cuda_corr/correlation_kernel.cu
================================================
#include
#include
#include
#include
#define BLOCK_H 4
#define BLOCK_W 8
#define BLOCK_HW BLOCK_H * BLOCK_W
#define CHANNEL_STRIDE 32
__forceinline__ __device__
bool within_bounds(int h, int w, int H, int W) {
return h >= 0 && h < H && w >= 0 && w < W;
}
template
__global__ void corr_forward_kernel(
const torch::PackedTensorAccessor32 fmap1,
const torch::PackedTensorAccessor32 fmap2,
const torch::PackedTensorAccessor32 coords,
torch::PackedTensorAccessor32 corr,
int r)
{
const int b = blockIdx.x;
const int h0 = blockIdx.y * blockDim.x;
const int w0 = blockIdx.z * blockDim.y;
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
const int H1 = fmap1.size(1);
const int W1 = fmap1.size(2);
const int H2 = fmap2.size(1);
const int W2 = fmap2.size(2);
const int N = coords.size(1);
const int C = fmap1.size(3);
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t x2s[BLOCK_HW];
__shared__ scalar_t y2s[BLOCK_HW];
for (int c=0; c(floor(y2s[k1]))-r+iy;
int w2 = static_cast(floor(x2s[k1]))-r+ix;
int c2 = tid % CHANNEL_STRIDE;
auto fptr = fmap2[b][h2][w2];
if (within_bounds(h2, w2, H2, W2))
f2[c2][k1] = fptr[c+c2];
else
f2[c2][k1] = 0.0;
}
__syncthreads();
scalar_t s = 0.0;
for (int k=0; k 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_nw) += nw;
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_ne) += ne;
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_sw) += sw;
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_se) += se;
}
}
}
}
}
template
__global__ void corr_backward_kernel(
const torch::PackedTensorAccessor32 fmap1,
const torch::PackedTensorAccessor32 fmap2,
const torch::PackedTensorAccessor32 coords,
const torch::PackedTensorAccessor32 corr_grad,
torch::PackedTensorAccessor32 fmap1_grad,
torch::PackedTensorAccessor32 fmap2_grad,
torch::PackedTensorAccessor32 coords_grad,
int r)
{
const int b = blockIdx.x;
const int h0 = blockIdx.y * blockDim.x;
const int w0 = blockIdx.z * blockDim.y;
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
const int H1 = fmap1.size(1);
const int W1 = fmap1.size(2);
const int H2 = fmap2.size(1);
const int W2 = fmap2.size(2);
const int N = coords.size(1);
const int C = fmap1.size(3);
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t x2s[BLOCK_HW];
__shared__ scalar_t y2s[BLOCK_HW];
for (int c=0; c(floor(y2s[k1]))-r+iy;
int w2 = static_cast(floor(x2s[k1]))-r+ix;
int c2 = tid % CHANNEL_STRIDE;
auto fptr = fmap2[b][h2][w2];
if (within_bounds(h2, w2, H2, W2))
f2[c2][k1] = fptr[c+c2];
else
f2[c2][k1] = 0.0;
f2_grad[c2][k1] = 0.0;
}
__syncthreads();
const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1];
scalar_t g = 0.0;
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
int ix_ne = H1*W1*((iy-1) + rd*ix);
int ix_sw = H1*W1*(iy + rd*(ix-1));
int ix_se = H1*W1*(iy + rd*ix);
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_nw) * dy * dx;
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_ne) * dy * (1-dx);
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_sw) * (1-dy) * dx;
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_se) * (1-dy) * (1-dx);
for (int k=0; k(floor(y2s[k1]))-r+iy;
int w2 = static_cast(floor(x2s[k1]))-r+ix;
int c2 = tid % CHANNEL_STRIDE;
scalar_t* fptr = &fmap2_grad[b][h2][w2][0];
if (within_bounds(h2, w2, H2, W2))
atomicAdd(fptr+c+c2, f2_grad[c2][k1]);
}
}
}
}
__syncthreads();
for (int k=0; k corr_cuda_forward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
int radius)
{
const auto B = coords.size(0);
const auto N = coords.size(1);
const auto H = coords.size(2);
const auto W = coords.size(3);
const auto rd = 2 * radius + 1;
auto opts = fmap1.options();
auto corr = torch::zeros({B, N, rd*rd, H, W}, opts);
const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W);
const dim3 threads(BLOCK_H, BLOCK_W);
corr_forward_kernel<<>>(
fmap1.packed_accessor32(),
fmap2.packed_accessor32(),
coords.packed_accessor32(),
corr.packed_accessor32(),
radius);
return {corr};
}
std::vector corr_cuda_backward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius)
{
const auto B = coords.size(0);
const auto N = coords.size(1);
const auto H1 = fmap1.size(1);
const auto W1 = fmap1.size(2);
const auto H2 = fmap2.size(1);
const auto W2 = fmap2.size(2);
const auto C = fmap1.size(3);
auto opts = fmap1.options();
auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts);
auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts);
auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts);
const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W);
const dim3 threads(BLOCK_H, BLOCK_W);
corr_backward_kernel<<>>(
fmap1.packed_accessor32(),
fmap2.packed_accessor32(),
coords.packed_accessor32(),
corr_grad.packed_accessor32(),
fmap1_grad.packed_accessor32(),
fmap2_grad.packed_accessor32(),
coords_grad.packed_accessor32(),
radius);
return {fmap1_grad, fmap2_grad, coords_grad};
}
================================================
FILE: third_party/RAFT/alt_cuda_corr/setup.py
================================================
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='correlation',
ext_modules=[
CUDAExtension('alt_cuda_corr',
sources=['correlation.cpp', 'correlation_kernel.cu'],
extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
],
cmdclass={
'build_ext': BuildExtension
})
================================================
FILE: third_party/RAFT/chairs_split.txt
================================================
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
2
1
1
2
2
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
2
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
2
2
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
2
1
1
1
1
2
1
1
2
1
1
1
1
2
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
2
1
1
2
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
2
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
2
2
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
2
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
2
1
1
1
1
2
1
1
1
1
1
1
1
2
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
2
1
1
1
1
2
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
2
1
1
1
1
1
================================================
FILE: third_party/RAFT/core/__init__.py
================================================
================================================
FILE: third_party/RAFT/core/corr.py
================================================
import torch
import torch.nn.functional as F
from .utils.utils import bilinear_sampler, coords_grid
try:
import alt_cuda_corr
except:
# alt_cuda_corr is not compiled
pass
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
self.corr_pyramid.append(corr)
for i in range(self.num_levels-1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd)
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
class AlternateCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.pyramid = [(fmap1, fmap2)]
for i in range(self.num_levels):
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
self.pyramid.append((fmap1, fmap2))
def __call__(self, coords):
coords = coords.permute(0, 2, 3, 1)
B, H, W, _ = coords.shape
dim = self.pyramid[0][0].shape[1]
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))
corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / torch.sqrt(torch.tensor(dim).float())
================================================
FILE: third_party/RAFT/core/datasets.py
================================================
# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
import os
import math
import random
from glob import glob
import os.path as osp
from utils import frame_utils
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
class FlowDataset(data.Dataset):
def __init__(self, aug_params=None, sparse=False):
self.augmentor = None
self.sparse = sparse
if aug_params is not None:
if sparse:
self.augmentor = SparseFlowAugmentor(**aug_params)
else:
self.augmentor = FlowAugmentor(**aug_params)
self.is_test = False
self.init_seed = False
self.flow_list = []
self.image_list = []
self.extra_info = []
def __getitem__(self, index):
if self.is_test:
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, self.extra_info[index]
if not self.init_seed:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
torch.manual_seed(worker_info.id)
np.random.seed(worker_info.id)
random.seed(worker_info.id)
self.init_seed = True
index = index % len(self.image_list)
valid = None
if self.sparse:
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
else:
flow = frame_utils.read_gen(self.flow_list[index])
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
flow = np.array(flow).astype(np.float32)
img1 = np.array(img1).astype(np.uint8)
img2 = np.array(img2).astype(np.uint8)
# grayscale images
if len(img1.shape) == 2:
img1 = np.tile(img1[...,None], (1, 1, 3))
img2 = np.tile(img2[...,None], (1, 1, 3))
else:
img1 = img1[..., :3]
img2 = img2[..., :3]
if self.augmentor is not None:
if self.sparse:
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
else:
img1, img2, flow = self.augmentor(img1, img2, flow)
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
if valid is not None:
valid = torch.from_numpy(valid)
else:
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
return img1, img2, flow, valid.float()
def __rmul__(self, v):
self.flow_list = v * self.flow_list
self.image_list = v * self.image_list
return self
def __len__(self):
return len(self.image_list)
class MpiSintel(FlowDataset):
def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
super(MpiSintel, self).__init__(aug_params)
flow_root = osp.join(root, split, 'flow')
image_root = osp.join(root, split, dstype)
if split == 'test':
self.is_test = True
for scene in os.listdir(image_root):
image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
for i in range(len(image_list)-1):
self.image_list += [ [image_list[i], image_list[i+1]] ]
self.extra_info += [ (scene, i) ] # scene and frame_id
if split != 'test':
self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
class FlyingChairs(FlowDataset):
def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
super(FlyingChairs, self).__init__(aug_params)
images = sorted(glob(osp.join(root, '*.ppm')))
flows = sorted(glob(osp.join(root, '*.flo')))
assert (len(images)//2 == len(flows))
split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
for i in range(len(flows)):
xid = split_list[i]
if (split=='training' and xid==1) or (split=='validation' and xid==2):
self.flow_list += [ flows[i] ]
self.image_list += [ [images[2*i], images[2*i+1]] ]
class FlyingThings3D(FlowDataset):
def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
super(FlyingThings3D, self).__init__(aug_params)
for cam in ['left']:
for direction in ['into_future', 'into_past']:
image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
for idir, fdir in zip(image_dirs, flow_dirs):
images = sorted(glob(osp.join(idir, '*.png')) )
flows = sorted(glob(osp.join(fdir, '*.pfm')) )
for i in range(len(flows)-1):
if direction == 'into_future':
self.image_list += [ [images[i], images[i+1]] ]
self.flow_list += [ flows[i] ]
elif direction == 'into_past':
self.image_list += [ [images[i+1], images[i]] ]
self.flow_list += [ flows[i+1] ]
class KITTI(FlowDataset):
def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
super(KITTI, self).__init__(aug_params, sparse=True)
if split == 'testing':
self.is_test = True
root = osp.join(root, split)
images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
for img1, img2 in zip(images1, images2):
frame_id = img1.split('/')[-1]
self.extra_info += [ [frame_id] ]
self.image_list += [ [img1, img2] ]
if split == 'training':
self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
class HD1K(FlowDataset):
def __init__(self, aug_params=None, root='datasets/HD1k'):
super(HD1K, self).__init__(aug_params, sparse=True)
seq_ix = 0
while 1:
flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
if len(flows) == 0:
break
for i in range(len(flows)-1):
self.flow_list += [flows[i]]
self.image_list += [ [images[i], images[i+1]] ]
seq_ix += 1
def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
""" Create the data loader for the corresponding trainign set """
if args.stage == 'chairs':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
train_dataset = FlyingChairs(aug_params, split='training')
elif args.stage == 'things':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
train_dataset = clean_dataset + final_dataset
elif args.stage == 'sintel':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
if TRAIN_DS == 'C+T+K+S+H':
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
elif TRAIN_DS == 'C+T+K/S':
train_dataset = 100*sintel_clean + 100*sintel_final + things
elif args.stage == 'kitti':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
train_dataset = KITTI(aug_params, split='training')
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
print('Training with %d image pairs' % len(train_dataset))
return train_loader
================================================
FILE: third_party/RAFT/core/extractor.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes//4)
self.norm2 = nn.BatchNorm2d(planes//4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes//4)
self.norm2 = nn.InstanceNorm2d(planes//4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
if not stride == 1:
self.norm4 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class BasicEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
class SmallEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(32)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(32)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
================================================
FILE: third_party/RAFT/core/raft.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .update import BasicUpdateBlock, SmallUpdateBlock
from .extractor import BasicEncoder, SmallEncoder
from .corr import CorrBlock, AlternateCorrBlock
from .utils.utils import bilinear_sampler, coords_grid, upflow8
try:
autocast = torch.cuda.amp.autocast
except:
# dummy autocast for PyTorch < 1.6
class autocast:
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
class RAFT(nn.Module):
def __init__(self, args):
super(RAFT, self).__init__()
self.args = args
if args.small:
self.hidden_dim = hdim = 96
self.context_dim = cdim = 64
args.corr_levels = 4
args.corr_radius = 3
else:
self.hidden_dim = hdim = 128
self.context_dim = cdim = 128
args.corr_levels = 4
args.corr_radius = 4
# if 'dropout' not in self.args:
self.args.dropout = 0
# if 'alternate_corr' not in self.args:
self.args.alternate_corr = False
# feature network, context network, and update block
if args.small:
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
else:
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def initialize_flow(self, img):
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
N, C, H, W = img.shape
coords0 = coords_grid(N, H//8, W//8, device=img.device)
coords1 = coords_grid(N, H//8, W//8, device=img.device)
# optical flow computed as difference: flow = coords1 - coords0
return coords0, coords1
def upsample_flow(self, flow, mask):
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(8 * flow, [3,3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8*H, 8*W)
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
""" Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
image1 = image1.contiguous()
image2 = image2.contiguous()
hdim = self.hidden_dim
cdim = self.context_dim
# run the feature network
with autocast(enabled=self.args.mixed_precision):
fmap1, fmap2 = self.fnet([image1, image2])
fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# run the context network
with autocast(enabled=self.args.mixed_precision):
cnet = self.cnet(image1)
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
net = torch.tanh(net)
inp = torch.relu(inp)
coords0, coords1 = self.initialize_flow(image1)
if flow_init is not None:
coords1 = coords1 + flow_init
flow_predictions = []
for itr in range(iters):
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume
flow = coords1 - coords0
with autocast(enabled=self.args.mixed_precision):
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow
# upsample predictions
if up_mask is None:
flow_up = upflow8(coords1 - coords0)
else:
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
flow_predictions.append(flow_up)
if test_mode:
return coords1 - coords0, flow_up
return flow_predictions
================================================
FILE: third_party/RAFT/core/update.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class SmallMotionEncoder(nn.Module):
def __init__(self, args):
super(SmallMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
self.conv = nn.Conv2d(128, 80, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class SmallUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=96):
super(SmallUpdateBlock, self).__init__()
self.encoder = SmallMotionEncoder(args)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, None, delta_flow
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64*9, 1, padding=0))
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = .25 * self.mask(net)
return net, mask, delta_flow
================================================
FILE: third_party/RAFT/core/utils/__init__.py
================================================
================================================
FILE: third_party/RAFT/core/utils/augmentor.py
================================================
import numpy as np
import random
import math
from PIL import Image
import cv2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
import torch
from torchvision.transforms import ColorJitter
import torch.nn.functional as F
class FlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
# spatial augmentation params
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
""" Photometric augmentation """
# asymmetric
if np.random.rand() < self.asymmetric_color_aug_prob:
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
# symmetric
else:
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2, bounds=[50, 100]):
""" Occlusion augmentation """
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(bounds[0], bounds[1])
dy = np.random.randint(bounds[0], bounds[1])
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
return img1, img2
def spatial_transform(self, img1, img2, flow):
# randomly sample scale
ht, wd = img1.shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 8) / float(ht),
(self.crop_size[1] + 8) / float(wd))
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = scale
scale_y = scale
if np.random.rand() < self.stretch_prob:
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_x = np.clip(scale_x, min_scale, None)
scale_y = np.clip(scale_y, min_scale, None)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow = flow * [scale_x, scale_y]
if self.do_flip:
if np.random.rand() < self.h_flip_prob: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
if np.random.rand() < self.v_flip_prob: # v-flip
img1 = img1[::-1, :]
img2 = img2[::-1, :]
flow = flow[::-1, :] * [1.0, -1.0]
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
return img1, img2, flow
def __call__(self, img1, img2, flow):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
img1, img2, flow = self.spatial_transform(img1, img2, flow)
img1 = np.ascontiguousarray(img1)
img2 = np.ascontiguousarray(img2)
flow = np.ascontiguousarray(flow)
return img1, img2, flow
class SparseFlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
# spatial augmentation params
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2):
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(50, 100)
dy = np.random.randint(50, 100)
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
return img1, img2
def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
ht, wd = flow.shape[:2]
coords = np.meshgrid(np.arange(wd), np.arange(ht))
coords = np.stack(coords, axis=-1)
coords = coords.reshape(-1, 2).astype(np.float32)
flow = flow.reshape(-1, 2).astype(np.float32)
valid = valid.reshape(-1).astype(np.float32)
coords0 = coords[valid>=1]
flow0 = flow[valid>=1]
ht1 = int(round(ht * fy))
wd1 = int(round(wd * fx))
coords1 = coords0 * [fx, fy]
flow1 = flow0 * [fx, fy]
xx = np.round(coords1[:,0]).astype(np.int32)
yy = np.round(coords1[:,1]).astype(np.int32)
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
xx = xx[v]
yy = yy[v]
flow1 = flow1[v]
flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
valid_img = np.zeros([ht1, wd1], dtype=np.int32)
flow_img[yy, xx] = flow1
valid_img[yy, xx] = 1
return flow_img, valid_img
def spatial_transform(self, img1, img2, flow, valid):
# randomly sample scale
ht, wd = img1.shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 1) / float(ht),
(self.crop_size[1] + 1) / float(wd))
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = np.clip(scale, min_scale, None)
scale_y = np.clip(scale, min_scale, None)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
if self.do_flip:
if np.random.rand() < 0.5: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
valid = valid[:, ::-1]
margin_y = 20
margin_x = 50
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
return img1, img2, flow, valid
def __call__(self, img1, img2, flow, valid):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
img1 = np.ascontiguousarray(img1)
img2 = np.ascontiguousarray(img2)
flow = np.ascontiguousarray(flow)
valid = np.ascontiguousarray(valid)
return img1, img2, flow, valid
================================================
FILE: third_party/RAFT/core/utils/flow_viz.py
================================================
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
# MIT License
#
# Copyright (c) 2018 Tom Runia
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to conditions.
#
# Author: Tom Runia
# Date Created: 2018-08-03
import numpy as np
def make_colorwheel():
"""
Generates a color wheel for optical flow visualization as presented in:
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
Code follows the original C++ source code of Daniel Scharstein.
Code follows the the Matlab source code of Deqing Sun.
Returns:
np.ndarray: Color wheel
"""
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros((ncols, 3))
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
col = col+RY
# YG
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
colorwheel[col:col+YG, 1] = 255
col = col+YG
# GC
colorwheel[col:col+GC, 1] = 255
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
col = col+GC
# CB
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
colorwheel[col:col+CB, 2] = 255
col = col+CB
# BM
colorwheel[col:col+BM, 2] = 255
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
col = col+BM
# MR
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
colorwheel[col:col+MR, 0] = 255
return colorwheel
def flow_uv_to_colors(u, v, convert_to_bgr=False):
"""
Applies the flow color wheel to (possibly clipped) flow components u and v.
According to the C++ source code of Daniel Scharstein
According to the Matlab source code of Deqing Sun
Args:
u (np.ndarray): Input horizontal flow of shape [H,W]
v (np.ndarray): Input vertical flow of shape [H,W]
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
colorwheel = make_colorwheel() # shape [55x3]
ncols = colorwheel.shape[0]
rad = np.sqrt(np.square(u) + np.square(v))
a = np.arctan2(-v, -u)/np.pi
fk = (a+1) / 2*(ncols-1)
k0 = np.floor(fk).astype(np.int32)
k1 = k0 + 1
k1[k1 == ncols] = 0
f = fk - k0
for i in range(colorwheel.shape[1]):
tmp = colorwheel[:,i]
col0 = tmp[k0] / 255.0
col1 = tmp[k1] / 255.0
col = (1-f)*col0 + f*col1
idx = (rad <= 1)
col[idx] = 1 - rad[idx] * (1-col[idx])
col[~idx] = col[~idx] * 0.75 # out of range
# Note the 2-i => BGR instead of RGB
ch_idx = 2-i if convert_to_bgr else i
flow_image[:,:,ch_idx] = np.floor(255 * col)
return flow_image
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
"""
Expects a two dimensional flow image of shape.
Args:
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
if clip_flow is not None:
flow_uv = np.clip(flow_uv, 0, clip_flow)
u = flow_uv[:,:,0]
v = flow_uv[:,:,1]
rad = np.sqrt(np.square(u) + np.square(v))
rad_max = np.max(rad)
epsilon = 1e-5
u = u / (rad_max + epsilon)
v = v / (rad_max + epsilon)
return flow_uv_to_colors(u, v, convert_to_bgr)
================================================
FILE: third_party/RAFT/core/utils/frame_utils.py
================================================
import numpy as np
from PIL import Image
from os.path import *
import re
import cv2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
TAG_CHAR = np.array([202021.25], np.float32)
def readFlow(fn):
""" Read .flo file in Middlebury format"""
# Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
# print 'fn = %s'%(fn)
with open(fn, 'rb') as f:
magic = np.fromfile(f, np.float32, count=1)
if 202021.25 != magic:
print('Magic number incorrect. Invalid .flo file')
return None
else:
w = np.fromfile(f, np.int32, count=1)
h = np.fromfile(f, np.int32, count=1)
# print 'Reading %d x %d flo file\n' % (w, h)
data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
# Reshape data into 3D array (columns, rows, bands)
# The reshape here is for visualization, the original code is (w,h,2)
return np.resize(data, (int(h), int(w), 2))
def readPFM(file):
file = open(file, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header == b'PF':
color = True
elif header == b'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = '<'
scale = -scale
else:
endian = '>' # big-endian
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data
def writeFlow(filename,uv,v=None):
""" Write optical flow to file.
If v is None, uv is assumed to contain both u and v channels,
stacked in depth.
Original code by Deqing Sun, adapted from Daniel Scharstein.
"""
nBands = 2
if v is None:
assert(uv.ndim == 3)
assert(uv.shape[2] == 2)
u = uv[:,:,0]
v = uv[:,:,1]
else:
u = uv
assert(u.shape == v.shape)
height,width = u.shape
f = open(filename,'wb')
# write the header
f.write(TAG_CHAR)
np.array(width).astype(np.int32).tofile(f)
np.array(height).astype(np.int32).tofile(f)
# arrange into matrix form
tmp = np.zeros((height, width*nBands))
tmp[:,np.arange(width)*2] = u
tmp[:,np.arange(width)*2 + 1] = v
tmp.astype(np.float32).tofile(f)
f.close()
def readFlowKITTI(filename):
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
flow = flow[:,:,::-1].astype(np.float32)
flow, valid = flow[:, :, :2], flow[:, :, 2]
flow = (flow - 2**15) / 64.0
return flow, valid
def readDispKITTI(filename):
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
valid = disp > 0.0
flow = np.stack([-disp, np.zeros_like(disp)], -1)
return flow, valid
def writeFlowKITTI(filename, uv):
uv = 64.0 * uv + 2**15
valid = np.ones([uv.shape[0], uv.shape[1], 1])
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
cv2.imwrite(filename, uv[..., ::-1])
def read_gen(file_name, pil=False):
ext = splitext(file_name)[-1]
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
return Image.open(file_name)
elif ext == '.bin' or ext == '.raw':
return np.load(file_name)
elif ext == '.flo':
return readFlow(file_name).astype(np.float32)
elif ext == '.pfm':
flow = readPFM(file_name).astype(np.float32)
if len(flow.shape) == 2:
return flow
else:
return flow[:, :, :-1]
return []
================================================
FILE: third_party/RAFT/core/utils/utils.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
from scipy import interpolate
class InputPadder:
""" Pads images such that dimensions are divisible by 8 """
def __init__(self, dims, mode='sintel'):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
if mode == 'sintel':
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
else:
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
def unpad(self,x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
return x[..., c[0]:c[1], c[2]:c[3]]
def forward_interpolate(flow):
flow = flow.detach().cpu().numpy()
dx, dy = flow[0], flow[1]
ht, wd = dx.shape
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
x1 = x0 + dx
y1 = y0 + dy
x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
dx = dx.reshape(-1)
dy = dy.reshape(-1)
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
x1 = x1[valid]
y1 = y1[valid]
dx = dx[valid]
dy = dy[valid]
flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
flow_y = interpolate.griddata(
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float()
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd, device):
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def upflow8(flow, mode='bilinear'):
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
================================================
FILE: third_party/RAFT/demo.py
================================================
import sys
sys.path.append('core')
import argparse
import os
import cv2
import glob
import numpy as np
import torch
from PIL import Image
from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder
DEVICE = 'cuda'
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img[None].to(DEVICE)
def viz(img, flo):
img = img[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()
# map flow to rgb image
flo = flow_viz.flow_to_image(flo)
img_flo = np.concatenate([img, flo], axis=0)
# import matplotlib.pyplot as plt
# plt.imshow(img_flo / 255.0)
# plt.show()
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
cv2.waitKey()
def demo(args):
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model = model.module
model.to(DEVICE)
model.eval()
with torch.no_grad():
images = glob.glob(os.path.join(args.path, '*.png')) + \
glob.glob(os.path.join(args.path, '*.jpg'))
images = sorted(images)
for imfile1, imfile2 in zip(images[:-1], images[1:]):
image1 = load_image(imfile1)
image2 = load_image(imfile2)
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
viz(image1, flow_up)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint")
parser.add_argument('--path', help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
args = parser.parse_args()
demo(args)
================================================
FILE: third_party/RAFT/download_models.sh
================================================
#!/bin/bash
wget https://dl.dropboxusercontent.com/s/4j4z58wuv8o0mfz/models.zip
unzip models.zip
================================================
FILE: third_party/RAFT/evaluate.py
================================================
import sys
sys.path.append('core')
from PIL import Image
import argparse
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import datasets
from utils import flow_viz
from utils import frame_utils
from raft import RAFT
from utils.utils import InputPadder, forward_interpolate
@torch.no_grad()
def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):
""" Create submission for the Sintel leaderboard """
model.eval()
for dstype in ['clean', 'final']:
test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)
flow_prev, sequence_prev = None, None
for test_id in range(len(test_dataset)):
image1, image2, (sequence, frame) = test_dataset[test_id]
if sequence != sequence_prev:
flow_prev = None
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)
flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
if warm_start:
flow_prev = forward_interpolate(flow_low[0])[None].cuda()
output_dir = os.path.join(output_path, dstype, sequence)
output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
frame_utils.writeFlow(output_file, flow)
sequence_prev = sequence
@torch.no_grad()
def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
""" Create submission for the Sintel leaderboard """
model.eval()
test_dataset = datasets.KITTI(split='testing', aug_params=None)
if not os.path.exists(output_path):
os.makedirs(output_path)
for test_id in range(len(test_dataset)):
image1, image2, (frame_id, ) = test_dataset[test_id]
padder = InputPadder(image1.shape, mode='kitti')
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
output_filename = os.path.join(output_path, frame_id)
frame_utils.writeFlowKITTI(output_filename, flow)
@torch.no_grad()
def validate_chairs(model, iters=24):
""" Perform evaluation on the FlyingChairs (test) split """
model.eval()
epe_list = []
val_dataset = datasets.FlyingChairs(split='validation')
for val_id in range(len(val_dataset)):
image1, image2, flow_gt, _ = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()
epe_list.append(epe.view(-1).numpy())
epe = np.mean(np.concatenate(epe_list))
print("Validation Chairs EPE: %f" % epe)
return {'chairs': epe}
@torch.no_grad()
def validate_sintel(model, iters=32):
""" Peform validation using the Sintel (train) split """
model.eval()
results = {}
for dstype in ['clean', 'final']:
val_dataset = datasets.MpiSintel(split='training', dstype=dstype)
epe_list = []
for val_id in range(len(val_dataset)):
image1, image2, flow_gt, _ = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow = padder.unpad(flow_pr[0]).cpu()
epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
epe_list.append(epe.view(-1).numpy())
epe_all = np.concatenate(epe_list)
epe = np.mean(epe_all)
px1 = np.mean(epe_all<1)
px3 = np.mean(epe_all<3)
px5 = np.mean(epe_all<5)
print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
results[dstype] = np.mean(epe_list)
return results
@torch.no_grad()
def validate_kitti(model, iters=24):
""" Peform validation using the KITTI-2015 (train) split """
model.eval()
val_dataset = datasets.KITTI(split='training')
out_list, epe_list = [], []
for val_id in range(len(val_dataset)):
image1, image2, flow_gt, valid_gt = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape, mode='kitti')
image1, image2 = padder.pad(image1, image2)
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow = padder.unpad(flow_pr[0]).cpu()
epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
mag = torch.sum(flow_gt**2, dim=0).sqrt()
epe = epe.view(-1)
mag = mag.view(-1)
val = valid_gt.view(-1) >= 0.5
out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
epe_list.append(epe[val].mean().item())
out_list.append(out[val].cpu().numpy())
epe_list = np.array(epe_list)
out_list = np.concatenate(out_list)
epe = np.mean(epe_list)
f1 = 100 * np.mean(out_list)
print("Validation KITTI: %f, %f" % (epe, f1))
return {'kitti-epe': epe, 'kitti-f1': f1}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint")
parser.add_argument('--dataset', help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
args = parser.parse_args()
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model.cuda()
model.eval()
# create_sintel_submission(model.module, warm_start=True)
# create_kitti_submission(model.module)
with torch.no_grad():
if args.dataset == 'chairs':
validate_chairs(model.module)
elif args.dataset == 'sintel':
validate_sintel(model.module)
elif args.dataset == 'kitti':
validate_kitti(model.module)
================================================
FILE: third_party/RAFT/train.py
================================================
from __future__ import print_function, division
import sys
sys.path.append('core')
import argparse
import os
import cv2
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from raft import RAFT
import evaluate
import datasets
from torch.utils.tensorboard import SummaryWriter
try:
from torch.cuda.amp import GradScaler
except:
# dummy GradScaler for PyTorch < 1.6
class GradScaler:
def __init__(self):
pass
def scale(self, loss):
return loss
def unscale_(self, optimizer):
pass
def step(self, optimizer):
optimizer.step()
def update(self):
pass
# exclude extremly large displacements
MAX_FLOW = 400
SUM_FREQ = 100
VAL_FREQ = 5000
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
""" Loss function defined over sequence of flow predictions """
n_predictions = len(flow_preds)
flow_loss = 0.0
# exlude invalid pixels and extremely large diplacements
mag = torch.sum(flow_gt**2, dim=1).sqrt()
valid = (valid >= 0.5) & (mag < max_flow)
for i in range(n_predictions):
i_weight = gamma**(n_predictions - i - 1)
i_loss = (flow_preds[i] - flow_gt).abs()
flow_loss += i_weight * (valid[:, None] * i_loss).mean()
epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
epe = epe.view(-1)[valid.view(-1)]
metrics = {
'epe': epe.mean().item(),
'1px': (epe < 1).float().mean().item(),
'3px': (epe < 3).float().mean().item(),
'5px': (epe < 5).float().mean().item(),
}
return flow_loss, metrics
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def fetch_optimizer(args, model):
""" Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
return optimizer, scheduler
class Logger:
def __init__(self, model, scheduler):
self.model = model
self.scheduler = scheduler
self.total_steps = 0
self.running_loss = {}
self.writer = None
def _print_training_status(self):
metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
# print the training status
print(training_str + metrics_str)
if self.writer is None:
self.writer = SummaryWriter()
for k in self.running_loss:
self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps)
self.running_loss[k] = 0.0
def push(self, metrics):
self.total_steps += 1
for key in metrics:
if key not in self.running_loss:
self.running_loss[key] = 0.0
self.running_loss[key] += metrics[key]
if self.total_steps % SUM_FREQ == SUM_FREQ-1:
self._print_training_status()
self.running_loss = {}
def write_dict(self, results):
if self.writer is None:
self.writer = SummaryWriter()
for key in results:
self.writer.add_scalar(key, results[key], self.total_steps)
def close(self):
self.writer.close()
def train(args):
model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
print("Parameter Count: %d" % count_parameters(model))
if args.restore_ckpt is not None:
model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
model.cuda()
model.train()
if args.stage != 'chairs':
model.module.freeze_bn()
train_loader = datasets.fetch_dataloader(args)
optimizer, scheduler = fetch_optimizer(args, model)
total_steps = 0
scaler = GradScaler(enabled=args.mixed_precision)
logger = Logger(model, scheduler)
VAL_FREQ = 5000
add_noise = True
should_keep_training = True
while should_keep_training:
for i_batch, data_blob in enumerate(train_loader):
optimizer.zero_grad()
image1, image2, flow, valid = [x.cuda() for x in data_blob]
if args.add_noise:
stdv = np.random.uniform(0.0, 5.0)
image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
flow_predictions = model(image1, image2, iters=args.iters)
loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
scaler.step(optimizer)
scheduler.step()
scaler.update()
logger.push(metrics)
if total_steps % VAL_FREQ == VAL_FREQ - 1:
PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
torch.save(model.state_dict(), PATH)
results = {}
for val_dataset in args.validation:
if val_dataset == 'chairs':
results.update(evaluate.validate_chairs(model.module))
elif val_dataset == 'sintel':
results.update(evaluate.validate_sintel(model.module))
elif val_dataset == 'kitti':
results.update(evaluate.validate_kitti(model.module))
logger.write_dict(results)
model.train()
if args.stage != 'chairs':
model.module.freeze_bn()
total_steps += 1
if total_steps > args.num_steps:
should_keep_training = False
break
logger.close()
PATH = 'checkpoints/%s.pth' % args.name
torch.save(model.state_dict(), PATH)
return PATH
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='raft', help="name your experiment")
parser.add_argument('--stage', help="determines which dataset to use for training")
parser.add_argument('--restore_ckpt', help="restore checkpoint")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--validation', type=str, nargs='+')
parser.add_argument('--lr', type=float, default=0.00002)
parser.add_argument('--num_steps', type=int, default=100000)
parser.add_argument('--batch_size', type=int, default=6)
parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--iters', type=int, default=12)
parser.add_argument('--wdecay', type=float, default=.00005)
parser.add_argument('--epsilon', type=float, default=1e-8)
parser.add_argument('--clip', type=float, default=1.0)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
parser.add_argument('--add_noise', action='store_true')
args = parser.parse_args()
torch.manual_seed(1234)
np.random.seed(1234)
if not os.path.isdir('checkpoints'):
os.mkdir('checkpoints')
train(args)
================================================
FILE: third_party/RAFT/train_mixed.sh
================================================
#!/bin/bash
mkdir -p checkpoints
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision
================================================
FILE: third_party/RAFT/train_standard.sh
================================================
#!/bin/bash
mkdir -p checkpoints
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85
================================================
FILE: train_stereoanyvideo.py
================================================
import argparse
import logging
from pathlib import Path
from tqdm import tqdm
import os
import cv2
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from munch import DefaultMunch
import json
from pytorch_lightning.lite import LightningLite
from torch.cuda.amp import GradScaler
from stereoanyvideo.train_utils.utils import (
run_test_eval,
save_ims_to_tb,
count_parameters,
)
from stereoanyvideo.train_utils.logger import Logger
from stereoanyvideo.evaluation.core.evaluator import Evaluator
from stereoanyvideo.train_utils.losses import sequence_loss, temporal_loss
import stereoanyvideo.datasets.video_datasets as datasets
from stereoanyvideo.models.core.stereoanyvideo import StereoAnyVideo
def fetch_optimizer(args, model):
"""Create the optimizer and learning rate scheduler"""
for name, param in model.named_parameters():
if any([key in name for key in ['depthanything']]):
param.requires_grad_(False)
optimizer = optim.AdamW(
model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8
)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
args.lr,
args.num_steps + 100,
pct_start=0.01,
cycle_momentum=False,
anneal_strategy="linear",
)
return optimizer, scheduler
def forward_batch(batch, model, args):
output = {}
disparities = model(
batch["img"][:, :, 0],
batch["img"][:, :, 1],
iters=args.train_iters,
test_mode=False,
)
num_traj = len(batch["disp"][0])
for i in range(num_traj):
seq_loss, metrics = sequence_loss(
disparities[:, i], -batch["disp"][:, i, 0], batch["valid_disp"][:, i, 0])
output[f"disp_{i}"] = {"loss": seq_loss / num_traj, "metrics": metrics}
output["disparity"] = {
"predictions": torch.cat(
[disparities[-1, i] for i in range(num_traj)], dim=1).detach(),
}
return output
class Lite(LightningLite):
def run(self, args):
self.seed_everything(0)
evaluator = Evaluator()
eval_vis_cfg = {
"visualize_interval": 0, # Use 0 for no visualization
"exp_dir": args.ckpt_path,
}
eval_vis_cfg = DefaultMunch.fromDict(eval_vis_cfg, object())
evaluator.setup_visualization(eval_vis_cfg)
model = StereoAnyVideo()
model.cuda()
with open(args.ckpt_path + "/meta.json", "w") as file:
json.dump(vars(args), file, sort_keys=True, indent=4)
train_loader = datasets.fetch_dataloader(args)
train_loader = self.setup_dataloaders(train_loader, move_to_device=False)
logging.info(f"Train loader size: {len(train_loader)}")
optimizer, scheduler = fetch_optimizer(args, model)
print("Parameter Count:", {count_parameters(model)})
logging.info(f"Parameter Count: {count_parameters(model)}")
total_steps = 0
logger = Logger(model, scheduler, args.ckpt_path)
folder_ckpts = [
f
for f in os.listdir(args.ckpt_path)
if not os.path.isdir(f) and f.endswith(".pth") and not "final" in f
]
if len(folder_ckpts) > 0:
ckpt_path = sorted(folder_ckpts)[-1]
ckpt = self.load(os.path.join(args.ckpt_path, ckpt_path))
logging.info(f"Loading checkpoint {ckpt_path}")
if "model" in ckpt:
model.load_state_dict(ckpt["model"])
else:
model.load_state_dict(ckpt)
if "optimizer" in ckpt:
logging.info("Load optimizer")
optimizer.load_state_dict(ckpt["optimizer"])
if "scheduler" in ckpt:
logging.info("Load scheduler")
scheduler.load_state_dict(ckpt["scheduler"])
if "total_steps" in ckpt:
total_steps = ckpt["total_steps"]
logging.info(f"Load total_steps {total_steps}")
elif args.restore_ckpt is not None:
assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(
".pt"
)
logging.info("Loading checkpoint...")
strict = True
state_dict = self.load(args.restore_ckpt)
if "model" in state_dict:
state_dict = state_dict["model"]
if list(state_dict.keys())[0].startswith("module."):
state_dict = {
k.replace("module.", ""): v for k, v in state_dict.items()
}
model.load_state_dict(state_dict, strict=strict)
logging.info(f"Done loading checkpoint")
model, optimizer = self.setup(model, optimizer, move_to_device=False)
model.cuda()
model.train()
model.module.module.freeze_bn() # We keep BatchNorm frozen
scaler = GradScaler(enabled=args.mixed_precision)
should_keep_training = True
global_batch_num = 0
epoch = -1
while should_keep_training:
epoch += 1
for i_batch, batch in enumerate(tqdm(train_loader)):
optimizer.zero_grad()
if batch is None:
print("batch is None")
continue
for k, v in batch.items():
batch[k] = v.cuda()
assert model.training
output = forward_batch(batch, model, args)
loss = 0
logger.update()
for k, v in output.items():
if "loss" in v:
loss += v["loss"]
logger.writer.add_scalar(
f"live_{k}_loss", v["loss"].item(), total_steps
)
if "metrics" in v:
logger.push(v["metrics"], k)
if self.global_rank == 0:
if len(output) > 1:
logger.writer.add_scalar(
f"live_total_loss", loss.item(), total_steps
)
logger.writer.add_scalar(
f"learning_rate", optimizer.param_groups[0]["lr"], total_steps
)
global_batch_num += 1
self.barrier()
self.backward(scaler.scale(loss))
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
if total_steps < args.num_steps:
scheduler.step()
scaler.update()
total_steps += 1
if self.global_rank == 0:
if (total_steps % args.save_steps == 0) or (total_steps == 1 and args.validate_at_start):
ckpt_iter = "0" * (6 - len(str(total_steps))) + str(total_steps)
save_path = Path(
f"{args.ckpt_path}/model_{args.name}_{ckpt_iter}.pth"
)
save_dict = {
"model": model.module.module.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"total_steps": total_steps,
}
logging.info(f"Saving file {save_path}")
self.save(save_dict, save_path)
self.barrier()
if total_steps > args.num_steps:
should_keep_training = False
break
logger.close()
PATH = f"{args.ckpt_path}/{args.name}_final.pth"
torch.save(model.module.module.state_dict(), PATH)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--name", default="StereoAnyVideo", help="name your experiment")
parser.add_argument("--restore_ckpt", help="restore checkpoint")
parser.add_argument("--ckpt_path", help="path to save checkpoints")
parser.add_argument(
"--mixed_precision", action="store_true", help="use mixed precision"
)
# Training parameters
parser.add_argument(
"--batch_size", type=int, default=8, help="batch size used during training."
)
parser.add_argument(
"--train_datasets",
nargs="+",
default=["things", "monkaa", "driving"],
help="training datasets.",
)
parser.add_argument("--lr", type=float, default=0.0001, help="max learning rate.")
parser.add_argument(
"--num_steps", type=int, default=80000, help="length of training schedule."
)
parser.add_argument(
"--save_steps", type=int, default=3000, help="length of training schedule."
)
parser.add_argument(
"--image_size",
type=int,
nargs="+",
default=[320, 720],
help="size of the random image crops used during training.",
)
parser.add_argument(
"--train_iters",
type=int,
default=12,
help="number of updates to the disparity field in each forward pass.",
)
parser.add_argument(
"--wdecay", type=float, default=0.00001, help="Weight decay in optimizer."
)
parser.add_argument(
"--sample_len", type=int, default=5, help="length of training video samples"
)
parser.add_argument(
"--validate_at_start", action="store_true", help="validate the model at start"
)
parser.add_argument(
"--evaluate_every_n_epoch",
type=int,
default=1,
help="evaluate every n epoch",
)
parser.add_argument(
"--num_workers", type=int, default=6, help="number of dataloader workers."
)
# Validation parameters
parser.add_argument(
"--valid_iters",
type=int,
default=32,
help="number of updates to the disparity field in each forward pass during validation.",
)
# Data augmentation
parser.add_argument(
"--img_gamma", type=float, nargs="+", default=None, help="gamma range"
)
parser.add_argument(
"--saturation_range",
type=float,
nargs="+",
default=None,
help="color saturation",
)
parser.add_argument(
"--do_flip",
default=False,
choices=["h", "v"],
help="flip the images horizontally or vertically",
)
parser.add_argument(
"--spatial_scale",
type=float,
nargs="+",
default=[0, 0],
help="re-scale the images randomly",
)
parser.add_argument(
"--noyjitter",
action="store_true",
help="don't simulate imperfect rectification",
)
args = parser.parse_args()
Path(args.ckpt_path).mkdir(exist_ok=True, parents=True)
logging.basicConfig(
level=logging.INFO,
filename=args.ckpt_path + '/' + args.name + '.log',
filemode='a',
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
)
from pytorch_lightning.strategies import DDPStrategy
Lite(
strategy=DDPStrategy(find_unused_parameters=True),
devices="auto",
accelerator="gpu",
precision=32,
).run(args)
================================================
FILE: train_stereoanyvideo.sh
================================================
#!/bin/bash
export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH
python train_stereoanyvideo.py --batch_size 1 \
--spatial_scale -0.2 0.4 --image_size 256 512 --saturation_range 0 1.4 --num_steps 80000 \
--ckpt_path logging/StereoAnyVideo_SF \
--sample_len 5 --train_iters 10 --lr 0.0001 \
--num_workers 8 --save_steps 3000 --train_datasets things monkaa driving
================================================
FILE: train_utils/logger.py
================================================
import logging
import os
from torch.utils.tensorboard import SummaryWriter
class Logger:
SUM_FREQ = 100
def __init__(self, model, scheduler, ckpt_path):
self.model = model
self.scheduler = scheduler
self.total_steps = 0
self.running_loss = {}
self.ckpt_path = ckpt_path
self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, "runs"))
logging.info(
f"Training Metrics: 1px_disp_0...5, 3px_disp_0...5, 5px_disp_0...5, epe_disp_0...5"
)
def _print_training_status(self):
metrics_data = [
self.running_loss[k] / Logger.SUM_FREQ
for k in sorted(self.running_loss.keys())
]
training_str = "[{:6d}] ".format(self.total_steps + 1)
metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)
# print the training status
logging.info(
f"Training Metrics ({self.total_steps}): {training_str + metrics_str}"
)
if self.writer is None:
self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, "runs"))
for k in self.running_loss:
self.writer.add_scalar(
k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps
)
self.running_loss[k] = 0.0
def push(self, metrics, task):
for key in metrics:
task_key = str(key) + "_" + task
if task_key not in self.running_loss:
self.running_loss[task_key] = 0.0
self.running_loss[task_key] += metrics[key]
def update(self):
self.total_steps += 1
if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1:
print(self.running_loss)
self._print_training_status()
self.running_loss = {}
def write_dict(self, results):
if self.writer is None:
self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, "runs"))
for key in results:
self.writer.add_scalar(key, results[key], self.total_steps)
def close(self):
self.writer.close()
================================================
FILE: train_utils/losses.py
================================================
import torch
from einops import rearrange
import torch.nn.functional as F
def sequence_loss(flow_preds, flow_gt, valid, loss_gamma=0.9, max_flow=700):
"""Loss function defined over sequence of flow predictions"""
n_predictions = len(flow_preds)
assert n_predictions >= 1
flow_loss = 0.0
# exlude invalid pixels and extremely large diplacements
mag = torch.sum(flow_gt ** 2, dim=1).sqrt().unsqueeze(1)
if len(valid.shape) != len(flow_gt.shape):
valid = valid.unsqueeze(1)
valid = (valid >= 0.5) & (mag < max_flow)
if valid.shape != flow_gt.shape:
valid = torch.cat([valid, valid], dim=1)
assert valid.shape == flow_gt.shape, [valid.shape, flow_gt.shape]
assert not torch.isinf(flow_gt[valid.bool()]).any()
for i in range(n_predictions):
assert (
not torch.isnan(flow_preds[i]).any()
and not torch.isinf(flow_preds[i]).any()
)
if n_predictions == 1:
i_weight = 1
else:
# We adjust the loss_gamma so it is consistent for any number of iterations
adjusted_loss_gamma = loss_gamma ** (15 / (n_predictions - 1))
i_weight = adjusted_loss_gamma ** (n_predictions - i - 1)
flow_pred = flow_preds[i].clone()
if valid.shape[1] == 1 and flow_preds[i].shape[1] == 2:
flow_pred = flow_pred[:, :1]
i_loss = (flow_pred - flow_gt).abs()
assert i_loss.shape == valid.shape, [
i_loss.shape,
valid.shape,
flow_gt.shape,
flow_pred.shape,
]
flow_loss += i_weight * i_loss[valid.bool()].mean()
epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt()
valid = valid[:, 0]
epe = epe.view(-1)
epe = epe[valid.reshape(epe.shape)]
metrics = {
"epe": epe.mean().item(),
"1px": (epe < 1).float().mean().item(),
"3px": (epe < 3).float().mean().item(),
"5px": (epe < 5).float().mean().item(),
}
return flow_loss, metrics
def temporal_loss(flow_preds, flow_preds2, flow_gt, flow_gt2, valid, loss_gamma=0.9, max_flow=700):
assert len(flow_preds) == len(flow_preds2)
n_predictions = len(flow_preds)
assert n_predictions >= 1
flow_loss = 0.0
# exlude invalid pixels and extremely large diplacements
mag = torch.sum(flow_gt ** 2, dim=1).sqrt().unsqueeze(1)
if len(valid.shape) != len(flow_gt.shape):
valid = valid.unsqueeze(1)
valid = (valid >= 0.5) & (mag < max_flow)
if valid.shape != flow_gt.shape:
valid = torch.cat([valid, valid], dim=1)
assert valid.shape == flow_gt.shape, [valid.shape, flow_gt.shape]
assert not torch.isinf(flow_gt[valid.bool()]).any()
for i in range(n_predictions):
assert (
not torch.isnan(flow_preds[i]).any()
and not torch.isinf(flow_preds[i]).any()
)
assert (
not torch.isnan(flow_preds2[i]).any()
and not torch.isinf(flow_preds2[i]).any()
)
if n_predictions == 1:
i_weight = 1
else:
# We adjust the loss_gamma so it is consistent for any number of iterations
adjusted_loss_gamma = loss_gamma ** (15 / (n_predictions - 1))
i_weight = adjusted_loss_gamma ** (n_predictions - i - 1)
flow_pred = flow_preds[i].clone()
flow_pred2 = flow_preds2[i].clone()
if valid.shape[1] == 1 and flow_preds[i].shape[1] == 2:
flow_pred = flow_pred[:, :1]
flow_pred2 = flow_pred2[:, :1]
i_loss = ((flow_pred2 - flow_pred).abs() - (flow_gt2 - flow_gt).abs()).abs()
assert i_loss.shape == valid.shape, [
i_loss.shape,
valid.shape,
flow_gt.shape,
flow_pred.shape,
]
flow_loss += i_weight * i_loss[valid.bool()].mean()
tepe = torch.sum(((flow_preds2[-1] - flow_preds[-1]) - - (flow_gt2 - flow_gt)) ** 2, dim=1).sqrt()
mask = (flow_gt2 - flow_gt) < 5
valid = mask * valid
valid = valid[:, 0]
tepe = tepe.view(-1)
tepe = tepe[valid.reshape(tepe.shape)]
metrics = {
"tepe": tepe.mean().item(),
}
return flow_loss, metrics
def compute_flow(Flow_Model, seq):
n, t, c, h, w = seq.size()
flows_forward = []
flows_backward = []
for i in range(t-1):
# i-th flow_backward denotes seq[i+1] towards seq[i]
flow_backward = Flow_Model.forward_fullres(seq[:,i], seq[:,i+1])
# i-th flow_forward denotes seq[i] towards seq[i+1]
flow_forward = Flow_Model.forward_fullres(seq[:,i+1], seq[:,i])
flows_backward.append(flow_backward)
flows_forward.append(flow_forward)
flows_forward = torch.stack(flows_forward, dim=1)
flows_backward = torch.stack(flows_backward, dim=1)
return flows_forward, flows_backward
def flow_warp(x, flow):
if flow.size(3) != 2: # [B, H, W, 2]
flow = flow.permute(0, 2, 3, 1)
if x.size()[-2:] != flow.size()[1:3]:
raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
f'flow ({flow.size()[1:3]}) are not the same.')
_, _, h, w = x.size()
# create mesh grid
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (h, w, 2)
grid.requires_grad = False
grid_flow = grid + flow
# scale grid_flow to [-1,1]
grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
output = F.grid_sample(
x,
grid_flow,
mode='bilinear',
padding_mode='zeros',
align_corners=True)
return output
def bidirectional_alignment(seq, flows_backward, flows_forward):
b, T, *_ = seq.shape
# seq_backward = seq[:, 1:, ...]
# seq_forward = seq[:, :T - 1, ...]
# seq_backward = rearrange(seq_backward, "b t c h w -> (b t) c h w")
# seq_forward = rearrange(seq_forward, "b t c h w -> (b t) c h w")
# flows_forward = rearrange(flows_forward, "b t c h w -> (b t) c h w")
# flows_backward = rearrange(flows_backward, "b t c h w -> (b t) c h w")
# seq_backward = flow_warp(seq_backward, flows_backward)
# seq_forward = flow_warp(seq_forward, flows_forward)
# seq_backward = rearrange(seq_backward, "(b t) c h w -> b t c h w", b=b, t=T - 1)
# seq_forward = rearrange(seq_forward, "(b t) c h w -> b t c h w", b=b, t=T - 1)
# output_backward = torch.cat((seq_backward, seq[:, -1:]), dim=1)
# output_forward = torch.cat((seq[:, :1], seq_forward), dim=1)
output_backward = []
for i in range(1, T):
feat_prop = flow_warp(seq[:, i], flows_backward[:, i-1])
output_backward.append(feat_prop)
output_backward.append(seq[:, T - 1])
output_backward = torch.stack(output_backward, dim=1)
# forward-time process
output_forward = [seq[:, 0]]
for i in range(T - 1):
feat_prop = flow_warp(seq[:, i], flows_forward[:, i])
output_forward.append(feat_prop)
output_forward = torch.stack(output_forward, dim=1)
return output_backward, output_forward
def consistency_loss(seq, disparities, Flow_Model, alpha=50):
b, T, *_ = seq.shape
# compute optical flow
flows_forward, flows_backward = compute_flow(Flow_Model, seq)
seq_backward, seq_forward = bidirectional_alignment(seq, flows_backward, flows_forward)
disparities_backward, disparities_forward = bidirectional_alignment(disparities, flows_backward, flows_forward)
diff_disparities_back = torch.abs(disparities_backward - disparities)
diff_disparities_for = torch.abs(disparities_forward - disparities)
diff_seq_back = (seq_backward - seq) ** 2
diff_seq_for = (seq_forward - seq) ** 2
mask_seq_back = torch.exp(-(alpha * diff_seq_back))
mask_seq_for = torch.exp(-(alpha * diff_seq_for))
mask_seq_back = torch.sum(mask_seq_back, dim=2, keepdim=True)
mask_seq_for = torch.sum(mask_seq_for, dim=2, keepdim=True)
temporal_loss_back = torch.mul(mask_seq_back, diff_disparities_back)
temporal_loss_for = torch.mul(mask_seq_for, diff_disparities_for)
temporal_loss = torch.mean(temporal_loss_back) + torch.mean(temporal_loss_for)
return temporal_loss
================================================
FILE: train_utils/utils.py
================================================
import numpy as np
import os
import torch
import json
import flow_vis
import matplotlib.pyplot as plt
import stereoanyvideo.datasets.video_datasets as datasets
from stereoanyvideo.evaluation.utils.utils import aggregate_and_print_results
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def run_test_eval(ckpt_path, eval_type, evaluator, model, dataloaders, writer, step):
for ds_name, dataloader in dataloaders:
# evaluator.visualize_interval = 1 if not "sintel" in ds_name else 0
evaluate_result = evaluator.evaluate_sequence(
model=model.module.module,
test_dataloader=dataloader,
writer=writer if not "sintel" in ds_name else None,
step=step,
train_mode=True,
)
aggregate_result = aggregate_and_print_results(
evaluate_result,
)
save_metrics = [
"flow_mean_accuracy_5px",
"flow_mean_accuracy_3px",
"flow_mean_accuracy_1px",
"flow_epe_traj_mean",
]
for epe_name in ("epe", "temp_epe", "temp_epe_r"):
for m in [
f"disp_{epe_name}_bad_0.5px",
f"disp_{epe_name}_bad_1px",
f"disp_{epe_name}_bad_2px",
f"disp_{epe_name}_bad_3px",
f"disp_{epe_name}_mean",
]:
save_metrics.append(m)
for k, v in aggregate_result.items():
if k in save_metrics:
writer.add_scalars(
f"{ds_name}_{k.rsplit('_', 1)[0]}",
{f"{ds_name}_{k}": v},
step,
)
result_file = os.path.join(
ckpt_path,
f"result_{ds_name}_{eval_type}_{step}_mimo.json",
)
print(f"Dumping {eval_type} results to {result_file}.")
with open(result_file, "w") as f:
json.dump(aggregate_result, f)
def fig2data(fig):
"""
fig = plt.figure()
image = fig2data(fig)
@brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it
@param fig a matplotlib figure
@return a numpy 3D array of RGBA values
"""
import PIL.Image as Image
# draw the renderer
fig.canvas.draw()
# Get the RGBA buffer from the figure
w, h = fig.canvas.get_width_height()
buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
buf.shape = (w, h, 3)
image = Image.frombytes("RGB", (w, h), buf.tobytes())
image = np.asarray(image)
return image
def save_ims_to_tb(writer, batch, output, total_steps):
writer.add_image(
"train_im",
torch.cat([torch.cat([im[0], im[1]], dim=-1) for im in batch["img"][0]], dim=-2)
/ 255.0,
total_steps,
dataformats="CHW",
)
if "disp" in batch and len(batch["disp"]) > 0:
disp_im = [
(torch.cat([im[0], im[1]], dim=-1) * torch.cat([val[0], val[1]], dim=-1))
for im, val in zip(batch["disp"][0], batch["valid_disp"][0])
]
disp_im = torch.cat(disp_im, dim=1)
figure = plt.figure()
plt.imshow(disp_im.cpu()[0])
disp_im = fig2data(figure).copy()
writer.add_image(
"train_disp",
disp_im,
total_steps,
dataformats="HWC",
)
for k, v in output.items():
if "predictions" in v:
pred = v["predictions"]
if k == "disparity":
figure = plt.figure()
plt.imshow(pred.cpu()[0])
pred = fig2data(figure).copy()
dataformat = "HWC"
else:
pred = torch.tensor(
flow_vis.flow_to_color(
pred.permute(1, 2, 0).cpu().numpy(), convert_to_bgr=False
)
/ 255.0
)
dataformat = "HWC"
writer.add_image(
f"pred_{k}",
pred,
total_steps,
dataformats=dataformat,
)
if "gt" in v:
gt = v["gt"]
gt = torch.tensor(
flow_vis.flow_to_color(
gt.permute(1, 2, 0).cpu().numpy(), convert_to_bgr=False
)
/ 255.0
)
dataformat = "HWC"
writer.add_image(
f"gt_{k}",
gt,
total_steps,
dataformats=dataformat,
)