1Nanjing University 2Institute of Automation, Chinese Academy of Science
*Equal Contribution †Corresponding Author
CVPR 2026
## 🎉NEWS
+ [2026.02.21] 🎉 SpatialVID is accepted by CVPR 2026!
+ [2025.10.11] 🐳 Docker support is now available, featuring a pre-configured environment with NVIDIA GPU-accelerated FFmpeg.
+ [2025.09.29] 🚀 Depth data for the SpatialVID-HQ dataset is now officially available.
+ [2025.09.24] 🤗 Raw metadata access is now available via a [gated HuggingFace dataset](https://huggingface.co/datasets/SpatialVID/SpatialVID-RAW) to better support community research!!
+ [2025.09.24] 🔭 Enhanced instructions for better camera control are updated.
+ [2025.09.18] 🎆 SpatialVID dataset is now available on both HuggingFace and ModelScope.
+ [2025.09.14] 📢 We have also uploaded the SpatialVID-HQ dataset to ModelScope offering more diverse download options.
+ [2025.09.11] 🔥 Our paper, code and SpatialVID-HQ dataset are released!
**[✍️ Note]** Each video clip is paired with a dedicated annotation folder (named after the video’s id). The folder contains 5 key files, and details regarding these files can be found in [Detailed Explanation of Annotation Files](https://huggingface.co/datasets/SpatialVID/SpatialVID#3-detailed-explanation-of-annotation-files).
## Abstract
Significant progress has been made in spatial intelligence, spanning both spatial reconstruction and world exploration. However, the scalability and real-world fidelity of current models remain severely constrained by the scarcity of large-scale, high-quality training data. While several datasets provide camera pose information, they are typically limited in scale, diversity, and annotation richness, particularly for real-world dynamic scenes with ground-truth camera motion. To this end, we collect **SpatialVID**, a dataset consisting of a large corpus of in-the-wild videos with diverse scenes, camera movements and dense 3D annotations such as per-frame camera poses, depth, and motion instructions. Specifically, we collect more than **21,000 hours** of raw videos, and process them into **2.7 million clips** through a hierarchical filtering pipeline, totaling **7,089 hours** of dynamic content. A subsequent annotation pipeline enriches these clips with detailed spatial and semantic information, including camera poses, depth maps, dynamic masks, structured captions, and serialized motion instructions. Analysis of SpatialVID's data statistics reveals a richness and diversity that directly foster improved model generalization and performance, establishing it as a key asset for the video and 3D vision research community.
## Preparation
This section describes how to set up the environment manually. For a simpler, containerized setup, please refer to the **[Docker Setup and Usage](#docker-setup-and-usage)** section.
### Environment
1. Necessary packages
```bash
git clone --recursive https://github.com/NJU-3DV/SpatialVID.git
cd SpatialVid
conda create -n SpatialVID python=3.10.13
conda activate SpatialVID
pip install -r requirements/requirements.txt
```
2. Package needed for scoring
```bash
pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
pip install -r requirements/requirements_scoring.txt
```
Ignore the warning about `nvidia-nccl-cu12` and `numpy` version, it is not a problem.
About FFMPEG, please refer to the [`INSTALL.md`](scoring/motion/INSTALL.md) for detailed instructions on how to install ffmpeg. After installation, replace the `FFMPEG_PATH` variable in the [`scoring/motion/inference.py`](scoring/motion/inference.py) and [`utils/cut.py`](utils/cut.py) with the actual path to your ffmpeg executable, default is `/usr/local/bin/ffmpeg`.
⚠️ If your videos are in av1 codec instead of h264, you need to install ffmpeg (already in our requirement script), then run the following to make conda support av1 codec:
```bash
pip uninstall opencv-python
conda install -c conda-forge opencv==4.11.0
```
If unfortunately your conda environment still cannot support av1 codec, you can use the `--backend av` option in the scoring scripts to use PyAV as the video reading backend.
But note that using PyAV for frame extraction may lead to slight inaccuracies in frame positioning.
3. Package needed for annotation
```bash
pip install -r requirements/requirements_annotation.txt
```
Compile the extensions for the camera tracking module:
```bash
cd camera_pose_annotation/base
python setup.py install
```
4. [Optional] Package needed for visualization
```bash
pip install plotly
pip install -e viser
```
### Model Weight
Download the model weights used in our experiments:
```bash
bash scripts/download_checkpoints.sh
```
Or you can manually download the model weights from the following links and place them in the appropriate directories.
| Model | File Name | URL |
| ------------------- | ----------------------- | --------------------------------------------------------------------------------------------------------------- |
| Aesthetic Predictor | aesthetic | [🔗](https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac+logos+ava1-l14-linearMSE.pth) |
| MegaSAM | megasam_final | [🔗](https://github.com/mega-sam/mega-sam/blob/main/checkpoints/megasam_final.pth) |
| RAFT | raft-things | [🔗](https://drive.google.com/uc?id=1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM) |
| Depth Anything | Depth-Anything-V2-Large | [🔗](https://huggingface.co/depth-anything/Depth-Anything-V2-Large) |
| UniDepth | unidepth-v2-vitl14 | [🔗](https://huggingface.com/lpiccinelli/unidepth-v2-vitl14) |
| SAM | sam2.1-hiera-large | [🔗](https://huggingface.co/facebook/sam2.1-hiera-large) |
## Quick Start
The whole pipeline is illustrated in the figure below:
1. Scoring
```bash
bash scripts/scoring.sh
```
Inside the [`scoring.sh`](scripts/scoring.sh) script, you need to set the following variables:
- `ROOT_VIDEO` is the directory containing the input video files.
- `OUTPUT_DIR` is the directory where the output files will be saved.
2. Annotation
```bash
bash scripts/annotation.sh
```
Inside the [`annotation.sh`](scripts/annotation.sh) script, you need to set the following variables:
- `CSV` is the CSV file generated by the scoring script, default is `$OUTPUT_DIR/results.csv`.
- `OUTPUT_DIR` is the directory where the output files will be saved.
3. Caption
```bash
bash scripts/caption.sh
```
Inside the [`caption.sh`](scripts/caption.sh) script, you need to set the following variables:
- `CSV` is the CSV file generated by the annotation script, default is `$OUTPUT_DIR/results.csv`.
- `SRC_DIR` is the annotation output directory, default is the same as the `OUTPUT_DIR` in the annotation step.
- `OUTPUT_DIR` is the directory where the output files will be saved.
- The API keys for the LLM models used in the captioning step. You can replace them with your own API keys.
4. Visualization
- You can visualize the `poses.npy` in the `reconstruction` folder of each annotated clip using the [`visualize_pose.py`](viser/visualize_pose.py) script.
- You can visualize the final annotation result(`sgd_cvd_hr.npz`) using the [`visualize_megasam.py`](viser/visualize_megasam.py) script.
Note that if you want to visualize any clip in our dataset, you need to use the script [`pack_clip_assets.py`](utils/pack_clip_assets.py) to unify the depth, RGB frames, intrinsics, extrinsics, etc. of that clip into a single npz file first. And then you can use the visualization script to visualize it.
## Docker Setup and Usage
We provide a Dockerfile to create a fully configured environment that includes all dependencies, including a custom-built FFmpeg with NVIDIA acceleration. This is the recommended way to ensure reproducibility and avoid environment-related issues.
Before you begin, ensure your system environment is similar to the configuration below. Version matching is crucial for a successful compilation.
The GPU needs to support HEVC; refer to the [NVIDIA NVDEC Support Matrix](https://en.wikipedia.org/wiki/NVIDIA_Video_Coding_Engine#NVDEC).
### Prerequisites: Setting up the Host Environment
Before building and running the Docker container, your host machine must be configured to support GPU access for Docker.
1. **NVIDIA Drivers**: Ensure you have the latest NVIDIA drivers installed. You can verify this by running `nvidia-smi`.
2. **Docker Engine**: Install Docker on your system. Follow the official instructions at [docs.docker.com/engine/install/](https://docs.docker.com/engine/install/).
3. **NVIDIA Container Toolkit**: This toolkit allows Docker containers to access the host's NVIDIA GPU. Install it using the following commands (for Debian/Ubuntu):
To run docker containers with GPU support you have to install the [nvidia container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
```bash
# Add the GPG key
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
# Add the repository
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
# Update package lists and install the toolkit
sudo apt-get install -y \
nvidia-container-toolkit=${NVIDIA_CONTAINER_TOOLKIT_VERSION} \
nvidia-container-toolkit-base=${NVIDIA_CONTAINER_TOOLKIT_VERSION} \
libnvidia-container-tools=${NVIDIA_CONTAINER_TOOLKIT_VERSION} \
libnvidia-container1=${NVIDIA_CONTAINER_TOOLKIT_VERSION}
# Configure Docker to use the NVIDIA runtime
sudo nvidia-ctk runtime configure --runtime=containerd
# Restart the Docker daemon to apply the changes
sudo systemctl restart containerd
```
For other operating systems, please refer to the [official NVIDIA documentation](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
4. **Docker Image Pre-pulls [optional]**: To accelerate the build process, we provide a script to pre-pull necessary Docker images from a mirror registry.
```bash
bash scripts/build_gpu_docker.sh
```
### Build and Run the Container
You can also build and run the image using standard Docker commands from the root of the repository.
1. **Build the GPU image**:
```bash
docker build -f Dockerfile.cuda \
--build-arg NUM_JOBS=8 \
-t spatialvid-gpu .
```
2. **Run the container**:
```bash
docker run --gpus all --rm -it \
-v $(pwd):/workspace \
-w /workspace \
-e NVIDIA_DRIVER_CAPABILITIES=compute,video,utility \
spatialvid-gpu bash
```
3. **Verify the environment (inside the container)**:
Once inside the container, you can verify that FFmpeg and PyTorch are correctly installed and can access the GPU.
```bash
# Check the custom FFmpeg build
/usr/local/bin/ffmpeg -version
# Check PyTorch and CUDA availability
python3 -c "import torch; print(f'PyTorch: {torch.__version__}, CUDA: {torch.version.cuda}, GPU Available: {torch.cuda.is_available()}')"
```
## Dataset Download
Our dataset is available on [HuggingFace](https://huggingface.co/SpatialVID) and [ModelScope](https://www.modelscope.cn/organization/SpatialVID).
Apart from downloading the dataset using terminal commands, we provide scripts to download the SpatialVID/SpatialVID-HQ dataset from HuggingFace. Please refer to the [`download_SpatialVID.py`](utils/download_SpatialVID.py) script for more details.
We also provide our script to download the raw videos from YouTube. You can refer to the [`download_YouTube.py`](utils/download_YouTube.py) script for more details.
## License
Please refer to the [LICENSE](LICENSE) file for more details about the license of our code.
⚠️ SpatialVID dataset is released under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) (CC-BY-NC-SA-4.0). Users must attribute the original source, use the resource only for non-commercial purposes, and release any modified/derived works under the same license. If you are the copyright owner of any video in our dataset and you need it to be removed, please contact us, and we will remove the video samples from our dataset / Github / project webpage / technical presentation as soon as possible.
## References
Thanks to the developers and contributors of the following open-source repositories, whose invaluable work has greatly inspire our project:
- [Open-Sora](https://github.com/hpcaitech/Open-Sora): An initiative dedicated to efficiently producing high-quality video.
- [MegaSaM](https://github.com/mega-sam/mega-sam): An accurate, fast and robust casual structure and motion from casual dynamic videos.
- [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2): A model for monocular depth estimation.
- [UniDepthV2](https://github.com/lpiccinelli-eth/UniDepth): A model for universal monocular metric depth estimation.
- [SAM2](https://github.com/facebookresearch/sam2): A model towards solving promptable visual segmentation in images and videos.
- [Viser](https://viser.studio/latest/): A library for interactive 3D visualization in Python.
Our repository is licensed under the Apache 2.0 License. However, if you use MegaSaM or other components in your work, please follow their license.
## Citation
```bibtex
@article{wang2025spatialvid,
title={Spatialvid: A large-scale video dataset with spatial annotations},
author={Wang, Jiahao and Yuan, Yufeng and Zheng, Rujie and Lin, Youtian and Gao, Jian and Chen, Lin-Zhuo and Bao, Yajie and Zhang, Yi and Zeng, Chang and Zhou, Yanxi and others},
journal={arXiv preprint arXiv:2509.09676},
year={2025}
}
```
================================================
FILE: camera_pose_annotation/.gitignore
================================================
# files
data/*
*.log
*.txt
*.bz2
*.zip
*.ipynb
data_videos
!requirements.txt
!requirements_megasam.txt
#python
*.pyc
__pycache__/
# dir
outputs/
outputs_303/
data_videos/
checkpoints/*
!checkpoints/megasam_final.pth
DROID-SLAM/
.vscode/
================================================
FILE: camera_pose_annotation/README.md
================================================
# Camera Pose Annotation
## Depth Estimation
Use both [Depth-Anything V2](depth_estimation/Depth-Anything) and [UniDepth V2](depth_estimation/UniDepth) to estimate depth maps from images.
Download the pre-trained models from the respective repositories. Skip this step if you already follow the installation instructions in [README](../README.md).
- [Depth-Anything V2](https://huggingface.co/depth-anything/Depth-Anything-V2-Large)
- [UniDepth V2](https://huggingface.co/lpiccinelli/unidepth-v2-vitl14)
To inference depth using Depth-Anything V2, run the following command:
```bash
torchrun --standalone --nproc_per_node ${GPU_NUM} camera_pose_annotation/depth_estimation/Depth-Anything/inference_batch.py \
${CSV} \
--encoder vitl \
--checkpoints_path checkpoints \
--OUTPUT_DIR ${OUTPUT_DIR} \
--bs 16 \
--num_workers ${GPU_NUM}
```
To inference depth using UniDepth V2, run the following command:
```bash
torchrun --standalone --nproc_per_node ${GPU_NUM} camera_pose_annotation/depth_estimation/UniDepth/inference_batch.py \
${CSV} \
--OUTPUT_DIR ${OUTPUT_DIR} \
--checkpoints_path checkpoints \
--bs 32 \
--num_workers ${GPU_NUM}
```
## Camera Tracking
Using a DROID-SLAM based method to track camera poses from videos.
To inference a single video, run the following command:
```bash
python camera_pose_annotation/camera_tracking/camera_tracking.py \
--dir_path ${DIR_PATH} \
--weights checkpoints/megasam_final.pth \
--disable_vis
```
To inference videos in batch, run the following command:
```bash
python camera_pose_annotation/camera_tracking/inference_batch.py ${CSV} \
--OUTPUT_DIR ${OUTPUT_DIR} \
--checkpoints_path checkpoints --gpu_id ${CUDA_VISIBLE_DEVICES} \
--num_workers $((GPU_NUM * 2))
```
## CVD (Camera View Depth) Optimization
### Optical Flow
Infer optical flow using RAFT model.
Download the [`raft_things.pth`](https://drive.google.com/uc?id=1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM).
To inference a single video, run the following command:
```bash
python camera_pose_annotation/cvd_opt/preprocess/preprocess_flow.py \
--dir_path ${DIR_PATH} \
--model checkpoints/raft-things.pth \
--mixed_precision
```
To inference videos in batch, run the following command:
```bash
python camera_pose_annotation/cvd_opt/preprocess/inference_batch.py ${CSV} \
--OUTPUT_DIR ${OUTPUT_DIR} \
--checkpoints_path checkpoints --gpu_id ${CUDA_VISIBLE_DEVICES} \
--num_workers $((GPU_NUM * 2))
```
### Optimization
Using the optical flow to optimize the estimated depth maps.
To inference a single video, run the following command:
```bash
python camera_pose_annotation/cvd_opt/cvd_opt.py \
--dir_path ${DIR_PATH} \
--w_grad 2.0 --w_normal 5.0
```
To inference videos in batch, run the following command:
```bash
python camera_pose_annotation/cvd_opt/inference_batch.py ${CSV} \
--OUTPUT_DIR ${OUTPUT_DIR} \
--gpu_id ${CUDA_VISIBLE_DEVICES} \
--num_workers $((GPU_NUM * 2))
```
## Dynamic Mask
Given the limitations of MegaSaM in predicting motion probabilities, we opt to enhance its performance using SAM2.
Specifically, an adaptive thresholding mechanism, calibrated to the system’s motion probability distribution, is first employed to generate initial masks. Subsequently, contour detection is performed to mitigate redundant segmentation of overlapping regions; for each identified contour, four evenly spaced anchor points are sampled along its perimeter to serve as dedicated prompts for the SAM2 model.
Download the pre-trained [SAM2 model](https://huggingface.co/facebook/sam2.1-hiera-large).
Run the following command:
```bash
python camera_pose_annotation/dynamic_mask/inference_batch.py ${CSV} \
--OUTPUT_DIR ${OUTPUT_DIR} \
--checkpoints_path checkpoints --gpu_num ${GPU_NUM} \
--num_workers $((GPU_NUM * 2))
```
================================================
FILE: camera_pose_annotation/__init__.py
================================================
================================================
FILE: camera_pose_annotation/camera_tracking/__init__.py
================================================
================================================
FILE: camera_pose_annotation/camera_tracking/camera_tracking.py
================================================
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test camera tracking on a single scene."""
# pylint: disable=invalid-name
# pylint: disable=g-importing-member
# pylint: disable=g-bad-import-order
# pylint: disable=g-import-not-at-top
# pylint: disable=redefined-outer-name
# pylint: disable=undefined-variable
# pylint: disable=undefined-loop-variable
import sys
sys.path.append("camera_pose_annotation/base/droid_slam")
from droid import Droid
from lietorch import SE3
import argparse
import glob
import os
import cv2
import torch
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
def image_stream(
image_list,
mono_disp_list,
scene_name,
use_depth=False,
aligns=None,
K=None,
stride=1,
):
"""image generator."""
del scene_name, stride
fx, fy, cx, cy = (
K[0, 0],
K[1, 1],
K[0, 2],
K[1, 2],
) # np.loadtxt(os.path.join(dir_path, 'calibration.txt')).tolist()
for t, (image_file) in enumerate(image_list):
image = cv2.imread(image_file)
# depth = cv2.imread(depth_file, cv2.IMREAD_ANYDEPTH) / 5000.
# depth = np.float32(np.load(depth_file)) / 300.0
# depth = 1. / pt_data["depth"]
mono_disp = mono_disp_list[t]
# mono_disp = np.float32(np.load(disp_file)) #/ 300.0
depth = np.clip(
1.0 / ((1.0 / aligns[2]) * (aligns[0] * mono_disp + aligns[1])),
1e-4,
1e4,
)
depth[depth < 1e-2] = 0.0
# breakpoint()
h0, w0, _ = image.shape
h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0)))
w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0)))
image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_AREA)
image = image[: h1 - h1 % 8, : w1 - w1 % 8]
image = torch.as_tensor(image).permute(2, 0, 1)
depth = torch.as_tensor(depth)
depth = F.interpolate(
depth[None, None], (h1, w1), mode="nearest-exact"
).squeeze()
depth = depth[: h1 - h1 % 8, : w1 - w1 % 8]
mask = torch.ones_like(depth)
intrinsics = torch.as_tensor([fx, fy, cx, cy])
intrinsics[0::2] *= w1 / w0
intrinsics[1::2] *= h1 / h0
if use_depth:
yield t, image[None], depth, intrinsics, mask
else:
yield t, image[None], intrinsics, mask
def save_full_reconstruction(
droid, full_traj, rgb_list, senor_depth_list, motion_prob, scene_name, save_path
):
"""Save full reconstruction."""
from pathlib import Path
t = full_traj.shape[0]
images = np.array(rgb_list[:t]) # droid.video.images[:t].cpu().numpy()
disps = 1.0 / (np.array(senor_depth_list[:t]) + 1e-6)
poses = full_traj # .cpu().numpy()
intrinsics = droid.video.intrinsics[:t].cpu().numpy()
Path(f"{save_path}").mkdir(parents=True, exist_ok=True)
np.save(f"{save_path}/images.npy", images)
np.save(f"{save_path}/disps.npy", disps)
np.save(f"{save_path}/poses.npy", poses)
np.save(f"{save_path}/intrinsics.npy", intrinsics * 8.0)
np.save(f"{save_path}/motion_prob.npy", motion_prob)
intrinsics = intrinsics[0] * 8.0
poses_th = torch.as_tensor(poses, device="cpu")
cam_c2w = SE3(poses_th).inv().matrix().numpy()
K = np.eye(3)
K[0, 0] = intrinsics[0]
K[1, 1] = intrinsics[1]
K[0, 2] = intrinsics[2]
K[1, 2] = intrinsics[3]
max_frames = min(1000, images.shape[0])
if not os.path.exists(save_path):
os.makedirs(save_path)
np.savez(
os.path.join(save_path, f"{scene_name}_droid.npz"),
images=np.uint8(images[:max_frames, ::-1, ...].transpose(0, 2, 3, 1)),
depths=np.float32(1.0 / disps[:max_frames, ...]),
intrinsic=K,
cam_c2w=cam_c2w[:max_frames],
)
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--dir_path", help="path to the dataset")
parser.add_argument("--weights", default="droid.pth")
parser.add_argument("--buffer", type=int, default=1024)
parser.add_argument("--image_size", default=[240, 320])
parser.add_argument("--disable_vis", action="store_true")
parser.add_argument("--beta", type=float, default=0.3)
parser.add_argument(
"--filter_thresh", type=float, default=2.0
) # motion threhold for keyframe
parser.add_argument("--warmup", type=int, default=8)
parser.add_argument("--keyframe_thresh", type=float, default=2.0)
parser.add_argument("--frontend_thresh", type=float, default=12.0)
parser.add_argument("--frontend_window", type=int, default=25)
parser.add_argument("--frontend_radius", type=int, default=2)
parser.add_argument("--frontend_nms", type=int, default=1)
parser.add_argument("--stereo", action="store_true")
parser.add_argument("--depth", action="store_true")
parser.add_argument("--upsample", action="store_true")
parser.add_argument("--scene_name", help="scene_name")
parser.add_argument("--backend_thresh", type=float, default=16.0)
parser.add_argument("--backend_radius", type=int, default=2)
parser.add_argument("--backend_nms", type=int, default=3)
return parser.parse_args()
def main():
args = parse_args()
scene_name = os.path.basename(args.dir_path)
rgb_list = []
senor_depth_list = []
img_path = os.path.join(args.dir_path, "img")
img_list = sorted(glob.glob(os.path.join(img_path, "*.jpg")))
img_list += sorted(glob.glob(os.path.join(img_path, "*.png")))
# NOTE Mono is inverse depth, but metric-depth is depth!
mono_disp_paths = sorted(
glob.glob(os.path.join(args.dir_path, "depth-anything", "*.npy"))
)
metric_depth_paths = sorted(
glob.glob(os.path.join(args.dir_path, "unidepth", "*.npz"))
)
img_0 = cv2.imread(img_list[0])
scales = []
shifts = []
mono_disp_list = []
fovs = []
for t, (mono_disp_file, metric_depth_file) in enumerate(
zip(mono_disp_paths, metric_depth_paths)
):
da_disp = np.float32(np.load(mono_disp_file)) # / 300.0
uni_data = np.load(metric_depth_file)
metric_depth = uni_data["depth"]
fovs.append(uni_data["fov"])
da_disp = cv2.resize(
da_disp,
(metric_depth.shape[1], metric_depth.shape[0]),
interpolation=cv2.INTER_NEAREST_EXACT,
)
mono_disp_list.append(da_disp)
gt_disp = 1.0 / (metric_depth + 1e-8)
# avoid some bug from UniDepth
valid_mask = (metric_depth < 2.0) & (da_disp < 0.02)
gt_disp[valid_mask] = 1e-2
# avoid cases sky dominate entire video
sky_ratio = np.sum(da_disp < 0.01) / (da_disp.shape[0] * da_disp.shape[1])
if sky_ratio > 0.5:
non_sky_mask = da_disp > 0.01
gt_disp_ms = gt_disp[non_sky_mask] - np.median(gt_disp[non_sky_mask]) + 1e-8
da_disp_ms = da_disp[non_sky_mask] - np.median(da_disp[non_sky_mask]) + 1e-8
scale = np.median(gt_disp_ms / da_disp_ms)
shift = np.median(gt_disp[non_sky_mask] - scale * da_disp[non_sky_mask])
else:
gt_disp_ms = gt_disp - np.median(gt_disp) + 1e-8
da_disp_ms = da_disp - np.median(da_disp) + 1e-8
scale = np.median(gt_disp_ms / da_disp_ms)
shift = np.median(gt_disp - scale * da_disp)
gt_disp_ms = gt_disp - np.median(gt_disp) + 1e-8
da_disp_ms = da_disp - np.median(da_disp) + 1e-8
scale = np.median(gt_disp_ms / da_disp_ms)
shift = np.median(gt_disp - scale * da_disp)
scales.append(scale)
shifts.append(shift)
print("************** UNIDEPTH FOV ", np.median(fovs))
ff = img_0.shape[1] / (2 * np.tan(np.radians(np.median(fovs) / 2.0)))
K = np.eye(3)
K[0, 0] = ff * 1.0 # pp_intrinsic[0] * (img_0.shape[1] / (pp_intrinsic[1] * 2))
K[1, 1] = ff * 1.0 # pp_intrinsic[0] * (img_0.shape[0] / (pp_intrinsic[2] * 2))
K[0, 2] = (
img_0.shape[1] / 2.0
) # pp_intrinsic[1]) * (img_0.shape[1] / (pp_intrinsic[1] * 2))
K[1, 2] = (
img_0.shape[0] / 2.0
) # (pp_intrinsic[2]) * (img_0.shape[0] / (pp_intrinsic[2] * 2))
ss_product = np.array(scales) * np.array(shifts)
med_idx = np.argmin(np.abs(ss_product - np.median(ss_product)))
align_scale = scales[med_idx] # np.median(np.array(scales))
align_shift = shifts[med_idx] # np.median(np.array(shifts))
normalize_scale = (
np.percentile((align_scale * np.array(mono_disp_list) + align_shift), 98) / 2.0
)
aligns = (align_scale, align_shift, normalize_scale)
for t, image, depth, intrinsics, mask in tqdm(
image_stream(
img_list,
mono_disp_list,
scene_name,
use_depth=True,
aligns=aligns,
K=K,
)
):
rgb_list.append(image[0])
senor_depth_list.append(depth)
# breakpoint()
if t == 0:
args.image_size = [image.shape[2], image.shape[3]]
droid = Droid(args, device=0)
droid.track(t, image, depth, intrinsics=intrinsics, mask=mask)
# last frame
droid.track_final(t, image, depth, intrinsics=intrinsics, mask=mask)
traj_est, depth_est, motion_prob = droid.terminate(
image_stream(
img_list,
mono_disp_list,
scene_name,
use_depth=True,
aligns=aligns,
K=K,
),
_opt_intr=True, # default is opt_focal
full_ba=True,
scene_name=scene_name,
)
save_full_reconstruction(
droid,
traj_est,
rgb_list,
senor_depth_list,
motion_prob,
args.scene_name,
os.path.join(args.dir_path, "reconstructions"),
)
if __name__ == "__main__":
main()
================================================
FILE: camera_pose_annotation/camera_tracking/inference_batch.py
================================================
"""
Batch inference for camera tracking using multiple GPUs.
This module provides functionality for:
- Parallel camera tracking processing across multiple videos
- Multi-GPU support with automatic device assignment
- Subprocess management for camera tracking pipeline
- Progress tracking and error handling
"""
import pandas as pd
import os
import argparse
import concurrent.futures
from multiprocessing import Manager
import subprocess
import queue
from tqdm import tqdm
def process_single_row(row, index, args, worker_id=0):
"""
Process a single video for camera tracking.
"""
dir_path = os.path.join(args.dir_path, row["id"])
device_id = worker_id % args.gpu_num
cmd = (
f"CUDA_VISIBLE_DEVICES={args.gpu_id[device_id]} python camera_pose_annotation/camera_tracking/camera_tracking.py "
f"--dir_path {dir_path} "
f"--weights {args.checkpoints_path}/megasam_final.pth "
f"--disable_vis"
)
process = subprocess.Popen(
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdout, stderr = process.communicate()
if process.returncode != 0:
print(f"Error tracking camera for {row['id']}: {stderr.decode()}")
def worker(task_queue, args, worker_id, pbar):
"""
Worker function for parallel camera tracking processing.
"""
while True:
try:
index, row = task_queue.get(timeout=1)
except queue.Empty:
break
process_single_row(row, index, args, worker_id)
task_queue.task_done()
pbar.update(1)
def parse_args():
"""Parse command line arguments for camera tracking batch inference."""
parser = argparse.ArgumentParser()
parser.add_argument("--csv_path", type=str, help="Path to the csv file")
parser.add_argument("--dir_path", type=str, default="./outputs")
parser.add_argument("--checkpoints_path", type=str, default="./checkpoints")
parser.add_argument(
"--gpu_id", type=str, default="0", help="Comma-separated list of GPU IDs to use"
)
parser.add_argument(
"--num_workers",
type=int,
default=4,
help="Number of workers for parallel processing",
)
parser.add_argument(
"--disable_parallel", action="store_true", help="Disable parallel processing"
)
return parser.parse_args()
def main():
args = parse_args()
# Parse GPU configuration
args.gpu_num = len(args.gpu_id.split(","))
args.gpu_id = [int(gpu) for gpu in args.gpu_id.split(",")]
df = pd.read_csv(args.csv_path)
if args.disable_parallel:
# Sequential processing
for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing rows"):
process_single_row(row, index, args)
else:
# Parallel processing with multiple workers
manager = Manager()
task_queue = manager.Queue()
# Add all tasks to queue
for index, row in df.iterrows():
task_queue.put((index, row))
with tqdm(total=len(df), desc="Processing rows") as pbar:
with concurrent.futures.ThreadPoolExecutor(
max_workers=args.num_workers
) as executor:
futures = []
for id in range(args.num_workers):
futures.append(executor.submit(worker, task_queue, args, id, pbar))
for future in concurrent.futures.as_completed(futures):
future.result()
if __name__ == "__main__":
main()
================================================
FILE: camera_pose_annotation/cvd_opt/__init__.py
================================================
================================================
FILE: camera_pose_annotation/cvd_opt/cvd_opt.py
================================================
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Consistent video depth optimization."""
# pylint: disable=invalid-name
# pylint: disable=g-importing-member
# pylint: disable=redefined-outer-name
import argparse
import os
from pathlib import Path
import pandas as pd
from geometry_utils import NormalGenerator
import kornia
from lietorch import SE3
import numpy as np
import torch
import zipfile
import tempfile
import OpenEXR
import Imath
def save_depth(path, depths):
with zipfile.ZipFile(path, "w", zipfile.ZIP_DEFLATED) as z:
for index, depth in enumerate(depths):
height, width = depth.shape
header = OpenEXR.Header(width, height)
header["channels"] = {"Z": Imath.Channel(Imath.PixelType(Imath.PixelType.HALF))}
with tempfile.NamedTemporaryFile(suffix=".exr") as f:
exr = OpenEXR.OutputFile(f.name, header)
exr.writePixels({"Z": depth.astype(np.float16).tobytes()})
exr.close()
z.write(f.name, f"{index:05d}.exr")
def gradient_loss(gt, pred, u):
"""Gradient loss."""
del u
diff = pred - gt
v_gradient = torch.abs(diff[..., 0:-2, 1:-1] - diff[..., 2:, 1:-1]) # * mask_v
h_gradient = torch.abs(diff[..., 1:-1, 0:-2] - diff[..., 1:-1, 2:]) # * mask_h
pred_grad = torch.abs(pred[..., 0:-2, 1:-1] - (pred[..., 2:, 1:-1])) + torch.abs(
pred[..., 1:-1, 0:-2] - pred[..., 1:-1, 2:]
)
gt_grad = torch.abs(gt[..., 0:-2, 1:-1] - (gt[..., 2:, 1:-1])) + torch.abs(
gt[..., 1:-1, 0:-2] - gt[..., 1:-1, 2:]
)
grad_diff = torch.abs(pred_grad - gt_grad)
nearby_mask = (torch.exp(gt[..., 1:-1, 1:-1]) > 1.0).float().detach()
# weight = (1. - torch.exp(-(grad_diff * 5.)).detach())
weight = 1.0 - torch.exp(-(grad_diff * 5.0)).detach()
weight *= nearby_mask
g_loss = torch.mean(h_gradient * weight) + torch.mean(v_gradient * weight)
return g_loss
def si_loss(gt, pred):
log_gt = torch.log(torch.clamp(gt, 1e-3, 1e3)).view(gt.shape[0], -1)
log_pred = torch.log(torch.clamp(pred, 1e-3, 1e3)).view(pred.shape[0], -1)
log_diff = log_gt - log_pred
num_pixels = gt.shape[-2] * gt.shape[-1]
data_loss = torch.sum(log_diff**2, dim=-1) / num_pixels - torch.sum(
log_diff, dim=-1
) ** 2 / (num_pixels**2)
return torch.mean(data_loss)
def sobel_fg_alpha(disp, mode="sobel", beta=10.0):
sobel_grad = kornia.filters.spatial_gradient(disp, mode=mode, normalized=False)
sobel_mag = torch.sqrt(
sobel_grad[:, :, 0, Ellipsis] ** 2 + sobel_grad[:, :, 1, Ellipsis] ** 2
)
alpha = torch.exp(-1.0 * beta * sobel_mag).detach()
return alpha
ALPHA_MOTION = 0.25
RESIZE_FACTOR = 0.5
def consistency_loss(
cam_c2w,
K,
K_inv,
disp_data,
init_disp,
uncertainty,
flows,
flow_masks,
ii,
jj,
compute_normals,
fg_alpha,
w_ratio=1.0,
w_flow=0.2,
w_si=1.0,
w_grad=2.0,
w_normal=4.0,
):
"""Consistency loss."""
_, H, W = disp_data.shape
# mesh grid
xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
xx = xx.view(1, 1, H, W) # .repeat(B ,1 ,1 ,1)
yy = yy.view(1, 1, H, W) # .repeat(B ,1 ,1 ,1)
grid = torch.cat((xx, yy), 1).float().cuda().permute(0, 2, 3, 1) # [None, ...]
loss_flow = 0.0 # flow reprojection loss
loss_d_ratio = 0.0 # depth consistency loss
flows_step = flows.permute(0, 2, 3, 1)
flow_masks_step = flow_masks.permute(0, 2, 3, 1).squeeze(-1)
cam_1to2 = torch.bmm(
torch.linalg.inv(torch.index_select(cam_c2w, dim=0, index=jj)),
torch.index_select(cam_c2w, dim=0, index=ii),
)
# warp disp from target time
pixel_locations = grid + flows_step
resize_factor = torch.tensor([W - 1.0, H - 1.0]).cuda()[None, None, None, ...]
normalized_pixel_locations = 2 * (pixel_locations / resize_factor) - 1.0
disp_sampled = torch.nn.functional.grid_sample(
torch.index_select(disp_data, dim=0, index=jj)[:, None, ...],
normalized_pixel_locations,
align_corners=True,
)
uu = torch.index_select(uncertainty, dim=0, index=ii).squeeze(1)
grid_h = torch.cat([grid, torch.ones_like(grid[..., 0:1])], dim=-1).unsqueeze(-1)
# depth of reference view
ref_depth = 1.0 / torch.clamp(
torch.index_select(disp_data, dim=0, index=ii), 1e-3, 1e3
)
pts_3d_ref = ref_depth[..., None, None] * (K_inv[None, None, None] @ grid_h)
rot = cam_1to2[:, None, None, :3, :3]
trans = cam_1to2[:, None, None, :3, 3:4]
pts_3d_tgt = (rot @ pts_3d_ref) + trans # [:, None, None, :, None]
depth_tgt = pts_3d_tgt[:, :, :, 2:3, 0]
disp_tgt = 1.0 / torch.clamp(depth_tgt, 0.1, 1e3)
# flow consistency loss
pts_2D_tgt = K[None, None, None] @ pts_3d_tgt
flow_masks_step_ = flow_masks_step * (pts_2D_tgt[:, :, :, 2, 0] > 0.1)
pts_2D_tgt = pts_2D_tgt[:, :, :, :2, 0] / torch.clamp(
pts_2D_tgt[:, :, :, 2:, 0], 1e-3, 1e3
)
disp_sampled = torch.clamp(disp_sampled, 1e-3, 1e2)
disp_tgt = torch.clamp(disp_tgt, 1e-3, 1e2)
ratio = torch.maximum(
disp_sampled.squeeze() / disp_tgt.squeeze(),
disp_tgt.squeeze() / disp_sampled.squeeze(),
)
ratio_error = torch.abs(ratio - 1.0) #
loss_d_ratio += torch.sum(
(ratio_error * uu + ALPHA_MOTION * torch.log(1.0 / uu)) * flow_masks_step_
) / (torch.sum(flow_masks_step_) + 1e-8)
flow_error = torch.abs(pts_2D_tgt - pixel_locations)
loss_flow += torch.sum(
(flow_error * uu[..., None] + ALPHA_MOTION * torch.log(1.0 / uu[..., None]))
* flow_masks_step_[..., None]
) / (torch.sum(flow_masks_step_) * 2.0 + 1e-8)
# prior mono-depth reg loss
loss_prior = si_loss(init_disp, disp_data)
KK = torch.inverse(K_inv)
# multi gradient consistency
disp_data_ds = disp_data[:, None, ...]
init_disp_ds = init_disp[:, None, ...]
K_rescale = KK.clone()
K_inv_rescale = torch.inverse(K_rescale)
pred_normal = compute_normals[0](
1.0 / torch.clamp(disp_data_ds, 1e-3, 1e3), K_inv_rescale[None]
)
init_normal = compute_normals[0](
1.0 / torch.clamp(init_disp_ds, 1e-3, 1e3), K_inv_rescale[None]
)
loss_normal = torch.mean(
fg_alpha * (1.0 - torch.sum(pred_normal * init_normal, dim=1))
) # / (1e-8 + torch.sum(fg_alpha))
loss_grad = 0.0
for scale in range(4):
interval = 2**scale
disp_data_ds = torch.nn.functional.interpolate(
disp_data[:, None, ...],
scale_factor=(1.0 / interval, 1.0 / interval),
mode="nearest-exact",
)
init_disp_ds = torch.nn.functional.interpolate(
init_disp[:, None, ...],
scale_factor=(1.0 / interval, 1.0 / interval),
mode="nearest-exact",
)
uncertainty_rs = torch.nn.functional.interpolate(
uncertainty,
scale_factor=(1.0 / interval, 1.0 / interval),
mode="nearest-exact",
)
loss_grad += gradient_loss(
torch.log(disp_data_ds), torch.log(init_disp_ds), uncertainty_rs
)
return (
w_ratio * loss_d_ratio
+ w_si * loss_prior
+ w_flow * loss_flow
+ w_normal * loss_normal
+ loss_grad * w_grad
)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--w_grad", type=float, default=2.0, help="w_grad")
parser.add_argument("--w_normal", type=float, default=6.0, help="w_normal")
parser.add_argument("--dir_path", type=str, default=".", help="directory path")
parser.add_argument("--only_depth", action="store_true", help="only save optimize depth")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
scene_name = os.path.basename(args.dir_path)
cache_dir = os.path.join(args.dir_path, "cache-flow")
rootdir = os.path.join(args.dir_path, "reconstructions")
print("***************************** ", scene_name)
img_data = np.load(os.path.join(rootdir, "images.npy"))[:, ::-1, ...]
disp_data = np.load(os.path.join(rootdir, "disps.npy")) + 1e-6
intrinsics = np.load(os.path.join(rootdir, "intrinsics.npy"))
poses = np.load(os.path.join(rootdir, "poses.npy"))
mot_prob = np.load(os.path.join(rootdir, "motion_prob.npy"))
flows = np.load(os.path.join(cache_dir, "flows.npy"), allow_pickle=True)
flow_masks = np.load(os.path.join(cache_dir, "flows_masks.npy"), allow_pickle=True)
flow_masks = np.float32(flow_masks)
iijj = np.load(os.path.join(cache_dir, "ii-jj.npy"), allow_pickle=True)
intrinsics = intrinsics[0]
poses_th = torch.as_tensor(poses, device="cpu").float().cuda()
K = np.eye(3)
K[0, 0] = intrinsics[0]
K[1, 1] = intrinsics[1]
K[0, 2] = intrinsics[2]
K[1, 2] = intrinsics[3]
img_data_pt = (
torch.from_numpy(np.ascontiguousarray(img_data)).float().cuda() / 255.0
)
flows = torch.from_numpy(np.ascontiguousarray(flows)).float().cuda()
flow_masks = (
torch.from_numpy(np.ascontiguousarray(flow_masks)).float().cuda()
) # .unsqueeze(1)
iijj = torch.from_numpy(np.ascontiguousarray(iijj)).float().cuda()
ii = iijj[0, ...].long()
jj = iijj[1, ...].long()
K = torch.from_numpy(K).float().cuda()
init_disp = torch.from_numpy(disp_data).float().cuda()
disp_data = torch.from_numpy(disp_data).float().cuda()
assert init_disp.shape == disp_data.shape
init_disp = torch.nn.functional.interpolate(
init_disp.unsqueeze(1),
scale_factor=(RESIZE_FACTOR, RESIZE_FACTOR),
mode="bilinear",
).squeeze(1)
disp_data = torch.nn.functional.interpolate(
disp_data.unsqueeze(1),
scale_factor=(RESIZE_FACTOR, RESIZE_FACTOR),
mode="bilinear",
).squeeze(1)
fg_alpha = sobel_fg_alpha(init_disp[:, None, ...]) > 0.2
fg_alpha = fg_alpha.squeeze(1).float() + 0.2
cvd_prob = torch.nn.functional.interpolate(
torch.from_numpy(mot_prob).unsqueeze(1).cuda(),
scale_factor=(4, 4),
mode="bilinear",
)
cvd_prob[cvd_prob > 0.5] = 0.5
cvd_prob = torch.clamp(cvd_prob, 1e-3, 1.0)
# rescale intrinsic matrix to small resolution
K_o = K.clone()
K[0:2, ...] *= RESIZE_FACTOR
K_inv = torch.linalg.inv(K)
disp_data.requires_grad = False
poses_th.requires_grad = False
uncertainty = cvd_prob
# First optimize scale and shift to align them
log_scale_ = torch.log(torch.ones(init_disp.shape[0]).to(disp_data.device))
shift_ = torch.zeros(init_disp.shape[0]).to(disp_data.device)
log_scale_.requires_grad = True
shift_.requires_grad = True
uncertainty.requires_grad = True
optim = torch.optim.Adam(
[
{"params": log_scale_, "lr": 1e-2},
{"params": shift_, "lr": 1e-2},
{"params": uncertainty, "lr": 1e-2},
]
)
compute_normals = []
compute_normals.append(NormalGenerator(disp_data.shape[-2], disp_data.shape[-1]))
init_disp = torch.clamp(init_disp, 1e-3, 1e3)
for i in range(100):
optim.zero_grad()
cam_c2w = SE3(poses_th).inv().matrix()
scale_ = torch.exp(log_scale_)
loss = consistency_loss(
cam_c2w,
K,
K_inv,
torch.clamp(
disp_data * scale_[..., None, None] + shift_[..., None, None],
1e-3,
1e3,
),
init_disp,
torch.clamp(uncertainty, 1e-4, 1e3),
flows,
flow_masks,
ii,
jj,
compute_normals,
fg_alpha,
)
loss.backward()
uncertainty.grad = torch.nan_to_num(uncertainty.grad, nan=0.0)
log_scale_.grad = torch.nan_to_num(log_scale_.grad, nan=0.0)
shift_.grad = torch.nan_to_num(shift_.grad, nan=0.0)
optim.step()
print("step ", i, loss.item())
# Then optimize depth and uncertainty
disp_data = (
disp_data * torch.exp(log_scale_)[..., None, None].detach()
+ shift_[..., None, None].detach()
)
init_disp = (
init_disp * torch.exp(log_scale_)[..., None, None].detach()
+ shift_[..., None, None].detach()
)
init_disp = torch.clamp(init_disp, 1e-3, 1e3)
disp_data.requires_grad = True
uncertainty.requires_grad = True
poses_th.requires_grad = False # True
optim = torch.optim.Adam(
[
{"params": disp_data, "lr": 5e-3},
{"params": uncertainty, "lr": 5e-3},
]
)
losses = []
for i in range(400):
optim.zero_grad()
cam_c2w = SE3(poses_th).inv().matrix()
loss = consistency_loss(
cam_c2w,
K,
K_inv,
torch.clamp(disp_data, 1e-3, 1e3),
init_disp,
torch.clamp(uncertainty, 1e-4, 1e3),
flows,
flow_masks,
ii,
jj,
compute_normals,
fg_alpha,
w_ratio=1.0,
w_flow=0.2,
w_si=1,
w_grad=args.w_grad,
w_normal=args.w_normal,
)
loss.backward()
disp_data.grad = torch.nan_to_num(disp_data.grad, nan=0.0)
uncertainty.grad = torch.nan_to_num(uncertainty.grad, nan=0.0)
optim.step()
print("step ", i, loss.item())
losses.append(loss)
disp_data_opt = (
torch.nn.functional.interpolate(
disp_data.unsqueeze(1), scale_factor=(2, 2), mode="bilinear"
)
.squeeze(1)
.detach()
.cpu()
.numpy()
)
if args.only_depth:
save_depth(
os.path.join(args.dir_path, "depth_opt.zip"),
disp_data_opt
)
else:
np.savez(
os.path.join(args.dir_path, "sgd_cvd_hr.npz"),
images=np.uint8(img_data_pt.cpu().numpy().transpose(0, 2, 3, 1) * 255.0),
depths=np.clip(np.float16(1.0 / disp_data_opt), 1e-3, 1e2),
intrinsic=K_o.detach().cpu().numpy(),
cam_c2w=cam_c2w.detach().cpu().numpy(),
)
================================================
FILE: camera_pose_annotation/cvd_opt/geometry_utils.py
================================================
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Geometry utils for MegaSaM."""
# pylint: disable=invalid-name
import kornia
import numpy as np
import torch
from torch import jit
from torch import nn
from torch import Tensor # pylint: disable=g-importing-member
import torch.nn.functional as F
@torch.jit.script
def to_homogeneous(input_tensor: Tensor, dim: int = 0) -> Tensor:
"""Converts tensor to homogeneous coordinates by adding ones to the specified dimension."""
ones = torch.ones_like(input_tensor.select(dim, 0).unsqueeze(dim))
output_bkn = torch.cat([input_tensor, ones], dim=dim)
return output_bkn
class BackprojectDepth(nn.Module):
"""Layer that projects points from 2D camera to 3D space.
The 3D points are represented in homogeneous coordinates.
"""
def __init__(self, height: int, width: int):
super().__init__()
self.height = height
self.width = width
xx, yy = torch.meshgrid(
torch.arange(self.width),
torch.arange(self.height),
indexing="xy",
)
pix_coords_2hw = torch.stack((xx, yy), axis=0) + 0.5
pix_coords_13N = (
to_homogeneous(
pix_coords_2hw,
dim=0,
)
.flatten(1)
.unsqueeze(0)
)
# make these tensors into buffers so they are put on the correct GPU
# automatically
self.register_buffer("pix_coords_13N", pix_coords_13N)
# @jit.script_method
def forward(self, depth_b1hw: Tensor, invK_b44: Tensor) -> Tensor:
"""Backprojects spatial points in 2D image space to world space using invK_b44 at the depths defined in depth_b1hw."""
cam_points_b3N = torch.matmul(
invK_b44[:, :3, :3], self.pix_coords_13N.float().cuda()
)
cam_points_b3N = depth_b1hw.flatten(start_dim=2) * cam_points_b3N
cam_points_b4N = to_homogeneous(cam_points_b3N, dim=1)
return cam_points_b4N
class Project3D(jit.ScriptModule):
"""Layer that projects 3D points into the 2D camera."""
def __init__(self, eps: float = 1e-8):
super().__init__()
self.register_buffer("eps", torch.tensor(eps).view(1, 1, 1))
@jit.script_method
def forward(
self, points_b4N: Tensor, K_b44: Tensor, cam_T_world_b44: Tensor
) -> Tensor:
"""Projects spatial points in 3D world space to camera image space using the extrinsics matrix cam_T_world_b44 and intrinsics K_b44."""
P_b44 = K_b44 @ cam_T_world_b44
cam_points_b3N = P_b44[:, :3] @ points_b4N
# from Kornia and OpenCV:
# https://kornia.readthedocs.io/en/latest/_modules/kornia/geometry/conversions.html#convert_points_from_homogeneous
mask = torch.abs(cam_points_b3N[:, 2:]) > self.eps
depth_b1N = cam_points_b3N[:, 2:] + self.eps
scale = torch.where(
mask, 1.0 / depth_b1N, torch.tensor(1.0, device=depth_b1N.device)
)
pix_coords_b2N = cam_points_b3N[:, :2] * scale
return torch.cat([pix_coords_b2N, depth_b1N], dim=1)
class NormalGenerator(nn.Module):
"""Estimates normals from depth maps."""
def __init__(
self,
height: int,
width: int,
smoothing_kernel_size: int = 5,
smoothing_kernel_std: float = 2.0,
):
"""Estimates normals from depth maps."""
super().__init__()
self.height = height
self.width = width
self.backproject = BackprojectDepth(self.height, self.width)
self.kernel_size = smoothing_kernel_size
self.std = smoothing_kernel_std
# @jit.script_method
def forward(self, depth_b1hw: Tensor, invK_b44: Tensor) -> Tensor:
"""Estimates a normal at each location in the depth map."""
# First smoothes incoming depth maps with a gaussian blur, backprojects
# those depth points into world space (see BackprojectDepth), estimates
# the spatial gradient at those points, and finally uses normalized cross
# correlation to estimate a normal vector at each location.
depth_smooth_b1hw = kornia.filters.gaussian_blur2d(
depth_b1hw,
(self.kernel_size, self.kernel_size),
(self.std, self.std),
)
cam_points_b4N = self.backproject(depth_smooth_b1hw, invK_b44)
cam_points_b3hw = cam_points_b4N[:, :3].view(-1, 3, self.height, self.width)
gradients_b32hw = kornia.filters.spatial_gradient(cam_points_b3hw)
return F.normalize(
torch.cross(
gradients_b32hw[:, :, 0],
gradients_b32hw[:, :, 1],
dim=1,
),
dim=1,
)
def get_camera_rays(
world_T_cam_b44,
world_points_b3N,
in_camera_frame,
cam_T_world_b44=None,
eps=1e-4,
):
"""Computes camera rays for given camera data and points, optionally shifts rays to camera frame."""
del eps
if in_camera_frame:
batch_size = world_points_b3N.shape[0]
num_points = world_points_b3N.shape[2]
world_points_b4N = torch.cat(
[
world_points_b3N,
torch.ones(batch_size, 1, num_points).to(world_points_b3N.device),
],
1,
)
camera_points_b3N = torch.matmul(
cam_T_world_b44[:, :3, :4], world_points_b4N
)
rays_b3N = camera_points_b3N
else:
rays_b3N = world_points_b3N - world_T_cam_b44[:, 0:3, 3][:, :, None].expand(
world_points_b3N.shape
)
rays_b3N = torch.nn.functional.normalize(rays_b3N, dim=1)
return rays_b3N
def pose_distance(pose_b44):
"""DVMVS frame pose distance."""
R = pose_b44[:, :3, :3]
t = pose_b44[:, :3, 3]
R_trace = R.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)
R_measure = torch.sqrt(
2 * (1 - torch.minimum(torch.ones_like(R_trace) * 3.0, R_trace) / 3)
)
t_measure = torch.norm(t, dim=1)
combined_measure = torch.sqrt(t_measure**2 + R_measure**2)
return combined_measure, R_measure, t_measure
def qvec2rotmat(qvec):
"""Quaternion to 3x3 rotation matrix."""
return np.array([
[
1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
],
[
2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
],
[
2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
],
])
def rotx(t):
"""3D Rotation about the x-axis."""
c = np.cos(t)
s = np.sin(t)
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]])
def roty(t):
"""3D Rotation about the y-axis."""
c = np.cos(t)
s = np.sin(t)
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
def rotz(t):
"""3D Rotation about the z-axis."""
c = np.cos(t)
s = np.sin(t)
return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
================================================
FILE: camera_pose_annotation/cvd_opt/inference_batch.py
================================================
"""
Batch inference script for CVD (Camera View Depth) optimization.
Processes multiple video clips in parallel using multi-GPU setup.
"""
import pandas as pd
import os
import argparse
import concurrent.futures
from multiprocessing import Manager
import subprocess
import queue
from tqdm import tqdm
def process_single_row(row, index, args, worker_id=0):
"""Process a single video clip for CVD optimization."""
dir_path = os.path.join(args.dir_path, row["id"])
device_id = worker_id % args.gpu_num
# Build command for CVD optimization with specific GPU
cmd = (
f"CUDA_VISIBLE_DEVICES={args.gpu_id[device_id]} python camera_pose_annotation/cvd_opt/cvd_opt.py "
f"--dir_path {dir_path} "
f"--w_grad 2.0 --w_normal 5.0 "
)
if args.only_depth:
cmd += "--only_depth "
process = subprocess.Popen(
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdout, stderr = process.communicate()
if process.returncode != 0:
print(f"Error optimizing CVD for {row['id']}: {stderr.decode()}")
def worker(task_queue, args, worker_id, pbar):
"""Worker function for parallel CVD optimization processing."""
while True:
try:
index, row = task_queue.get(timeout=1)
except queue.Empty:
break
process_single_row(row, index, args, worker_id)
task_queue.task_done()
pbar.update(1)
def parse_args():
"""Parse command line arguments for CVD batch processing."""
parser = argparse.ArgumentParser()
parser.add_argument("--csv_path", type=str, help="Path to the csv file")
parser.add_argument("--dir_path", type=str, default="./outputs")
parser.add_argument("--only_depth", action="store_true", help="Only save optimized depth")
parser.add_argument(
"--gpu_id", type=str, default="0", help="Comma-separated list of GPU IDs to use"
)
parser.add_argument(
"--num_workers",
type=int,
default=4,
help="Number of workers for parallel processing",
)
parser.add_argument(
"--disable_parallel", action="store_true", help="Disable parallel processing"
)
return parser.parse_args()
def main():
args = parse_args()
# Parse GPU configuration
args.gpu_num = len(args.gpu_id.split(","))
args.gpu_id = [int(gpu) for gpu in args.gpu_id.split(",")]
df = pd.read_csv(args.csv_path)
if args.disable_parallel:
# Sequential processing
for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing rows"):
process_single_row(row, index, args)
else:
# Parallel processing with multiple workers
manager = Manager()
task_queue = manager.Queue()
for index, row in df.iterrows():
task_queue.put((index, row))
with tqdm(total=len(df), desc="Processing rows") as pbar:
with concurrent.futures.ThreadPoolExecutor(
max_workers=args.num_workers
) as executor:
futures = []
for id in range(args.num_workers):
futures.append(executor.submit(worker, task_queue, args, id, pbar))
for future in concurrent.futures.as_completed(futures):
future.result()
if __name__ == "__main__":
main()
================================================
FILE: camera_pose_annotation/cvd_opt/preprocess/__init__.py
================================================
================================================
FILE: camera_pose_annotation/cvd_opt/preprocess/core/__init__.py
================================================
================================================
FILE: camera_pose_annotation/cvd_opt/preprocess/core/corr.py
================================================
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Correlation block for MegaSaM."""
import torch
import torch.nn.functional as F
from .utils.utils import bilinear_sampler
# pylint: disable=g-import-not-at-top
try:
import alt_cuda_corr
except: # pylint: disable=bare-except
# alt_cuda_corr is not compiled
pass
class CorrBlock:
"""Correlation block for MegaSaM."""
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
self.corr_pyramid.append(corr)
for _ in range(self.num_levels - 1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
@classmethod
def corr(cls, fmap1, fmap2):
del cls
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht * wd)
fmap2 = fmap2.view(batch, dim, ht * wd)
corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
class AlternateCorrBlock:
"""Correlation block for MegaSaM."""
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.pyramid = [(fmap1, fmap2)]
for _ in range(self.num_levels):
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
self.pyramid.append((fmap1, fmap2))
def __call__(self, coords):
coords = coords.permute(0, 2, 3, 1)
# pylint: disable=invalid-name
B, H, W, _ = coords.shape
dim = self.pyramid[0][0].shape[1]
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
(corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))
corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / torch.sqrt(torch.tensor(dim).float())
================================================
FILE: camera_pose_annotation/cvd_opt/preprocess/core/datasets.py
================================================
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataset classes for MegaSaM."""
import glob
import os
import os.path as osp
import random
import numpy as np
import torch
from torch.utils import data
from utils import frame_utils
from utils.augmentor import FlowAugmentor
from utils.augmentor import SparseFlowAugmentor
class FlowDataset(data.Dataset):
"""Base class for flow datasets."""
def __init__(self, aug_params=None, sparse=False):
self.augmentor = None
self.sparse = sparse
if aug_params is not None:
if sparse:
self.augmentor = SparseFlowAugmentor(**aug_params)
else:
self.augmentor = FlowAugmentor(**aug_params)
self.is_test = False
self.init_seed = False
self.flow_list = []
self.image_list = []
self.extra_info = []
def __getitem__(self, index):
if self.is_test:
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, self.extra_info[index]
if not self.init_seed:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
torch.manual_seed(worker_info.id)
np.random.seed(worker_info.id)
random.seed(worker_info.id)
self.init_seed = True
index = index % len(self.image_list)
valid = None
if self.sparse:
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
else:
flow = frame_utils.read_gen(self.flow_list[index])
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
flow = np.array(flow).astype(np.float32)
img1 = np.array(img1).astype(np.uint8)
img2 = np.array(img2).astype(np.uint8)
# grayscale images
if len(img1.shape) == 2:
img1 = np.tile(img1[..., None], (1, 1, 3))
img2 = np.tile(img2[..., None], (1, 1, 3))
else:
img1 = img1[..., :3]
img2 = img2[..., :3]
if self.augmentor is not None:
if self.sparse:
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
else:
img1, img2, flow = self.augmentor(img1, img2, flow)
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
if valid is not None:
valid = torch.from_numpy(valid)
else:
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
return img1, img2, flow, valid.float()
def __rmul__(self, v):
self.flow_list = v * self.flow_list
self.image_list = v * self.image_list
return self
def __len__(self):
return len(self.image_list)
class MpiSintel(FlowDataset):
"""MpiSintel dataset."""
def __init__(
self,
aug_params=None,
split='training',
root='datasets/Sintel',
dstype='clean',
):
super(MpiSintel, self).__init__(aug_params)
flow_root = osp.join(root, split, 'flow')
image_root = osp.join(root, split, dstype)
if split == 'test':
self.is_test = True
for scene in os.listdir(image_root):
image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
for i in range(len(image_list) - 1):
self.image_list += [[image_list[i], image_list[i + 1]]]
self.extra_info += [(scene, i)] # scene and frame_id
if split != 'test':
self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
class FlyingChairs(FlowDataset):
"""FlyingChairs dataset."""
def __init__(
self,
aug_params=None,
split='train',
root='datasets/FlyingChairs_release/data',
):
super(FlyingChairs, self).__init__(aug_params)
images = sorted(glob(osp.join(root, '*.ppm')))
flows = sorted(glob(osp.join(root, '*.flo')))
assert len(images) // 2 == len(flows)
split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
for i in range(len(flows)):
exid = split_list[i]
if (split == 'training' and exid == 1) or (
split == 'validation' and exid == 2
):
self.flow_list += [flows[i]]
self.image_list += [[images[2 * i], images[2 * i + 1]]]
class FlyingThings3D(FlowDataset):
"""FlyingThings3D dataset."""
def __init__(
self,
aug_params=None,
root='datasets/FlyingThings3D',
dstype='frames_cleanpass',
):
super(FlyingThings3D, self).__init__(aug_params)
for cam in ['left']:
for direction in ['into_future', 'into_past']:
image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
for idir, fdir in zip(image_dirs, flow_dirs):
images = sorted(glob(osp.join(idir, '*.png')))
flows = sorted(glob(osp.join(fdir, '*.pfm')))
for i in range(len(flows) - 1):
if direction == 'into_future':
self.image_list += [[images[i], images[i + 1]]]
self.flow_list += [flows[i]]
elif direction == 'into_past':
self.image_list += [[images[i + 1], images[i]]]
self.flow_list += [flows[i + 1]]
class KITTI(FlowDataset):
"""KITTI dataset."""
def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
super(KITTI, self).__init__(aug_params, sparse=True)
if split == 'testing':
self.is_test = True
root = osp.join(root, split)
images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
for img1, img2 in zip(images1, images2):
frame_id = img1.split('/')[-1]
self.extra_info += [[frame_id]]
self.image_list += [[img1, img2]]
if split == 'training':
self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
class HD1K(FlowDataset):
"""HD1K dataset."""
def __init__(self, aug_params=None, root='datasets/HD1k'):
super(HD1K, self).__init__(aug_params, sparse=True)
seq_ix = 0
while 1:
flows = sorted(
glob(
os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)
)
)
images = sorted(
glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))
)
if not flows:
break
for i in range(len(flows) - 1):
self.flow_list += [flows[i]]
self.image_list += [[images[i], images[i + 1]]]
seq_ix += 1
# pylint: disable=invalid-name
def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
"""Create the data loader for the corresponding training set."""
if args.stage == 'chairs':
aug_params = {
'crop_size': args.image_size,
'min_scale': -0.1,
'max_scale': 1.0,
'do_flip': True,
}
train_dataset = FlyingChairs(aug_params, split='training')
elif args.stage == 'things':
aug_params = {
'crop_size': args.image_size,
'min_scale': -0.4,
'max_scale': 0.8,
'do_flip': True,
}
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
train_dataset = clean_dataset + final_dataset
elif args.stage == 'sintel':
aug_params = {
'crop_size': args.image_size,
'min_scale': -0.2,
'max_scale': 0.6,
'do_flip': True,
}
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
if TRAIN_DS == 'C+T+K+S+H':
kitti = KITTI({
'crop_size': args.image_size,
'min_scale': -0.3,
'max_scale': 0.5,
'do_flip': True,
})
hd1k = HD1K({
'crop_size': args.image_size,
'min_scale': -0.5,
'max_scale': 0.2,
'do_flip': True,
})
train_dataset = (
100 * sintel_clean
+ 100 * sintel_final
+ 200 * kitti
+ 5 * hd1k
+ things
)
elif TRAIN_DS == 'C+T+K/S':
train_dataset = 100 * sintel_clean + 100 * sintel_final + things
else:
raise ValueError('Unknown split: %s' % TRAIN_DS)
elif args.stage == 'kitti':
aug_params = {
'crop_size': args.image_size,
'min_scale': -0.2,
'max_scale': 0.4,
'do_flip': False,
}
train_dataset = KITTI(aug_params, split='training')
else:
raise ValueError('Unknown training set: %s' % args.stage)
train_loader = data.DataLoader(
train_dataset,
batch_size=args.batch_size,
pin_memory=False,
shuffle=True,
num_workers=4,
drop_last=True,
)
print('Training with %d image pairs' % len(train_dataset))
return train_loader
================================================
FILE: camera_pose_annotation/cvd_opt/preprocess/core/extractor.py
================================================
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Network layer classes for MegaSaM."""
import torch
from torch import nn
class ResidualBlock(nn.Module):
"""Residual block for MegaSaM."""
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes, planes, kernel_size=3, padding=1, stride=stride
)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if stride != 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if stride != 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if stride != 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if stride != 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BottleneckBlock(nn.Module):
"""Bottleneck block for MegaSaM."""
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
)
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if stride != 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes // 4)
self.norm2 = nn.BatchNorm2d(planes // 4)
self.norm3 = nn.BatchNorm2d(planes)
if stride != 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes // 4)
self.norm2 = nn.InstanceNorm2d(planes // 4)
self.norm3 = nn.InstanceNorm2d(planes)
if stride != 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
if stride != 1:
self.norm4 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Module):
"""Basic encoder for MegaSaM."""
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0) # pylint: disable=undefined-variable
return x
class SmallEncoder(nn.Module):
"""Small encoder for MegaSaM."""
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(32)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(32)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0) # pylint: disable=undefined-variable
return x
================================================
FILE: camera_pose_annotation/cvd_opt/preprocess/core/raft.py
================================================
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""RAFT network for MegaSaM."""
from .corr import AlternateCorrBlock
from .corr import CorrBlock
from .extractor import BasicEncoder
from .extractor import SmallEncoder
import torch
from torch import nn
import torch.nn.functional as F
from .update import BasicUpdateBlock
from .update import SmallUpdateBlock
from .utils.utils import coords_grid
from .utils.utils import upflow8
try:
autocast = torch.cuda.amp.autocast
except: # pylint: disable=bare-except
# dummy autocast for PyTorch < 1.6
class autocast: # pylint: disable=invalid-name
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
class RAFT(nn.Module):
"""RAFT network for MegaSaM."""
def __init__(self, args):
super(RAFT, self).__init__()
self.args = args
self.mixed_precision = True
if args.small:
self.hidden_dim = hdim = 96
self.context_dim = cdim = 64
args.corr_levels = 4
args.corr_radius = 3
else:
self.hidden_dim = hdim = 128
self.context_dim = cdim = 128
args.corr_levels = 4
args.corr_radius = 4
if 'dropout' not in self.args:
self.args.dropout = 0
if 'alternate_corr' not in self.args:
self.args.alternate_corr = False
# feature network, context network, and update block
if args.small:
self.fnet = SmallEncoder(
output_dim=128, norm_fn='instance', dropout=args.dropout
)
self.cnet = SmallEncoder(
output_dim=hdim + cdim, norm_fn='none', dropout=args.dropout
)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
else:
self.fnet = BasicEncoder(
output_dim=256, norm_fn='instance', dropout=args.dropout
)
self.cnet = BasicEncoder(
output_dim=hdim + cdim, norm_fn='batch', dropout=args.dropout
)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def initialize_flow(self, img):
"""Flow is represented as difference between two coordinate grids flow = coords1 - coords0."""
# pylint: disable=invalid-name
N, _, H, W = img.shape
coords0 = coords_grid(N, H // 8, W // 8).to(img.device)
coords1 = coords_grid(N, H // 8, W // 8).to(img.device)
# optical flow computed as difference: flow = coords1 - coords0
return coords0, coords1
def upsample_flow(self, flow, mask):
"""Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination."""
# pylint: disable=invalid-name
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8 * H, 8 * W)
def forward(
self,
image1,
image2,
iters=12,
flow_init=None,
upsample=True,
test_mode=False,
):
"""Estimate optical flow between pair of frames."""
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
image1 = image1.contiguous()
image2 = image2.contiguous()
hdim = self.hidden_dim
cdim = self.context_dim
# run the feature network
with autocast(enabled=self.mixed_precision):
fmap1, fmap2 = self.fnet([image1, image2])
fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# run the context network
with autocast(enabled=self.mixed_precision):
cnet = self.cnet(image1)
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
net = torch.tanh(net)
inp = torch.relu(inp)
coords0, coords1 = self.initialize_flow(image1)
if flow_init is not None:
coords1 = coords1 + flow_init
flow_predictions = []
flow_up = None
for _ in range(iters):
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume
flow = coords1 - coords0
with autocast(enabled=self.mixed_precision):
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow
# upsample predictions
if up_mask is None:
flow_up = upflow8(coords1 - coords0)
else:
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
flow_predictions.append(flow_up)
if test_mode:
if flow_up is None:
raise ValueError('flow_up is None')
return coords1 - coords0, flow_up, net
return flow_predictions
================================================
FILE: camera_pose_annotation/cvd_opt/preprocess/core/update.py
================================================
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Update block for consistent video depth optimization."""
import torch
from torch import nn
import torch.nn.functional as F
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
"""GRU with convolution."""
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class SepConvGRU(nn.Module):
"""GRU with separate convolution for horizontal and vertical directions."""
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convr1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convq1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convz2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convr2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convq2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class SmallMotionEncoder(nn.Module):
"""Small motion encoder for MegaSaM."""
def __init__(self, args):
super(SmallMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
self.conv = nn.Conv2d(128, 80, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicMotionEncoder(nn.Module):
"""Basic motion encoder for MegaSaM."""
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class SmallUpdateBlock(nn.Module):
"""Small update block for MegaSaM."""
def __init__(self, args, hidden_dim=96):
super(SmallUpdateBlock, self).__init__()
self.encoder = SmallMotionEncoder(args)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, None, delta_flow
class BasicUpdateBlock(nn.Module):
"""Basic update block for MegaSaM."""
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64 * 9, 1, padding=0),
)
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = 0.25 * self.mask(net)
return net, mask, delta_flow
================================================
FILE: camera_pose_annotation/cvd_opt/preprocess/core/utils/__init__.py
================================================
================================================
FILE: camera_pose_annotation/cvd_opt/preprocess/core/utils/augmentor.py
================================================
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Augmentation utils for MegaSaM."""
# pylint: disable=g-import-not-at-top
# pylint: disable=g-importing-member
import cv2
import numpy as np
from PIL import Image
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
from torchvision.transforms import ColorJitter
class FlowAugmentor:
"""Augmentation for flow for MegaSaM."""
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
# spatial augmentation params
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14
)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
"""Photometric augmentation."""
# asymmetric
if np.random.rand() < self.asymmetric_color_aug_prob:
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
# symmetric
else:
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(
self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8
)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2, bounds=[50, 100]): # pylint: disable=dangerous-default-value
"""Occlusion augmentation."""
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(bounds[0], bounds[1])
dy = np.random.randint(bounds[0], bounds[1])
img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color
return img1, img2
def spatial_transform(self, img1, img2, flow):
"""Spatial augmentation."""
# randomly sample scale
ht, wd = img1.shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd)
)
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = scale
scale_y = scale
if np.random.rand() < self.stretch_prob:
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_x = np.clip(scale_x, min_scale, None)
scale_y = np.clip(scale_y, min_scale, None)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
img1 = cv2.resize(
img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
)
img2 = cv2.resize(
img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
)
flow = cv2.resize(
flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
)
flow = flow * [scale_x, scale_y]
if self.do_flip:
if np.random.rand() < self.h_flip_prob: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
if np.random.rand() < self.v_flip_prob: # v-flip
img1 = img1[::-1, :]
img2 = img2[::-1, :]
flow = flow[::-1, :] * [1.0, -1.0]
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
return img1, img2, flow
def __call__(self, img1, img2, flow):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
img1, img2, flow = self.spatial_transform(img1, img2, flow)
img1 = np.ascontiguousarray(img1)
img2 = np.ascontiguousarray(img2)
flow = np.ascontiguousarray(flow)
return img1, img2, flow
class SparseFlowAugmentor:
"""Augmentation for sparse flow for MegaSaM."""
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
# spatial augmentation params
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = ColorJitter(
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14
)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(
self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8
)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2):
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(50, 100)
dy = np.random.randint(50, 100)
img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color
return img1, img2
def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
"""Resize sparse flow map."""
ht, wd = flow.shape[:2]
coords = np.meshgrid(np.arange(wd), np.arange(ht))
coords = np.stack(coords, axis=-1)
coords = coords.reshape(-1, 2).astype(np.float32)
flow = flow.reshape(-1, 2).astype(np.float32)
valid = valid.reshape(-1).astype(np.float32)
coords0 = coords[valid >= 1]
flow0 = flow[valid >= 1]
ht1 = int(round(ht * fy))
wd1 = int(round(wd * fx))
coords1 = coords0 * [fx, fy]
flow1 = flow0 * [fx, fy]
xx = np.round(coords1[:, 0]).astype(np.int32)
yy = np.round(coords1[:, 1]).astype(np.int32)
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
xx = xx[v]
yy = yy[v]
flow1 = flow1[v]
flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
valid_img = np.zeros([ht1, wd1], dtype=np.int32)
flow_img[yy, xx] = flow1
valid_img[yy, xx] = 1
return flow_img, valid_img
def spatial_transform(self, img1, img2, flow, valid):
"""Randomly sample scale and apply it to images and flow map."""
ht, wd = img1.shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 1) / float(ht), (self.crop_size[1] + 1) / float(wd)
)
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = np.clip(scale, min_scale, None)
scale_y = np.clip(scale, min_scale, None)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
img1 = cv2.resize(
img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
)
img2 = cv2.resize(
img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
)
flow, valid = self.resize_sparse_flow_map(
flow, valid, fx=scale_x, fy=scale_y
)
if self.do_flip:
if np.random.rand() < 0.5: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
valid = valid[:, ::-1]
margin_y = 20
margin_x = 50
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
x0 = np.random.randint(
-margin_x, img1.shape[1] - self.crop_size[1] + margin_x
)
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
valid = valid[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
return img1, img2, flow, valid
def __call__(self, img1, img2, flow, valid):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
img1 = np.ascontiguousarray(img1)
img2 = np.ascontiguousarray(img2)
flow = np.ascontiguousarray(flow)
valid = np.ascontiguousarray(valid)
return img1, img2, flow, valid
================================================
FILE: camera_pose_annotation/cvd_opt/preprocess/core/utils/flow_viz.py
================================================
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flow visualization code.
Based on https://github.com/tomrunia/OpticalFlow_Visualization
"""
import numpy as np
def make_colorwheel():
"""Generates a color wheel for optical flow visualization.
Baker et al. "A Database and Evaluation Methodology for Optical Flow"
(ICCV, 2007)
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
Code follows the original C++ source code of Daniel Scharstein.
Code follows the the Matlab source code of Deqing Sun.
Returns:
np.ndarray: Color wheel
"""
# pylint: disable=invalid-name
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros((ncols, 3))
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
col = col + RY
# YG
colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
colorwheel[col : col + YG, 1] = 255
col = col + YG
# GC
colorwheel[col : col + GC, 1] = 255
colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
col = col + GC
# CB
colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
colorwheel[col : col + CB, 2] = 255
col = col + CB
# BM
colorwheel[col : col + BM, 2] = 255
colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
col = col + BM
# MR
colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
colorwheel[col : col + MR, 0] = 255
return colorwheel
def flow_uv_to_colors(u, v, convert_to_bgr=False):
"""Applies the flow color wheel to (possibly clipped) flow components u and v.
According to the C++ source code of Daniel Scharstein
According to the Matlab source code of Deqing Sun
Args:
u (np.ndarray): Input horizontal flow of shape [H,W]
v (np.ndarray): Input vertical flow of shape [H,W]
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to
False.
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
colorwheel = make_colorwheel() # shape [55x3]
ncols = colorwheel.shape[0]
rad = np.sqrt(np.square(u) + np.square(v))
a = np.arctan2(-v, -u) / np.pi
fk = (a + 1) / 2 * (ncols - 1)
k0 = np.floor(fk).astype(np.int32)
k1 = k0 + 1
k1[k1 == ncols] = 0
f = fk - k0
for i in range(colorwheel.shape[1]):
tmp = colorwheel[:, i]
col0 = tmp[k0] / 255.0
col1 = tmp[k1] / 255.0
col = (1 - f) * col0 + f * col1
idx = rad <= 1
col[idx] = 1 - rad[idx] * (1 - col[idx])
col[~idx] = col[~idx] * 0.75 # out of range
# Note the 2-i => BGR instead of RGB
ch_idx = 2 - i if convert_to_bgr else i
flow_image[:, :, ch_idx] = np.floor(255 * col)
return flow_image
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
"""Expects a two dimensional flow image of shape.
Args:
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
clip_flow (float, optional): Clip maximum of flow values. Defaults to
None.
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to
False.
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
if clip_flow is not None:
flow_uv = np.clip(flow_uv, 0, clip_flow)
u = flow_uv[:, :, 0]
v = flow_uv[:, :, 1]
rad = np.sqrt(np.square(u) + np.square(v))
rad_max = np.max(rad)
epsilon = 1e-5
u = u / (rad_max + epsilon)
v = v / (rad_max + epsilon)
return flow_uv_to_colors(u, v, convert_to_bgr)
================================================
FILE: camera_pose_annotation/cvd_opt/preprocess/core/utils/frame_utils.py
================================================
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Frame utils for MegaSaM."""
# pylint: disable=invalid-name
# pylint: disable=g-doc-args
# pylint: disable=broad-exception-raised
import os
import re
import cv2
import numpy as np
from PIL import Image
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
TAG_CHAR = np.array([202021.25], np.float32)
def readFlow(fn):
"""Read .flo file in Middlebury format."""
# Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
# print 'fn = %s'%(fn)
with open(fn, 'rb') as f:
magic = np.fromfile(f, np.float32, count=1)
if 202021.25 != magic:
print('Magic number incorrect. Invalid .flo file')
return None
else:
w = np.fromfile(f, np.int32, count=1)
h = np.fromfile(f, np.int32, count=1)
# print 'Reading %d x %d flo file\n' % (w, h)
data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
# Reshape data into 3D array (columns, rows, bands)
# The reshape here is for visualization, the original code is (w,h,2)
return np.resize(data, (int(h), int(w), 2))
def readPFM(file):
"""Read PFM file."""
file = open(file, 'rb')
header = file.readline().rstrip()
if header == b'PF':
color = True
elif header == b'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = '<'
else:
endian = '>' # big-endian
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data
def writeFlow(filename, uv, v=None):
"""Write optical flow to file.
If v is None, uv is assumed to contain both u and v channels,
stacked in depth.
Original code by Deqing Sun, adapted from Daniel Scharstein.
"""
nBands = 2
if v is None:
assert uv.ndim == 3
assert uv.shape[2] == 2
u = uv[:, :, 0]
v = uv[:, :, 1]
else:
u = uv
assert u.shape == v.shape
height, width = u.shape
f = open(filename, 'wb')
# write the header
f.write(TAG_CHAR)
np.array(width).astype(np.int32).tofile(f)
np.array(height).astype(np.int32).tofile(f)
# arrange into matrix form
tmp = np.zeros((height, width * nBands))
tmp[:, np.arange(width) * 2] = u
tmp[:, np.arange(width) * 2 + 1] = v
tmp.astype(np.float32).tofile(f)
f.close()
def readFlowKITTI(filename):
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
flow = flow[:, :, ::-1].astype(np.float32)
flow, valid = flow[:, :, :2], flow[:, :, 2]
flow = (flow - 2**15) / 64.0
return flow, valid
def readDispKITTI(filename):
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
valid = disp > 0.0
flow = np.stack([-disp, np.zeros_like(disp)], -1)
return flow, valid
def writeFlowKITTI(filename, uv):
uv = 64.0 * uv + 2**15
valid = np.ones([uv.shape[0], uv.shape[1], 1])
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
cv2.imwrite(filename, uv[..., ::-1])
def read_gen(file_name, pil=False):
"""Read image or flow file."""
del pil
ext = os.path.splitext(file_name)[-1]
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
return Image.open(file_name)
elif ext == '.bin' or ext == '.raw':
return np.load(file_name)
elif ext == '.flo':
return readFlow(file_name).astype(np.float32) # pylint: disable=attribute-error
elif ext == '.pfm':
flow = readPFM(file_name).astype(np.float32)
if len(flow.shape) == 2:
return flow
else:
return flow[:, :, :-1]
return []
================================================
FILE: camera_pose_annotation/cvd_opt/preprocess/core/utils/utils.py
================================================
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for MegaSaM."""
# pylint: disable=invalid-name
import numpy as np
from scipy import interpolate
import torch
import torch.nn.functional as F
class InputPadder:
"""Pads images such that dimensions are divisible by 8."""
def __init__(self, dims, mode='sintel'):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
if mode == 'sintel':
self._pad = [
pad_wd // 2,
pad_wd - pad_wd // 2,
pad_ht // 2,
pad_ht - pad_ht // 2,
]
else:
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
def unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]
def forward_interpolate(flow):
"""Interpolate flow map to match the original image size."""
flow = flow.detach().cpu().numpy()
dx, dy = flow[0], flow[1]
ht, wd = dx.shape
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
x1 = x0 + dx
y1 = y0 + dy
x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
dx = dx.reshape(-1)
dy = dy.reshape(-1)
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
x1 = x1[valid]
y1 = y1[valid]
dx = dx[valid]
dy = dy[valid]
flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0
)
flow_y = interpolate.griddata(
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0
)
flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float()
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
"""Wrapper for grid_sample, uses pixel coordinates."""
del mode
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1, 1], dim=-1)
xgrid = 2 * xgrid / (W - 1) - 1
ygrid = 2 * ygrid / (H - 1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def upflow8(flow, mode='bilinear'):
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
================================================
FILE: camera_pose_annotation/cvd_opt/preprocess/inference_batch.py
================================================
"""
Batch inference script for optical flow preprocessing using RAFT model.
Processes multiple video clips in parallel to generate optical flow data for CVD optimization.
"""
import pandas as pd
import os
import argparse
import concurrent.futures
from multiprocessing import Manager
import subprocess
import queue
from tqdm import tqdm
def process_single_row(row, index, args, worker_id=0):
"""Process a single video clip for optical flow generation."""
dir_path = os.path.join(args.dir_path, row["id"])
device_id = worker_id % args.gpu_num
# Build command for optical flow preprocessing with RAFT model
cmd = (
f"CUDA_VISIBLE_DEVICES={args.gpu_id[device_id]} python camera_pose_annotation/cvd_opt/preprocess/preprocess_flow.py "
f"--dir_path {dir_path} "
f"--model {args.checkpoints_path}/raft-things.pth "
f"--mixed_precision"
)
process = subprocess.Popen(
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdout, stderr = process.communicate()
if process.returncode != 0:
print(f"Error generating optical flow for {row['id']}: {stderr.decode()}")
def worker(task_queue, args, worker_id, pbar):
"""Worker function for parallel optical flow preprocessing."""
while True:
try:
index, row = task_queue.get(timeout=1)
except queue.Empty:
break
process_single_row(row, index, args, worker_id)
task_queue.task_done()
pbar.update(1)
def parse_args():
"""Parse command line arguments for optical flow preprocessing."""
parser = argparse.ArgumentParser()
parser.add_argument("--csv_path", type=str, help="Path to the csv file")
parser.add_argument("--dir_path", type=str, default="./outputs")
parser.add_argument("--checkpoints_path", type=str, default="./checkpoints")
parser.add_argument(
"--gpu_id", type=str, default="0", help="Comma-separated list of GPU IDs to use"
)
parser.add_argument(
"--num_workers",
type=int,
default=4,
help="Number of workers for parallel processing",
)
parser.add_argument(
"--disable_parallel", action="store_true", help="Disable parallel processing"
)
return parser.parse_args()
def main():
args = parse_args()
# Parse GPU configuration
args.gpu_num = len(args.gpu_id.split(","))
args.gpu_id = [int(gpu) for gpu in args.gpu_id.split(",")]
df = pd.read_csv(args.csv_path)
if args.disable_parallel:
# Sequential processing
for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing rows"):
process_single_row(row, index, args)
else:
# Parallel processing with multiple workers
manager = Manager()
task_queue = manager.Queue()
for index, row in df.iterrows():
task_queue.put((index, row))
with tqdm(total=len(df), desc="Processing rows") as pbar:
with concurrent.futures.ThreadPoolExecutor(
max_workers=args.num_workers
) as executor:
futures = []
for id in range(args.num_workers):
futures.append(executor.submit(worker, task_queue, args, id, pbar))
for future in concurrent.futures.as_completed(futures):
future.result()
if __name__ == "__main__":
main()
================================================
FILE: camera_pose_annotation/cvd_opt/preprocess/preprocess_flow.py
================================================
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocess flow for MegaSaM."""
import cv2
import tqdm
import argparse
from pathlib import Path # pylint: disable=g-importing-member
from core.utils.utils import InputPadder
from core.raft import RAFT
import glob
import os
import sys
import numpy as np
import torch
def warp_flow(img, flow):
h, w = flow.shape[:2]
flow_new = flow.copy()
flow_new[:, :, 0] += np.arange(w)
flow_new[:, :, 1] += np.arange(h)[:, np.newaxis]
res = cv2.remap(
img, flow_new, None, cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT
)
return res
def resize_flow(flow, img_h, img_w):
# flow = np.load(flow_path)
flow_h, flow_w = flow.shape[0], flow.shape[1]
flow[:, :, 0] *= float(img_w) / float(flow_w)
flow[:, :, 1] *= float(img_h) / float(flow_h)
flow = cv2.resize(flow, (img_w, img_h), cv2.INTER_LINEAR)
return flow
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="raft-things.pth", help="restore checkpoint")
parser.add_argument("--small", action="store_true", help="use small model")
parser.add_argument("--dir_path", help="dataset for evaluation")
parser.add_argument(
"--num_heads",
default=1,
type=int,
help="number of heads in attention and aggregation",
)
parser.add_argument(
"--position_only",
default=False,
action="store_true",
help="only use position-wise attention",
)
parser.add_argument(
"--position_and_content",
default=False,
action="store_true",
help="use position and content-wise attention",
)
parser.add_argument(
"--mixed_precision", action="store_true", help="use mixed precision"
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
flow_model = model.module
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
flow_model.to(device).eval()
img_path = os.path.join(args.dir_path, "img")
image_list = sorted(glob.glob(os.path.join(img_path, "*.png"))) # [::stride]
image_list += sorted(glob.glob(os.path.join(img_path, "*.jpg"))) # [::stride]
img_data = []
for t, (image_file) in tqdm.tqdm(enumerate(image_list)):
image = cv2.imread(image_file)[..., ::-1] # rgb
h0, w0, _ = image.shape
h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0)))
w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0)))
image = cv2.resize(image, (w1, h1))
image = image[: h1 - h1 % 8, : w1 - w1 % 8].transpose(2, 0, 1)
img_data.append(image)
img_data = np.array(img_data)
flows_low = []
flows_high = []
flow_masks_high = []
flow_init = None
flows_arr_low_bwd = {}
flows_arr_low_fwd = {}
ii = []
jj = []
flows_arr_up = []
masks_arr_up = []
for step in [1, 2, 4, 8, 15]:
flows_arr_low = []
for i in tqdm.tqdm(range(max(0, -step), img_data.shape[0] - max(0, step))):
image1 = (
torch.as_tensor(np.ascontiguousarray(img_data[i : i + 1]))
.float()
.cuda()
)
image2 = (
torch.as_tensor(np.ascontiguousarray(img_data[i + step : i + step + 1]))
.float()
.cuda()
)
ii.append(i)
jj.append(i + step)
with torch.no_grad():
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
if np.abs(step) > 1:
flow_init = np.stack(
[flows_arr_low_fwd[i], flows_arr_low_bwd[i + step]], axis=0
)
flow_init = (
torch.as_tensor(np.ascontiguousarray(flow_init))
.float()
.cuda()
.permute(0, 3, 1, 2)
)
else:
flow_init = None
flow_low, flow_up, _ = flow_model(
torch.cat([image1, image2], dim=0),
torch.cat([image2, image1], dim=0),
iters=22,
test_mode=True,
flow_init=flow_init,
)
flow_low_fwd = flow_low[0].cpu().numpy().transpose(1, 2, 0)
flow_low_bwd = flow_low[1].cpu().numpy().transpose(1, 2, 0)
flow_up_fwd = resize_flow(
flow_up[0].cpu().numpy().transpose(1, 2, 0),
flow_up.shape[-2] // 2,
flow_up.shape[-1] // 2,
)
flow_up_bwd = resize_flow(
flow_up[1].cpu().numpy().transpose(1, 2, 0),
flow_up.shape[-2] // 2,
flow_up.shape[-1] // 2,
)
bwd2fwd_flow = warp_flow(flow_up_bwd, flow_up_fwd)
fwd_lr_error = np.linalg.norm(flow_up_fwd + bwd2fwd_flow, axis=-1)
fwd_mask_up = fwd_lr_error < 1.0
# flows_arr_low.append(flow_low_fwd)
flows_arr_low_bwd[i + step] = flow_low_bwd
flows_arr_low_fwd[i] = flow_low_fwd
# masks_arr_low.append(fwd_mask_low)
flows_arr_up.append(flow_up_fwd)
masks_arr_up.append(fwd_mask_up)
iijj = np.stack((ii, jj), axis=0)
flows_high = np.array(flows_arr_up).transpose(0, 3, 1, 2)
flow_masks_high = np.array(masks_arr_up)[:, None, ...]
output_path = os.path.join(args.dir_path, "cache-flow")
if not os.path.exists(output_path):
os.makedirs(output_path)
np.save(os.path.join(output_path, "flows.npy"), np.float16(flows_high))
np.save(os.path.join(output_path, "flows_masks.npy"), flow_masks_high)
np.save(os.path.join(output_path, "ii-jj.npy"), iijj)
================================================
FILE: camera_pose_annotation/depth_estimation/Depth-Anything/__init__.py
================================================
================================================
FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
from functools import partial
import math
import logging
from typing import Sequence, Tuple, Union, Callable
import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.nn.init import trunc_normal_
from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
logger = logging.getLogger("dinov2")
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = ".".join((name, child_name)) if name else child_name
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
fn(module=module, name=name)
return module
class BlockChunk(nn.ModuleList):
def forward(self, x):
for b in self:
x = b(x)
return x
class DinoVisionTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
ffn_bias=True,
proj_bias=True,
drop_path_rate=0.0,
drop_path_uniform=False,
init_values=None, # for layerscale: None or 0 => no layerscale
embed_layer=PatchEmbed,
act_layer=nn.GELU,
block_fn=Block,
ffn_layer="mlp",
block_chunks=1,
num_register_tokens=0,
interpolate_antialias=False,
interpolate_offset=0.1,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
proj_bias (bool): enable bias for proj in attn if True
ffn_bias (bool): enable bias for ffn if True
drop_path_rate (float): stochastic depth rate
drop_path_uniform (bool): apply uniform drop rate across blocks
weight_init (str): weight init scheme
init_values (float): layer-scale init values
embed_layer (nn.Module): patch embedding layer
act_layer (nn.Module): MLP activation layer
block_fn (nn.Module): transformer block class
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
"""
super().__init__()
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1
self.n_blocks = depth
self.num_heads = num_heads
self.patch_size = patch_size
self.num_register_tokens = num_register_tokens
self.interpolate_antialias = interpolate_antialias
self.interpolate_offset = interpolate_offset
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
assert num_register_tokens >= 0
self.register_tokens = (
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
)
if drop_path_uniform is True:
dpr = [drop_path_rate] * depth
else:
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
if ffn_layer == "mlp":
logger.info("using MLP layer as FFN")
ffn_layer = Mlp
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
logger.info("using SwiGLU layer as FFN")
ffn_layer = SwiGLUFFNFused
elif ffn_layer == "identity":
logger.info("using Identity layer as FFN")
def f(*args, **kwargs):
return nn.Identity()
ffn_layer = f
else:
raise NotImplementedError
blocks_list = [
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
ffn_layer=ffn_layer,
init_values=init_values,
)
for i in range(depth)
]
if block_chunks > 0:
self.chunked_blocks = True
chunked_blocks = []
chunksize = depth // block_chunks
for i in range(0, depth, chunksize):
# this is to keep the block index consistent if we chunk the block list
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
else:
self.chunked_blocks = False
self.blocks = nn.ModuleList(blocks_list)
self.norm = norm_layer(embed_dim)
self.head = nn.Identity()
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
self.init_weights()
def init_weights(self):
trunc_normal_(self.pos_embed, std=0.02)
nn.init.normal_(self.cls_token, std=1e-6)
if self.register_tokens is not None:
nn.init.normal_(self.register_tokens, std=1e-6)
named_apply(init_weights_vit_timm, self)
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
# DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
# w0, h0 = w0 + 0.1, h0 + 0.1
sqrt_N = math.sqrt(N)
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
scale_factor=(sx, sy),
# (int(w0), int(h0)), # to solve the upsampling shape issue
mode="bicubic",
antialias=self.interpolate_antialias
)
assert int(w0) == patch_pos_embed.shape[-2]
assert int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
def prepare_tokens_with_masks(self, x, masks=None):
B, nc, w, h = x.shape
x = self.patch_embed(x)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h)
if self.register_tokens is not None:
x = torch.cat(
(
x[:, :1],
self.register_tokens.expand(x.shape[0], -1, -1),
x[:, 1:],
),
dim=1,
)
return x
def forward_features_list(self, x_list, masks_list):
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
for blk in self.blocks:
x = blk(x)
all_x = x
output = []
for x, masks in zip(all_x, masks_list):
x_norm = self.norm(x)
output.append(
{
"x_norm_clstoken": x_norm[:, 0],
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
"x_prenorm": x,
"masks": masks,
}
)
return output
def forward_features(self, x, masks=None):
if isinstance(x, list):
return self.forward_features_list(x, masks)
x = self.prepare_tokens_with_masks(x, masks)
for blk in self.blocks:
x = blk(x)
x_norm = self.norm(x)
return {
"x_norm_clstoken": x_norm[:, 0],
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
"x_prenorm": x,
"masks": masks,
}
def _get_intermediate_layers_not_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
# If n is an int, take the n last blocks. If it's a list, take them
output, total_block_len = [], len(self.blocks)
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in blocks_to_take:
output.append(x)
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def _get_intermediate_layers_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
output, i, total_block_len = [], 0, len(self.blocks[-1])
# If n is an int, take the n last blocks. If it's a list, take them
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for block_chunk in self.blocks:
for blk in block_chunk[i:]: # Passing the nn.Identity()
x = blk(x)
if i in blocks_to_take:
output.append(x)
i += 1
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1, # Layers or n last layers to take
reshape: bool = False,
return_class_token: bool = False,
norm=True
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
if self.chunked_blocks:
outputs = self._get_intermediate_layers_chunked(x, n)
else:
outputs = self._get_intermediate_layers_not_chunked(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0] for out in outputs]
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
if reshape:
B, _, w, h = x.shape
outputs = [
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_class_token:
return tuple(zip(outputs, class_tokens))
return tuple(outputs)
def forward(self, *args, is_training=False, **kwargs):
ret = self.forward_features(*args, **kwargs)
if is_training:
return ret
else:
return self.head(ret["x_norm_clstoken"])
def init_weights_vit_timm(module: nn.Module, name: str = ""):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
"""
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
"""
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1536,
depth=40,
num_heads=24,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def DINOv2(model_name):
model_zoo = {
"vits": vit_small,
"vitb": vit_base,
"vitl": vit_large,
"vitg": vit_giant2
}
return model_zoo[model_name](
img_size=518,
patch_size=14,
init_values=1.0,
ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
block_chunks=0,
num_register_tokens=0,
interpolate_antialias=False,
interpolate_offset=0.1
)
================================================
FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2_layers/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .mlp import Mlp
from .patch_embed import PatchEmbed
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
from .block import NestedTensorBlock
from .attention import MemEffAttention
================================================
FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2_layers/attention.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import logging
from torch import Tensor
from torch import nn
logger = logging.getLogger("dinov2")
try:
from xformers.ops import memory_efficient_attention, unbind, fmha
XFORMERS_AVAILABLE = True
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: Tensor) -> Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffAttention(Attention):
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
if not XFORMERS_AVAILABLE:
assert attn_bias is None, "xFormers is required for nested tensors usage"
return super().forward(x)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = unbind(qkv, 2)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
================================================
FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2_layers/block.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
import logging
from typing import Callable, List, Any, Tuple, Dict
import torch
from torch import nn, Tensor
from .attention import Attention, MemEffAttention
from .drop_path import DropPath
from .layer_scale import LayerScale
from .mlp import Mlp
logger = logging.getLogger("dinov2")
try:
from xformers.ops import fmha
from xformers.ops import scaled_index_add, index_select_cat
XFORMERS_AVAILABLE = True
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
) -> None:
super().__init__()
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: Tensor) -> Tensor:
def attn_residual_func(x: Tensor) -> Tensor:
return self.ls1(self.attn(self.norm1(x)))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
# the overhead is compensated only for a drop path rate larger than 0.1
x = drop_add_residual_stochastic_depth(
x,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
x = drop_add_residual_stochastic_depth(
x,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
elif self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x)
x = x + ffn_residual_func(x)
return x
def drop_add_residual_stochastic_depth(
x: Tensor,
residual_func: Callable[[Tensor], Tensor],
sample_drop_ratio: float = 0.0,
) -> Tensor:
# 1) extract subset using permutation
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
x_subset = x[brange]
# 2) apply residual_func to get residual
residual = residual_func(x_subset)
x_flat = x.flatten(1)
residual = residual.flatten(1)
residual_scale_factor = b / sample_subset_size
# 3) add the residual
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
return x_plus_residual.view_as(x)
def get_branges_scales(x, sample_drop_ratio=0.0):
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
residual_scale_factor = b / sample_subset_size
return brange, residual_scale_factor
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
if scaling_vector is None:
x_flat = x.flatten(1)
residual = residual.flatten(1)
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
else:
x_plus_residual = scaled_index_add(
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
)
return x_plus_residual
attn_bias_cache: Dict[Tuple, Any] = {}
def get_attn_bias_and_cat(x_list, branges=None):
"""
this will perform the index select, cat the tensors, and provide the attn_bias from cache
"""
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
if all_shapes not in attn_bias_cache.keys():
seqlens = []
for b, x in zip(batch_sizes, x_list):
for _ in range(b):
seqlens.append(x.shape[1])
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
attn_bias._batch_sizes = batch_sizes
attn_bias_cache[all_shapes] = attn_bias
if branges is not None:
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
else:
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
cat_tensors = torch.cat(tensors_bs1, dim=1)
return attn_bias_cache[all_shapes], cat_tensors
def drop_add_residual_stochastic_depth_list(
x_list: List[Tensor],
residual_func: Callable[[Tensor, Any], Tensor],
sample_drop_ratio: float = 0.0,
scaling_vector=None,
) -> Tensor:
# 1) generate random set of indices for dropping samples in the batch
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
branges = [s[0] for s in branges_scales]
residual_scale_factors = [s[1] for s in branges_scales]
# 2) get attention bias and index+concat the tensors
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
# 3) apply residual_func to get residual, and split the result
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
outputs = []
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
return outputs
class NestedTensorBlock(Block):
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
"""
x_list contains a list of tensors to nest together and run
"""
assert isinstance(self.attn, MemEffAttention)
if self.training and self.sample_drop_ratio > 0.0:
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.attn(self.norm1(x), attn_bias=attn_bias)
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.mlp(self.norm2(x))
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
)
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
)
return x_list
else:
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
attn_bias, x = get_attn_bias_and_cat(x_list)
x = x + attn_residual_func(x, attn_bias=attn_bias)
x = x + ffn_residual_func(x)
return attn_bias.split(x)
def forward(self, x_or_x_list):
if isinstance(x_or_x_list, Tensor):
return super().forward(x_or_x_list)
elif isinstance(x_or_x_list, list):
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
return self.forward_nested(x_or_x_list)
else:
raise AssertionError
================================================
FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2_layers/drop_path.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
from torch import nn
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0:
random_tensor.div_(keep_prob)
output = x * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
================================================
FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2_layers/layer_scale.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
from typing import Union
import torch
from torch import Tensor
from torch import nn
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, Tensor] = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
================================================
FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2_layers/mlp.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
from typing import Callable, Optional
from torch import Tensor, nn
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
================================================
FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2_layers/patch_embed.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
from typing import Callable, Optional, Tuple, Union
from torch import Tensor
import torch.nn as nn
def make_2tuple(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
assert isinstance(x, int)
return (x, x)
class PatchEmbed(nn.Module):
"""
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
Args:
img_size: Image size.
patch_size: Patch token size.
in_chans: Number of input image channels.
embed_dim: Number of linear projection output channels.
norm_layer: Normalization layer.
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (
image_HW[0] // patch_HW[0],
image_HW[1] // patch_HW[1],
)
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
patch_H, patch_W = self.patch_size
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
x = self.proj(x) # B C H W
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2) # B HW C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
return x
def flops(self) -> float:
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
================================================
FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dinov2_layers/swiglu_ffn.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional
from torch import Tensor, nn
import torch.nn.functional as F
class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
def forward(self, x: Tensor) -> Tensor:
x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
return self.w3(hidden)
try:
from xformers.ops import SwiGLU
XFORMERS_AVAILABLE = True
except ImportError:
SwiGLU = SwiGLUFFN
XFORMERS_AVAILABLE = False
class SwiGLUFFNFused(SwiGLU):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
out_features = out_features or in_features
hidden_features = hidden_features or in_features
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
super().__init__(
in_features=in_features,
hidden_features=hidden_features,
out_features=out_features,
bias=bias,
)
================================================
FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/dpt.py
================================================
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose
from .dinov2 import DINOv2
from .util.blocks import FeatureFusionBlock, _make_scratch
from .util.transform import Resize, NormalizeImage, PrepareForNet
def _make_fusion_block(features, use_bn, size=None):
return FeatureFusionBlock(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
size=size,
)
class ConvBlock(nn.Module):
def __init__(self, in_feature, out_feature):
super().__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_feature),
nn.ReLU(True)
)
def forward(self, x):
return self.conv_block(x)
class DPTHead(nn.Module):
def __init__(
self,
in_channels,
features=256,
use_bn=False,
out_channels=[256, 512, 1024, 1024],
use_clstoken=False
):
super(DPTHead, self).__init__()
self.use_clstoken = use_clstoken
self.projects = nn.ModuleList([
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=1,
stride=1,
padding=0,
) for out_channel in out_channels
])
self.resize_layers = nn.ModuleList([
nn.ConvTranspose2d(
in_channels=out_channels[0],
out_channels=out_channels[0],
kernel_size=4,
stride=4,
padding=0),
nn.ConvTranspose2d(
in_channels=out_channels[1],
out_channels=out_channels[1],
kernel_size=2,
stride=2,
padding=0),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3],
out_channels=out_channels[3],
kernel_size=3,
stride=2,
padding=1)
])
if use_clstoken:
self.readout_projects = nn.ModuleList()
for _ in range(len(self.projects)):
self.readout_projects.append(
nn.Sequential(
nn.Linear(2 * in_channels, in_channels),
nn.GELU()))
self.scratch = _make_scratch(
out_channels,
features,
groups=1,
expand=False,
)
self.scratch.stem_transpose = None
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
head_features_1 = features
head_features_2 = 32
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True),
nn.Identity(),
)
def forward(self, out_features, patch_h, patch_w):
out = []
for i, x in enumerate(out_features):
if self.use_clstoken:
x, cls_token = x[0], x[1]
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
else:
x = x[0]
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
layer_1, layer_2, layer_3, layer_4 = out
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv1(path_1)
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
out = self.scratch.output_conv2(out)
return out
class DepthAnythingV2(nn.Module):
def __init__(
self,
encoder='vitl',
features=256,
out_channels=[256, 512, 1024, 1024],
use_bn=False,
use_clstoken=False
):
super(DepthAnythingV2, self).__init__()
self.intermediate_layer_idx = {
'vits': [2, 5, 8, 11],
'vitb': [2, 5, 8, 11],
'vitl': [4, 11, 17, 23],
'vitg': [9, 19, 29, 39]
}
self.encoder = encoder
self.pretrained = DINOv2(model_name=encoder)
self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
def forward(self, x):
patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
depth = self.depth_head(features, patch_h, patch_w)
depth = F.relu(depth)
return depth.squeeze(1)
@torch.no_grad()
def infer_image(self, raw_image, input_size=518):
image, (h, w) = self.image2tensor(raw_image, input_size)
depth = self.forward(image)
depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
return depth.cpu().numpy()
def image2tensor(self, raw_image, input_size=518):
transform = Compose([
Resize(
width=input_size,
height=input_size,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method='lower_bound',
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
])
h, w = raw_image.shape[:2]
image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
image = transform({'image': image})['image']
image = torch.from_numpy(image).unsqueeze(0)
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
image = image.to(DEVICE)
return image, (h, w)
================================================
FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/util/blocks.py
================================================
import torch.nn as nn
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
if len(in_shape) >= 4:
out_shape4 = out_shape
if expand:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
if len(in_shape) >= 4:
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
if len(in_shape) >= 4:
scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
return scratch
class ResidualConvUnit(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups=1
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
if self.bn == True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn == True:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn == True:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
class FeatureFusionBlock(nn.Module):
"""Feature fusion block.
"""
def __init__(
self,
features,
activation,
deconv=False,
bn=False,
expand=False,
align_corners=True,
size=None
):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups=1
self.expand = expand
out_features = features
if self.expand == True:
out_features = features // 2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
self.size=size
def forward(self, *xs, size=None):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
output = self.resConfUnit2(output)
if (size is None) and (self.size is None):
modifier = {"scale_factor": 2}
elif size is None:
modifier = {"size": self.size}
else:
modifier = {"size": size}
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
output = self.out_conv(output)
return output
================================================
FILE: camera_pose_annotation/depth_estimation/Depth-Anything/depth_anything_v2/util/transform.py
================================================
import numpy as np
import cv2
class Resize(object):
"""Resize sample to given size (width, height).
"""
def __init__(
self,
width,
height,
resize_target=True,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_AREA,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self.__width = width
self.__height = height
self.__resize_target = resize_target
self.__keep_aspect_ratio = keep_aspect_ratio
self.__multiple_of = ensure_multiple_of
self.__resize_method = resize_method
self.__image_interpolation_method = image_interpolation_method
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
if y < min_val:
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
return y
def get_size(self, width, height):
# determine new height and width
scale_height = self.__height / height
scale_width = self.__width / width
if self.__keep_aspect_ratio:
if self.__resize_method == "lower_bound":
# scale such that output size is lower bound
if scale_width > scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "upper_bound":
# scale such that output size is upper bound
if scale_width < scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "minimal":
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
elif self.__resize_method == "upper_bound":
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
elif self.__resize_method == "minimal":
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
return (new_width, new_height)
def __call__(self, sample):
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
# resize sample
sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
if self.__resize_target:
if "depth" in sample:
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
if "mask" in sample:
sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
return sample
class NormalizeImage(object):
"""Normlize image by given mean and std.
"""
def __init__(self, mean, std):
self.__mean = mean
self.__std = std
def __call__(self, sample):
sample["image"] = (sample["image"] - self.__mean) / self.__std
return sample
class PrepareForNet(object):
"""Prepare sample for usage as network input.
"""
def __init__(self):
pass
def __call__(self, sample):
image = np.transpose(sample["image"], (2, 0, 1))
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
if "depth" in sample:
depth = sample["depth"].astype(np.float32)
sample["depth"] = np.ascontiguousarray(depth)
if "mask" in sample:
sample["mask"] = sample["mask"].astype(np.float32)
sample["mask"] = np.ascontiguousarray(sample["mask"])
return sample
================================================
FILE: camera_pose_annotation/depth_estimation/Depth-Anything/inference.py
================================================
"""
Single-threaded inference script for Depth-Anything V2 model.
Processes images in a directory to generate depth maps sequentially.
"""
import argparse
import cv2
import glob
import numpy as np
import os
import torch
from depth_anything_v2.dpt import DepthAnythingV2
# Model configuration for different encoder variants
model_configs = {
"vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
"vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
"vitl": {
"encoder": "vitl",
"features": 256,
"out_channels": [256, 512, 1024, 1024],
},
"vitg": {
"encoder": "vitg",
"features": 384,
"out_channels": [1536, 1536, 1536, 1536],
},
}
def parse_args():
"""Parse command line arguments for depth estimation."""
parser = argparse.ArgumentParser(description="Depth Anything V2")
parser.add_argument("--input-size", type=int, default=518)
parser.add_argument("--dir_path", type=str, default="./vis_depth")
parser.add_argument(
"--encoder", type=str, default="vitl", choices=["vits", "vitb", "vitl", "vitg"]
)
parser.add_argument(
"--load-from",
type=str,
default="checkpoints/Depth-Anything/depth_anything_v2_vitl.pth",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# Auto-detect best available device
DEVICE = (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"Using device: {DEVICE}")
# Initialize Depth-Anything V2 model
depth_anything = DepthAnythingV2(**model_configs[args.encoder])
depth_anything.load_state_dict(torch.load(args.load_from, map_location="cpu"))
depth_anything = depth_anything.to(DEVICE).eval()
# Setup input and output paths
img_path = os.path.join(args.dir_path, "img")
out_path = os.path.join(args.dir_path, "depth-anything")
if not os.path.exists(out_path):
os.makedirs(out_path)
# Collect all image files
img_list = sorted(glob.glob(os.path.join(img_path, "*.jpg")))
img_list += sorted(glob.glob(os.path.join(img_path, "*.png")))
# Process each image sequentially
for k, img in enumerate(img_list):
print(f"Progress {k+1}/{len(img_list)}: {img}")
# Load and process image
raw_image = cv2.imread(img)
# Generate depth map
depth = depth_anything.infer_image(raw_image, args.input_size)
# Save depth map as numpy array
output_path = os.path.join(
out_path, os.path.splitext(os.path.basename(img))[0] + ".npy"
)
np.save(output_path, depth)
================================================
FILE: camera_pose_annotation/depth_estimation/Depth-Anything/inference_batch.py
================================================
"""
Distributed batch inference script for Depth-Anything V2 model.
Processes video frames to generate depth maps using distributed computing.
"""
import argparse
from datetime import timedelta
import cv2
import glob
import numpy as np
import pandas as pd
import os
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torchvision.transforms import Compose
import torch.nn.functional as F
from torchvision.transforms import ToTensor
from tqdm import tqdm
from depth_anything_v2.util.transform import Resize, NormalizeImage, PrepareForNet
from depth_anything_v2.dpt import DepthAnythingV2
# Model configuration for different encoder variants
model_configs = {
"vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
"vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
"vitl": {
"encoder": "vitl",
"features": 256,
"out_channels": [256, 512, 1024, 1024],
},
"vitg": {
"encoder": "vitg",
"features": 384,
"out_channels": [1536, 1536, 1536, 1536],
},
}
class ImageDataset(Dataset):
"""Dataset for loading and preprocessing images for depth estimation."""
def __init__(self, img_list, input_size):
self.img_list = img_list
self.input_size = input_size
self.transform = Compose(
[
Resize(
width=input_size,
height=input_size,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
]
)
def __len__(self):
return len(self.img_list)
def image2tensor(self, raw_image):
"""Convert raw image to tensor format for model input."""
h, w = raw_image.shape[:2]
image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
image = self.transform({"image": image})["image"]
image = torch.from_numpy(image)
return image, (h, w)
def __getitem__(self, idx):
"""Load and preprocess a single image with error handling."""
def inner_func(idx):
img_path = self.img_list[idx]
raw_image = cv2.imread(img_path)
image, (original_h, original_w) = self.image2tensor(raw_image)
data = {
"image": image,
"path": img_path,
"original_size": (original_h, original_w),
}
return data
while True:
try:
return inner_func(idx)
except Exception as e:
print(f"e: [{e}], path: {self.img_list[idx]}, try to get next idx")
idx += 1
if idx >= len(self.img_list):
raise StopIteration
def parse_args():
"""Parse command line arguments for depth estimation."""
parser = argparse.ArgumentParser(
description="Depth Anything V2 Distributed Inference"
)
parser.add_argument("--csv_path", type=str, help="Path to the csv file")
parser.add_argument("--input-size", type=int, default=518)
parser.add_argument("--output_dir", type=str, default="./output")
parser.add_argument(
"--encoder", type=str, default="vitl", choices=["vits", "vitb", "vitl", "vitg"]
)
parser.add_argument("--checkpoints_path", type=str, default="./checkpoints")
parser.add_argument("--bs", type=int, default=8, help="Batch size for inference")
parser.add_argument(
"--num_workers", type=int, default=4, help="Number of data loading workers"
)
return parser.parse_args()
def collate_fn(batch):
"""Custom collate function for batching data."""
return_batch = {}
for key in batch[0].keys():
if key == "image":
return_batch[key] = torch.stack([item[key] for item in batch], dim=0)
else:
return_batch[key] = [item[key] for item in batch]
return return_batch
def main():
args = parse_args()
# Initialize distributed environment
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
local_rank = dist.get_rank()
torch.cuda.set_device(local_rank)
DEVICE = f"cuda:{local_rank}"
# Load data list from CSV
df = pd.read_csv(args.csv_path)
img_list = []
for index, row in tqdm(
df.iterrows(), total=len(df), desc="Loading images", disable=(local_rank != 0)
):
img_dir = os.path.join(args.output_dir, row["id"], "img")
if not os.path.exists(img_dir):
print(f"Image directory not found: {img_dir}")
continue
img_list += sorted(glob.glob(os.path.join(img_dir, "*.jpg")))
img_list += sorted(glob.glob(os.path.join(img_dir, "*.png")))
# Create dataset and distributed sampler
dataset = ImageDataset(img_list, args.input_size)
sampler = DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=local_rank,
shuffle=False,
drop_last=False,
)
dataloader = DataLoader(
dataset,
batch_size=args.bs,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=collate_fn,
)
# Initialize Depth-Anything V2 model
depth_anything = DepthAnythingV2(**model_configs[args.encoder])
load_from = os.path.join(
args.checkpoints_path, f"Depth-Anything/depth_anything_v2_{args.encoder}.pth"
)
depth_anything.load_state_dict(torch.load(load_from, map_location="cpu"))
depth_anything = depth_anything.to(DEVICE).eval()
# Run inference and save depth maps
with torch.no_grad():
for batch in tqdm(
dataloader, desc="Depth inference", disable=(local_rank != 0)
):
images = batch["image"].to(DEVICE)
original_sizes = batch["original_size"]
paths = batch["path"]
# Forward pass through depth model
depth = depth_anything(images)
# Upsample to original image size
original_h, original_w = original_sizes[0]
depth = F.interpolate(
depth[:, None],
size=(original_h, original_w),
mode="bilinear",
align_corners=False,
)
# Save depth maps as numpy arrays
for i in range(depth.shape[0]):
depth_i = depth[i, 0].cpu().numpy()
img_path = paths[i]
output_filename = (
os.path.splitext(os.path.basename(img_path))[0] + ".npy"
)
output_dir = os.path.join(
os.path.dirname(os.path.dirname(img_path)), "depth-anything"
)
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, output_filename)
np.save(output_path, depth_i)
dist.destroy_process_group()
if __name__ == "__main__":
main()
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/__init__.py
================================================
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/inference.py
================================================
"""
Single-threaded inference script for UniDepth V2 model.
Processes images in a directory to generate depth maps and camera parameters sequentially.
"""
import argparse
import glob
import os
import cv2
import numpy as np
from PIL import Image
import torch
from unidepth.models import UniDepthV2
# Maximum dimension for image resizing
LONG_DIM = 640
def parse_args():
"""Parse command line arguments for UniDepth inference."""
parser = argparse.ArgumentParser()
parser.add_argument("--dir_path", type=str, default="./vis_depth")
parser.add_argument("--load-from", type=str, default="checkpoints/UniDepth")
return parser.parse_args()
def main():
args = parse_args()
# Initialize UniDepth V2 model
model = UniDepthV2.from_pretrained(args.load_from)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Setup input and output paths
img_path = os.path.join(args.dir_path, "img")
out_path = os.path.join(args.dir_path, "unidepth")
if not os.path.exists(out_path):
os.makedirs(out_path)
# Collect all image files
img_list = sorted(glob.glob(os.path.join(img_path, "*.jpg")))
img_list += sorted(glob.glob(os.path.join(img_path, "*.png")))
fovs = []
# Process each image sequentially
for img_path in img_list:
# Load and preprocess image
rgb = np.array(Image.open(img_path))[..., :3]
# Calculate target size maintaining aspect ratio
if rgb.shape[1] > rgb.shape[0]:
final_w, final_h = LONG_DIM, int(
round(LONG_DIM * rgb.shape[0] / rgb.shape[1])
)
else:
final_w, final_h = (
int(round(LONG_DIM * rgb.shape[1] / rgb.shape[0])),
LONG_DIM,
)
rgb = cv2.resize(rgb, (final_w, final_h), cv2.INTER_AREA)
# Convert to tensor format
rgb_torch = torch.from_numpy(rgb).permute(2, 0, 1)
# Predict depth and intrinsics
predictions = model.infer(rgb_torch)
# Calculate FOV (horizontal field of view) from predicted intrinsics
fov_ = np.rad2deg(
2
* np.arctan(
predictions["depth"].shape[-1]
/ (2 * predictions["intrinsics"][0, 0, 0].cpu().numpy())
)
)
depth = predictions["depth"][0, 0].cpu().numpy()
print(fov_)
fovs.append(fov_)
# Save depth map and FOV
np.savez(
os.path.join(out_path, img_path.split("/")[-1][:-4] + ".npz"),
depth=np.float32(depth),
fov=fov_,
)
if __name__ == "__main__":
main()
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/inference_batch.py
================================================
"""
Distributed batch inference script for UniDepth V2 model.
Processes video frames to generate depth maps and camera intrinsics using distributed computing.
"""
import argparse
from datetime import timedelta
import glob
import os
import cv2
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from tqdm import tqdm
from unidepth.models import UniDepthV2
class ImageDataset(Dataset):
"""Dataset for loading and preprocessing images for UniDepth inference."""
def __init__(self, img_list, input_size):
self.img_list = img_list
self.input_size = input_size
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
"""Load and preprocess a single image with error handling."""
def inner_func(idx):
img_path = self.img_list[idx]
rgb = np.array(Image.open(img_path))[..., :3]
h, w = rgb.shape[:2]
# Calculate target size maintaining aspect ratio
if w > h:
final_w, final_h = self.input_size, int(round(self.input_size * h / w))
else:
final_w, final_h = int(round(self.input_size * w / h)), self.input_size
rgb_resized = cv2.resize(rgb, (final_w, final_h), cv2.INTER_AREA)
rgb_torch = (
torch.from_numpy(rgb_resized).permute(2, 0, 1).float()
) # Convert to CHW format
return {
"image": rgb_torch,
"path": img_path,
}
while True:
try:
return inner_func(idx)
except Exception as e:
print(f"e: [{e}], path: {self.img_list[idx]}, try to get next idx")
idx = (idx + 1) % len(self.img_list)
if idx >= len(self.img_list):
raise StopIteration
def collate_fn(batch):
"""Custom collate function for batching data."""
return_batch = {}
for key in batch[0].keys():
if key == "image":
return_batch[key] = torch.stack([item[key] for item in batch], dim=0)
else:
return_batch[key] = [item[key] for item in batch]
return return_batch
def parse_args():
"""Parse command line arguments for UniDepth inference."""
parser = argparse.ArgumentParser()
parser.add_argument("--csv_path", type=str, help="Path to the csv file")
parser.add_argument("--output_dir", type=str, default="./output")
parser.add_argument("--checkpoints_path", type=str, default="./checkpoints")
parser.add_argument(
"--input_size", type=int, default=640, help="Input size for the model"
)
parser.add_argument("--bs", type=int, default=8, help="Inference batch size")
parser.add_argument(
"--num_workers", type=int, default=4, help="Data loading workers"
)
return parser.parse_args()
def main():
args = parse_args()
# Initialize distributed environment
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
local_rank = dist.get_rank()
torch.cuda.set_device(local_rank)
DEVICE = f"cuda:{local_rank}"
# Load data list from CSV
df = pd.read_csv(args.csv_path)
img_list = []
for index, row in tqdm(
df.iterrows(), total=len(df), desc="Loading images", disable=local_rank != 0
):
img_dir = os.path.join(args.output_dir, row["id"], "img")
if not os.path.exists(img_dir):
print(f"Image directory not found: {img_dir}")
continue
img_list += sorted(glob.glob(os.path.join(img_dir, "*.jpg")))
img_list += sorted(glob.glob(os.path.join(img_dir, "*.png")))
# Create dataset and distributed sampler
dataset = ImageDataset(img_list, args.input_size)
sampler = DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=local_rank,
shuffle=False,
drop_last=False,
)
dataloader = DataLoader(
dataset,
batch_size=args.bs,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=collate_fn,
)
# Initialize UniDepth V2 model
load_from = os.path.join(args.checkpoints_path, "UniDepth")
model = UniDepthV2.from_pretrained(load_from)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()
# Run inference and save results
with torch.no_grad():
for batch in tqdm(
dataloader, desc="Processing batches", disable=(local_rank != 0)
):
images = batch["image"].to(device)
paths = batch["path"]
# Model inference
predictions = model.infer(images)
# Process results for each sample
for i in range(len(paths)):
depth = predictions["depth"][i, 0].cpu().numpy() # [H, W]
intrinsics = predictions["intrinsics"][i].cpu().numpy()
focal_length = intrinsics[
0, 0
] # Assume principal point at center, take fx
w = depth.shape[-1] # Width
# Calculate FOV (horizontal field of view)
fov = np.rad2deg(2 * np.arctan(w / (2 * focal_length)))
# Save results
img_path = paths[i]
output_filename = (
os.path.splitext(os.path.basename(img_path))[0] + ".npz"
)
output_dir = os.path.join(
os.path.dirname(os.path.dirname(img_path)), "unidepth"
)
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, output_filename)
np.savez(output_path, depth=np.float32(depth), fov=fov)
if __name__ == "__main__":
main()
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/_2d3ds.py
================================================
from typing import Any
import torch
from unidepth.datasets.pipelines import Compose, PanoCrop, PanoRoll
from unidepth.datasets.sequence_dataset import SequenceDataset
class _2D3DS(SequenceDataset):
min_depth = 0.01
max_depth = 10.0
depth_scale = 512.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"2D3DS.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["cam2w", "camera_params"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
self.resizer = Compose(
[PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer]
)
def preprocess(self, results):
self.resizer.ctx = None
if self.test_mode:
for i, seq in enumerate(results["sequence_fields"]):
results[seq]["points"] = results[seq]["camera"].reconstruct(
results[seq]["depth"]
)
results[seq]["depth"] = results[seq]["points"][:, -1:]
results[seq]["gt_fields"].add("points")
return super().preprocess(results)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [False] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/_4dor.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class _4DOR(SequenceDataset):
min_depth = 0.01
max_depth = 10.0
depth_scale = 1000.0
default_fps = 10
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["4DOR.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["camera_params", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [False] * self.num_frames * self.num_copies
results["si"] = [False] * self.num_frames * self.num_copies
results["quality"] = [2] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/__init__.py
================================================
from ._2d3ds import _2D3DS
from ._4dor import _4DOR
from .a2d2 import A2D2
from .adt import ADT
from .aimotive import aiMotive
from .argoverse import Argoverse
from .argoverse2 import Argoverse2
from .arkit import ARKit
from .ase import ASE
from .base_dataset import BaseDataset
from .bdd import BDD
from .bedlam import BEDLAM
from .behave import Behave
from .blendedmvg import BlendedMVG
from .cityscape import Cityscape
from .ddad import DDAD
from .deep360 import Deep360
from .dense import DENSE
from .diode import DiodeIndoor, DiodeIndoor_F
from .dl3dv import DL3DV
from .driving_stereo import DrivingStereo
from .dtu_rmvd import DTURMVD
from .dummy import Dummy
from .dynamic_replica import DynReplica
from .eden import EDEN
from .eth3d import ETH3D, ETH3D_F
from .eth3d_rmvd import ETH3DRMVD
from .facedepth import FaceDepth
from .flsea import FLSea
from .futurehouse import FutureHouse
from .gibson import Gibson
from .hammer import HAMMER
from .hm3d import HM3D
from .hoi4d import HOI4D
from .hypersim import HyperSim
from .ibims import IBims, IBims_F
from .image_dataset import ImageDataset
from .ken_burns import KenBurns
from .kitti import KITTI, KITTIBenchmark
from .kitti360 import KITTI360
from .kitti_multi import KITTIMulti
from .kitti_rmvd import KITTIRMVD
from .lyft import Lyft
from .mapillary import Mapillary
from .matrix_city import MatrixCity
from .matterport3d import Matterport3D
from .megadepth import MegaDepth
from .megadepth_s import MegaDepthS
from .midair import MidAir
from .mip import MIP
from .ms2 import MS2
from .mvimgnet import MVImgNet
from .mvsynth import MVSynth
from .nerds360 import NeRDS360
from .niantic_mapfree import NianticMapFree
from .nuscenes import Nuscenes
from .nyuv2 import NYUv2Depth
from .point_odyssey import PointOdyssey
from .proteus import Proteus
from .samplers import DistributedSamplerNoDuplicate
from .scannet import ScanNet
from .scannetpp import ScanNetpp, ScanNetpp_F
from .sequence_dataset import SequenceDataset
from .sintel import Sintel
from .sunrgbd import SUNRGBD
from .synscapes import Synscapes
from .tartanair import TartanAir
from .taskonomy import Taskonomy
from .tat_rmvd import TATRMVD
from .theo import Theo
from .unrealstereo4k import UnrealStereo4K
from .urbansyn import UrbanSyn
from .utils import ConcatDataset, collate_fn, get_weights
from .vkitti import VKITTI
from .void import VOID
from .waymo import Waymo
from .wildrgbd import WildRGBD
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/a2d2.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class A2D2(ImageDataset):
min_depth = 0.01
max_depth = 120.0
depth_scale = 256.0
train_split = "train_clean.txt"
intrisics_file = "intrinsics.json"
hdf5_paths = ["a2d2.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
h5file.close()
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
intrinsics_val = torch.tensor(
intrinsics[os.path.join(*image_filename.split("/")[:2])]
).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val]
dataset.append(sample)
# if not self.test_mode:
# dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [False] * self.num_copies
results["quality"] = [1] * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/adt.py
================================================
from typing import Any
import torch
from unidepth.datasets.sequence_dataset import SequenceDataset
class ADT(SequenceDataset):
min_depth = 0.01
max_depth = 20.0
depth_scale = 1000.0
test_split = "val.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"ADT.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["camera_params", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields, # if not test_mode else [*decode_fields, "points"],
inplace_fields=inplace_fields,
**kwargs,
)
def preprocess(self, results):
self.resizer.ctx = None
for i, seq in enumerate(results["sequence_fields"]):
# Create a mask where the distance from the center is less than H/2
H, W = results[seq]["image"].shape[-2:]
x = torch.linspace(-W / 2 - 0.5, W / 2 + 0.5, W)
y = torch.linspace(-H / 2 - 0.5, H / 2 + 0.5, H)
xv, yv = torch.meshgrid(x, y, indexing="xy")
distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W)
results[seq]["validity_mask"] = distance_from_center < (H / 2) + 20
results[seq]["depth_mask"] = results[seq]["validity_mask"].clone()
results[seq]["mask_fields"].add("depth_mask")
results[seq]["mask_fields"].add("validity_mask")
return super().preprocess(results)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/aimotive.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class aiMotive(SequenceDataset):
min_depth = 0.01
max_depth = 100.0
depth_scale = 256.0
default_fps = 10
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["aiMotive.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["camera_params", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [False] * self.num_frames * self.num_copies
results["synthetic"] = [False] * self.num_frames * self.num_copies
results["quality"] = [2] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/argoverse.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class Argoverse(ImageDataset):
min_depth = 0.05
max_depth = 120.0
depth_scale = 256.0
test_split = "argo_val.txt"
train_split = "argo_train.txt"
intrisics_file = "argo_intrinsics.json"
hdf5_paths = ["argoverse11.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.crop = crop
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
h5file.close()
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val]
dataset.append(sample)
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/argoverse2.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class Argoverse2(SequenceDataset):
min_depth = 0.05
max_depth = 120.0
depth_scale = 256.0
test_split = "val.txt"
train_split = "train.txt"
sequences_file = "sequences_clean.json"
hdf5_paths = [f"AV2_viz.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [False] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/arkit.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class ARKit(SequenceDataset):
min_depth = 0.01
max_depth = 10.0
depth_scale = 1000.0
test_split = "Training.txt"
train_split = "Training.txt"
sequences_file = "sequences.json"
hdf5_paths = ["ARKitS.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["quality"] = [2] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/ase.py
================================================
from typing import Any
import torch
from unidepth.datasets.sequence_dataset import SequenceDataset
class ASE(SequenceDataset):
min_depth = 0.01
max_depth = 20.0
depth_scale = 1000.0
test_split = "val.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"ASE.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["camera_params", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def preprocess(self, results):
self.resizer.ctx = None
for i, seq in enumerate(results["sequence_fields"]):
# Create a mask where the distance from the center is less than H/2
H, W = results[seq]["image"].shape[-2:]
x = torch.linspace(-W / 2 - 0.5, W / 2 + 0.5, W)
y = torch.linspace(-H / 2 - 0.5, H / 2 + 0.5, H)
xv, yv = torch.meshgrid(x, y, indexing="xy")
distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W)
results[seq]["validity_mask"] = distance_from_center < (H / 2) + 20
results[seq]["mask_fields"].add("validity_mask")
return super().preprocess(results)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/base_dataset.py
================================================
import os
from abc import abstractmethod
from copy import deepcopy
from math import ceil, log
from typing import Any, Dict, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset
import unidepth.datasets.pipelines as pipelines
from unidepth.utils import (eval_3d, eval_depth, identity, is_main_process,
recursive_index, sync_tensor_across_gpus)
from unidepth.utils.constants import (IMAGENET_DATASET_MEAN,
IMAGENET_DATASET_STD,
OPENAI_DATASET_MEAN, OPENAI_DATASET_STD)
class BaseDataset(Dataset):
min_depth = 0.01
max_depth = 1000.0
def __init__(
self,
image_shape: Tuple[int, int],
split_file: str,
test_mode: bool,
benchmark: bool,
normalize: bool,
augmentations_db: Dict[str, Any],
resize_method: str,
mini: float,
num_copies: int = 1,
**kwargs,
) -> None:
super().__init__()
assert normalize in [None, "imagenet", "openai"]
self.split_file = split_file
self.test_mode = test_mode
self.data_root = os.environ["DATAROOT"]
self.image_shape = image_shape
self.resize_method = resize_method
self.mini = mini
self.num_frames = 1
self.num_copies = num_copies
self.metrics_store = {}
self.metrics_count = {}
if normalize == "imagenet":
self.normalization_stats = {
"mean": torch.tensor(IMAGENET_DATASET_MEAN),
"std": torch.tensor(IMAGENET_DATASET_STD),
}
elif normalize == "openai":
self.normalization_stats = {
"mean": torch.tensor(OPENAI_DATASET_MEAN),
"std": torch.tensor(OPENAI_DATASET_STD),
}
else:
self.normalization_stats = {
"mean": torch.tensor([0.0, 0.0, 0.0]),
"std": torch.tensor([1.0, 1.0, 1.0]),
}
for k, v in augmentations_db.items():
setattr(self, k, v)
if not self.test_mode:
self._augmentation_space()
self.masker = pipelines.AnnotationMask(
min_value=0.0,
max_value=self.max_depth if test_mode else None,
custom_fn=identity,
)
self.filler = pipelines.RandomFiller(noise_pad=True)
shape_mult = self.shape_constraints["shape_mult"]
self.image_shape = [
ceil(self.image_shape[0] / shape_mult) * shape_mult,
ceil(self.image_shape[1] / shape_mult) * shape_mult,
]
self.resizer = pipelines.ContextCrop(
image_shape=self.image_shape,
train_ctx_range=(1.0 / self.random_scale, 1.0 * self.random_scale),
test_min_ctx=self.test_context,
keep_original=test_mode,
shape_constraints=self.shape_constraints,
)
self.collecter = pipelines.Collect(
keys=["image_fields", "mask_fields", "gt_fields", "camera_fields"]
)
def __len__(self):
return len(self.dataset)
def pack_batch(self, results):
results["paddings"] = [
results[x]["paddings"][0] for x in results["sequence_fields"]
]
for fields_name in [
"image_fields",
"gt_fields",
"mask_fields",
"camera_fields",
]:
fields = results.get(fields_name)
packed = {
field: torch.cat(
[results[seq][field] for seq in results["sequence_fields"]]
)
for field in fields
}
results.update(packed)
return results
def unpack_batch(self, results):
for fields_name in [
"image_fields",
"gt_fields",
"mask_fields",
"camera_fields",
]:
fields = results.get(fields_name)
unpacked = {
field: {
seq: results[field][idx : idx + 1]
for idx, seq in enumerate(results["sequence_fields"])
}
for field in fields
}
results.update(unpacked)
return results
def _augmentation_space(self):
self.augmentations_dict = {
"Flip": pipelines.RandomFlip(prob=self.flip_p),
"Jitter": pipelines.RandomColorJitter(
(-self.random_jitter, self.random_jitter), prob=self.jitter_p
),
"Gamma": pipelines.RandomGamma(
(-self.random_gamma, self.random_gamma), prob=self.gamma_p
),
"Blur": pipelines.GaussianBlur(
kernel_size=13, sigma=(0.1, self.random_blur), prob=self.blur_p
),
"Grayscale": pipelines.RandomGrayscale(prob=self.grayscale_p),
}
def augment(self, results):
for name, aug in self.augmentations_dict.items():
results = aug(results)
return results
def prepare_depth_eval(self, inputs, preds):
new_preds = {}
keyframe_idx = getattr(self, "keyframe_idx", None)
slice_idx = slice(
keyframe_idx, keyframe_idx + 1 if keyframe_idx is not None else None
)
new_gts = inputs["depth"][slice_idx]
new_masks = inputs["depth_mask"][slice_idx].bool()
for key, val in preds.items():
if "depth" in key:
new_preds[key] = val[slice_idx]
return new_gts, new_preds, new_masks
def prepare_points_eval(self, inputs, preds):
new_preds = {}
new_gts = inputs["points"]
new_masks = inputs["depth_mask"].bool()
if "points_mask" in inputs:
new_masks = inputs["points_mask"].bool()
for key, val in preds.items():
if "points" in key:
new_preds[key] = val
return new_gts, new_preds, new_masks
def add_points(self, inputs):
inputs["points"] = inputs.get("camera_original", inputs["camera"]).reconstruct(
inputs["depth"]
)
return inputs
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def accumulate_metrics(
self,
inputs,
preds,
keyframe_idx=None,
metrics=["depth", "points", "flow_fwd", "pairwise"],
):
if "depth" in inputs and "points" not in inputs:
inputs = self.add_points(inputs)
available_metrics = []
for metric in metrics:
metric_in_gt = any((metric in k for k in inputs.keys()))
metric_in_pred = any((metric in k for k in preds.keys()))
if metric_in_gt and metric_in_pred:
available_metrics.append(metric)
if keyframe_idx is not None:
inputs = recursive_index(inputs, slice(keyframe_idx, keyframe_idx + 1))
preds = recursive_index(preds, slice(keyframe_idx, keyframe_idx + 1))
if "depth" in available_metrics:
depth_gt, depth_pred, depth_masks = self.prepare_depth_eval(inputs, preds)
self.accumulate_metrics_depth(depth_gt, depth_pred, depth_masks)
if "points" in available_metrics:
points_gt, points_pred, points_masks = self.prepare_points_eval(
inputs, preds
)
self.accumulate_metrics_3d(points_gt, points_pred, points_masks)
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def accumulate_metrics_depth(self, gts, preds, masks):
for eval_type, pred in preds.items():
log_name = eval_type.replace("depth", "").strip("-").strip("_")
if log_name not in self.metrics_store:
self.metrics_store[log_name] = {}
current_count = self.metrics_count.get(
log_name, torch.tensor([], device=gts.device)
)
new_count = masks.view(gts.shape[0], -1).sum(dim=-1)
self.metrics_count[log_name] = torch.cat([current_count, new_count])
for k, v in eval_depth(gts, pred, masks, max_depth=self.max_depth).items():
current_metric = self.metrics_store[log_name].get(
k, torch.tensor([], device=gts.device)
)
self.metrics_store[log_name][k] = torch.cat([current_metric, v])
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def accumulate_metrics_3d(self, gts, preds, masks):
thresholds = torch.linspace(
log(self.min_depth),
log(self.max_depth / 20),
steps=100,
device=gts.device,
).exp()
for eval_type, pred in preds.items():
log_name = eval_type.replace("points", "").strip("-").strip("_")
if log_name not in self.metrics_store:
self.metrics_store[log_name] = {}
current_count = self.metrics_count.get(
log_name, torch.tensor([], device=gts.device)
)
new_count = masks.view(gts.shape[0], -1).sum(dim=-1)
self.metrics_count[log_name] = torch.cat([current_count, new_count])
for k, v in eval_3d(gts, pred, masks, thresholds=thresholds).items():
current_metric = self.metrics_store[log_name].get(
k, torch.tensor([], device=gts.device)
)
self.metrics_store[log_name][k] = torch.cat([current_metric, v])
def get_evaluation(self, metrics=None):
metric_vals = {}
for eval_type in metrics if metrics is not None else self.metrics_store.keys():
assert self.metrics_store[eval_type]
cnts = sync_tensor_across_gpus(self.metrics_count[eval_type])
for name, val in self.metrics_store[eval_type].items():
# vals_r = (sync_tensor_across_gpus(val) * cnts / cnts.sum()).sum()
vals_r = sync_tensor_across_gpus(val).mean()
metric_vals[f"{eval_type}_{name}".strip("_")] = np.round(
vals_r.cpu().item(), 5
)
self.metrics_store[eval_type] = {}
self.metrics_count = {}
return metric_vals
def replicate(self, results):
for i in range(1, self.num_copies):
results[(0, i)] = {k: deepcopy(v) for k, v in results[(0, 0)].items()}
results["sequence_fields"].append((0, i))
return results
def log_load_dataset(self):
if is_main_process():
info = f"Loaded {self.__class__.__name__} with {len(self)} images."
print(info)
def pre_pipeline(self, results):
results["image_fields"] = results.get("image_fields", set())
results["gt_fields"] = results.get("gt_fields", set())
results["mask_fields"] = results.get("mask_fields", set())
results["sequence_fields"] = results.get("sequence_fields", set())
results["camera_fields"] = results.get("camera_fields", set())
results["dataset_name"] = (
[self.__class__.__name__] * self.num_frames * self.num_copies
)
results["depth_scale"] = [self.depth_scale] * self.num_frames * self.num_copies
results["si"] = [False] * self.num_frames * self.num_copies
results["dense"] = [False] * self.num_frames * self.num_copies
results["synthetic"] = [False] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
results["valid_camera"] = [True] * self.num_frames * self.num_copies
results["valid_pose"] = [True] * self.num_frames * self.num_copies
return results
def eval_mask(self, valid_mask):
return valid_mask
def chunk(self, dataset, chunk_dim=1, pct=1.0):
subsampled_datasets = [
x
for i in range(0, len(dataset), int(1 / pct * chunk_dim))
for x in dataset[i : i + chunk_dim]
]
return subsampled_datasets
@abstractmethod
def preprocess(self, results):
raise NotImplementedError
@abstractmethod
def postprocess(self, results):
raise NotImplementedError
@abstractmethod
def get_mapper(self):
raise NotImplementedError
@abstractmethod
def get_intrinsics(self, idx, image_name):
raise NotImplementedError
@abstractmethod
def get_extrinsics(self, idx, image_name):
raise NotImplementedError
@abstractmethod
def load_dataset(self):
raise NotImplementedError
@abstractmethod
def get_single_item(self, idx, sample=None, mapper=None):
raise NotImplementedError
@abstractmethod
def __getitem__(self, idx):
raise NotImplementedError
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/bdd.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class BDD(ImageDataset):
min_depth = 0.01
max_depth = 70.0
depth_scale = 256.0
test_split = "val.txt"
train_split = "train_clean.txt"
intrisics_file = "intrinsics.json"
hdf5_paths = ["BDD.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
intrinsics_val = torch.tensor(
intrinsics[os.path.join(*image_filename.split("/")[:2])]
).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val]
dataset.append(sample)
h5file.close()
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
if self.test_mode and not self.benchmark:
dataset = self.chunk(dataset, chunk_dim=1, pct=0.1)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["si"] = [True] * self.num_copies
results["valid_camera"] = [False] * self.num_copies
results["dense"] = [False] * self.num_copies
results["quality"] = [2] * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/bedlam.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class BEDLAM(SequenceDataset):
min_depth = 0.01
max_depth = 256.0
depth_scale = 1000.0
test_split = "train.txt"
train_split = "val.txt"
sequences_file = "sequences.json"
hdf5_paths = ["BEDLAM.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/behave.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class Behave(SequenceDataset):
min_depth = 0.01
max_depth = 10.0
depth_scale = 1000.0
default_fps = 10
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["Behave.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["camera_params", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [False] * self.num_frames * self.num_copies
results["si"] = [False] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/blendedmvg.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class BlendedMVG(SequenceDataset):
min_depth = 0.01
max_depth = 5000.0
depth_scale = 1000.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences_clean.json"
hdf5_paths = ["BlendedMVG_.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["si"] = [False] * self.num_frames * self.num_copies
results["quality"] = [2] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/cityscape.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class Cityscape(ImageDataset):
min_depth = 0.05
max_depth = 80.0
depth_scale = 256.0
test_split = "val.txt"
train_split = "train.txt"
intrisics_file = "intrinsics.json"
hdf5_paths = ["cityscape.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.crop = crop
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
h5file.close()
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val]
dataset.append(sample)
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["quality"] = [2] * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/ddad.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class DDAD(ImageDataset):
min_depth = 0.05
max_depth = 120.0
depth_scale = 256.0
test_split = "val.txt"
train_split = "train.txt"
intrisics_file = "intrinsics.json"
hdf5_paths = [f"ddad/ddad_{i}.hdf5" for i in range(8)]
def __init__(
self,
image_shape,
split_file,
test_mode,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii").strip("\n")
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
h5file.close()
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename, chunk_idx = line.strip().split(" ")
intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val, chunk_idx]
dataset.append(sample)
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def get_mapper(self):
return {
"image_filename": 0,
"depth_filename": 1,
"K": 2,
"chunk_idx": 3,
}
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [False] * self.num_copies
results["quality"] = [1] * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/deep360.py
================================================
from typing import Any
import torch
from unidepth.datasets.pipelines import Compose, PanoCrop, PanoRoll
from unidepth.datasets.sequence_dataset import SequenceDataset
class Deep360(SequenceDataset):
min_depth = 0.1
max_depth = 1000.0
depth_scale = 1000.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"Deep360.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["cam2w", "camera_params"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
self.resizer = Compose(
[PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer]
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/dense.py
================================================
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class DENSE(ImageDataset):
CAM_INTRINSIC = {
"ALL": torch.tensor(
[
[1177.8614, 0.0, 474.319027],
[0.0, 1177.8614, 224.275919],
[0.0, 0.0, 1.0],
]
)
}
min_depth = 0.05
max_depth = 80.0
depth_scale = 255.0
test_split = "train.txt"
train_split = "train.txt"
hdf5_paths = ["DENSE.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.intrisics = {}
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
h5file.close()
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
sample = [image_filename, depth_filename]
dataset.append(sample)
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def get_intrinsics(self, idx, image_name):
return self.CAM_INTRINSIC["ALL"].clone()
def get_mapper(self):
return {
"image_filename": 0,
"depth_filename": 1,
}
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [False] * self.num_copies
results["quality"] = [1] * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/diml.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class DIML(ImageDataset):
min_depth = 0.01
max_depth = 100.0
depth_scale = 256.0
test_split = "test.txt"
train_split = "train.txt"
intrisics_file = "intrinsics.json"
hdf5_paths = ["DIML.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.intrisics = {}
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
intrinsics_val = torch.tensor(
intrinsics[image_filename.split("/")[0]]
).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val]
dataset.append(sample)
h5file.close()
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_copies
results["quality"] = [2] * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/diode.py
================================================
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.sequence_dataset import SequenceDataset
from unidepth.datasets.utils import DatasetFromList
class DiodeIndoor(ImageDataset):
CAM_INTRINSIC = {
"ALL": torch.tensor([[886.81, 0, 512], [0, 927.06, 384], [0, 0, 1]])
}
min_depth = 0.01
max_depth = 25.0
depth_scale = 256.0
test_split = "val.txt"
train_split = "train.txt"
hdf5_paths = ["DiodeIndoor.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
# load annotations
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
h5file.close()
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
sample = [
image_filename,
depth_filename,
]
dataset.append(sample)
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def get_intrinsics(self, *args, **kwargs):
return self.CAM_INTRINSIC["ALL"].clone()
def get_mapper(self):
return {
"image_filename": 0,
"depth_filename": 1,
}
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_copies
results["quality"] = [1] * self.num_copies
return results
class DiodeIndoor_F(SequenceDataset):
min_depth = 0.01
max_depth = 25.0
depth_scale = 1000.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["DiodeIndoor-F.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, float],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["camera_params", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=(
decode_fields if not test_mode else [*decode_fields, "points"]
),
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
class DiodeOutdoor(ImageDataset):
CAM_INTRINSIC = {
"ALL": torch.tensor([[886.81, 0, 512], [0, 927.06, 384], [0, 0, 1]])
}
min_depth = 0.1
max_depth = 80.0
log_mean = 0
log_std = 1
test_split = "diode_outdoor_val.txt"
train_split = "diode_outdoor_train.txt"
hdf5_paths = ["diode.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
depth_scale=256,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.depth_scale = depth_scale
self.masker = AnnotationMask(
min_value=self.min_depth,
max_value=self.max_depth if test_mode else None,
custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x,
)
# load annotations
self.load_dataset()
def load_dataset(self):
self.h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_path),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(self.h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii")[:-1]
dataset = {"depth_filename": [], "image_filename": []}
for line in txt_string.split("\n"):
depth_filename = line.strip().split(" ")[1]
img_name = line.strip().split(" ")[0]
image_filename = img_name
dataset["depth_filename"].append(depth_filename)
dataset["image_filename"].append(image_filename)
self.dataset = pl.from_dict(dataset)
if not self.test_mode and self.mini:
self.dataset = self.dataset[::2]
class Diode(ImageDataset):
CAM_INTRINSIC = {
"ALL": torch.tensor([[886.81, 0, 512], [0, 927.06, 384], [0, 0, 1]])
}
log_mean = 0
log_std = 1
min_depth = 0.6
max_depth = 80.0
test_split = "diode_val.txt"
train_split = "diode_train.txt"
hdf5_paths = ["diode.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
depth_scale=256,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.depth_scale = depth_scale
self.masker = AnnotationMask(
min_value=self.min_depth,
max_value=self.max_depth if test_mode else None,
custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x,
)
# load annotations
self.load_dataset()
def load_dataset(self):
self.h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_path),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(self.h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii")[:-1]
dataset = {"depth_filename": [], "image_filename": []}
for line in txt_string.split("\n"):
depth_filename = line.strip().split(" ")[1]
image_filename = line.strip().split(" ")[0]
dataset["depth_filename"].append(depth_filename)
dataset["image_filename"].append(image_filename)
self.dataset = pl.from_dict(dataset)
if not self.test_mode and self.mini:
self.dataset = self.dataset[::2]
def get_intrinsics(self, *args, **kwargs):
return self.CAM_INTRINSIC["ALL"].clone()
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/dl3dv.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class DL3DV(SequenceDataset):
min_depth = 0.001
max_depth = 250.0
depth_scale = 512.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"DL3DVcv.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["camera_params", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["si"] = [True] * self.num_frames * self.num_copies
results["quality"] = [2] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/driving_stereo.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class DrivingStereo(ImageDataset):
min_depth = 0.05
max_depth = 80.0
depth_scale = 256.0
test_split = "drivingstereo_val.txt"
train_split = "drivingstereo_train.txt"
intrisics_file = "drivingstereo_intrinsics.json"
hdf5_paths = ["DrivingStereo.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.crop = crop
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val]
dataset.append(sample)
h5file.close()
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
if self.test_mode and not self.benchmark:
dataset = self.chunk(dataset, chunk_dim=1, pct=1.0)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [False] * self.num_copies
results["quality"] = [1] * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/dtu_rmvd.py
================================================
import json
import os
from typing import Any
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.pipelines import AnnotationMask, KittiCrop
from unidepth.datasets.sequence_dataset import SequenceDataset
from unidepth.datasets.utils import DatasetFromList
from unidepth.utils import identity
class DTURMVD(SequenceDataset):
min_depth = 0.05
max_depth = 3.0
depth_scale = 1000.0
default_fps = 6
test_split = "test.txt"
train_split = "test.txt"
sequences_file = "sequences.json"
hdf5_paths = ["dtu_rmvd.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
augmentations_db={},
normalize=True,
resize_method="hard",
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["si"] = [True] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/dummy.py
================================================
import numpy as np
import torch
from torch.utils.data import Dataset
class Dummy(Dataset):
train_split = None
test_split = None
def __init__(self, *args, **kwargs):
super().__init__()
self.dataset = np.arange(1_000_000)
def get_single_item(self, idx):
# results = {}
# results["cam2w"] = torch.eye(4).unsqueeze(0)
# results["K"] = torch.eye(3).unsqueeze(0)
# results["image"] = torch.zeros(1, 3, 1024, 1024).to(torch.uint8)
# results["depth"] = torch.zeros(1, 1, 1024, 1024).to(torch.float32)
return {
"x": {(0, 0): torch.rand(1, 3, 1024, 1024, dtype=torch.float32)},
"img_metas": {"val": torch.rand(1, 1024, dtype=torch.float32)},
}
def __getitem__(self, idx):
if isinstance(idx, (list, tuple)):
results = [self.get_single_item(i) for i in idx]
else:
results = self.get_single_item(idx)
return results
def __len__(self):
return len(self.dataset)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/dynamic_replica.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class DynReplica(SequenceDataset):
min_depth = 0.01
max_depth = 20.0
default_fps = 30.0
depth_scale = 512.0
test_split = "val.txt"
train_split = "train.txt"
sequences_file = "sequences_clean.json"
hdf5_paths = ["DynReplica.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/eden.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class EDEN(SequenceDataset):
min_depth = 0.1
max_depth = 100.0
depth_scale = 256.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"EDEN.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/eth3d.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.sequence_dataset import SequenceDataset
from unidepth.datasets.utils import DatasetFromList
class ETH3D(ImageDataset):
min_depth = 0.01
max_depth = 50.0
depth_scale = 1000.0
test_split = "train.txt"
train_split = "train.txt"
intrisics_file = "intrinsics.json"
hdf5_paths = ["ETH3D.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
h5file.close()
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val]
dataset.append(sample)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
class ETH3D_F(SequenceDataset):
min_depth = 0.05
max_depth = 60.0
depth_scale = 1000.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["ETH3D-F.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, float],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["camera_params", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=(
decode_fields if not test_mode else [*decode_fields, "points"]
),
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/eth3d_rmvd.py
================================================
import json
import os
from typing import Any
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.pipelines import AnnotationMask, KittiCrop
from unidepth.datasets.sequence_dataset import SequenceDataset
from unidepth.datasets.utils import DatasetFromList
from unidepth.utils import identity
class ETH3DRMVD(SequenceDataset):
min_depth = 0.01
max_depth = 50.0
depth_scale = 1000.0
default_fps = 6
test_split = "test.txt"
train_split = "test.txt"
sequences_file = "sequences.json"
hdf5_paths = ["eth3d_rmvd.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
augmentations_db={},
normalize=True,
resize_method="hard",
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/facedepth.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class FaceDepth(SequenceDataset):
min_depth = 0.01
max_depth = 10.0
depth_scale = 1000.0
default_fps = 10
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["FaceDepth.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/flsea.py
================================================
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class FLSea(ImageDataset):
CAM_INTRINSIC = {
"canyons": torch.tensor(
[
[1175.3913431656817, 0.0, 466.2595428966926],
[0.0, 1174.2805075232263, 271.2116633091501],
[0.0, 0.0, 1.0],
]
),
"red_sea": torch.tensor(
[
[1296.666758476217, 0.0, 501.50386149846],
[0.0, 1300.831316354508, 276.161712082695],
[0.0, 0.0, 1.0],
]
),
}
min_depth = 0.05
max_depth = 20.0
depth_scale = 1000.0
train_split = "train.txt"
hdf5_paths = ["FLSea.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=False,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.crop = crop
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
h5file.close()
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
sample = [image_filename, depth_filename]
dataset.append(sample)
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
if self.test_mode and not self.benchmark:
dataset = self.chunk(dataset, chunk_dim=1, pct=0.33)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def get_intrinsics(self, idx, image_name):
return self.CAM_INTRINSIC[image_name.split("/")[0]][:, :3].clone()
def get_mapper(self):
return {
"image_filename": 0,
"depth_filename": 1,
}
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_copies
results["quality"] = [2] * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/futurehouse.py
================================================
from typing import Any
import torch
from unidepth.datasets.pipelines import Compose, PanoCrop, PanoRoll
from unidepth.datasets.sequence_dataset import SequenceDataset
class FutureHouse(SequenceDataset):
min_depth = 0.01
max_depth = 10.0
depth_scale = 1000.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"FutureHouse.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["cam2w", "camera_params"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
self.resizer = Compose(
[PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer]
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/gibson.py
================================================
from typing import Any
import torch
from unidepth.datasets.pipelines import Compose, PanoCrop, PanoRoll
from unidepth.datasets.sequence_dataset import SequenceDataset
class Gibson(SequenceDataset):
min_depth = 0.01
max_depth = 10.0
depth_scale = 1000.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"Gibson.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["cam2w", "camera_params"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
self.resizer = Compose(
[PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer]
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/hammer.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class HAMMER(ImageDataset):
min_depth = 0.005
max_depth = 10.0
depth_scale = 1000.0
train_split = "test.txt"
test_split = "test.txt"
intrisics_file = "intrinsics.json"
hdf5_paths = ["hammer.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.crop = crop
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
h5file.close()
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val]
dataset.append(sample)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/hm3d.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class HM3D(SequenceDataset):
min_depth = 0.01
max_depth = 10.0
depth_scale = 1000.0
test_split = "val.txt"
train_split = "full.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"HM3D.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["quality"] = [2] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/hoi4d.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class HOI4D(SequenceDataset):
min_depth = 0.01
max_depth = 10.0
depth_scale = 1000.0
default_fps = 5
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["HOI4D.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [False] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/hrwsi.py
================================================
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class HRWSI(ImageDataset):
min_depth = 0.01
max_depth = 1000.0
depth_scale = 50.0
test_split = "val.txt"
train_split = "train.txt"
hdf5_paths = ["HRWSI.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
# with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f:
# f.write(txt_string)
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
sample = [
image_filename,
depth_filename,
]
dataset.append(sample)
h5file.close()
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["ssi"] = [True]
results["valid_camera"] = [False]
return results
def get_mapper(self):
return {
"image_filename": 0,
"depth_filename": 1,
}
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/hypersim.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class HyperSim(ImageDataset):
min_depth = 0.01
max_depth = 50.0
depth_scale = 1000.0
test_split = "val.txt"
train_split = "train.txt"
intrisics_file = "intrinsics.json"
hdf5_paths = [f"hypersim/hypersim_{i}.hdf5" for i in range(8)]
def __init__(
self,
image_shape,
split_file,
test_mode,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii").strip("\n")
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
# with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f:
# f.write(txt_string)
# with open(os.path.join(os.environ["TMPDIR"], self.intrisics_file), "w") as f:
# json.dump(intrinsics, f)
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename, chunk_idx = line.strip().split(" ")
intrinsics_val = torch.tensor(
intrinsics[os.path.join(*image_filename.split("/")[:2])]
).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val, chunk_idx]
dataset.append(sample)
h5file.close()
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
if self.test_mode and not self.benchmark: # corresponds to 712 images
dataset = self.chunk(dataset, chunk_dim=1, pct=0.1)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def get_mapper(self):
return {
"image_filename": 0,
"depth_filename": 1,
"K": 2,
"chunk_idx": 3,
}
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_copies
results["synthetic"] = [True] * self.num_copies
results["quality"] = [0] * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/ibims.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.sequence_dataset import SequenceDataset
from unidepth.datasets.utils import DatasetFromList
class IBims(ImageDataset):
min_depth = 0.005
max_depth = 25.0
depth_scale = 1000.0
train_split = "ibims_val.txt"
test_split = "ibims_val.txt"
intrisics_file = "ibims_intrinsics.json"
hdf5_paths = ["ibims.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.crop = crop
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
h5file.close()
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val]
dataset.append(sample)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_copies
results["quality"] = [1] * self.num_copies
return results
class IBims_F(SequenceDataset):
min_depth = 0.01
max_depth = 25.0
depth_scale = 1000.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["IBims-F.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, float],
resize_method: str,
mini: float,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["camera_params", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=(
decode_fields if not test_mode else [*decode_fields, "points"]
),
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/image_dataset.py
================================================
import io
import os
from time import time
from typing import Any, Dict, List, Tuple
import numpy as np
import tables
import torch
import torchvision
import torchvision.transforms.v2.functional as TF
from PIL import Image
from unidepth.datasets.base_dataset import BaseDataset
from unidepth.utils import is_main_process
from unidepth.utils.camera import BatchCamera, Pinhole
"""
Awful class for legacy reasons, we assume only pinhole cameras
And we "fake" sequences by setting sequence_fields to [(0, 0)] and cam2w as eye(4)
"""
class ImageDataset(BaseDataset):
def __init__(
self,
image_shape: Tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: Dict[str, Any],
resize_method: str,
mini: float,
benchmark: bool = False,
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.mapper = self.get_mapper()
def get_single_item(self, idx, sample=None, mapper=None):
sample = self.dataset[idx] if sample is None else sample
mapper = self.mapper if mapper is None else mapper
results = {
(0, 0): dict(
gt_fields=set(),
image_fields=set(),
mask_fields=set(),
camera_fields=set(),
)
}
results = self.pre_pipeline(results)
results["sequence_fields"] = [(0, 0)]
chunk_idx = (
int(sample[self.mapper["chunk_idx"]]) if "chunk_idx" in self.mapper else 0
)
h5_path = os.path.join(self.data_root, self.hdf5_paths[chunk_idx])
with tables.File(
h5_path,
mode="r",
libver="latest",
swmr=True,
) as h5file_chunk:
for key_mapper, idx_mapper in mapper.items():
if "image" not in key_mapper and "depth" not in key_mapper:
continue
value = sample[idx_mapper]
results[(0, 0)][key_mapper] = value
name = key_mapper.replace("_filename", "")
value_root = "/" + value
if "image" in key_mapper:
results[(0, 0)]["filename"] = value
file = h5file_chunk.get_node(value_root).read()
image = (
torchvision.io.decode_image(torch.from_numpy(file))
.to(torch.uint8)
.squeeze()
)
results[(0, 0)]["image_fields"].add(name)
results[(0, 0)][f"image_ori_shape"] = image.shape[-2:]
results[(0, 0)][name] = image[None, ...]
# collect camera information for the given image
name = name.replace("image_", "")
results[(0, 0)]["camera_fields"].update({"camera", "cam2w"})
K = self.get_intrinsics(idx, value)
if K is None:
K = torch.eye(3)
K[0, 0] = K[1, 1] = 0.7 * self.image_shape[1]
K[0, 2] = 0.5 * self.image_shape[1]
K[1, 2] = 0.5 * self.image_shape[0]
camera = Pinhole(K=K[None, ...].clone())
results[(0, 0)]["camera"] = BatchCamera.from_camera(camera)
results[(0, 0)]["cam2w"] = self.get_extrinsics(idx, value)[
None, ...
]
elif "depth" in key_mapper:
# start = time()
file = h5file_chunk.get_node(value_root).read()
depth = Image.open(io.BytesIO(file))
depth = TF.pil_to_tensor(depth).squeeze().to(torch.float32)
if depth.ndim == 3:
depth = depth[2] + depth[1] * 255 + depth[0] * 255 * 255
results[(0, 0)]["gt_fields"].add(name)
results[(0, 0)][f"depth_ori_shape"] = depth.shape
depth = (
depth.view(1, 1, *depth.shape).contiguous() / self.depth_scale
)
results[(0, 0)][name] = depth
results = self.preprocess(results)
if not self.test_mode:
results = self.augment(results)
results = self.postprocess(results)
return results
def preprocess(self, results):
results = self.replicate(results)
for i, seq in enumerate(results["sequence_fields"]):
self.resizer.ctx = None
results[seq] = self.resizer(results[seq])
num_pts = torch.count_nonzero(results[seq]["depth"] > 0)
if num_pts < 50:
raise IndexError(f"Too few points in depth map ({num_pts})")
for key in results[seq].get("image_fields", ["image"]):
results[seq][key] = results[seq][key].to(torch.float32) / 255
# update fields common in sequence
for key in ["image_fields", "gt_fields", "mask_fields", "camera_fields"]:
if key in results[(0, 0)]:
results[key] = results[(0, 0)][key]
results = self.pack_batch(results)
return results
def postprocess(self, results):
# normalize after because color aug requires [0,255]?
for key in results.get("image_fields", ["image"]):
results[key] = TF.normalize(results[key], **self.normalization_stats)
results = self.filler(results)
results = self.unpack_batch(results)
results = self.masker(results)
results = self.collecter(results)
return results
def __getitem__(self, idx):
try:
if isinstance(idx, (list, tuple)):
results = [self.get_single_item(i) for i in idx]
else:
results = self.get_single_item(idx)
except Exception as e:
print(f"Error loading sequence {idx} for {self.__class__.__name__}: {e}")
idx = np.random.randint(0, len(self.dataset))
results = self[idx]
return results
def get_intrinsics(self, idx, image_name):
idx_sample = self.mapper.get("K", 1000)
sample = self.dataset[idx]
if idx_sample >= len(sample):
return None
return sample[idx_sample]
def get_extrinsics(self, idx, image_name):
idx_sample = self.mapper.get("cam2w", 1000)
sample = self.dataset[idx]
if idx_sample >= len(sample):
return torch.eye(4)
return sample[idx_sample]
def get_mapper(self):
return {
"image_filename": 0,
"depth_filename": 1,
"K": 2,
}
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/ken_burns.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class KenBurns(ImageDataset):
min_depth = 0.05
max_depth = 50.0
depth_scale = 256.0
test_split = "val.txt"
train_split = "train.txt"
intrisics_file = "intrinsics.json"
hdf5_paths = [f"3dkenburns/3DKenBurns_{i}.hdf5" for i in range(8)]
def __init__(
self,
image_shape,
split_file,
test_mode,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii").strip("\n")
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
# with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f:
# f.write(txt_string)
# with open(os.path.join(os.environ["TMPDIR"], self.intrisics_file), "w") as f:
# json.dump(intrinsics, f)
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename, chunk_idx = line.strip().split(" ")
intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val, chunk_idx]
dataset.append(sample)
h5file.close()
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
if self.test_mode and not self.benchmark: # corresponds to 500 images
dataset = self.chunk(dataset, chunk_dim=1, pct=0.25)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def get_mapper(self):
return {
"image_filename": 0,
"depth_filename": 1,
"K": 2,
"chunk_idx": 3,
}
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_copies
results["synthetic"] = [True] * self.num_copies
results["quality"] = [0] * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/kitti.py
================================================
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.pipelines import AnnotationMask, KittiCrop
from unidepth.datasets.utils import DatasetFromList
from unidepth.utils import identity
class KITTI(ImageDataset):
CAM_INTRINSIC = {
"2011_09_26": torch.tensor(
[
[7.215377e02, 0.000000e00, 6.095593e02, 4.485728e01],
[0.000000e00, 7.215377e02, 1.728540e02, 2.163791e-01],
[0.000000e00, 0.000000e00, 1.000000e00, 2.745884e-03],
]
),
"2011_09_28": torch.tensor(
[
[7.070493e02, 0.000000e00, 6.040814e02, 4.575831e01],
[0.000000e00, 7.070493e02, 1.805066e02, -3.454157e-01],
[0.000000e00, 0.000000e00, 1.000000e00, 4.981016e-03],
]
),
"2011_09_29": torch.tensor(
[
[7.183351e02, 0.000000e00, 6.003891e02, 4.450382e01],
[0.000000e00, 7.183351e02, 1.815122e02, -5.951107e-01],
[0.000000e00, 0.000000e00, 1.000000e00, 2.616315e-03],
]
),
"2011_09_30": torch.tensor(
[
[7.070912e02, 0.000000e00, 6.018873e02, 4.688783e01],
[0.000000e00, 7.070912e02, 1.831104e02, 1.178601e-01],
[0.000000e00, 0.000000e00, 1.000000e00, 6.203223e-03],
]
),
"2011_10_03": torch.tensor(
[
[7.188560e02, 0.000000e00, 6.071928e02, 4.538225e01],
[0.000000e00, 7.188560e02, 1.852157e02, -1.130887e-01],
[0.000000e00, 0.000000e00, 1.000000e00, 3.779761e-03],
]
),
}
min_depth = 0.05
max_depth = 80.0
depth_scale = 256.0
log_mean = 2.5462
log_std = 0.5871
test_split = "kitti_eigen_test.txt"
train_split = "kitti_eigen_train.txt"
test_split_benchmark = "kitti_test.txt"
hdf5_paths = ["kitti.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.masker = AnnotationMask(
min_value=0.0,
max_value=self.max_depth if test_mode else None,
custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x,
)
self.test_mode = test_mode
self.crop = crop
self.cropper_base = KittiCrop(crop_size=(352, 1216))
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
h5file.close()
dataset = []
for line in txt_string.split("\n"):
image_filename = line.strip().split(" ")[0]
depth_filename = line.strip().split(" ")[1]
if depth_filename == "None":
self.invalid_depth_num += 1
continue
sample = [
image_filename,
depth_filename,
]
dataset.append(sample)
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def get_intrinsics(self, idx, image_name):
return self.CAM_INTRINSIC[image_name.split("/")[0]][:, :3].clone()
def preprocess(self, results):
results = self.replicate(results)
for i, seq in enumerate(results["sequence_fields"]):
self.resizer.ctx = None
results[seq] = self.cropper_base(results[seq])
results[seq] = self.resizer(results[seq])
num_pts = torch.count_nonzero(results[seq]["depth"] > 0)
if num_pts < 50:
raise IndexError(f"Too few points in depth map ({num_pts})")
for key in results[seq].get("image_fields", ["image"]):
results[seq][key] = results[seq][key].to(torch.float32) / 255
# update fields common in sequence
for key in ["image_fields", "gt_fields", "mask_fields", "camera_fields"]:
if key in results[(0, 0)]:
results[key] = results[(0, 0)][key]
results = self.pack_batch(results)
return results
def eval_mask(self, valid_mask, info={}):
"""Do grag_crop or eigen_crop for testing"""
mask_height, mask_width = valid_mask.shape[-2:]
eval_mask = torch.zeros_like(valid_mask)
if "garg" in self.crop:
eval_mask[
...,
int(0.40810811 * mask_height) : int(0.99189189 * mask_height),
int(0.03594771 * mask_width) : int(0.96405229 * mask_width),
] = 1
elif "eigen" in self.crop:
eval_mask[
...,
int(0.3324324 * mask_height) : int(0.91351351 * mask_height),
int(0.03594771 * mask_width) : int(0.96405229 * mask_width),
] = 1
return torch.logical_and(valid_mask, eval_mask)
def get_mapper(self):
return {
"image_filename": 0,
"depth_filename": 1,
}
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [False] * self.num_copies
results["quality"] = [1] * self.num_copies
return results
import json
class KITTIBenchmark(ImageDataset):
min_depth = 0.05
max_depth = 80.0
depth_scale = 256.0
test_split = "test_split.txt"
train_split = "val_split.txt"
intrinsics_file = "intrinsics.json"
hdf5_paths = ["kitti_benchmark.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=True,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.crop = crop
self.masker = AnnotationMask(
min_value=self.min_depth,
max_value=self.max_depth if test_mode else None,
custom_fn=lambda x, *args, **kwargs: x,
)
self.collecter = Collect(keys=["image_fields", "mask_fields", "gt_fields"])
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_path),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(self.h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
intrinsics = np.array(h5file[self.intrinsics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
h5file.close()
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
intrinsics = torch.tensor(
intrinsics[os.path.join(*image_filename.split("/")[:2])]
).squeeze()[:, :3]
sample = {
"image_filename": image_filename,
"depth_filename": depth_filename,
"K": intrinsics,
}
dataset.append(sample)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/kitti360.py
================================================
from typing import Any
import torch
from unidepth.datasets.sequence_dataset import SequenceDataset
class KITTI360(SequenceDataset):
min_depth = 0.01
max_depth = 80.0
depth_scale = 256.0
train_split = "train.txt"
test_split = "val_split.txt"
sequences_file = "sequences_split.json"
hdf5_paths = [f"KITTI360.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["camera_params", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=(
decode_fields if not test_mode else [*decode_fields, "points"]
),
inplace_fields=inplace_fields,
**kwargs,
)
def preprocess(self, results):
self.resizer.ctx = None
for i, seq in enumerate(results["sequence_fields"]):
# Create a mask where the distance from the center is less than H/2
H, W = results[seq]["image"].shape[-2:]
x = torch.linspace(-W / 2, W / 2, W)
y = torch.linspace(-H / 2, H / 2, H)
xv, yv = torch.meshgrid(x, y, indexing="xy")
distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W)
results[seq]["validity_mask"] = distance_from_center < (H / 2)
return super().preprocess(results)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [False] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/kitti_multi.py
================================================
import json
import os
from typing import Any
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.pipelines import AnnotationMask, KittiCrop
from unidepth.datasets.sequence_dataset import SequenceDataset
from unidepth.datasets.utils import DatasetFromList
from unidepth.utils import identity
class KITTIMulti(SequenceDataset):
min_depth = 0.05
max_depth = 80.0
depth_scale = 256.0
default_fps = 10.0
test_split = "val.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["KITTI_sequence.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
augmentations_db={},
normalize=True,
resize_method="hard",
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
self.test_mode = test_mode
self.crop = crop
self.cropper_base = KittiCrop(crop_size=(352, 1216))
self.masker = AnnotationMask(
min_value=0.0,
max_value=self.max_depth if test_mode else None,
custom_fn=self.eval_mask if test_mode else identity,
)
self.eval_last = True
def __len__(self):
if self.test_mode:
return 64 # FIXME: Hardcoded for now
return len(self.dataset)
def preprocess(self, results):
self.resizer.ctx = None
for i, seq in enumerate(results["sequence_fields"]):
results[seq] = self.cropper_base(results[seq])
results[seq] = self.resizer(results[seq])
for key in results[seq].get("image_fields", ["image"]):
results[seq][key] = results[seq][key].to(torch.float32) / 255
results.update({k: v for k, v in results[(0, 0)].items() if "fields" in k})
results = self.pack_batch(results)
return results
def eval_mask(self, valid_mask, info={}):
"""Do grag_crop or eigen_crop for testing"""
mask_height, mask_width = valid_mask.shape[-2:]
eval_mask = torch.zeros_like(valid_mask)
if "garg" in self.crop:
eval_mask[
...,
int(0.40810811 * mask_height) : int(0.99189189 * mask_height),
int(0.03594771 * mask_width) : int(0.96405229 * mask_width),
] = 1
elif "eigen" in self.crop:
eval_mask[
...,
int(0.3324324 * mask_height) : int(0.91351351 * mask_height),
int(0.03594771 * mask_width) : int(0.96405229 * mask_width),
] = 1
else:
return valid_mask
return torch.logical_and(valid_mask, eval_mask)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/kitti_rmvd.py
================================================
import json
import os
from typing import Any
import h5py
import numpy as np
import torch
from unidepth.datasets.pipelines import AnnotationMask, Compose, KittiCrop
from unidepth.datasets.sequence_dataset import SequenceDataset
from unidepth.utils import identity
class KITTIRMVD(SequenceDataset):
min_depth = 0.05
max_depth = 80.0
depth_scale = 256.0
default_fps = 10
test_split = "test.txt"
train_split = "test.txt"
sequences_file = "sequences.json"
hdf5_paths = ["kitti_rmvd.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
augmentations_db={},
normalize=True,
resize_method="hard",
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
self.crop = crop
self.resizer = Compose([KittiCrop(crop_size=(352, 1216)), self.resizer])
def eval_mask(self, valid_mask, info={}):
"""Do grag_crop or eigen_crop for testing"""
mask_height, mask_width = valid_mask.shape[-2:]
eval_mask = torch.zeros_like(valid_mask)
if "garg" in self.crop:
eval_mask[
...,
int(0.40810811 * mask_height) : int(0.99189189 * mask_height),
int(0.03594771 * mask_width) : int(0.96405229 * mask_width),
] = 1
elif "eigen" in self.crop:
eval_mask[
...,
int(0.3324324 * mask_height) : int(0.91351351 * mask_height),
int(0.03594771 * mask_width) : int(0.96405229 * mask_width),
] = 1
else:
return valid_mask
return torch.logical_and(valid_mask, eval_mask)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/lyft.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class Lyft(ImageDataset):
min_depth = 0.05
max_depth = 80.0
depth_scale = 256.0
test_split = "test.txt"
train_split = "train.txt"
intrisics_file = "intrinsics.json"
hdf5_paths = ["Lyft2.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
# with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f:
# f.write(txt_string)
# with open(os.path.join(os.environ["TMPDIR"], self.intrisics_file), "w") as f:
# json.dump(intrinsics, f)
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
sample = [
image_filename,
depth_filename,
intrinsics_val,
]
dataset.append(sample)
h5file.close()
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [False]
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/mapillary.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class Mapillary(ImageDataset):
min_depth = 0.01
max_depth = 70.0
depth_scale = 256.0
test_split = "mapillary_val.txt"
train_split = "mapillary_train_clean.txt"
intrisics_file = "intrinsics.json"
hdf5_paths = ["Mapillary.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.crop = crop
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val]
dataset.append(sample)
h5file.close()
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
if self.test_mode and not self.benchmark:
dataset = self.chunk(dataset, chunk_dim=1, pct=0.05)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["si"] = [True] * self.num_copies
results["valid_camera"] = [False] * self.num_copies
results["dense"] = [False] * self.num_copies
results["quality"] = [2] * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/matrix_city.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class MatrixCity(SequenceDataset):
min_depth = 0.01
max_depth = 200.0
depth_scale = 1000.0
test_split = "test.txt"
train_split = "train_full.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"MatrixCity.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/matterport3d.py
================================================
from typing import Any
import torch
from unidepth.datasets.pipelines import Compose, PanoCrop, PanoRoll
from unidepth.datasets.sequence_dataset import SequenceDataset
class Matterport3D(SequenceDataset):
min_depth = 0.01
max_depth = 10.0
depth_scale = 1000.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"Matterport3D.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["cam2w", "camera_params"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
self.resizer = Compose(
[PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer]
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/megadepth.py
================================================
import os
import h5py
import numpy as np
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class MegaDepth(ImageDataset):
min_depth = 0.01
max_depth = 1000.0
depth_scale = 50.0
test_split = "test.txt"
train_split = "train.txt"
hdf5_paths = ["MegaDepth.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
# with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f:
# f.write(txt_string)
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
sample = [
image_filename,
depth_filename,
]
dataset.append(sample)
h5file.close()
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
else:
dataset = self.chunk(dataset, chunk_dim=1, pct=0.5)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["ssi"] = [True]
results["valid_camera"] = [False]
results["dense"] = [False]
return results
def get_mapper(self):
return {
"image_filename": 0,
"depth_filename": 1,
}
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/megadepth_s.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class MegaDepthS(SequenceDataset):
min_depth = 0.001
max_depth = 10000.0
depth_scale = 512.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences_filter_clean.json"
hdf5_paths = ["MegaDepthS.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["intrinsics", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["si"] = [True] * self.num_frames * self.num_copies
results["dense"] = [False] * self.num_frames * self.num_copies
results["quality"] = [2] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/midair.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class MidAir(SequenceDataset):
min_depth = 0.1
max_depth = 1000.0
depth_scale = 1000.0
default_fps = 6
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["MidAir.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/mip.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class MIP(SequenceDataset):
min_depth = 0.01
max_depth = 100.0
depth_scale = 1000.0
default_fps = 10
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["MIP.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [False] * self.num_frames * self.num_copies
results["si"] = [True] * self.num_frames * self.num_copies
results["quality"] = [2] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/ms2.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class MS2(SequenceDataset):
min_depth = 0.01
max_depth = 100.0
depth_scale = 256.0
default_fps = 5
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["MS2.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [False] * self.num_frames * self.num_copies
results["synthetic"] = [False] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/mvimgnet.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
INVALID_SEQUENCES = [
"1/000121f2-0",
"15/1600ae56-0",
"26/000000f3-0",
"33/1d00e677-0",
"43/22008925-0",
"49/000147db-0",
"51/23002a43-0",
"51/23000916-0",
"108/000133ae-0",
"129/000037f2-0",
"141/17012545-0",
"141/1700f3de-0",
"152/1b00e061-0",
"154/1d00decb-0",
"154/1d017c1c-0",
"154/1d0019a5-0",
"154/1d00334d-0",
"154/1d012ed6-0",
"154/1d016b8a-0",
"154/1d016cc1-0",
"154/1d008d5f-0",
"159/000157f9-0",
"159/00000b96-0",
"159/000075c0-0",
"159/0000445c-0",
"159/000056a0-0",
"159/00010c68-0",
"159/0000573b-0",
"159/00002698-0",
"159/00008fca-0",
"159/00009ef8-0",
"159/00015f05-0",
"159/0000c6df-0",
"159/0000ee59-0",
"163/290159d2-0",
"163/29016c7c-0",
"163/2900239c-0",
"163/29002f7b-0",
"163/29014b05-0",
"163/29000196-0",
"163/2901750f-0",
"164/1b0145cf-0",
"164/1b00eb1d-0",
"164/1b00c28b-0",
"164/1b0110d0-0",
"164/1b00dd20-0",
"165/2600e15a-0",
"165/26008444-0",
"165/260145c5-0",
"165/26003a0c-0",
"165/260106ba-0",
"165/26001548-0",
"167/2a0092b0-0",
"167/2a014dbe-0",
"167/2a003ce6-0",
"169/1800c645-0",
"171/2500014d-0",
"176/1d0021c2-0",
"176/1d014abf-0",
"176/1d00e714-0",
"176/1d0159cb-0",
"176/1e016629-0",
"178/000102b8-0",
"191/23008fdb-0",
"191/2300187f-0",
"191/2300ae68-0",
"191/230076dd-0",
"191/24007d7e-0",
"192/000107b5-0",
"195/1f012359-0",
"195/1f00f751-0",
"195/1f011331-0",
"195/1e00d999-0",
"196/1c01304e-0",
"198/1a00e02f-0",
"198/050084ac-0",
"198/1a0075fa-0",
"199/1e001742-0",
"199/1e00116a-0",
"199/1e011d00-0",
"199/1e018040-0",
"199/1e001107-0",
]
class MVImgNet(SequenceDataset):
min_depth = 0.005
max_depth = 10.0
# weird scale issue, should be 1000, but avg depth is ~10meters...
depth_scale = 1000.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["MVImgNet.hdf5"]
invalid_sequences = INVALID_SEQUENCES
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["intrinsics", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["si"] = [True] * self.num_frames * self.num_copies
results["dense"] = [False] * self.num_frames * self.num_copies
results["quality"] = [2] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/mvsynth.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class MVSynth(SequenceDataset):
min_depth = 0.1
max_depth = 1000.0
depth_scale = 256.0
test_split = "val.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"MVSynth.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["si"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/nerds360.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class NeRDS360(SequenceDataset):
min_depth = 0.01
max_depth = 1000.0
depth_scale = 1000.0
test_split = "val.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["NeRDS360.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/niantic_mapfree.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class NianticMapFree(SequenceDataset):
min_depth = 0.1
max_depth = 250.0
depth_scale = 512.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"NianticMapFree.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["si"] = [True] * self.num_frames * self.num_copies
results["dense"] = [False] * self.num_frames * self.num_copies
results["quality"] = [2] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/nuscenes.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class Nuscenes(ImageDataset):
min_depth = 0.05
max_depth = 80.0
depth_scale = 256.0
test_split = "val.txt"
train_split = "train.txt"
intrisics_file = "intrinsics.json"
# hdf5_paths = ["Nuscenes2.hdf5"]
hdf5_paths = [f"nuscenes/nuscenes_{i}.hdf5" for i in range(8)]
def __init__(
self,
image_shape,
split_file,
test_mode,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii").strip("\n")
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename, chunk_idx = line.strip().split(" ")
intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val, chunk_idx]
dataset.append(sample)
h5file.close()
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=6, pct=self.mini)
if self.test_mode and not self.benchmark:
dataset = self.chunk(dataset, chunk_dim=6, pct=0.1)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def get_mapper(self):
return {
"image_filename": 0,
"depth_filename": 1,
"K": 2,
"chunk_idx": 3,
}
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [False] * self.num_copies
results["quality"] = [1] * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/nyuv2.py
================================================
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.pipelines import AnnotationMask
from unidepth.datasets.utils import DatasetFromList
from unidepth.utils import identity
class NYUv2Depth(ImageDataset):
CAM_INTRINSIC = {
"ALL": torch.tensor(
[
[5.1885790117450188e02, 0, 3.2558244941119034e02],
[0, 5.1946961112127485e02, 2.5373616633400465e02],
[0, 0, 1],
]
)
}
min_depth = 0.005
max_depth = 10.0
depth_scale = 1000.0
log_mean = 0.9140
log_std = 0.4825
test_split = "nyu_test.txt"
train_split = "nyu_train.txt"
hdf5_paths = ["nyuv2.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.masker = AnnotationMask(
min_value=0.0,
max_value=self.max_depth if test_mode else None,
custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x,
)
self.test_mode = test_mode
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
h5file.close()
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename, _ = line.strip().split(" ")
sample = [
image_filename,
depth_filename,
]
dataset.append(sample)
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_copies
return results
def get_intrinsics(self, idx, image_name):
return self.CAM_INTRINSIC["ALL"].clone()
def eval_mask(self, valid_mask, info={}):
border_mask = torch.zeros_like(valid_mask)
border_mask[..., 45:-9, 41:-39] = 1
return torch.logical_and(valid_mask, border_mask)
def get_mapper(self):
return {
"image_filename": 0,
"depth_filename": 1,
}
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_copies
results["quality"] = [2] * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/oasis.py
================================================
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class OASISv2(ImageDataset):
min_depth = 0.01
max_depth = 400.0
depth_scale = 1000.0
test_split = "val.txt"
train_split = "train.txt"
hdf5_paths = ["Oasis2.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
dataset = []
# with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f:
# f.write(txt_string)
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
sample = [image_filename, depth_filename]
dataset.append(sample)
h5file.close()
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["ssi"] = [True]
results["valid_camera"] = [False]
return results
def get_mapper(self):
return {
"image_filename": 0,
"depth_filename": 1,
}
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/pipelines/__init__.py
================================================
from .formating import AnnotationMask, Collect
from .transforms import (Compose, ContextCrop, Crop, GaussianBlur, KittiCrop,
PanoCrop, PanoRoll, RandomAutoContrast,
RandomBrightness, RandomColor, RandomColorJitter,
RandomContrast, RandomEqualize, RandomFiller,
RandomFlip, RandomGamma, RandomGrayscale,
RandomInvert, RandomMasking, RandomPosterize,
RandomSaturation, RandomSharpness, RandomShear,
RandomSolarize, RandomTranslate, Rotate)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/pipelines/formating.py
================================================
from collections.abc import Sequence
import numpy as np
import torch
class Collect(object):
def __init__(
self,
keys,
meta_keys=(
"filename",
"keyframe_idx",
"sequence_name",
"image_filename",
"depth_filename",
"image_ori_shape",
"camera",
"original_camera",
"sfm",
"image_shape",
"resized_shape",
"scale_factor",
"rotation",
"resize_factor",
"flip",
"flip_direction",
"dataset_name",
"paddings",
"max_value",
"log_mean",
"log_std",
"image_rescale",
"focal_rescale",
"depth_rescale",
),
):
self.keys = keys
self.meta_keys = meta_keys
def __call__(self, results):
data_keys = [key for field in self.keys for key in results.get(field, [])]
data = {
key: {
sequence_key: results[key][sequence_key]
for sequence_key in results["sequence_fields"]
}
for key in data_keys
}
data["img_metas"] = {
key: value for key, value in results.items() if key not in data_keys
}
return data
def __repr__(self):
return (
self.__class__.__name__ + f"(keys={self.keys}, meta_keys={self.meta_keys})"
)
class AnnotationMask(object):
def __init__(self, min_value, max_value, custom_fn=lambda x: x):
self.min_value = min_value
self.max_value = max_value
self.custom_fn = custom_fn
def __call__(self, results):
for key in results.get("gt_fields", []):
if key + "_mask" in results["mask_fields"]:
if "flow" in key:
for sequence_idx in results.get("sequence_fields", []):
boundaries = (results[key][sequence_idx] >= -1) & (
results[key][sequence_idx] <= 1
)
boundaries = boundaries[:, :1] & boundaries[:, 1:]
results[key + "_mask"][sequence_idx] = (
results[key + "_mask"][sequence_idx] & boundaries
)
continue
for sequence_idx in results.get("sequence_fields", []):
mask = results[key][sequence_idx] > self.min_value
if self.max_value is not None:
mask = mask & (results[key][sequence_idx] < self.max_value)
mask = self.custom_fn(mask, info=results)
if key + "_mask" not in results:
results[key + "_mask"] = {}
results[key + "_mask"][sequence_idx] = mask.bool()
results["mask_fields"].add(key + "_mask")
return results
def __repr__(self):
return (
self.__class__.__name__
+ f"(min_value={self.min_value}, max_value={ self.max_value})"
)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/pipelines/transforms.py
================================================
import os
import random
from copy import deepcopy
from math import ceil, exp, log, log2, log10, tanh
from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.v2.functional as TF
from unidepth.utils.geometric import downsample
class PanoCrop:
def __init__(self, crop_v=0.1):
self.crop_v = crop_v
def _crop_data(self, results, crop_size):
"""Function to randomly crop images, bounding boxes, masks, semantic
segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
crop_size (tuple): Expected absolute size after cropping, (h, w).
allow_negative_crop (bool): Whether to allow a crop that does not
contain any bbox area. Default to False.
Returns:
dict: Randomly cropped results, 'image_shape' key in result dict is
updated according to crop size.
"""
offset_w, offset_h = crop_size
left, top, right, bottom = offset_w[0], offset_h[0], offset_w[1], offset_h[1]
H, W = results["image"].shape[-2:]
for key in results.get("image_fields", ["image"]):
img = results[key][..., top : H - bottom, left : W - right]
results[key] = img
results["image_shape"] = tuple(img.shape)
for key in results.get("gt_fields", []):
results[key] = results[key][..., top : H - bottom, left : W - right]
for key in results.get("mask_fields", []):
results[key] = results[key][..., top : H - bottom, left : W - right]
results["camera"].crop(left, top, right, bottom)
return results
def __call__(self, results):
H, W = results["image"].shape[-2:]
crop_w = (0, 0)
crop_h = (int(H * self.crop_v), int(H * self.crop_v))
results = self._crop_data(results, (crop_w, crop_h))
return results
class PanoRoll:
def __init__(self, roll=[-0.5, 0.5]):
self.roll = roll
def __call__(self, results):
W = results["image"].shape[-1]
roll = random.randint(int(W * self.roll[0]), int(W * self.roll[1]))
for key in results.get("image_fields", ["image"]):
img = results[key]
img = torch.roll(img, roll, dims=-1)
results[key] = img
for key in results.get("gt_fields", []):
results[key] = torch.roll(results[key], roll, dims=-1)
for key in results.get("mask_fields", []):
results[key] = torch.roll(results[key], roll, dims=-1)
return results
class RandomFlip:
"""Flip the points & bbox.
If the input dict contains the key "flip", then the flag will be used,
otherwise it will be randomly decided by a ratio specified in the init
method.
Args:
flip_ratio_bev_horizontal (float, optional): The flipping probability
in horizontal direction. Defaults to 0.0.
flip_ratio_bev_vertical (float, optional): The flipping probability
in vertical direction. Defaults to 0.0.
"""
def __init__(self, direction="horizontal", prob=0.5, **kwargs):
self.flip_ratio = prob
valid_directions = ["horizontal", "vertical", "diagonal"]
if isinstance(direction, str):
assert direction in valid_directions
elif isinstance(direction, list):
assert set(direction).issubset(set(valid_directions))
else:
raise ValueError("direction must be either str or list of str")
self.direction = direction
def __call__(self, results):
"""Call function to flip points, values in the ``bbox3d_fields`` and
also flip 2D image and its annotations.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Flipped results, 'flip', 'flip_direction',
"""
if "flip" not in results:
if isinstance(self.direction, list):
# None means non-flip
direction_list = self.direction + [None]
else:
# None means non-flip
direction_list = [self.direction, None]
if isinstance(self.flip_ratio, list):
non_flip_ratio = 1 - sum(self.flip_ratio)
flip_ratio_list = self.flip_ratio + [non_flip_ratio]
else:
non_flip_ratio = 1 - self.flip_ratio
# exclude non-flip
single_ratio = self.flip_ratio / (len(direction_list) - 1)
flip_ratio_list = [single_ratio] * (len(direction_list) - 1) + [
non_flip_ratio
]
cur_dir = np.random.choice(direction_list, p=flip_ratio_list)
results["flip"] = cur_dir is not None
if "flip_direction" not in results:
results["flip_direction"] = cur_dir
if results["flip"]:
# flip image
if results["flip_direction"] != "vertical":
for key in results.get("image_fields", ["image"]):
results[key] = TF.hflip(results[key])
for key in results.get("mask_fields", []):
results[key] = TF.hflip(results[key])
for key in results.get("gt_fields", []):
results[key] = TF.hflip(results[key])
if "flow" in key: # flip u direction
results[key][:, 0] = -results[key][:, 0]
H, W = results["image"].shape[-2:]
results["camera"] = results["camera"].flip(
H=H, W=W, direction="horizontal"
)
# results["K"][..., 0, 2] = results["image"].shape[-1] - results["K"][..., 0, 2]
# flip: - t_x rotate around y by: pi - angle_y * 2
flip_transform = torch.tensor(
[[-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
dtype=torch.float32,
).unsqueeze(0)
repeats = (results["cam2w"].shape[0],) + (1,) * (
results["cam2w"].ndim - 1
)
results["cam2w"] = flip_transform.repeat(*repeats) @ results["cam2w"]
if results["flip_direction"] != "horizontal":
for key in results.get("image_fields", ["image"]):
results[key] = TF.vflip(results[key])
for key in results.get("mask_fields", []):
results[key] = TF.vflip(results[key])
for key in results.get("gt_fields", []):
results[key] = TF.vflip(results[key])
results["K"][..., 1, 2] = (
results["image"].shape[-2] - results["K"][..., 1, 2]
)
results["flip"] = [results["flip"]] * len(results["image"])
return results
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f" flip_ratio={self.flip_ratio})"
return repr_str
class Crop:
def __init__(
self,
crop_size,
crop_type="absolute",
crop_offset=(0, 0),
):
if crop_type not in [
"relative_range",
"relative",
"absolute",
"absolute_range",
]:
raise ValueError(f"Invalid crop_type {crop_type}.")
if crop_type in ["absolute", "absolute_range"]:
assert crop_size[0] > 0 and crop_size[1] > 0
assert isinstance(crop_size[0], int) and isinstance(crop_size[1], int)
else:
assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1
self.crop_size = crop_size
self.crop_type = crop_type
self.offset_h, self.offset_w = (
crop_offset[: len(crop_offset) // 2],
crop_offset[len(crop_offset) // 2 :],
)
def _get_crop_size(self, image_shape):
h, w = image_shape
if self.crop_type == "absolute":
return (min(self.crop_size[0], h), min(self.crop_size[1], w))
elif self.crop_type == "absolute_range":
assert self.crop_size[0] <= self.crop_size[1]
crop_h = np.random.randint(
min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1
)
crop_w = np.random.randint(
min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1
)
return crop_h, crop_w
elif self.crop_type == "relative":
crop_h, crop_w = self.crop_size
return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
elif self.crop_type == "relative_range":
crop_size = np.asarray(self.crop_size, dtype=np.float32)
crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size)
return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
def _crop_data(self, results, crop_size):
assert crop_size[0] > 0 and crop_size[1] > 0
for key in results.get("image_fields", ["image"]):
img = results[key]
img = TF.crop(
img, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1]
)
results[key] = img
results["image_shape"] = tuple(img.shape)
for key in results.get("gt_fields", []):
gt = results[key]
results[key] = TF.crop(
gt, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1]
)
# crop semantic seg
for key in results.get("mask_fields", []):
mask = results[key]
results[key] = TF.crop(
mask, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1]
)
results["K"][..., 0, 2] = results["K"][..., 0, 2] - self.offset_w[0]
results["K"][..., 1, 2] = results["K"][..., 1, 2] - self.offset_h[0]
return results
def __call__(self, results):
image_shape = results["image"].shape[-2:]
crop_size = self._get_crop_size(image_shape)
results = self._crop_data(results, crop_size)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(crop_size={self.crop_size}, "
repr_str += f"crop_type={self.crop_type}, "
return repr_str
class KittiCrop:
def __init__(self, crop_size):
self.crop_size = crop_size
def _crop_data(self, results, crop_size):
"""Function to randomly crop images, bounding boxes, masks, semantic
segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
crop_size (tuple): Expected absolute size after cropping, (h, w).
allow_negative_crop (bool): Whether to allow a crop that does not
contain any bbox area. Default to False.
Returns:
dict: Randomly cropped results, 'image_shape' key in result dict is
updated according to crop size.
"""
assert crop_size[0] > 0 and crop_size[1] > 0
for key in results.get("image_fields", ["image"]):
img = results[key]
h, w = img.shape[-2:]
offset_h, offset_w = int(h - self.crop_size[0]), int(
(w - self.crop_size[1]) / 2
)
# crop the image
img = TF.crop(img, offset_h, offset_w, crop_size[0], crop_size[1])
results[key] = img
results["image_shape"] = tuple(img.shape)
for key in results.get("gt_fields", []):
gt = results[key]
results[key] = TF.crop(gt, offset_h, offset_w, crop_size[0], crop_size[1])
# crop semantic seg
for key in results.get("mask_fields", []):
mask = results[key]
results[key] = TF.crop(mask, offset_h, offset_w, crop_size[0], crop_size[1])
results["camera"].crop(offset_w, offset_h)
return results
def __call__(self, results):
"""Call function to randomly crop images, bounding boxes, masks,
semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Randomly cropped results, 'image_shape' key in result dict is
updated according to crop size.
"""
results = self._crop_data(results, self.crop_size)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(crop_size={self.crop_size}, "
return repr_str
class RandomMasking:
def __init__(
self,
mask_ratio,
mask_patch=16,
prob=0.5,
warmup_steps=50000,
sampling="random",
curriculum=False,
):
self.mask_patch = mask_patch
self.prob = prob
self.mask_ratio = mask_ratio
self.warmup_steps = max(1, warmup_steps)
self.hard_bound = 1
self.idx = 0
self.curriculum = curriculum
self.sampling = sampling
self.low_bound = 0.0
self.up_bound = 0.0
def __call__(self, results):
B, _, H, W = results["image"].shape
device = results["image"].device
down_size = H // self.mask_patch, W // self.mask_patch
if np.random.random() > self.prob: # fill with dummy
return self._nop(results, down_size, device)
validity_mask = results["validity_mask"].float().reshape(B, -1, H, W)
validity_mask = F.interpolate(validity_mask, size=down_size).bool()
validity_mask = validity_mask.reshape(B, 1, *down_size)
is_random = self.is_warmup or results.get("guidance") is None
if not is_random:
guidance = F.interpolate(results["guidance"], size=(H, W), mode="bilinear")
results["guidance"] = -F.max_pool2d(
-guidance, kernel_size=self.mask_patch, stride=self.mask_patch
)
if is_random and self.sampling == "inverse":
sampling = self.inverse_sampling
elif is_random and self.sampling == "random":
sampling = self.random_sampling
else:
sampling = self.guided_sampling
mask_ratio = np.random.uniform(self.low_bound, self.up_bound)
for key in results.get("image_fields", ["image"]):
mask = sampling(results, mask_ratio, down_size, validity_mask, device)
results[key + "_mask"] = mask
return results
def _nop(self, results, down_size, device):
B = results["image"].shape[0]
for key in results.get("image_fields", ["image"]):
mask_blocks = torch.zeros(size=(B, 1, *down_size), device=device)
results[key + "_mask"] = mask_blocks
return results
def random_sampling(self, results, mask_ratio, down_size, validity_mask, device):
B = results["image"].shape[0]
prob_blocks = torch.rand(size=(B, 1, *down_size), device=device)
mask_blocks = torch.logical_and(prob_blocks < mask_ratio, validity_mask)
return mask_blocks
def inverse_sampling(self, results, mask_ratio, down_size, validity_mask, device):
# from PIL import Image
# from unidepth.utils import colorize
def area_sample(depth, fx, fy):
dtype = depth.dtype
B = depth.shape[0]
H, W = down_size
depth = downsample(depth, depth.shape[-2] // H)
depth[depth > 200] = 50 # set sky as if depth 50 meters
pixel_area3d = depth / torch.sqrt(fx * fy)
# Set invalid as -1 (no div problem) -> then clip to 0.0
pixel_area3d[depth == 0.0] = -1
prob_density = (1 / pixel_area3d).clamp(min=0.0).square()
prob_density = prob_density / prob_density.sum(
dim=(-1, -2), keepdim=True
).clamp(min=1e-5)
# Image.fromarray((prob_density[0] * 255 * 100).clamp(min=0.0, max=255.0).squeeze().cpu().byte().numpy()).save("prob_density.png")
# Sample locations based on prob_density
prob_density_flat = prob_density.view(B, -1)
# Get the avgerage valid locations, of those we mask self.mask_ratio
valid_locations = (prob_density_flat > 0).to(dtype).sum(dim=1)
masks = []
for i in range(B):
num_samples = int(valid_locations[i] * mask_ratio)
mask = torch.zeros_like(prob_density_flat[i])
# Sample indices
if num_samples > 0:
sampled_indices_flat = torch.multinomial(
prob_density_flat[i], num_samples, replacement=False
)
mask.scatter_(0, sampled_indices_flat, 1)
masks.append(mask)
return torch.stack(masks).bool().view(B, 1, H, W)
def random_sample(validity_mask):
prob_blocks = torch.rand(
size=(validity_mask.shape[0], 1, *down_size), device=device
)
mask = torch.logical_and(prob_blocks < mask_ratio, validity_mask)
return mask
fx = results["K"][..., 0, 0].view(-1, 1, 1, 1) / self.mask_patch
fy = results["K"][..., 1, 1].view(-1, 1, 1, 1) / self.mask_patch
valid = ~results["ssi"] & ~results["si"] & results["valid_camera"]
mask_blocks = torch.zeros_like(validity_mask)
if valid.any():
out = area_sample(results["depth"][valid], fx[valid], fy[valid])
mask_blocks[valid] = out
if (~valid).any():
mask_blocks[~valid] = random_sample(validity_mask[~valid])
# mask_blocks_ = (mask_blocks.float() * 255).squeeze(1).byte().cpu().numpy()
# Image.fromarray(mask_blocks_[0]).save("mask1.png")
# Image.fromarray(mask_blocks_[-1]).save("mask2.png")
# dd = results["depth"]
# Image.fromarray(colorize(dd[0].squeeze().cpu().numpy())).save("depth1_p.png")
# Image.fromarray(colorize(dd[-1].squeeze().cpu().numpy())).save("depth2_p.png")
# dd = downsample(dd, dd.shape[-2] // down_size[0])
# Image.fromarray(colorize(dd[0].squeeze().cpu().numpy())).save("depth1.png")
# Image.fromarray(colorize(dd[-1].squeeze().cpu().numpy())).save("depth2.png")
# raise ValueError
return mask_blocks
def guided_sampling(self, results, mask_ratio, down_size, validity_mask, device):
# get the lowest (based on guidance) "mask_ratio" quantile of the patches that are in validity mask
B = results["image"].shape[0]
guidance = results["guidance"]
mask_blocks = torch.zeros(size=(B, 1, *down_size), device=device)
for b in range(B):
low_bound = torch.quantile(
guidance[b][validity_mask[b]], max(0.0, self.hard_bound - mask_ratio)
)
up_bound = torch.quantile(
guidance[b][validity_mask[b]], min(1.0, self.hard_bound)
)
mask_blocks[b] = torch.logical_and(
guidance[b] < up_bound, guidance[b] > low_bound
)
mask_blocks = torch.logical_and(mask_blocks, validity_mask)
return mask_blocks
def step(self):
self.idx += 1
# schedule hard from 1.0 to self.mask_ratio
if self.curriculum:
step = max(0, self.idx / self.warmup_steps / 2 - 0.5)
self.hard_bound = 1 - (1 - self.mask_ratio) * tanh(step)
self.up_bound = self.mask_ratio * tanh(step)
self.low_bound = 0.2 * tanh(step)
@property
def is_warmup(self):
return self.idx < self.warmup_steps
class Rotate:
def __init__(
self, angle, center=None, img_fill_val=(123.68, 116.28, 103.53), prob=0.5
):
if isinstance(img_fill_val, (float, int)):
img_fill_val = tuple([float(img_fill_val)] * 3)
elif isinstance(img_fill_val, tuple):
assert len(img_fill_val) == 3, (
"image_fill_val as tuple must "
f"have 3 elements. got {len(img_fill_val)}."
)
img_fill_val = tuple([float(val) for val in img_fill_val])
else:
raise ValueError("image_fill_val must be float or tuple with 3 elements.")
assert np.all(
[0 <= val <= 255 for val in img_fill_val]
), f"all elements of img_fill_val should between range [0,255] got {img_fill_val}."
assert 0 <= prob <= 1.0, f"The probability should be in range [0,1]bgot {prob}."
self.center = center
self.img_fill_val = img_fill_val
self.prob = prob
self.random = not isinstance(angle, (float, int))
self.angle = angle
def _rotate(self, results, angle, center=None, fill_val=0.0):
for key in results.get("image_fields", ["image"]):
img = results[key]
img_rotated = TF.rotate(
img,
angle,
center=center,
interpolation=TF.InterpolationMode.NEAREST_EXACT,
fill=self.img_fill_val,
)
results[key] = img_rotated.to(img.dtype)
results["image_shape"] = results[key].shape
for key in results.get("mask_fields", []):
results[key] = TF.rotate(
results[key],
angle,
center=center,
interpolation=TF.InterpolationMode.NEAREST_EXACT,
fill=fill_val,
)
for key in results.get("gt_fields", []):
results[key] = TF.rotate(
results[key],
angle,
center=center,
interpolation=TF.InterpolationMode.NEAREST_EXACT,
fill=fill_val,
)
def __call__(self, results):
if np.random.random() > self.prob:
return results
angle = (
(self.angle[1] - self.angle[0]) * np.random.rand() + self.angle[0]
if self.random
else np.random.choice([-1, 1], size=1) * self.angle
)
self._rotate(results, angle, None, fill_val=0.0)
results["rotation"] = angle
return results
class RandomColor:
def __init__(self, level, prob=0.5):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
def _adjust_color_img(self, results, factor=1.0):
for key in results.get("image_fields", ["image"]):
results[key] = TF.adjust_hue(results[key], factor) # .to(img.dtype)
def __call__(self, results):
if np.random.random() > self.prob:
return results
factor = (
((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
if self.random
else self.level
)
self._adjust_color_img(results, factor)
return results
class RandomSaturation:
def __init__(self, level, prob=0.5):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
def _adjust_saturation_img(self, results, factor=1.0):
for key in results.get("image_fields", ["image"]):
# NOTE defaultly the image should be BGR format
results[key] = TF.adjust_saturation(results[key], factor) # .to(img.dtype)
def __call__(self, results):
if np.random.random() > self.prob:
return results
factor = (
2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
if self.random
else 2**self.level
)
self._adjust_saturation_img(results, factor)
return results
class RandomSharpness:
def __init__(self, level, prob=0.5):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
def _adjust_sharpeness_img(self, results, factor=1.0):
for key in results.get("image_fields", ["image"]):
# NOTE defaultly the image should be BGR format
results[key] = TF.adjust_sharpness(results[key], factor) # .to(img.dtype)
def __call__(self, results):
if np.random.random() > self.prob:
return results
factor = (
2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
if self.random
else 2**self.level
)
self._adjust_sharpeness_img(results, factor)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(level={self.level}, "
repr_str += f"prob={self.prob})"
return repr_str
class RandomSolarize:
def __init__(self, level, prob=0.5):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
def _adjust_solarize_img(self, results, factor=255.0):
for key in results.get("image_fields", ["image"]):
results[key] = TF.solarize(results[key], factor) # .to(img.dtype)
def __call__(self, results):
if np.random.random() > self.prob:
return results
factor = (
((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
if self.random
else self.level
)
self._adjust_solarize_img(results, factor)
return results
class RandomPosterize:
def __init__(self, level, prob=0.5):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
def _posterize_img(self, results, factor=1.0):
for key in results.get("image_fields", ["image"]):
results[key] = TF.posterize(results[key], int(factor)) # .to(img.dtype)
def __call__(self, results):
if np.random.random() > self.prob:
return results
factor = (
((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
if self.random
else self.level
)
self._posterize_img(results, factor)
return results
class RandomEqualize:
def __init__(self, prob=0.5):
assert 0 <= prob <= 1.0, "The probability should be in range [0,1]."
self.prob = prob
def _imequalize(self, results):
for key in results.get("image_fields", ["image"]):
results[key] = TF.equalize(results[key]) # .to(img.dtype)
def __call__(self, results):
if np.random.random() > self.prob:
return results
self._imequalize(results)
return results
class RandomBrightness:
def __init__(self, level, prob=0.5):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
def _adjust_brightness_img(self, results, factor=1.0):
for key in results.get("image_fields", ["image"]):
results[key] = TF.adjust_brightness(results[key], factor) # .to(img.dtype)
def __call__(self, results, level=None):
if np.random.random() > self.prob:
return results
factor = (
2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
if self.random
else 2**self.level
)
self._adjust_brightness_img(results, factor)
return results
class RandomContrast:
def __init__(self, level, prob=0.5):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
def _adjust_contrast_img(self, results, factor=1.0):
for key in results.get("image_fields", ["image"]):
results[key] = TF.adjust_contrast(results[key], factor) # .to(img.dtype)
def __call__(self, results, level=None):
if np.random.random() > self.prob:
return results
factor = (
2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
if self.random
else 2**self.level
)
self._adjust_contrast_img(results, factor)
return results
class RandomGamma:
def __init__(self, level, prob=0.5):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
def __call__(self, results, level=None):
if np.random.random() > self.prob:
return results
factor = (self.level[1] - self.level[0]) * np.random.rand() + self.level[0]
for key in results.get("image_fields", ["image"]):
if "original" not in key:
results[key] = TF.adjust_gamma(results[key], 1 + factor)
return results
class RandomInvert:
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, results):
if np.random.random() > self.prob:
return results
for key in results.get("image_fields", ["image"]):
if "original" not in key:
results[key] = TF.invert(results[key]) # .to(img.dtype)
return results
class RandomAutoContrast:
def __init__(self, prob=0.5):
self.prob = prob
def _autocontrast_img(self, results):
for key in results.get("image_fields", ["image"]):
img = results[key]
results[key] = TF.autocontrast(img) # .to(img.dtype)
def __call__(self, results):
if np.random.random() > self.prob:
return results
self._autocontrast_img(results)
return results
class RandomShear(object):
def __init__(
self,
level,
prob=0.5,
direction="horizontal",
):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
self.direction = direction
def _shear_img(self, results, magnitude):
for key in results.get("image_fields", ["image"]):
img_sheared = TF.affine(
results[key],
angle=0.0,
translate=[0.0, 0.0],
scale=1.0,
shear=magnitude,
interpolation=TF.InterpolationMode.BILINEAR,
fill=0.0,
)
results[key] = img_sheared
def _shear_masks(self, results, magnitude):
for key in results.get("mask_fields", []):
mask_sheared = TF.affine(
results[key],
angle=0.0,
translate=[0.0, 0.0],
scale=1.0,
shear=magnitude,
interpolation=TF.InterpolationMode.NEAREST_EXACT,
fill=0.0,
)
results[key] = mask_sheared
def _shear_gt(
self,
results,
magnitude,
):
for key in results.get("gt_fields", []):
mask_sheared = TF.affine(
results[key],
angle=0.0,
translate=[0.0, 0.0],
scale=1.0,
shear=magnitude,
interpolation=TF.InterpolationMode.NEAREST_EXACT,
fill=0.0,
)
results[key] = mask_sheared
def __call__(self, results):
if np.random.random() > self.prob:
return results
magnitude = (
((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
if self.random
else np.random.choice([-1, 1], size=1) * self.level
)
if self.direction == "horizontal":
magnitude = [magnitude, 0.0]
else:
magnitude = [0.0, magnitude]
self._shear_img(results, magnitude)
self._shear_masks(results, magnitude)
self._shear_gt(results, magnitude)
return results
class RandomTranslate(object):
def __init__(
self,
range,
prob=0.5,
direction="horizontal",
):
self.range = range
self.prob = prob
self.direction = direction
def _translate_img(self, results, magnitude):
"""Shear the image.
Args:
results (dict): Result dict from loading pipeline.
magnitude (int | float): The magnitude used for shear.
direction (str): The direction for shear, either "horizontal"
or "vertical".
interpolation (str): Same as in :func:`mmcv.imshear`.
"""
for key in results.get("image_fields", ["image"]):
img_sheared = TF.affine(
results[key],
angle=0.0,
translate=magnitude,
scale=1.0,
shear=[0.0, 0.0],
interpolation=TF.InterpolationMode.BILINEAR,
fill=(123.68, 116.28, 103.53),
)
results[key] = img_sheared
def _translate_mask(self, results, magnitude):
"""Shear the masks."""
for key in results.get("mask_fields", []):
mask_sheared = TF.affine(
results[key],
angle=0.0,
translate=magnitude,
scale=1.0,
shear=[0.0, 0.0],
interpolation=TF.InterpolationMode.NEAREST_EXACT,
fill=0.0,
)
results[key] = mask_sheared
def _translate_gt(
self,
results,
magnitude,
):
"""Shear the segmentation maps."""
for key in results.get("gt_fields", []):
mask_sheared = TF.affine(
results[key],
angle=0.0,
translate=magnitude,
scale=1.0,
shear=[0.0, 0.0],
interpolation=TF.InterpolationMode.NEAREST_EXACT,
fill=0.0,
)
results[key] = mask_sheared
def __call__(self, results):
"""Call function to shear images, bounding boxes, masks and semantic
segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Sheared results.
"""
if np.random.random() > self.prob:
return results
magnitude = (self.range[1] - self.range[0]) * np.random.rand() + self.range[0]
if self.direction == "horizontal":
magnitude = [magnitude * results["image"].shape[1], 0]
else:
magnitude = [0, magnitude * results["image"].shape[0]]
self._translate_img(results, magnitude)
self._translate_mask(results, magnitude)
self._translate_gt(results, magnitude)
results["K"][..., 0, 2] = results["K"][..., 0, 2] + magnitude[0]
results["K"][..., 1, 2] = results["K"][..., 1, 2] + magnitude[1]
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(range={self.range}, "
repr_str += f"prob={self.prob}, "
repr_str += f"direction={self.direction}, "
return repr_str
class RandomColorJitter:
def __init__(self, level, prob=0.9):
self.level = level
self.prob = prob
self.list_transform = [
self._adjust_brightness_img,
# self._adjust_sharpness_img,
self._adjust_contrast_img,
self._adjust_saturation_img,
self._adjust_color_img,
]
def _adjust_contrast_img(self, results, factor=1.0):
"""Adjust the image contrast."""
for key in results.get("image_fields", ["image"]):
if "original" not in key:
img = results[key]
results[key] = TF.adjust_contrast(img, factor)
def _adjust_sharpness_img(self, results, factor=1.0):
"""Adjust the image contrast."""
for key in results.get("image_fields", ["image"]):
if "original" not in key:
img = results[key]
results[key] = TF.adjust_sharpness(img, factor)
def _adjust_brightness_img(self, results, factor=1.0):
"""Adjust the brightness of image."""
for key in results.get("image_fields", ["image"]):
if "original" not in key:
img = results[key]
results[key] = TF.adjust_brightness(img, factor)
def _adjust_saturation_img(self, results, factor=1.0):
"""Apply Color transformation to image."""
for key in results.get("image_fields", ["image"]):
if "original" not in key:
img = results[key]
results[key] = TF.adjust_saturation(img, factor / 2.0)
def _adjust_color_img(self, results, factor=1.0):
"""Apply Color transformation to image."""
for key in results.get("image_fields", ["image"]):
if "original" not in key:
img = results[key]
results[key] = TF.adjust_hue(img, (factor - 1.0) / 4.0)
def __call__(self, results):
"""Call function for color transformation.
Args:
results (dict): Results dict from loading pipeline.
Returns:
dict: Results after the transformation.
"""
random.shuffle(self.list_transform)
for op in self.list_transform:
if np.random.random() < self.prob:
factor = 1.0 + (
(self.level[1] - self.level[0]) * np.random.random() + self.level[0]
)
op(results, factor)
return results
class RandomGrayscale:
def __init__(self, prob=0.1, num_output_channels=3):
super().__init__()
self.prob = prob
self.num_output_channels = num_output_channels
def __call__(self, results):
if np.random.random() > self.prob:
return results
for key in results.get("image_fields", ["image"]):
if "original" not in key:
results[key] = TF.rgb_to_grayscale(
results[key], num_output_channels=self.num_output_channels
)
return results
def masked_nearest_interpolation(input, mask, target_size):
"""
Resize the depth map using bilinear interpolation, considering only valid pixels within NxN neighbors.
Args:
depth (torch.Tensor): The depth map tensor of shape (H, W).
mask (torch.Tensor): The mask tensor of shape (H, W) where 1 indicates valid depth and 0 indicates missing depth.
target_size (tuple): The desired output size (target_H, target_W).
Returns:
torch.Tensor: The resized depth map.
"""
B, C, H, W = input.shape
target_H, target_W = target_size
mask = mask.float()
# Generate a grid of coordinates in the target space
grid_y, grid_x = torch.meshgrid(
torch.linspace(0, H - 1, target_H),
torch.linspace(0, W - 1, target_W),
indexing="ij",
)
grid_y = grid_y.to(input.device)
grid_x = grid_x.to(input.device)
# Calculate the floor and ceil of the grid coordinates to get the bounding box
x0 = torch.floor(grid_x).long().clamp(0, W - 1)
x1 = (x0 + 1).clamp(0, W - 1)
y0 = torch.floor(grid_y).long().clamp(0, H - 1)
y1 = (y0 + 1).clamp(0, H - 1)
# Gather depth values at the four corners
Ia = input[..., y0, x0]
Ib = input[..., y1, x0]
Ic = input[..., y0, x1]
Id = input[..., y1, x1]
# Gather corresponding mask values
ma = mask[..., y0, x0]
mb = mask[..., y1, x0]
mc = mask[..., y0, x1]
md = mask[..., y1, x1]
# Calculate distances to each neighbor
# The distances are calculated from the center (grid_x, grid_y) to each corner
dist_a = (grid_x - x0.float()) ** 2 + (grid_y - y0.float()) ** 2 # Top-left
dist_b = (grid_x - x0.float()) ** 2 + (grid_y - y1.float()) ** 2 # Bottom-left
dist_c = (grid_x - x1.float()) ** 2 + (grid_y - y0.float()) ** 2 # Top-right
dist_d = (grid_x - x1.float()) ** 2 + (grid_y - y1.float()) ** 2 # Bottom-right
# Stack the neighbors, their masks, and distances
stacked_values = torch.stack(
[Ia, Ib, Ic, Id], dim=-1
) # Shape: (B, C, target_H, target_W, 4)
stacked_masks = torch.stack(
[ma, mb, mc, md], dim=-1
) # Shape: (B, 1, target_H, target_W, 4)
stacked_distances = torch.stack(
[dist_a, dist_b, dist_c, dist_d], dim=-1
) # Shape: (target_H, target_W, 4)
stacked_distances = (
stacked_distances.unsqueeze(0).unsqueeze(1).repeat(B, 1, 1, 1, 1)
) # Shape: (B, 1, target_H, target_W, 4)
# Set distances to infinity for invalid neighbors (so that invalid neighbors are never chosen)
stacked_distances[stacked_masks == 0] = float("inf")
# Find the index of the nearest valid neighbor (the one with the smallest distance)
nearest_indices = stacked_distances.argmin(dim=-1, keepdim=True)[
..., :1
] # Shape: (B, 1, target_H, target_W, 1)
# Select the corresponding depth value using the nearest valid neighbor index
interpolated_depth = torch.gather(
stacked_values, dim=-1, index=nearest_indices.repeat(1, C, 1, 1, 1)
).squeeze(-1)
# Set depth to zero where no valid neighbors were found
interpolated_depth = interpolated_depth * stacked_masks.sum(dim=-1).clip(
min=0.0, max=1.0
)
return interpolated_depth
class ContextCrop:
def __init__(
self,
image_shape,
keep_original=False,
test_min_ctx=1.0,
train_ctx_range=[0.5, 1.5],
shape_constraints={},
):
self.image_shape = image_shape
self.keep_original = keep_original
self.test_min_ctx = test_min_ctx
self.train_ctx_range = train_ctx_range
self.shape_mult = shape_constraints["shape_mult"]
self.sample = shape_constraints["sample"]
self.ratio_bounds = shape_constraints["ratio_bounds"]
pixels_min = shape_constraints["pixels_min"] / (
self.shape_mult * self.shape_mult
)
pixels_max = shape_constraints["pixels_max"] / (
self.shape_mult * self.shape_mult
)
self.pixels_bounds = (pixels_min, pixels_max)
self.ctx = None
def _transform_img(self, results, shapes):
for key in results.get("image_fields", ["image"]):
img = self.crop(results[key], **shapes)
img = TF.resize(
img,
results["resized_shape"],
interpolation=TF.InterpolationMode.BICUBIC,
antialias=True,
)
results[key] = img
def _transform_masks(self, results, shapes):
for key in results.get("mask_fields", []):
mask = self.crop(results[key].float(), **shapes).byte()
mask = masked_nearest_interpolation(
mask, mask > 0, results["resized_shape"]
)
results[key] = mask
def _transform_gt(self, results, shapes):
for key in results.get("gt_fields", []):
gt = self.crop(results[key], **shapes)
gt = masked_nearest_interpolation(gt, gt > 0, results["resized_shape"])
results[key] = gt
@staticmethod
def crop(img, height, width, top, left) -> torch.Tensor:
h, w = img.shape[-2:]
right = left + width
bottom = top + height
padding_ltrb = [
max(-left + min(0, right), 0),
max(-top + min(0, bottom), 0),
max(right - max(w, left), 0),
max(bottom - max(h, top), 0),
]
image_cropped = img[..., max(top, 0) : bottom, max(left, 0) : right]
return TF.pad(image_cropped, padding_ltrb)
def test_closest_shape(self, image_shape):
h, w = image_shape
input_ratio = w / h
if self.sample:
input_pixels = int(ceil(h / self.shape_mult * w / self.shape_mult))
pixels = max(
min(input_pixels, self.pixels_bounds[1]), self.pixels_bounds[0]
)
ratio = min(max(input_ratio, self.ratio_bounds[0]), self.ratio_bounds[1])
h = round((pixels / ratio) ** 0.5)
w = h * ratio
self.image_shape[0] = int(h) * self.shape_mult
self.image_shape[1] = int(w) * self.shape_mult
def _get_crop_shapes(self, image_shape, ctx=None):
h, w = image_shape
input_ratio = w / h
if self.keep_original:
self.test_closest_shape(image_shape)
ctx = 1.0
elif ctx is None:
ctx = float(
torch.empty(1)
.uniform_(self.train_ctx_range[0], self.train_ctx_range[1])
.item()
)
output_ratio = self.image_shape[1] / self.image_shape[0]
if output_ratio <= input_ratio: # out like 4:3 in like kitti
if (
ctx >= 1
): # fully in -> use just max_length with sqrt(ctx), here max is width
new_w = w * ctx**0.5
# sporge un po in una sola dim
# we know that in_width will stick out before in_height, partial overshoot (sporge)
# new_h > old_h via area -> new_h ** 2 * ratio_new = old_h ** 2 * ratio_old * ctx
elif output_ratio / input_ratio * ctx > 1:
new_w = w * ctx
else: # fully contained -> use area
new_w = w * (ctx * output_ratio / input_ratio) ** 0.5
new_h = new_w / output_ratio
else:
if ctx >= 1:
new_h = h * ctx**0.5
elif input_ratio / output_ratio * ctx > 1:
new_h = h * ctx
else:
new_h = h * (ctx * input_ratio / output_ratio) ** 0.5
new_w = new_h * output_ratio
return (int(ceil(new_h - 0.5)), int(ceil(new_w - 0.5))), ctx
def __call__(self, results):
h, w = results["image"].shape[-2:]
results["image_ori_shape"] = (h, w)
results.get("mask_fields", set()).add("validity_mask")
if "validity_mask" not in results:
results["validity_mask"] = torch.ones(
(results["image"].shape[0], 1, h, w),
dtype=torch.uint8,
device=results["image"].device,
)
n_iter = 1 if self.keep_original or not self.sample else 100
min_valid_area = 0.5
results["camera_fields"].add("camera_original")
results["camera_original"] = results["camera"].clone()
max_hfov, max_vfov = results["camera"].max_fov[0] # it is a 1-dim list
ctx = None
for ii in range(n_iter):
(height, width), ctx = self._get_crop_shapes((h, w), ctx=self.ctx or ctx)
margin_h = h - height
margin_w = w - width
# keep it centered in y direction
top = margin_h // 2
left = margin_w // 2
if not self.keep_original:
left = left + np.random.randint(
-self.shape_mult // 2, self.shape_mult // 2 + 1
)
top = top + np.random.randint(
-self.shape_mult // 2, self.shape_mult // 2 + 1
)
right = left + width
bottom = top + height
x_zoom = self.image_shape[0] / height
paddings = [
max(-left + min(0, right), 0),
max(bottom - max(h, top), 0),
max(right - max(w, left), 0),
max(-top + min(0, bottom), 0),
]
valid_area = (
h
* w
/ (h + paddings[1] + paddings[3])
/ (w + paddings[0] + paddings[2])
)
new_hfov, new_vfov = results["camera_original"].get_new_fov(
new_shape=(height, width), original_shape=(h, w)
)[0]
if (
valid_area >= min_valid_area
and new_hfov < max_hfov
and new_vfov < max_vfov
):
results["camera"] = results["camera"].crop(
left, top, right=w - right, bottom=h - bottom
)
results["camera"] = results["camera"].resize(x_zoom)
break
ctx = (
ctx * 0.96
) # if not enough valid area, try again with less ctx (more zoom)
# save ctx for next iteration of sequences?
self.ctx = ctx
results["resized_shape"] = self.image_shape
results["paddings"] = paddings # left ,top ,right, bottom
results["image_rescale"] = x_zoom
results["scale_factor"] = results.get("scale_factor", 1.0) * x_zoom
shapes = dict(height=height, width=width, top=top, left=left)
self._transform_img(results, shapes)
if not self.keep_original:
self._transform_gt(results, shapes)
self._transform_masks(results, shapes)
else:
# only validity_mask (rgb's masks follows rgb transform) #FIXME
mask = results["validity_mask"].float()
mask = self.crop(mask, **shapes).byte()
mask = TF.resize(
mask,
results["resized_shape"],
interpolation=TF.InterpolationMode.NEAREST,
)
results["validity_mask"] = mask
# keep original images before photo-augment
results["image_original"] = results["image"].clone()
results["image_fields"].add(
*[
field.replace("image", "image_original")
for field in results["image_fields"]
]
)
# repeat for batch resized shape and paddings
results["paddings"] = [results["paddings"]] * results["image"].shape[0]
results["resized_shape"] = [results["resized_shape"]] * results["image"].shape[
0
]
return results
class RandomFiller:
def __init__(self, *args, **kwargs):
super().__init__()
def _transform(self, results):
def fill_noise(size, device):
return torch.normal(0, 2.0, size=size, device=device)
def fill_black(size, device):
return -4 * torch.ones(size, device=device, dtype=torch.float32)
def fill_white(size, device):
return 4 * torch.ones(size, device=device, dtype=torch.float32)
def fill_zero(size, device):
return torch.zeros(size, device=device, dtype=torch.float32)
B, C = results["image"].shape[:2]
mismatch = B // results["validity_mask"].shape[0]
if mismatch:
results["validity_mask"] = results["validity_mask"].repeat(
mismatch, 1, 1, 1
)
validity_mask = results["validity_mask"].repeat(1, C, 1, 1).bool()
filler_fn = np.random.choice([fill_noise, fill_black, fill_white, fill_zero])
for key in results.get("image_fields", ["image"]):
results[key][~validity_mask] = filler_fn(
size=results[key][~validity_mask].shape, device=results[key].device
)
def __call__(self, results):
# generate mask for filler
if "validity_mask" not in results:
paddings = results.get("padding_size", [0] * 4)
height, width = results["image"].shape[-2:]
results.get("mask_fields", []).add("validity_mask")
results["validity_mask"] = torch.zeros_like(results["image"][:, :1])
results["validity_mask"][
...,
paddings[1] : height - paddings[3],
paddings[0] : width - paddings[2],
] = 1.0
self._transform(results)
return results
class GaussianBlur:
def __init__(self, kernel_size, sigma=(0.1, 2.0), prob=0.9):
super().__init__()
self.kernel_size = kernel_size
self.sigma = sigma
self.prob = prob
self.padding = kernel_size // 2
def apply(self, x, kernel):
# Pad the input tensor
x = F.pad(
x, (self.padding, self.padding, self.padding, self.padding), mode="reflect"
)
# Apply the convolution with the Gaussian kernel
return F.conv2d(x, kernel, stride=1, padding=0, groups=x.size(1))
def _create_kernel(self, sigma):
# Create a 1D Gaussian kernel
kernel_1d = torch.exp(
-torch.arange(-self.padding, self.padding + 1) ** 2 / (2 * sigma**2)
)
kernel_1d = kernel_1d / kernel_1d.sum()
# Expand the kernel to 2D and match size of the input
kernel_2d = kernel_1d.unsqueeze(0) * kernel_1d.unsqueeze(1)
kernel_2d = kernel_2d.view(1, 1, self.kernel_size, self.kernel_size).expand(
3, 1, -1, -1
)
return kernel_2d
def __call__(self, results):
if np.random.random() > self.prob:
return results
sigma = (self.sigma[1] - self.sigma[0]) * np.random.rand() + self.sigma[0]
kernel = self._create_kernel(sigma)
for key in results.get("image_fields", ["image"]):
if "original" not in key:
results[key] = self.apply(results[key], kernel)
return results
class Compose:
def __init__(self, transforms):
self.transforms = deepcopy(transforms)
def __call__(self, results):
for t in self.transforms:
results = t(results)
return results
def __setattr__(self, name: str, value) -> None:
super().__setattr__(name, value)
for t in self.transforms:
setattr(t, name, value)
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += f"\n {t}"
format_string += "\n)"
return format_string
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/point_odyssey.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class PointOdyssey(SequenceDataset):
min_depth = 0.01
max_depth = 250.0
depth_scale = 1000.0
test_split = "test.txt"
train_split = "train.txt"
sequences_file = "sequences_clean.json"
hdf5_paths = [f"PointOdyssey.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/proteus.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class Proteus(SequenceDataset):
min_depth = 0.01
max_depth = 10.0
depth_scale = 1000.0
default_fps = 5
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["Proteus.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/samplers copy.py
================================================
import itertools
import warnings
from operator import itemgetter
from typing import Any, Optional
import numpy as np
import torch
from torch.utils.data import Sampler
from unidepth.utils import get_dist_info
def _get_numpy_dtype(size: int) -> Any:
return np.int32 if size <= 2**31 else np.int64
def _get_torch_dtype(size: int) -> Any:
return torch.int32 if size <= 2**31 else torch.int64
def _generate_randperm_indices(*, size: int, generator: torch.Generator):
"""Generate the indices of a random permutation."""
dtype = _get_torch_dtype(size)
# This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921
perm = torch.arange(size, dtype=dtype)
for i in range(size):
j = torch.randint(i, size, size=(1,), generator=generator).item()
# Always swap even if no-op
value = perm[j].item()
perm[j] = perm[i].item()
perm[i] = value
yield value
# The following function is somewhat equivalent to _new_shuffle_tensor_slice below,
# but avoids a full in-place random permutation generation.
def _shuffle_tensor_slice(
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
) -> np.ndarray:
stop = len(tensor)
count = stop // step
drop_count = stop - step * count
if drop_count:
warnings.warn(f"# of dropped samples: {drop_count}")
dtype = _get_numpy_dtype(stop)
result = np.empty(count, dtype=dtype)
for i in range(count):
j = (
torch.randint(0, i + 1, size=(1,), generator=generator).item()
if i > 0
else 0
)
result[i] = result[j]
result[j] = tensor[start + i * step].item()
return result
def _new_shuffle_tensor_slice(
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
) -> np.ndarray:
stop = len(tensor)
count = stop // step
dtype = torch.int64 # Needed for using randperm result as indices
count = stop // step
drop_count = stop - step * count
if drop_count:
warnings.warn(f"# of dropped samples: {drop_count}")
indices = torch.randperm(count, dtype=dtype, generator=generator)
return tensor[start::step][indices].numpy()
def _make_seed(seed: int, start: int, iter_count: int) -> int:
# NOTE: Tried a few variants (including iter_count << 32), this one worked best.
return seed + start + (iter_count << 24)
class ShardedInfiniteSampler(Sampler):
def __init__(
self,
*,
sample_count: int,
shuffle: bool = False,
seed: int = 0,
start: Optional[int] = None,
step: Optional[int] = None,
advance: int = 0,
use_new_shuffle_tensor_slice: bool = False,
):
self._sample_count = sample_count
self._seed = seed
self._shuffle = shuffle
rank, world_size = get_dist_info()
self._start = rank if start is None else start
self._step = world_size if step is None else step
self._advance = advance
self._iter_count = 0
self._shuffle_tensor_slice_fn = (
_new_shuffle_tensor_slice
if use_new_shuffle_tensor_slice
else _shuffle_tensor_slice
)
def __iter__(self):
iter_count = self._advance // self._sample_count
if iter_count > 0:
self._advance -= iter_count * self._sample_count
self._iter_count += iter_count
if self._shuffle:
iterator = self._shuffled_iterator()
else:
iterator = self._iterator()
yield from itertools.islice(iterator, self._advance, None)
def _iterator(self):
assert not self._shuffle
while True:
iterable = range(self._sample_count)
yield from itertools.islice(iterable, self._start, None, self._step)
def _shuffled_iterator(self):
assert self._shuffle
# Instantiate a generator here (rather than in the ctor) to be keep the class
# picklable (requirement of mp.spawn)
generator = torch.Generator()
# Always shuffle everything first
generator.manual_seed(self._seed)
dtype = _get_torch_dtype(self._sample_count)
perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator)
while True:
# Re-seed on each iteration to allow skipping whole permutations
seed = _make_seed(self._seed, self._start, self._iter_count)
generator.manual_seed(seed)
iterable = self._shuffle_tensor_slice_fn(
tensor=perm, start=self._start, step=self._step, generator=generator
)
yield from iterable
self._iter_count += 1
class DistributedSamplerNoDuplicate(torch.utils.data.DistributedSampler):
"""A distributed sampler that doesn't add duplicates. Arguments are the same as DistributedSampler"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self.drop_last and len(self.dataset) % self.num_replicas != 0:
# some ranks may have less samples, that's fine
if self.rank >= len(self.dataset) % self.num_replicas:
self.num_samples -= 1
self.total_size = len(self.dataset)
class DatasetFromSampler(torch.utils.data.Dataset):
"""Dataset to create indexes from `Sampler`.
Args:
sampler: PyTorch sampler
"""
def __init__(self, sampler: Sampler):
"""Initialisation for DatasetFromSampler."""
self.sampler = sampler
self.sampler_list = None
def __getitem__(self, index: int):
"""Gets element of the dataset.
Args:
index: index of the element in the dataset
Returns:
Single element by index
"""
if self.sampler_list is None:
self.sampler_list = list(self.sampler)
return self.sampler_list[index]
def __len__(self) -> int:
"""
Returns:
int: length of the dataset
"""
return len(self.sampler)
class DistributedSamplerWrapper(torch.utils.data.DistributedSampler):
"""
Wrapper over `Sampler` for distributed training
Allows you to use any sampler in distributed mode.
It is especially useful in conjunction with
`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSamplerWrapper instance as a DataLoader
sampler, and load a subset of subsampled data of the original dataset
that is exclusive to it.
.. note::
Sampler is assumed to be of constant size.
"""
def __init__(
self,
sampler,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
):
"""
Args:
sampler: Sampler used for subsampling
num_replicas (int, optional): Number of processes participating in
distributed training
rank (int, optional): Rank of the current process
within ``num_replicas``
shuffle (bool, optional): If true (default),
sampler will shuffle the indices
"""
super(DistributedSamplerWrapper, self).__init__(
DatasetFromSampler(sampler),
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
)
self.sampler = sampler
def __iter__(self):
self.dataset = DatasetFromSampler(self.sampler)
indexes_of_indexes = super().__iter__()
subsampler_indexes = self.dataset
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/samplers.py
================================================
import torch
class DistributedSamplerNoDuplicate(torch.utils.data.DistributedSampler):
"""A distributed sampler that doesn't add duplicates. Arguments are the same as DistributedSampler"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self.drop_last and len(self.dataset) % self.num_replicas != 0:
# some ranks may have less samples, that's fine
if self.rank >= len(self.dataset) % self.num_replicas:
self.num_samples -= 1
self.total_size = len(self.dataset)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/scannet.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class ScanNet(SequenceDataset):
min_depth = 0.005
max_depth = 10.0
depth_scale = 1000.0
test_split = "test.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["ScanNetS.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/scannetpp.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class ScanNetpp(SequenceDataset):
min_depth = 0.001
max_depth = 10.0
depth_scale = 1000.0
test_split = "val_iphone.txt"
train_split = "train_iphone.txt"
sequences_file = "sequences_iphone_clean.json"
hdf5_paths = [f"ScanNetpp_viz.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
class ScanNetpp_F(SequenceDataset):
min_depth = 0.001
max_depth = 10.0
depth_scale = 1000.0
train_split = "train.txt"
test_split = "val_split.txt"
sequences_file = "sequences_split.json"
hdf5_paths = [f"ScanNetpp_F.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["camera_params", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=(
decode_fields if not test_mode else [*decode_fields, "points"]
),
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/sequence_dataset.py
================================================
import json
import os
from functools import partial
from typing import Any, Dict, Tuple
import h5py
import numpy as np
import tables
import torch
import torchvision.transforms.v2.functional as TF
from unidepth.datasets.base_dataset import BaseDataset
from unidepth.datasets.utils import DatasetFromList
from unidepth.datasets.utils_decode import (decode_camera, decode_depth,
decode_flow, decode_K, decode_mask,
decode_numpy, decode_rgb,
decode_tensor)
from unidepth.utils.distributed import is_main_process
class SequenceDataset(BaseDataset):
DECODE_FNS = {
"image": partial(decode_rgb, name="image"),
"points": partial(decode_numpy, name="points"),
"K": partial(decode_K, name="camera"),
"camera_params": partial(decode_camera, name="camera"),
"cam2w": partial(decode_tensor, name="cam2w"),
"depth": partial(decode_depth, name="depth"),
"flow_fwd": partial(decode_flow, name="flow_fwd"),
"flow_bwd": partial(decode_flow, name="flow_bwd"),
"flow_fwd_mask": partial(decode_mask, name="flow_fwd_mask"),
"flow_bwd_mask": partial(decode_mask, name="flow_bwd_mask"),
}
default_fps = 5
def __init__(
self,
image_shape: Tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: Dict[str, Any],
resize_method: str,
mini: float,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth", "flow_fwd", "flow_fwd_mask"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.num_frames = num_frames
self.original_num_frames = num_frames
self.decode_fields = decode_fields
self.inplace_fields = inplace_fields
self.fps = self.default_fps
self.fps_range = kwargs.get("fps_range", None)
if self.fps_range is not None:
self.fps_range[1] = min(self.default_fps, self.fps_range[1])
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii").strip()
sequences = np.array(h5file[self.sequences_file]).tostring().decode("ascii")
sequences = json.loads(sequences)
h5file.close()
dataset = []
for line in txt_string.split("\n"):
if len(line.strip().split(" ")) == 1:
print(line)
continue
sequence_name, num_samples = line.strip().split(" ")
dataset.append(
{
"sequence_name": sequence_name,
"num_samples": int(num_samples),
"chunk_idx": 0,
}
)
# filter dataset based on attr "invalid_sequences"
invalid_sequences = getattr(self, "invalid_sequences", [])
dataset = [
sample
for sample in dataset
if sample["sequence_name"] not in invalid_sequences
]
self.dataset = DatasetFromList(dataset)
self.sequences = DatasetFromList(
[sequences[sample["sequence_name"]] for sample in dataset]
)
self.log_load_dataset()
def get_random_idxs(self, num_samples_sequence):
if self.num_frames == 1:
return [np.random.randint(0, num_samples_sequence)], 0
# Check if we can satisfy the required number of frames
if self.num_frames > num_samples_sequence:
raise ValueError(
"Cannot sample more frames than available in the sequence."
)
# Restrict FPS range to be within default FPS
min_fps, max_fps = self.fps_range
max_fps = min(max_fps, self.default_fps)
if min_fps > self.default_fps:
sampled_fps = self.default_fps
else:
# Compute minimal viable FPS
min_required_fps = (
self.num_frames / num_samples_sequence
) * self.default_fps
min_fps = max(min_fps, min_required_fps)
# Sample an FPS from the viable range
sampled_fps = np.random.uniform(min_fps, max_fps)
# Compute the stride based on the sampled FPS
stride = self.default_fps / sampled_fps
max_start_index = num_samples_sequence - int(stride * (self.num_frames - 1))
# Ensure a valid starting position
if max_start_index <= 0:
raise ValueError(
"No valid start position allows sampling num_frames with the chosen FPS."
)
start_index = np.random.randint(0, max_start_index + 1)
# Compute indices based on the sampled FPS
indices = [int(start_index + i * stride) for i in range(self.num_frames)]
return indices, np.random.randint(0, len(indices))
def get_test_idxs(self, num_samples_sequence, keyframe_idx):
if self.num_frames == 1:
return [
keyframe_idx if keyframe_idx is not None else num_samples_sequence // 2
], 0
if self.num_frames == -1:
cap_idxs = min(32, num_samples_sequence) # CAP 32 images
idxs = list(
range(max(0, num_samples_sequence - cap_idxs), num_samples_sequence, 1)
)
return idxs, keyframe_idx
# pick closest keyframe_idx st they are around it or capped by the 0 and max num_samples_sequence
keyframe_idx = (
keyframe_idx if keyframe_idx is not None else num_samples_sequence - 1
)
excess_tail = 0 - min(0, keyframe_idx - self.num_frames // 2)
excess_head = (
max(num_samples_sequence, keyframe_idx + (self.num_frames - 1) // 2)
- num_samples_sequence
)
start = keyframe_idx - self.num_frames // 2 + excess_tail - excess_head
end = keyframe_idx + (self.num_frames - 1) // 2 + excess_head - excess_tail
idxs = list(range(start, 1 + end))
return idxs, idxs.index(keyframe_idx)
def get_single_sequence(self, idx):
self.num_frames = self.original_num_frames
# sequence_name = self.dataset[idx]["sequence_name"]
sample = self.sequences[idx]
chunk_idx = int(sample.get("chunk_idx", 0))
h5_path = os.path.join(self.data_root, self.hdf5_paths[chunk_idx])
num_samples_sequence = len(sample["image"])
if self.num_frames > 0 and num_samples_sequence < self.num_frames:
raise IndexError(f"Sequence {idx} has less than {self.num_frames} frames")
keyframe_idx = None
if not self.test_mode:
idxs, keyframe_idx = self.get_random_idxs(num_samples_sequence)
else:
idxs, keyframe_idx = self.get_test_idxs(
num_samples_sequence, sample.get("keyframe_idx", None)
)
self.num_frames = len(idxs)
results = {}
results = self.pre_pipeline(results)
results["sequence_fields"] = [(i, 0) for i in range(self.num_frames)]
results["keyframe_idx"] = keyframe_idx
with tables.File(
h5_path,
mode="r",
libver="latest",
swmr=True,
) as h5file_chunk:
for i, j in enumerate(idxs):
results[(i, 0)] = {
k: v.copy() for k, v in results.items() if "fields" in k
}
for inplace_field in self.inplace_fields:
inplace_field_ = inplace_field.replace("intrinsics", "K").replace(
"extrinsics", "cam2w"
)
results = self.DECODE_FNS[inplace_field_](
results, sample[inplace_field][j], idx=i, sample=sample, j=j
)
for i, j in enumerate(idxs):
for decode_field in self.decode_fields:
results = self.DECODE_FNS[decode_field](
results,
h5file_chunk,
sample[decode_field][j],
idx=i,
depth_scale=self.depth_scale,
)
results["filename"] = sample["image"][j]
results = self.preprocess(results)
if not self.test_mode:
results = self.augment(results)
results = self.postprocess(results)
return results
def preprocess(self, results):
results = self.replicate(results)
for i, seq in enumerate(results["sequence_fields"]):
results[seq] = self.resizer(results[seq])
self.resizer.ctx = None if self.num_copies > 1 else self.resizer.ctx
num_pts = torch.count_nonzero(results[seq]["depth"] > 0)
if num_pts < 50:
raise IndexError(f"Too few points in depth map ({num_pts})")
for key in results[seq].get("image_fields", ["image"]):
results[seq][key] = results[seq][key].to(torch.float32) / 255
# update fields common in sequence
for key in [
"image_fields",
"gt_fields",
"mask_fields",
"camera_fields",
]:
if key in results[(0, 0)]:
results[key] = results[(0, 0)][key]
results = self.pack_batch(results)
return results
def postprocess(self, results):
# # normalize after because color aug requires [0,255]?
for key in results.get("image_fields", ["image"]):
results[key] = TF.normalize(results[key], **self.normalization_stats)
results = self.filler(results)
results = self.unpack_batch(results)
results = self.masker(results)
results = self.collecter(results)
return results
def __getitem__(self, idx):
try:
if isinstance(idx, (list, tuple)):
results = [self.get_single_sequence(i) for i in idx]
else:
results = self.get_single_sequence(idx)
except Exception as e:
print(f"Error loading sequence {idx} for {self.__class__.__name__}: {e}")
idx = np.random.randint(0, len(self.dataset))
results = self[idx]
return results
def log_load_dataset(self):
if is_main_process():
info = f"Loaded {self.__class__.__name__} with {sum([len(x['image']) for x in self.sequences])} images in {len(self)} sequences."
print(info)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/sintel copy.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class Sintel(SequenceDataset):
min_depth = 0.001
max_depth = 1000.0
depth_scale = 1000.0
test_split = "training.txt"
train_split = "training.txt"
sequences_file = "sequences.json"
hdf5_paths = ["Sintel.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth", "flow_fwd", "flow_fwd_mask"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/sintel.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class Sintel(SequenceDataset):
min_depth = 0.001
max_depth = 1000.0
depth_scale = 1000.0
test_split = "training.txt"
train_split = "training.txt"
sequences_file = "sequences.json"
hdf5_paths = ["Sintel.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames
results["synthetic"] = [True] * self.num_frames
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/sunrgbd.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class SUNRGBD(ImageDataset):
min_depth = 0.005
max_depth = 8.0
depth_scale = 1000.0
test_split = "alltest.txt"
train_split = "alltrain.txt"
intrisics_file = "intrinsics.json"
hdf5_paths = ["SUNRGB.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.crop = crop
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
h5file.close()
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val]
dataset.append(sample)
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/synscapes.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class Synscapes(SequenceDataset):
min_depth = 0.1
max_depth = 1000.0
depth_scale = 256.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"Synscapes.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/tartanair.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class TartanAir(SequenceDataset):
min_depth = 0.01
max_depth = 512.0
depth_scale = 1000.0
default_fps = 15
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["TartanAir.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/taskonomy.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class Taskonomy(ImageDataset):
min_depth = 0.005
max_depth = 15.0
depth_scale = 512.0
test_split = "val.txt"
train_split = "train_clean.txt"
intrisics_file = "intrinsics.json"
hdf5_paths = ["Taskonomy.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
# with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f:
# f.write(txt_string)
# with open(os.path.join(os.environ["TMPDIR"], self.intrisics_file), "w") as f:
# json.dump(intrinsics, f)
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename, chunk_idx = line.strip().split(" ")
intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val, chunk_idx]
dataset.append(sample)
h5file.close()
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
if self.test_mode and not self.benchmark:
dataset = self.chunk(dataset, chunk_dim=1, pct=0.01)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def get_mapper(self):
return {
"image_filename": 0,
"depth_filename": 1,
"K": 2,
}
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["quality"] = [2] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/tat_rmvd.py
================================================
import json
import os
from copy import deepcopy
from typing import Any
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.pipelines import AnnotationMask, KittiCrop
from unidepth.datasets.sequence_dataset import SequenceDataset
from unidepth.datasets.utils import DatasetFromList
from unidepth.utils import identity
class TATRMVD(SequenceDataset):
min_depth = 0.001
max_depth = 50.0
depth_scale = 1000.0
default_fps = 6
test_split = "test.txt"
train_split = "test.txt"
sequences_file = "sequences.json"
hdf5_paths = ["tanks_and_temples_rmvd.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
augmentations_db={},
normalize=True,
resize_method="hard",
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [False] * self.num_frames * self.num_copies
results["si"] = [True] * self.num_frames * self.num_copies
results["quality"] = [2] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/theo.py
================================================
from typing import Any
import torch
from unidepth.datasets.sequence_dataset import SequenceDataset
class Theo(SequenceDataset):
min_depth = 0.01
max_depth = 10.0
depth_scale = 1000.0
default_fps = 5
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["THEO.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["camera_params", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def preprocess(self, results):
self.resizer.ctx = None
for i, seq in enumerate(results["sequence_fields"]):
# Create a mask where the distance from the center is less than H/2
H, W = results[seq]["image"].shape[-2:]
x = torch.linspace(-(W - 1) / 2, (W - 1) / 2, W)
y = torch.linspace(-(H - 1) / 2, (H - 1) / 2, H)
xv, yv = torch.meshgrid(x, y, indexing="xy")
distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W)
results[seq]["validity_mask"] = distance_from_center < (H - 1) / 2
return super().preprocess(results)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/unrealstereo4k.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class UnrealStereo4K(SequenceDataset):
min_depth = 0.01
max_depth = 200.0
depth_scale = 1000.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"UnrealStereo4K.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/urbansyn.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class UrbanSyn(SequenceDataset):
min_depth = 0.1
max_depth = 1000.0
depth_scale = 256.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"UrbanSyn.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/utils.py
================================================
import copy
import multiprocessing as mp
import pickle
from collections import defaultdict
from typing import Any, Dict, List
import numpy as np
import torch
import torch.utils.data
from unidepth.utils.distributed import (all_gather, get_local_rank,
get_local_size, get_rank,
get_world_size)
class ConcatDataset(torch.utils.data.ConcatDataset):
def __init__(self, datasets, shape_constraints: dict[str, list[int]] = {}):
super().__init__(datasets)
self.sample = shape_constraints["sample"]
self.shape_mult = shape_constraints["shape_mult"]
self.ratio_bounds = shape_constraints["ratio_bounds"]
self.pixels_max = shape_constraints["pixels_max"]
self.pixels_min = shape_constraints["pixels_min"]
self.height_min = shape_constraints["height_min"]
self.width_min = shape_constraints["width_min"]
def sample_shape(self):
if not self.sample:
return
# 1: sample image ratio
ratio = np.random.uniform(*self.ratio_bounds)
pixels_min = self.pixels_min // (self.shape_mult * self.shape_mult)
pixels_max = self.pixels_max // (self.shape_mult * self.shape_mult)
# 2: sample image height or width, if ratio > 1 or < 1
if ratio > 1:
height_min = max(self.height_min, np.sqrt(pixels_min / ratio))
height = np.random.uniform(height_min, np.sqrt(pixels_max / ratio))
width = height * ratio
else:
width_min = max(self.width_min, np.sqrt(pixels_min * ratio))
width = np.random.uniform(width_min, np.sqrt(pixels_max * ratio))
height = width / ratio
# 3: get final shape based on the shape_mult
shape = [int(height) * self.shape_mult, int(width) * self.shape_mult]
for dataset in self.datasets:
setattr(dataset, "image_shape", shape)
setattr(dataset.resizer, "image_shape", shape)
def __getitem__(self, idxs):
self.sample_shape()
return [super(ConcatDataset, self).__getitem__(idx) for idx in idxs]
def _paddings(image_shape, network_shape):
cur_h, cur_w = image_shape
h, w = network_shape
pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2
pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2
return pad_left, pad_right, pad_top, pad_bottom
def collate_fn(in_data: List[List[Dict[str, Any]]], is_batched: bool = True):
out_data = defaultdict(list)
img_metas = []
in_data = in_data[0] if is_batched else in_data
# get max_shape and paddings
shapes = [tensor.shape[-2:] for x in in_data for tensor in x["depth"].values()]
max_shape_tuple = tuple(max(elements) for elements in zip(*shapes))
paddings = [
[
_paddings(tensor.shape[-2:], max_shape_tuple)
for tensor in x["depth"].values()
]
for x in in_data
]
for x in in_data: # here iter over batches
padding = paddings.pop(0)
for k, v in x.items():
if "img_metas" not in k:
values = list(v.values())
v = torch.cat(values)
out_data[k].append(v)
else:
v["depth_paddings"] = padding
img_metas.append(v)
output_dict = {
"data": {k: torch.stack(v, dim=0) for k, v in out_data.items()},
"img_metas": img_metas,
}
# camera are always flattened and the stack/cat so if list of B times (T, 3, 3) cameras
# it goes to (B * T, 3, 3), to be consistent with the image shape -> reshape
if "camera" in output_dict["data"]:
output_dict["data"]["camera"] = output_dict["data"]["camera"].reshape(
*output_dict["data"]["image"].shape[:2]
)
return output_dict
def local_scatter(array: list[Any]):
"""
Scatter an array from local leader to all local workers.
The i-th local worker gets array[i].
Args:
array: Array with same size of #local workers.
"""
if get_world_size() == 1:
return array[0]
if get_local_rank() == 0:
assert len(array) == get_local_size()
all_gather(array)
else:
all_data = all_gather(None)
array = all_data[get_rank() - get_local_rank()]
return array[get_local_rank()]
class DatasetFromList(torch.utils.data.Dataset): # type: ignore
"""Wrap a list to a torch Dataset.
We serialize and wrap big python objects in a torch.Dataset due to a
memory leak when dealing with large python objects using multiple workers.
See: https://github.com/pytorch/pytorch/issues/13246
"""
def __init__(self, lst: List[Any], deepcopy: bool = False, serialize: bool = True):
"""Creates an instance of the class.
Args:
lst: a list which contains elements to produce.
deepcopy: whether to deepcopy the element when producing it, s.t.
the result can be modified in place without affecting the source
in the list.
serialize: whether to hold memory using serialized objects. When
enabled, data loader workers can use shared RAM from master
process instead of making a copy.
"""
self._copy = deepcopy
self._serialize = serialize
def _serialize(data: Any):
buffer = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
return torch.frombuffer(buffer, dtype=torch.uint8)
if self._serialize:
# load only on 0th rank
if get_local_rank() == 0:
_lst = [_serialize(x) for x in lst]
self._addr = torch.cumsum(
torch.tensor([len(x) for x in _lst], dtype=torch.int64), dim=0
)
self._lst = torch.concatenate(_lst)
# Move data to shared memory, obtain a handle to send to each local worker.
handles = [None] + [
bytes(mp.reduction.ForkingPickler.dumps((self._addr, self._lst)))
for _ in range(get_local_size() - 1)
]
else:
handles = None
# Each worker receives the handle from local leader (rank 0)
# then materialize the tensor from shared memory
handle = local_scatter(handles)
if get_local_rank() > 0:
self._addr, self._lst = mp.reduction.ForkingPickler.loads(handle)
else:
self._lst = lst
def __len__(self) -> int:
"""Return len of list."""
if self._serialize:
return len(self._addr)
return len(self._lst)
def __getitem__(self, idx: int) -> Any:
"""Return item of list at idx."""
if self._serialize:
start_addr = 0 if idx == 0 else self._addr[idx - 1]
end_addr = self._addr[idx]
bytes_ = memoryview(self._lst[start_addr:end_addr].numpy())
return pickle.loads(bytes_)
if self._copy:
return copy.deepcopy(self._lst[idx])
return self._lst[idx]
def get_weights(
train_datasets: dict[str, torch.utils.data.Dataset], sampling: dict[str, float]
) -> torch.Tensor:
from .image_dataset import ImageDataset
from .sequence_dataset import SequenceDataset
weights = []
num_samples = 0
info_weights = {}
for dataset_name, dataset in train_datasets.items():
assert (
dataset_name in sampling
), f"Dataset {dataset_name} not found in {sampling.keys()}"
if isinstance(dataset, ImageDataset):
# sum of all samples has weight as in sampling s.t. sampling dataset in general is as in sampling
# inside is uniform
weight = sampling[dataset_name] / len(dataset)
weights.append(torch.full((len(dataset),), weight).double())
num_samples += len(dataset)
elif isinstance(dataset, SequenceDataset):
# local weight is num_samples, but global must be as in sampling
# hence is num_samples / (sum num_samples / sampling[dataset_name])
# s.t. sampling anything from the dataset is
# sum(num_samples / (sum num_samples / sampling[dataset_name]))
# -> sampling[dataset_name]
numerator = [int(data["num_samples"]) for data in dataset.dataset]
weights.append(
sampling[dataset_name]
* torch.tensor(numerator).double()
/ sum(numerator)
)
num_samples += sum(numerator)
else:
weight = sampling[dataset_name] / len(dataset)
weights.append(torch.full((len(dataset),), weight).double())
info_weights[dataset_name] = weights[-1][-1]
return torch.cat(weights), num_samples
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/utils_decode.py
================================================
import io
import cv2
import numpy as np
import torch
import torchvision
import torchvision.transforms.v2.functional as TF
from PIL import Image
from unidepth.utils.camera import (EUCM, MEI, BatchCamera, Fisheye624, Pinhole,
Spherical)
def decode_depth(results, h5file, value, idx, depth_scale, name="depth", **kwargs):
file = h5file.get_node("/" + value).read()
decoded_data = Image.open(io.BytesIO(file))
decoded_data = TF.pil_to_tensor(decoded_data).squeeze()
if decoded_data.ndim == 3: # 24 channel loading
decoded_channels = [
(decoded_data[0] & 0xFF).to(torch.int32),
(decoded_data[1] & 0xFF).to(torch.int32),
(decoded_data[2] & 0xFF).to(torch.int32),
]
# Reshape and extract the original depth map
decoded_data = (
decoded_channels[0]
| (decoded_channels[1] << 8)
| (decoded_channels[2] << 16)
)
decoded_data = decoded_data.to(torch.float32)
results.get("gt_fields", set()).add(name)
results[(idx, 0)].get("gt_fields", set()).add(name)
results[f"{name}_ori_shape"] = decoded_data.shape
results[(idx, 0)][name] = (
decoded_data.view(1, 1, *decoded_data.shape).contiguous() / depth_scale
)
return results
def decode_numpy(results, h5file, value, idx, name="points", **kwargs):
file = h5file.get_node("/" + value).read()
decoded_data = np.load(io.BytesIO(file), allow_pickle=False)
decoded_data = torch.from_numpy(decoded_data).to(torch.float32)
if decoded_data.ndim > 2:
decoded_data = decoded_data.permute(2, 0, 1)
results.get("gt_fields", set()).add(name)
results[(idx, 0)].get("gt_fields", set()).add(name)
results[(idx, 0)][name] = decoded_data.unsqueeze(0)
return results
def decode_tensor(results, value, idx, name, **kwargs):
results.get("camera_fields", set()).add(name)
results[(idx, 0)].get("camera_fields", set()).add(name)
results[(idx, 0)][name] = torch.tensor(value).unsqueeze(0)
return results
def decode_camera(results, value, idx, name, sample, j, **kwargs):
results.get("camera_fields", set()).add(name)
results[(idx, 0)].get("camera_fields", set()).add(name)
camera = eval(sample["camera_model"][j])(params=torch.tensor(value).unsqueeze(0))
results[(idx, 0)][name] = BatchCamera.from_camera(camera)
return results
def decode_K(results, value, idx, name, **kwargs):
results.get("camera_fields", set()).add(name)
results[(idx, 0)].get("camera_fields", set()).add(name)
camera = Pinhole(K=torch.tensor(value).unsqueeze(0))
results[(idx, 0)][name] = BatchCamera.from_camera(camera)
return results
def decode_mask(results, h5file, value, idx, name, **kwargs):
file = h5file.get_node("/" + value).read()
mask = torchvision.io.decode_image(torch.from_numpy(file)).bool().squeeze()
results.get("mask_fields", set()).add(name)
results[(idx, 0)].get("mask_fields", set()).add(name)
results[f"{name}_ori_shape"] = mask.shape[-2:]
results[(idx, 0)][name] = mask.view(1, 1, *mask.shape).contiguous()
return results
def decode_rgb(results, h5file, value, idx, name="image", **kwargs):
file = h5file.get_node("/" + value).read()
image = (
torchvision.io.decode_image(torch.from_numpy(file)).to(torch.uint8).squeeze()
)
results.get("image_fields", set()).add(name)
results[(idx, 0)].get("image_fields", set()).add(name)
results[f"{name}_ori_shape"] = image.shape[-2:]
if image.ndim == 2:
image = image.unsqueeze(0).repeat(3, 1, 1)
results[(idx, 0)][name] = image.unsqueeze(0)
return results
def decode_flow(results, h5file, value, idx, name, **kwargs):
file = h5file.get_node("/" + value).read()
image = (
torchvision.io.decode_image(torch.from_numpy(file)).to(torch.uint8).squeeze()
)
decoded_channels = [
(image[0] & 0xFF).to(torch.int16),
(image[1] & 0xFF).to(torch.int16),
(image[2] & 0xFF).to(torch.int16),
]
# Reshape and extract the original 2-channel flow map
flow = torch.zeros((2, image.shape[1], image.shape[2]), dtype=torch.int16)
flow[0] = (decoded_channels[0] | decoded_channels[1] << 8) & 0xFFF
flow[1] = (decoded_channels[1] >> 4 | decoded_channels[2] << 4) & 0xFFF
results.get("gt_fields", set()).add(name)
results[(idx, 0)].get("gt_fields", set()).add(name)
results[f"{name}_ori_shape"] = flow.shape[-2:]
flow = flow.unsqueeze(0).contiguous().float()
results[(idx, 0)][name] = (0.5 + flow) / 4095.0 * 2 - 1
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/vkitti.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class VKITTI(SequenceDataset):
min_depth = 0.01
max_depth = 255.0
depth_scale = 256.0
test_split = "training.txt"
train_split = "training.txt"
sequences_file = "sequences.json"
hdf5_paths = ["VKITTI2.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth", "flow_fwd", "flow_fwd_mask"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["synthetic"] = [True] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/void.py
================================================
import json
import os
import h5py
import numpy as np
import torch
from unidepth.datasets.image_dataset import ImageDataset
from unidepth.datasets.utils import DatasetFromList
class VOID(ImageDataset):
min_depth = 0.01
max_depth = 10.0
depth_scale = 256.0
test_split = "void_val.txt"
train_split = "void_train.txt"
intrisics_file = "void_intrinsics.json"
hdf5_paths = ["void.hdf5"]
def __init__(
self,
image_shape,
split_file,
test_mode,
crop=None,
benchmark=False,
augmentations_db={},
normalize=True,
resize_method="hard",
mini=1.0,
**kwargs,
):
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
**kwargs,
)
self.test_mode = test_mode
self.crop = crop
self.load_dataset()
def load_dataset(self):
h5file = h5py.File(
os.path.join(self.data_root, self.hdf5_paths[0]),
"r",
libver="latest",
swmr=True,
)
txt_file = np.array(h5file[self.split_file])
txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
intrinsics = json.loads(intrinsics)
h5file.close()
dataset = []
for line in txt_string.split("\n"):
image_filename, depth_filename = line.strip().split(" ")
intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
sample = [image_filename, depth_filename, intrinsics_val]
dataset.append(sample)
if not self.test_mode:
dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
self.dataset = DatasetFromList(dataset)
self.log_load_dataset()
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_copies
results["quality"] = [2] * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/waymo.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class Waymo(SequenceDataset):
min_depth = 0.05
max_depth = 70.0
depth_scale = 256.0
test_split = "validation.txt"
train_split = "training.txt"
sequences_file = "sequences.json"
hdf5_paths = [f"Waymo_viz.hdf5"]
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [False] * self.num_frames * self.num_copies
results["synthetic"] = [False] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/datasets/wildrgbd.py
================================================
from typing import Any
from unidepth.datasets.sequence_dataset import SequenceDataset
class WildRGBD(SequenceDataset):
min_depth = 0.01
max_depth = 10.0
depth_scale = 1000.0
test_split = "train.txt"
train_split = "train.txt"
sequences_file = "sequences.json"
hdf5_paths = ["WildRGBD.hdf5"]
default_fps = 30
def __init__(
self,
image_shape: tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: dict[str, Any],
resize_method: str,
mini: float = 1.0,
num_frames: int = 1,
benchmark: bool = False,
decode_fields: list[str] = ["image", "depth"],
inplace_fields: list[str] = ["K", "cam2w"],
**kwargs,
) -> None:
super().__init__(
image_shape=image_shape,
split_file=split_file,
test_mode=test_mode,
benchmark=benchmark,
normalize=normalize,
augmentations_db=augmentations_db,
resize_method=resize_method,
mini=mini,
num_frames=num_frames,
decode_fields=decode_fields,
inplace_fields=inplace_fields,
**kwargs,
)
def pre_pipeline(self, results):
results = super().pre_pipeline(results)
results["dense"] = [True] * self.num_frames * self.num_copies
results["quality"] = [1] * self.num_frames * self.num_copies
return results
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/__init__.py
================================================
from .activation import GEGLU, SwiGLU
from .attention import AttentionBlock, AttentionDecoderBlock, AttentionLayer
from .convnext import CvnxtBlock
from .mlp import MLP
from .nystrom_attention import NystromBlock
from .positional_encoding import PositionEmbeddingSine
from .upsample import (ConvUpsample, ConvUpsampleShuffle,
ConvUpsampleShuffleResidual, ResUpsampleBil)
__all__ = [
"SwiGLU",
"GEGLU",
"CvnxtBlock",
"AttentionBlock",
"NystromBlock",
"PositionEmbeddingSine",
"ConvUpsample",
"MLP",
"ConvUpsampleShuffle",
"AttentionDecoderBlock",
"ConvUpsampleShuffleResidual",
]
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/activation.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class SwiGLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gates = x.chunk(2, dim=-1)
return x * F.silu(gates)
class GEGLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gates = x.chunk(2, dim=-1)
return x * F.gelu(gates)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/attention.py
================================================
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .layer_scale import LayerScale
from .mlp import MLP
class SimpleAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 4,
dropout: float = 0.0,
cosine: bool = False,
context_dim: int | None = None,
):
super().__init__()
self.dropout = dropout
self.num_heads = num_heads
self.hidden_dim = dim
context_dim = context_dim or dim
self.kv = nn.Linear(context_dim, dim * 2, bias=False)
self.q = nn.Linear(dim, dim, bias=False)
self.norm_attnx = nn.LayerNorm(dim)
self.norm_attnctx = nn.LayerNorm(context_dim)
self.cosine = cosine
self.out = nn.Linear(dim, dim)
def forward(
self,
x: torch.Tensor,
attn_bias: torch.Tensor | None = None,
context: torch.Tensor | None = None,
pos_embed: torch.Tensor | None = None,
pos_embed_context: torch.Tensor | None = None,
rope: nn.Module | None = None,
) -> torch.Tensor:
context = x if context is None else context
x = self.norm_attnx(x)
context = self.norm_attnctx(context)
k, v = rearrange(
self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
).unbind(dim=-1)
q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads)
if rope is not None:
q = rope(q)
k = rope(k)
else:
if pos_embed is not None:
pos_embed = rearrange(
pos_embed, "b n (h d) -> b h n d", h=self.num_heads
)
q = q + pos_embed
if pos_embed_context is not None:
pos_embed_context = rearrange(
pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads
)
k = k + pos_embed_context
if self.cosine:
q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
)
x = rearrange(x, "b h n d -> b n (h d)")
x = self.out(x)
return x
class AttentionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 4,
expansion: int = 4,
dropout: float = 0.0,
cosine: bool = False,
gated: bool = False,
layer_scale: float = 1.0,
context_dim: int | None = None,
use_bias: bool = True,
):
super().__init__()
self.dropout = dropout
self.num_heads = num_heads
self.hidden_dim = dim
context_dim = context_dim or dim
self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
self.kv = nn.Linear(context_dim, dim * 2, bias=use_bias)
self.q = nn.Linear(dim, dim, bias=use_bias)
self.norm_attnx = nn.LayerNorm(dim)
self.norm_attnctx = nn.LayerNorm(context_dim)
self.cosine = cosine
self.out = nn.Linear(dim, dim, bias=use_bias)
self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
def attn(
self,
x: torch.Tensor,
attn_bias: torch.Tensor | None = None,
context: torch.Tensor | None = None,
pos_embed: torch.Tensor | None = None,
pos_embed_context: torch.Tensor | None = None,
) -> torch.Tensor:
x = self.norm_attnx(x)
context = self.norm_attnctx(context)
k, v = rearrange(
self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
).unbind(dim=-1)
q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads)
if pos_embed is not None:
pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=self.num_heads)
q = q + pos_embed
if pos_embed_context is not None:
pos_embed_context = rearrange(
pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads
)
k = k + pos_embed_context
if self.cosine:
q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
)
x = rearrange(x, "b h n d -> b n (h d)")
x = self.out(x)
return x
def forward(
self,
x: torch.Tensor,
attn_bias: torch.Tensor | None = None,
context: torch.Tensor | None = None,
pos_embed: torch.Tensor | None = None,
pos_embed_context: torch.Tensor | None = None,
) -> torch.Tensor:
context = x if context is None else context
x = (
self.ls1(
self.attn(
x,
attn_bias=attn_bias,
context=context,
pos_embed=pos_embed,
pos_embed_context=pos_embed_context,
)
)
+ x
)
x = self.ls2(self.mlp(x)) + x
return x
class AttentionLayer(nn.Module):
def __init__(
self,
num_blocks: int,
dim: int,
num_heads: int = 4,
expansion: int = 4,
dropout: float = 0.0,
cosine: bool = False,
gated: bool = False,
layer_scale: float = 1.0,
context_dim: int | None = None,
use_bias: bool = True,
):
super().__init__()
self.layers = nn.ModuleList(
[
AttentionBlock(
dim=dim,
num_heads=num_heads,
expansion=expansion,
dropout=dropout,
cosine=cosine,
gated=gated,
layer_scale=layer_scale,
context_dim=context_dim,
use_bias=use_bias,
)
for _ in range(num_blocks)
]
)
def forward(
self,
x: torch.Tensor,
context: torch.Tensor | None = None,
pos_embed: torch.Tensor | None = None,
pos_embed_context: torch.Tensor | None = None,
attn_bias: torch.Tensor | None = None,
) -> torch.Tensor:
for layer in self.layers:
x = layer(
x,
context=context,
pos_embed=pos_embed,
pos_embed_context=pos_embed_context,
attn_bias=attn_bias,
)
return x
class AttentionDecoderBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 4,
expansion: int = 4,
dropout: float = 0.0,
cosine: bool = False,
gated: bool = False,
layer_scale: float = 1.0,
context_dim: int | None = None,
single_head_ca: bool = True,
):
super().__init__()
self.dropout = dropout
self.num_heads = num_heads
self.hidden_dim = dim
self.single_head_ca = single_head_ca
context_dim = context_dim or dim
self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
self.kv_ca = nn.Linear(context_dim, dim * 2)
self.q_ca = nn.Linear(dim, dim)
self.kv_sa = nn.Linear(dim, dim * 2)
self.q_sa = nn.Linear(dim, dim)
self.norm_x_sa = nn.LayerNorm(dim)
self.norm_x_ca = nn.LayerNorm(dim)
self.norm_ctx_ca = nn.LayerNorm(context_dim)
self.cosine = cosine
self.out_ca = nn.Linear(dim, dim)
self.out_sa = nn.Linear(dim, dim)
self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
self.ls3 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
def cross_attn(
self,
x: torch.Tensor,
attn_bias: torch.Tensor | None = None,
context: torch.Tensor | None = None,
pos_embed: torch.Tensor | None = None,
pos_embed_context: torch.Tensor | None = None,
rope: nn.Module | None = None,
) -> torch.Tensor:
num_heads = 1 if self.single_head_ca else self.num_heads
x = self.norm_x_ca(x)
context = self.norm_ctx_ca(context)
k, v = rearrange(
self.kv_ca(context), "b n (kv h d) -> b h n d kv", h=num_heads, kv=2
).unbind(dim=-1)
q = rearrange(self.q_ca(x), "b n (h d) -> b h n d", h=num_heads)
if rope is not None:
q = rope(q)
k = rope(k)
else:
if pos_embed is not None:
pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=num_heads)
q = q + pos_embed
if pos_embed_context is not None:
pos_embed_context = rearrange(
pos_embed_context, "b n (h d) -> b h n d", h=num_heads
)
k = k + pos_embed_context
if self.cosine:
q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
)
x = rearrange(x, "b h n d -> b n (h d)")
x = self.out_ca(x)
return x
def self_attn(
self,
x: torch.Tensor,
attn_bias: torch.Tensor | None = None,
pos_embed: torch.Tensor | None = None,
rope: nn.Module | None = None,
) -> torch.Tensor:
x = self.norm_x_sa(x)
k, v = rearrange(
self.kv_sa(x), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
).unbind(dim=-1)
q = rearrange(self.q_sa(x), "b n (h d) -> b h n d", h=self.num_heads)
if rope is not None:
q = rope(q)
k = rope(k)
elif pos_embed is not None:
pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=self.num_heads)
q = q + pos_embed
if self.cosine:
q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
)
x = rearrange(x, "b h n d -> b n (h d)")
x = self.out_sa(x)
return x
def forward(
self,
x: torch.Tensor,
attn_bias: torch.Tensor | None = None,
context: torch.Tensor | None = None,
pos_embed: torch.Tensor | None = None,
pos_embed_context: torch.Tensor | None = None,
rope: nn.Module | None = None,
) -> torch.Tensor:
context = x if context is None else context
x = (
self.ls1(
self.cross_attn(
x,
rope=rope,
attn_bias=attn_bias,
context=context,
pos_embed=pos_embed,
pos_embed_context=pos_embed_context,
)
)
+ x
)
x = (
self.ls2(
self.self_attn(x, rope=rope, attn_bias=attn_bias, pos_embed=pos_embed)
)
+ x
)
x = self.ls3(self.mlp(x)) + x
return x
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/convnext.py
================================================
import torch
import torch.nn as nn
class CvnxtBlock(nn.Module):
def __init__(
self,
dim,
kernel_size=7,
layer_scale=1.0,
expansion=4,
dilation=1,
padding_mode: str = "zeros",
):
super().__init__()
self.dwconv = nn.Conv2d(
dim,
dim,
kernel_size=kernel_size,
padding=dilation * (kernel_size - 1) // 2,
groups=dim,
dilation=dilation,
padding_mode=padding_mode,
) # depthwise conv
self.norm = nn.LayerNorm(dim)
self.pwconv1 = nn.Linear(dim, expansion * dim)
self.act = nn.GELU()
self.pwconv2 = nn.Linear(expansion * dim, dim)
self.gamma = (
nn.Parameter(layer_scale * torch.ones((dim))) if layer_scale > 0.0 else 1.0
)
def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
x = self.gamma * x
x = input + x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
return x
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/drop_path.py
================================================
import torch
import torch.nn as nn
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0:
random_tensor.div_(keep_prob)
output = x * random_tensor
return output
class DropPath(nn.Module):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/layer_scale.py
================================================
import torch
import torch.nn as nn
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: float | torch.Tensor = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/mlp.py
================================================
import torch
import torch.nn as nn
from unidepth.utils.misc import default
from .activation import SwiGLU
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
expansion: int = 4,
dropout: float = 0.0,
gated: bool = False,
output_dim: int | None = None,
):
super().__init__()
if gated:
expansion = int(expansion * 2 / 3)
hidden_dim = int(input_dim * expansion)
output_dim = default(output_dim, input_dim)
self.norm = nn.LayerNorm(input_dim)
self.proj1 = nn.Linear(input_dim, hidden_dim)
self.proj2 = nn.Linear(hidden_dim, output_dim)
self.act = nn.GELU() if not gated else SwiGLU()
self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.norm(x)
x = self.proj1(x)
x = self.act(x)
x = self.proj2(x)
x = self.dropout(x)
return x
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/nystrom.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from xformers.components.attention import Attention, AttentionConfig, register_attention
from xformers.components.attention.core import (
scaled_dot_product_attention,
scaled_query_key_softmax,
)
from xformers.components.attention.utils import (
bool_mask_to_additive,
iterative_pinv,
reshape_key_padding_mask,
)
logger = logging.getLogger("xformers")
@dataclass
class NystromSelfAttentionConfig(AttentionConfig):
"""
num_heads Number of heads.
num_landmarks Number of landmarks to use for softmax approximation. 64 often sufficient for a good
approximation according to https://arxiv.org/pdf/2102.03902.pdf.
causal Apply a causal mask, in that the attention cannot be applied to the future.
use_razavi_pinverse If true, use iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose
inverse, otherwise use standard torch inverse.
pinverse_original_init True if using original initialization when calculating Moore-Penrose pseudo inverse using
method from (Razavi et al. 2014).
False if using exact coefficient computation (leads to faster convergence).
inv_iterations Number of iterations for calculating the Moore-Penrose pseudo inverse.
v_skip_connection A module that will take V as input and will be added as a skip connection to the
softmax approximation. A skip connection is added in the paper to help with training.
conv_kernel_size Kernel size for convolution optionally added to help in training.
If v_skip_connection is not specified, this will be used to define the default
depth wise convolution used as a skip connection.
If both conv_kernel_size and v_skip_connection are None, no skip connection will
be added.
landmark_pooling Which module to use when computing landmarks. Default is AdaptiveAvgPool2d.
"""
num_heads: int
num_landmarks: Optional[int]
landmark_pooling: Optional[nn.Module]
causal: Optional[bool]
pinverse_original_init: Optional[bool]
inv_iterations: Optional[int]
v_skip_connection: Optional[nn.Module]
conv_kernel_size: Optional[int]
use_razavi_pinverse: Optional[bool]
class AvgPool(nn.Module):
def __init__(self, n: int):
super().__init__()
self.n = n
def forward(self, x: torch.Tensor):
# Average independently for every segment in the sequence dimension
seq_len = x.shape[1]
head_dim = x.shape[2]
segments = seq_len // self.n
assert segments > 0, "num_landmarks should be smaller than the sequence length"
# Dimensions are a match
if seq_len % self.n == 0:
return x.reshape(
-1,
self.n,
segments,
head_dim,
).mean(dim=-2)
# Handle the last segment boundary being off
n_round = self.n - seq_len % self.n
x_avg_round = (
x[:, : n_round * segments, :]
.reshape(-1, n_round, segments, head_dim)
.mean(dim=-2)
)
x_avg_off = (
x[:, n_round * segments :, :]
.reshape(-1, self.n - n_round, segments + 1, head_dim)
.mean(dim=-2)
)
return torch.cat((x_avg_round, x_avg_off), dim=-2)
@register_attention("nystrom", NystromSelfAttentionConfig)
class NystromAttention(Attention):
# TODO: update defaults for use_razavi_pinverse and inv_iterations
def __init__(
self,
dropout: float,
num_heads: int,
num_landmarks: int = 64,
landmark_pooling: Optional[nn.Module] = None,
causal: bool = False,
use_razavi_pinverse: bool = True,
pinverse_original_init: bool = False,
inv_iterations: int = 6, # recommended default in paper was 6.
v_skip_connection: Optional[nn.Module] = None,
conv_kernel_size: Optional[int] = None,
*args,
**kwargs,
):
"""
Nystrom attention mechanism, from Nystromformer_.
::
"A Nystrom-based Algorithm for Approximating Self-Attention."
Xiong, Y., Zeng, Z., Chakraborty, R., Tan, M., Fung, G., Li, Y., Singh, V. (2021)
Reference codebase: https://github.com/mlpen/Nystromformer
.. _Nystromformer: https://arxiv.org/pdf/2102.03902.pdf
"""
super().__init__()
# merged key padding mask and attention mask is not accepted
self.requires_separate_masks = True
self.num_landmarks = num_landmarks
# TODO: should be able to not have to pass in num_heads
self.num_heads = num_heads
self.use_razavi_pinverse = use_razavi_pinverse
self.pinverse_original_init = pinverse_original_init
self.inv_iterations = inv_iterations
self.attn_drop = nn.Dropout(dropout)
self.skip_connection = v_skip_connection
self.causal = causal
if self.skip_connection is None and conv_kernel_size is not None:
self.skip_connection = nn.Conv2d(
in_channels=self.num_heads,
out_channels=self.num_heads,
kernel_size=(conv_kernel_size, 1),
padding=(conv_kernel_size // 2, 0),
bias=False,
groups=self.num_heads,
)
if landmark_pooling is not None:
self.landmark_pooling = landmark_pooling
else:
self.landmark_pooling = AvgPool(n=self.num_landmarks)
# Optional lower triangular masks for causal attention
self.causal_mask_1: Optional[torch.Tensor] = None
self.causal_mask_2: Optional[torch.Tensor] = None
self.causal_mask_3: Optional[torch.Tensor] = None
# This attention does not support attention masks
self.supports_attention_mask = False
self.supports_key_padding_mask = True
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
*args,
**kwargs,
):
r"""
key_padding_mask Only a key padding mask is accepted here. The size must be (batch size, sequence length) or
(batch size * num_heads, 1, sequence length). If dimensions are not correct, the mask will
be ignored. An additive mask is expected, meaning float values using "-inf" to mask values
"""
batched_dim = k.size(0)
seq_len = k.size(-2)
tt = {"dtype": q.dtype, "device": q.device}
if key_padding_mask is not None:
if key_padding_mask.dtype == torch.bool:
logger.warning(
"Bool mask found, but an additive mask is expected. Converting but this is slow"
)
key_padding_mask = bool_mask_to_additive(key_padding_mask)
if key_padding_mask.ndim == 2:
key_padding_mask = reshape_key_padding_mask(
key_padding_mask, batched_dim
)
zeros = torch.zeros_like(key_padding_mask)
ones = torch.ones_like(key_padding_mask)
is_masked = torch.isinf(-key_padding_mask)
# _mask takes 1 if the token is not padded, otherwise 0.
_mask = torch.where(is_masked, zeros, ones)
_mask = _mask.transpose(2, 1)
assert _mask.shape == (batched_dim, q.shape[1], 1)
# Mask q and k before pooling
# https://github.com/mlpen/Nystromformer/blob/main/code/attention_nystrom.py#L31
q = q * _mask
k = k * _mask
assert key_padding_mask.size() == (batched_dim, 1, seq_len), (
f"key_padding_mask has invalid dimensions {key_padding_mask.size()}."
f" Must have dimensions {batched_dim, 1, seq_len} or (batch_size, {seq_len})."
)
if self.num_landmarks >= seq_len:
mask: Optional[torch.Tensor] = None
if self.causal:
mask = self._triu_mask(batched_dim, seq_len, seq_len, **tt)
if key_padding_mask is not None:
mask = key_padding_mask if mask is None else mask + key_padding_mask
x = scaled_dot_product_attention(q=q, k=k, v=v, att_mask=mask)
else:
q_landmarks = self.landmark_pooling(q)
k_landmarks = self.landmark_pooling(k)
if self.causal and (
self.causal_mask_1 is None
or (batched_dim, seq_len, self.num_landmarks)
!= self.causal_mask_1.size()
):
self.causal_mask_1 = self._triu_mask(
batched_dim, seq_len, self.num_landmarks, **tt
)
self.causal_mask_2 = self._triu_mask(
batched_dim, self.num_landmarks, self.num_landmarks, **tt
)
self.causal_mask_3 = self._triu_mask(
batched_dim, self.num_landmarks, seq_len, **tt
)
mask_3: Optional[torch.Tensor] = self.causal_mask_3
if key_padding_mask is not None:
mask_3 = (
key_padding_mask if mask_3 is None else mask_3 + key_padding_mask
)
kernel_1 = scaled_query_key_softmax(q=q, k=k_landmarks, att_mask=None)
kernel_2 = scaled_query_key_softmax(
q=q_landmarks, k=k_landmarks, att_mask=None
)
kernel_3 = scaled_dot_product_attention(
q=q_landmarks, k=k, v=v, att_mask=mask_3
)
kernel_2_inv = (
iterative_pinv(
kernel_2, self.inv_iterations, self.pinverse_original_init
)
if self.use_razavi_pinverse
else torch.linalg.pinv(kernel_2)
)
x = torch.matmul(
torch.matmul(
kernel_1,
kernel_2_inv,
),
kernel_3,
)
if self.skip_connection:
# Assumption here is that v is 3D.
v_conv = self.skip_connection(
v.reshape(-1, self.num_heads, v.size(-2), v.size(-1))
)
x += v_conv.reshape(-1, v_conv.size(-2), v_conv.size(-1))
x = self.attn_drop(x)
return x
def _triu_mask(self, dim_1: int, dim_2: int, dim_3: int, **kwargs) -> torch.Tensor:
device = kwargs["device"]
dtype = kwargs["dtype"]
return torch.triu(
torch.ones(dim_2, dim_3, dtype=dtype, device=device) * float("-inf"),
diagonal=1,
).expand(
dim_1, -1, -1
) # micro optim, save memory on the batch dimension
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/nystrom_attention.py
================================================
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .nystrom import NystromAttention
from .attention import AttentionBlock
class NystromBlock(AttentionBlock):
def __init__(
self,
dim: int,
num_heads: int = 4,
expansion: int = 4,
dropout: float = 0.0,
cosine: bool = False,
gated: bool = False,
layer_scale: float = 1.0,
context_dim: int | None = None,
):
super().__init__(
dim=dim,
num_heads=num_heads,
expansion=expansion,
dropout=dropout,
cosine=cosine,
gated=gated,
layer_scale=layer_scale,
context_dim=context_dim,
)
self.attention_fn = NystromAttention(
num_landmarks=128, num_heads=num_heads, dropout=dropout
)
def attn(
self,
x: torch.Tensor,
attn_bias: torch.Tensor | None = None,
context: torch.Tensor | None = None,
pos_embed: torch.Tensor | None = None,
pos_embed_context: torch.Tensor | None = None,
rope: nn.Module | None = None,
) -> torch.Tensor:
x = self.norm_attnx(x)
context = self.norm_attnctx(context)
k, v = rearrange(
self.kv(context), "b n (kv h d) -> b n h d kv", h=self.num_heads, kv=2
).unbind(dim=-1)
q = rearrange(self.q(x), "b n (h d) -> b n h d", h=self.num_heads)
if rope is not None:
q = rope(q)
k = rope(k)
else:
if pos_embed is not None:
pos_embed = rearrange(
pos_embed, "b n (h d) -> b n h d", h=self.num_heads
)
q = q + pos_embed
if pos_embed_context is not None:
pos_embed_context = rearrange(
pos_embed_context, "b n (h d) -> b n h d", h=self.num_heads
)
k = k + pos_embed_context
if self.cosine:
q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
x = self.attention_fn(q, k, v, key_padding_mask=attn_bias)
x = rearrange(x, "b n h d -> b n (h d)")
x = self.out(x)
return x
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/positional_encoding.py
================================================
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
from math import pi
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange, repeat
class PositionEmbeddingSine(nn.Module):
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * pi
self.scale = scale
def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
if mask is None:
mask = torch.zeros(
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (
2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats
)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def __repr__(self, _repr_indent=4):
head = "Positional encoding " + self.__class__.__name__
body = [
"num_pos_feats: {}".format(self.num_pos_feats),
"temperature: {}".format(self.temperature),
"normalize: {}".format(self.normalize),
"scale: {}".format(self.scale),
]
# _repr_indent = 4
lines = [head] + [" " * _repr_indent + line for line in body]
return "\n".join(lines)
class LearnedSinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x):
x = rearrange(x, "b -> b 1")
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1)
return fouriered
def generate_fourier_features(x, max_freq=64, num_bands=16):
x = x.unsqueeze(-1)
device, dtype, orig_x = x.device, x.dtype, x
scales = torch.linspace(
-max_freq / 2, max_freq / 2, num_bands, device=device, dtype=dtype
)
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
x = x * scales * pi
x = torch.cat([x.sin(), x.cos()], dim=-1)
x = torch.cat((x, orig_x), dim=-1)
return x.flatten(-2)
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class VisionRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
pt_seq_len,
ft_seq_len=None,
custom_freqs=None,
freqs_for="lang",
theta=10000,
max_freq=10,
num_freqs=1,
):
super().__init__()
if custom_freqs:
freqs = custom_freqs
elif freqs_for == "lang":
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant":
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f"unknown modality {freqs_for}")
if ft_seq_len is None:
ft_seq_len = pt_seq_len
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
freqs_h = torch.einsum("..., f -> ... f", t, freqs)
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
freqs_w = torch.einsum("..., f -> ... f", t, freqs)
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
self.register_buffer("freqs_cos", freqs.cos())
self.register_buffer("freqs_sin", freqs.sin())
print("======== shape of rope freq", self.freqs_cos.shape, "========")
def forward(self, t, start_index=0):
rot_dim = self.freqs_cos.shape[-1]
end_index = start_index + rot_dim
assert (
rot_dim <= t.shape[-1]
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
t_left, t, t_right = (
t[..., :start_index],
t[..., start_index:end_index],
t[..., end_index:],
)
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
return torch.cat((t_left, t, t_right), dim=-1)
class VisionRotaryEmbeddingFast(nn.Module):
def __init__(
self,
dim,
pt_seq_len,
ft_seq_len=None,
custom_freqs=None,
freqs_for="lang",
theta=10000,
max_freq=10,
num_freqs=1,
):
super().__init__()
if custom_freqs:
freqs = custom_freqs
elif freqs_for == "lang":
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant":
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f"unknown modality {freqs_for}")
if ft_seq_len is None:
ft_seq_len = pt_seq_len
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
freqs = torch.einsum("..., f -> ... f", t, freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
self.register_buffer("freqs_cos", freqs_cos)
self.register_buffer("freqs_sin", freqs_sin)
def forward(self, t):
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/layers/upsample.py
================================================
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import torch
import torch.nn as nn
from einops import rearrange
from .convnext import CvnxtBlock
class ConvUpsample(nn.Module):
def __init__(
self,
hidden_dim,
num_layers: int = 2,
expansion: int = 4,
layer_scale: float = 1.0,
kernel_size: int = 7,
**kwargs,
):
super().__init__()
self.convs = nn.ModuleList([])
for _ in range(num_layers):
self.convs.append(
CvnxtBlock(
hidden_dim,
kernel_size=kernel_size,
expansion=expansion,
layer_scale=layer_scale,
)
)
self.up = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1, padding=0),
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(hidden_dim // 2, hidden_dim // 2, kernel_size=3, padding=1),
)
def forward(self, x: torch.Tensor):
for conv in self.convs:
x = conv(x)
x = self.up(x)
x = rearrange(x, "b c h w -> b (h w) c")
return x
class ConvUpsampleShuffle(nn.Module):
def __init__(
self,
hidden_dim,
num_layers: int = 2,
expansion: int = 4,
layer_scale: float = 1.0,
kernel_size: int = 7,
**kwargs,
):
super().__init__()
self.convs = nn.ModuleList([])
for _ in range(num_layers):
self.convs.append(
CvnxtBlock(
hidden_dim,
kernel_size=kernel_size,
expansion=expansion,
layer_scale=layer_scale,
)
)
self.up = nn.Sequential(
nn.PixelShuffle(2),
nn.Conv2d(hidden_dim // 4, hidden_dim // 2, kernel_size=3, padding=1),
)
def forward(self, x: torch.Tensor):
for conv in self.convs:
x = conv(x)
x = self.up(x)
x = rearrange(x, "b c h w -> b (h w) c")
return x
class ConvUpsampleShuffleResidual(nn.Module):
def __init__(
self,
hidden_dim,
num_layers: int = 2,
expansion: int = 4,
layer_scale: float = 1.0,
kernel_size: int = 7,
padding_mode: str = "zeros",
**kwargs,
):
super().__init__()
self.convs = nn.ModuleList([])
for _ in range(num_layers):
self.convs.append(
CvnxtBlock(
hidden_dim,
kernel_size=kernel_size,
expansion=expansion,
layer_scale=layer_scale,
padding_mode=padding_mode,
)
)
self.up = nn.Sequential(
nn.PixelShuffle(2),
nn.Conv2d(
hidden_dim // 4,
hidden_dim // 4,
kernel_size=7,
padding=3,
padding_mode=padding_mode,
groups=hidden_dim // 4,
),
nn.ReLU(),
nn.Conv2d(
hidden_dim // 4,
hidden_dim // 2,
kernel_size=3,
padding=1,
padding_mode=padding_mode,
),
)
self.residual = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1, padding=0),
nn.UpsamplingBilinear2d(scale_factor=2),
)
def forward(self, x: torch.Tensor):
for conv in self.convs:
x = conv(x)
x = self.up(x) + self.residual(x)
x = rearrange(x, "b c h w -> b (h w) c")
return x
class ResidualConvUnit(nn.Module):
def __init__(
self,
dim,
kernel_size: int = 3,
padding_mode: str = "zeros",
dilation: int = 1,
layer_scale: float = 1.0,
use_norm: bool = False,
):
super().__init__()
self.conv1 = nn.Conv2d(
dim,
dim,
kernel_size=kernel_size,
padding=dilation * (kernel_size - 1) // 2,
dilation=dilation,
padding_mode=padding_mode,
)
self.conv2 = nn.Conv2d(
dim,
dim,
kernel_size=kernel_size,
padding=dilation * (kernel_size - 1) // 2,
dilation=dilation,
padding_mode=padding_mode,
)
self.activation = nn.LeakyReLU()
self.gamma = (
nn.Parameter(layer_scale * torch.ones(1, dim, 1, 1))
if layer_scale > 0.0
else 1.0
)
self.norm1 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity()
self.norm2 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity()
def forward(self, x):
out = self.activation(x)
out = self.conv1(out)
out = self.norm1(out)
out = self.activation(out)
out = self.conv2(out)
out = self.norm2(out)
return self.gamma * out + x
class ResUpsampleBil(nn.Module):
def __init__(
self,
hidden_dim,
output_dim: int = None,
num_layers: int = 2,
kernel_size: int = 3,
layer_scale: float = 1.0,
padding_mode: str = "zeros",
use_norm: bool = False,
**kwargs,
):
super().__init__()
output_dim = output_dim if output_dim is not None else hidden_dim // 2
self.convs = nn.ModuleList([])
for _ in range(num_layers):
self.convs.append(
ResidualConvUnit(
hidden_dim,
kernel_size=kernel_size,
layer_scale=layer_scale,
padding_mode=padding_mode,
use_norm=use_norm,
)
)
self.up = nn.Sequential(
nn.Conv2d(
hidden_dim,
output_dim,
kernel_size=1,
padding=0,
padding_mode=padding_mode,
),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
)
def forward(self, x: torch.Tensor):
for conv in self.convs:
x = conv(x)
x = self.up(x)
return x
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/__init__.py
================================================
from .unidepthv1 import UniDepthV1
from .unidepthv2 import UniDepthV2, UniDepthV2old
__all__ = [
"UniDepthV1",
"UniDepthV2old",
"UniDepthV2",
]
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/__init__.py
================================================
from .convnext import ConvNeXt
from .convnext2 import ConvNeXtV2
from .dinov2 import _make_dinov2_model
__all__ = [
"ConvNeXt",
"ConvNeXtV2",
"_make_dinov2_model",
]
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/convnext.py
================================================
from collections import OrderedDict
from functools import partial
from typing import Callable, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
from timm.layers import (AvgPool2dSame, DropPath, GlobalResponseNormMlp,
LayerNorm, LayerNorm2d, Mlp, create_conv2d,
get_act_layer, make_divisible, to_ntuple,
trunc_normal_)
from torch.utils.checkpoint import checkpoint
def get_num_layer_for_convnext(var_name):
"""
Divide [3, 3, 27, 3] layers into 12 groups; each group is three
consecutive blocks, including possible neighboring downsample layers;
adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
"""
if var_name.startswith("downsample_layers"):
stage_id = int(var_name.split(".")[1])
if stage_id == 0:
layer_id = 0
elif stage_id == 1 or stage_id == 2:
layer_id = stage_id + 1
elif stage_id == 3:
layer_id = 12
elif var_name.startswith("stages"):
stage_id = int(var_name.split(".")[1])
block_id = int(var_name.split(".")[3])
if stage_id == 0 or stage_id == 1:
layer_id = stage_id + 1
elif stage_id == 2:
layer_id = 3 + block_id // 3
elif stage_id == 3:
layer_id = 12
elif var_name.startswith("stem"):
return 0
else:
layer_id = 12
return layer_id + 1
def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=None):
parameter_group_names = {}
parameter_group_vars = {}
skip = set()
if skip_list is not None:
skip = skip_list
if hasattr(model, "no_weight_decay"):
skip.update(model.no_weight_decay())
num_layers = 12
layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2))
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or name in skip:
group_name = "no_decay"
this_wd = 0.0
else:
group_name = "decay"
this_wd = wd
layer_id = get_num_layer_for_convnext(name)
group_name = "layer_%d_%s" % (layer_id, group_name)
if group_name not in parameter_group_names:
scale = layer_scale[layer_id]
cur_lr = lr * scale
parameter_group_names[group_name] = {
"weight_decay": this_wd,
"weight_decay_init": this_wd,
"weight_decay_base": this_wd,
"params": [],
"lr_init": cur_lr,
"lr_base": lr,
"lr": cur_lr,
}
parameter_group_vars[group_name] = {
"weight_decay": this_wd,
"weight_decay_init": this_wd,
"weight_decay_base": this_wd,
"params": [],
"lr_init": cur_lr,
"lr_base": lr,
"lr": cur_lr,
}
if this_wd == 0.0:
parameter_group_names[group_name]["weight_decay_final"] = 0.0
parameter_group_vars[group_name]["weight_decay_final"] = 0.0
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(name)
# from unidepth.utils import is_main_process
# import json
# if is_main_process():
# print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
return list(parameter_group_vars.values()), [
v["lr"] for k, v in parameter_group_vars.items()
]
class Downsample(nn.Module):
def __init__(self, in_chs, out_chs, stride=1, dilation=1):
super().__init__()
avg_stride = stride if dilation == 1 else 1
if stride > 1 or dilation > 1:
avg_pool_fn = (
AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
)
self.pool = avg_pool_fn(
2, avg_stride, ceil_mode=True, count_include_pad=False
)
else:
self.pool = nn.Identity()
if in_chs != out_chs:
self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
else:
self.conv = nn.Identity()
def forward(self, x):
x = self.pool(x)
x = self.conv(x)
return x
class ConvNeXtBlock(nn.Module):
"""ConvNeXt Block
There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
"""
def __init__(
self,
in_chs: int,
out_chs: Optional[int] = None,
kernel_size: int = 7,
stride: int = 1,
dilation: Union[int, Tuple[int, int]] = (1, 1),
mlp_ratio: float = 4,
conv_mlp: bool = False,
conv_bias: bool = True,
use_grn: bool = False,
ls_init_value: Optional[float] = 1e-6,
act_layer: Union[str, Callable] = "gelu",
norm_layer: Optional[Callable] = None,
drop_path: float = 0.0,
):
"""
Args:
in_chs: Block input channels.
out_chs: Block output channels (same as in_chs if None).
kernel_size: Depthwise convolution kernel size.
stride: Stride of depthwise convolution.
dilation: Tuple specifying input and output dilation of block.
mlp_ratio: MLP expansion ratio.
conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
conv_bias: Apply bias for all convolution (linear) layers.
use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
ls_init_value: Layer-scale init values, layer-scale applied if not None.
act_layer: Activation layer.
norm_layer: Normalization layer (defaults to LN if not specified).
drop_path: Stochastic depth probability.
"""
super().__init__()
out_chs = out_chs or in_chs
dilation = to_ntuple(2)(dilation)
act_layer = get_act_layer(act_layer)
if not norm_layer:
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
mlp_layer = partial(
GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp
)
self.use_conv_mlp = conv_mlp
self.conv_dw = create_conv2d(
in_chs,
out_chs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation[0],
depthwise=True,
bias=conv_bias,
)
self.norm = norm_layer(out_chs)
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
self.gamma = (
nn.Parameter(ls_init_value * torch.ones(out_chs))
if ls_init_value is not None
else None
)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = Downsample(
in_chs, out_chs, stride=stride, dilation=dilation[0]
)
else:
self.shortcut = nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
shortcut = x
x = self.conv_dw(x.contiguous())
if self.use_conv_mlp:
x = self.norm(x)
x = self.mlp(x)
else:
x = x.permute(0, 2, 3, 1).contiguous()
x = self.norm(x)
x = self.mlp(x)
x = x.permute(0, 3, 1, 2).contiguous()
if self.gamma is not None:
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
x = self.drop_path(x) + self.shortcut(shortcut)
return x.contiguous()
class ConvNeXtStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
kernel_size=7,
stride=2,
depth=2,
dilation=(1, 1),
drop_path_rates=None,
ls_init_value=1.0,
conv_mlp=False,
conv_bias=True,
use_grn=False,
act_layer="gelu",
norm_layer=None,
norm_layer_cl=None,
):
super().__init__()
self.grad_checkpointing = False
if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
pad = (
"same" if dilation[1] > 1 else 0
) # same padding needed if dilation used
self.downsample = nn.Sequential(
norm_layer(in_chs),
create_conv2d(
in_chs,
out_chs,
kernel_size=ds_ks,
stride=stride,
dilation=dilation[0],
padding=pad,
bias=conv_bias,
),
)
in_chs = out_chs
else:
self.downsample = nn.Identity()
drop_path_rates = drop_path_rates or [0.0] * depth
stage_blocks = []
for i in range(depth):
stage_blocks.append(
ConvNeXtBlock(
in_chs=in_chs,
out_chs=out_chs,
kernel_size=kernel_size,
dilation=dilation[1],
drop_path=drop_path_rates[i],
ls_init_value=ls_init_value,
conv_mlp=conv_mlp,
conv_bias=conv_bias,
use_grn=use_grn,
act_layer=act_layer,
norm_layer=norm_layer if conv_mlp else norm_layer_cl,
)
)
in_chs = out_chs
self.blocks = nn.ModuleList(stage_blocks)
def forward(self, x):
xs = []
x = self.downsample(x)
for block in self.blocks:
if self.grad_checkpointing:
x = checkpoint(block, x)
else:
x = block(x)
xs.append(x)
return xs
class ConvNeXt(nn.Module):
def __init__(
self,
in_chans: int = 3,
output_stride: int = 32,
depths: Tuple[int, ...] = (3, 3, 9, 3),
dims: Tuple[int, ...] = (96, 192, 384, 768),
kernel_sizes: Union[int, Tuple[int, ...]] = 7,
ls_init_value: Optional[float] = 1e-6,
stem_type: str = "patch",
patch_size: int = 4,
conv_mlp: bool = False,
conv_bias: bool = True,
use_grn: bool = False,
act_layer: Union[str, Callable] = "gelu",
norm_layer: Optional[Union[str, Callable]] = None,
norm_eps: Optional[float] = None,
drop_path_rate: float = 0.0,
output_idx=[],
use_checkpoint=False,
):
"""
Args:
in_chans: Number of input image channels.
num_classes: Number of classes for classification head.
global_pool: Global pooling type.
output_stride: Output stride of network, one of (8, 16, 32).
depths: Number of blocks at each stage.
dims: Feature dimension at each stage.
kernel_sizes: Depthwise convolution kernel-sizes for each stage.
ls_init_value: Init value for Layer Scale, disabled if None.
stem_type: Type of stem.
patch_size: Stem patch size for patch stem.
head_init_scale: Init scaling value for classifier weights and biases.
head_norm_first: Apply normalization before global pool + head.
head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
conv_bias: Use bias layers w/ all convolutions.
use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
act_layer: Activation layer type.
norm_layer: Normalization layer type.
drop_rate: Head pre-classifier dropout rate.
drop_path_rate: Stochastic depth drop rate.
"""
super().__init__()
self.num_layers = len(depths)
self.depths = output_idx
self.embed_dims = [
int(dim) for i, dim in enumerate(dims) for _ in range(depths[i])
]
self.embed_dim = dims[0]
assert output_stride in (8, 16, 32)
kernel_sizes = to_ntuple(4)(kernel_sizes)
if norm_layer is None:
norm_layer = LayerNorm2d
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
if norm_eps is not None:
norm_layer = partial(norm_layer, eps=norm_eps)
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
else:
assert (
conv_mlp
), "If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input"
norm_layer_cl = norm_layer
if norm_eps is not None:
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
self.feature_info = []
assert stem_type in ("patch", "overlap", "overlap_tiered")
if stem_type == "patch":
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
self.stem = nn.Sequential(
nn.Conv2d(
in_chans,
dims[0],
kernel_size=patch_size,
stride=patch_size,
bias=conv_bias,
),
norm_layer(dims[0]),
)
stem_stride = patch_size
else:
mid_chs = make_divisible(dims[0] // 2) if "tiered" in stem_type else dims[0]
self.stem = nn.Sequential(
nn.Conv2d(
in_chans,
mid_chs,
kernel_size=3,
stride=2,
padding=1,
bias=conv_bias,
),
nn.Conv2d(
mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias
),
norm_layer(dims[0]),
)
stem_stride = 4
self.stages = nn.Sequential()
dp_rates = [
x.tolist()
for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)
]
stages = []
prev_chs = dims[0]
curr_stride = stem_stride
dilation = 1
# 4 feature resolution stages, each consisting of multiple residual blocks
for i in range(4):
stride = 2 if curr_stride == 2 or i > 0 else 1
if curr_stride >= output_stride and stride > 1:
dilation *= stride
stride = 1
curr_stride *= stride
first_dilation = 1 if dilation in (1, 2) else 2
out_chs = dims[i]
stages.append(
ConvNeXtStage(
prev_chs,
out_chs,
kernel_size=kernel_sizes[i],
stride=stride,
dilation=(first_dilation, dilation),
depth=depths[i],
drop_path_rates=dp_rates[i],
ls_init_value=ls_init_value,
conv_mlp=conv_mlp,
conv_bias=conv_bias,
use_grn=use_grn,
act_layer=act_layer,
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl,
)
)
prev_chs = out_chs
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
self.feature_info += [
dict(num_chs=prev_chs, reduction=curr_stride, module=f"stages.{i}")
]
self.stages = nn.ModuleList(stages)
self.mask_token = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
self.num_features = prev_chs
self.apply(self._init_weights)
self.set_grad_checkpointing(use_checkpoint)
def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
nn.init.zeros_(module.bias)
def forward(self, x, masks=None):
outs = []
x = self.stem(x)
if masks is not None:
masks = torch.nn.functional.interpolate(
masks.float(), size=x.shape[-2:], mode="nearest"
)
x = torch.where(masks.bool(), self.mask_token.to(x.dtype), x).contiguous()
for stage in self.stages:
xs = stage(x)
outs.extend([x.permute(0, 2, 3, 1).contiguous() for x in xs])
x = xs[-1]
return outs, [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs]
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r"^stem",
blocks=(
r"^stages\.(\d+)"
if coarse
else [
(r"^stages\.(\d+)\.downsample", (0,)), # blocks
(r"^stages\.(\d+)\.blocks\.(\d+)", None),
(r"^norm_pre", (99999,)),
]
),
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
def freeze(self) -> None:
for module in self.modules():
module.eval()
for parameters in self.parameters():
parameters.requires_grad = False
def get_params(self, lr, wd, ld, *args, **kwargs):
encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
return encoder_p, encoder_lr
def no_weight_decay(self):
return {"mask_token"}
@classmethod
def build(cls, config):
obj = globals()[config["model"]["encoder"]["name"]](config)
return obj
def checkpoint_filter_fn(state_dict, model):
"""Remap FB checkpoints -> timm"""
if "head.norm.weight" in state_dict or "norm_pre.weight" in state_dict:
return state_dict # non-FB checkpoint
if "model" in state_dict:
state_dict = state_dict["model"]
out_dict = {}
if "visual.trunk.stem.0.weight" in state_dict:
out_dict = {
k.replace("visual.trunk.", ""): v
for k, v in state_dict.items()
if k.startswith("visual.trunk.")
}
if "visual.head.proj.weight" in state_dict:
out_dict["head.fc.weight"] = state_dict["visual.head.proj.weight"]
out_dict["head.fc.bias"] = torch.zeros(
state_dict["visual.head.proj.weight"].shape[0]
)
elif "visual.head.mlp.fc1.weight" in state_dict:
out_dict["head.pre_logits.fc.weight"] = state_dict[
"visual.head.mlp.fc1.weight"
]
out_dict["head.pre_logits.fc.bias"] = state_dict["visual.head.mlp.fc1.bias"]
out_dict["head.fc.weight"] = state_dict["visual.head.mlp.fc2.weight"]
out_dict["head.fc.bias"] = torch.zeros(
state_dict["visual.head.mlp.fc2.weight"].shape[0]
)
return out_dict
import re
for k, v in state_dict.items():
k = k.replace("downsample_layers.0.", "stem.")
k = re.sub(r"stages.([0-9]+).([0-9]+)", r"stages.\1.blocks.\2", k)
k = re.sub(
r"downsample_layers.([0-9]+).([0-9]+)", r"stages.\1.downsample.\2", k
)
k = k.replace("dwconv", "conv_dw")
k = k.replace("pwconv", "mlp.fc")
if "grn" in k:
k = k.replace("grn.beta", "mlp.grn.bias")
k = k.replace("grn.gamma", "mlp.grn.weight")
v = v.reshape(v.shape[-1])
k = k.replace("head.", "head.fc.")
if k.startswith("norm."):
k = k.replace("norm", "head.norm")
if v.ndim == 2 and "head" not in k:
model_shape = model.state_dict()[k].shape
v = v.reshape(model_shape)
out_dict[k] = v
return out_dict
HF_URL = {
"convnext_xxlarge_pt": (
"laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup",
"open_clip_pytorch_model.bin",
),
"convnext_large_pt": (
"laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup",
"open_clip_pytorch_model.bin",
),
"convnext_large": (
"timm/convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384",
"pytorch_model.bin",
),
}
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/convnext2.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath, trunc_normal_
def get_num_layer_for_convnext_single(var_name, depths):
"""
Each layer is assigned distinctive layer ids
"""
if var_name.startswith("downsample_layers"):
stage_id = int(var_name.split(".")[1])
layer_id = sum(depths[:stage_id]) + 1
return layer_id
elif var_name.startswith("stages"):
stage_id = int(var_name.split(".")[1])
block_id = int(var_name.split(".")[2])
layer_id = sum(depths[:stage_id]) + block_id + 1
return layer_id
else:
return sum(depths) + 1
def get_num_layer_for_convnext(var_name):
"""
Divide [3, 3, 27, 3] layers into 12 groups; each group is three
consecutive blocks, including possible neighboring downsample layers;
adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
"""
num_max_layer = 12
if var_name.startswith("downsample_layers"):
stage_id = int(var_name.split(".")[1])
if stage_id == 0:
layer_id = 0
elif stage_id == 1 or stage_id == 2:
layer_id = stage_id + 1
elif stage_id == 3:
layer_id = 12
return layer_id
elif var_name.startswith("stages"):
stage_id = int(var_name.split(".")[1])
block_id = int(var_name.split(".")[2])
if stage_id == 0 or stage_id == 1:
layer_id = stage_id + 1
elif stage_id == 2:
layer_id = 3 + block_id // 3
elif stage_id == 3:
layer_id = 12
return layer_id
else:
return num_max_layer + 1
def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()):
parameter_group_names = {}
parameter_group_vars = {}
skip = {}
if skip_list is not None:
skip = skip_list
elif hasattr(model, "no_weight_decay"):
skip = model.no_weight_decay()
num_layers = 12 # sum(model.depths)
layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2))
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if (
len(param.shape) == 1
or name.endswith(".bias")
or name in skip
or name.endswith(".gamma")
or name.endswith(".beta")
):
group_name = "no_decay"
this_weight_decay = 0.0
else:
group_name = "decay"
this_weight_decay = wd
# layer_id = get_num_layer_for_convnext_single(name, model.depths)
layer_id = get_num_layer_for_convnext(name)
group_name = "layer_%d_%s" % (layer_id, group_name)
if group_name not in parameter_group_names:
scale = layer_scale[layer_id]
cur_lr = lr * scale
parameter_group_names[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale,
"lr": cur_lr,
}
parameter_group_vars[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale,
"lr": cur_lr,
}
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(name)
# if is_main_process():
# print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
return list(parameter_group_vars.values()), [
v["lr"] for k, v in parameter_group_vars.items()
]
class LayerNorm(nn.Module):
"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class GRN(nn.Module):
"""GRN (Global Response Normalization) layer"""
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
class Block(nn.Module):
"""ConvNeXtV2 Block.
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
"""
def __init__(self, dim, drop_path=0.0, mult=4, use_checkpoint=False):
super().__init__()
self.dwconv = nn.Conv2d(
dim, dim, kernel_size=7, padding=3, groups=dim
) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, mult * dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(mult * dim)
self.pwconv2 = nn.Linear(mult * dim, dim)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.use_checkpoint = use_checkpoint
def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
class ConvNeXtV2(nn.Module):
"""ConvNeXt V2
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate (float): Stochastic depth rate. Default: 0.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(
self,
in_chans=3,
depths=[3, 3, 9, 3],
dims=96,
drop_path_rate=0.0,
output_idx=[],
use_checkpoint=False,
):
super().__init__()
self.num_layers = len(depths)
self.depths = output_idx
self.embed_dims = [
int(dim) for i, dim in enumerate(dims) for _ in range(depths[i])
]
self.embed_dim = dims[0]
self.downsample_layers = (
nn.ModuleList()
) # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
)
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
)
self.downsample_layers.append(downsample_layer)
self.stages = (
nn.ModuleList()
) # 4 feature resolution stages, each consisting of multiple residual blocks
self.out_norms = nn.ModuleList()
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(4):
stage = nn.ModuleList(
[
Block(
dim=dims[i],
drop_path=dp_rates[cur + j],
use_checkpoint=use_checkpoint,
)
for j in range(depths[i])
]
)
self.stages.append(stage)
cur += depths[i]
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, x):
outs = []
for i in range(4):
x = self.downsample_layers[i](x)
for stage in self.stages[i]:
x = stage(x)
outs.append(x.permute(0, 2, 3, 1))
cls_tokens = [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs]
return outs, cls_tokens
def get_params(self, lr, wd, ld, *args, **kwargs):
encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
return encoder_p, encoder_lr
def freeze(self) -> None:
for module in self.modules():
module.eval()
for parameters in self.parameters():
parameters.requires_grad = False
@classmethod
def build(cls, config):
obj = globals()[config["model"]["encoder"]["name"]](config)
return obj
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/dinov2.py
================================================
import contextlib
import logging
import math
from functools import partial
from typing import Callable, Sequence
import torch
import torch.nn as nn
from torch.nn.init import trunc_normal_
from .metadinov2 import Attention, MemEffAttention, Mlp
from .metadinov2 import NestedTensorBlock as Block
from .metadinov2 import PatchEmbed, SwiGLUFFNFused
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
logger = logging.getLogger("dinov2")
def named_apply(
fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = ".".join((name, child_name)) if name else child_name
named_apply(
fn=fn,
module=child_module,
name=child_name,
depth_first=depth_first,
include_root=True,
)
if depth_first and include_root:
fn(module=module, name=name)
return module
def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()):
parameter_group_names = {}
parameter_group_vars = {}
skip = {}
if skip_list is not None:
skip = skip_list
elif hasattr(model, "no_weight_decay"):
skip = model.no_weight_decay()
num_layers = model.n_blocks
layer_scale = list(ld ** (num_layers - i) for i in range(num_layers))
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if len(param.shape) == 1: # norm
group_name = "no_decay"
this_wd = 0.0
# layer scale, bias beta?
elif (
name in skip
or name.endswith(".gamma")
or name.endswith(".beta")
or name.endswith(".bias")
):
group_name = "no_decay"
this_wd = 0.0
elif "cls_token" in name or "pos_embed" in name or "mask_token" in name:
group_name = "no_decay"
this_wd = 0.0
else:
group_name = "decay"
this_wd = wd
if name.startswith("blocks"):
layer_id = int(name.split(".")[1])
elif name.startswith("patch_embed"):
layer_id = 0
else:
layer_id = 0
group_name = f"layer_{layer_id}_{group_name}"
if group_name not in parameter_group_names:
scale = layer_scale[layer_id]
cur_lr = lr * scale
parameter_group_names[group_name] = {
"weight_decay": this_wd,
"params": [],
"lr_init": cur_lr,
"lr_base": lr,
"lr": cur_lr,
}
parameter_group_vars[group_name] = {
"weight_decay": this_wd,
"params": [],
"lr_init": cur_lr,
"lr_base": lr,
"lr": cur_lr,
}
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(name)
return list(parameter_group_vars.values()), [
v["lr"] for k, v in parameter_group_vars.items()
]
class BlockChunk(nn.ModuleList):
def forward(self, x):
for b in self:
x = b(x)
return x
class DinoVisionTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
ffn_bias=True,
proj_bias=True,
drop_path_rate=0.0,
drop_path_uniform=False,
init_values=None, # for layerscale: None or 0 => no layerscale
embed_layer=PatchEmbed,
act_layer=nn.GELU,
block_fn=Block,
ffn_layer="mlp",
block_chunks=1,
output_idx=[5, 12, 18, 24],
checkpoint: bool = False,
num_register_tokens=0,
interpolate_antialias=False,
interpolate_offset=0.0,
use_norm=False,
frozen_stages=0,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
proj_bias (bool): enable bias for proj in attn if True
ffn_bias (bool): enable bias for ffn if True
drop_path_rate (float): stochastic depth rate
drop_path_uniform (bool): apply uniform drop rate across blocks
weight_init (str): weight init scheme
init_values (float): layer-scale init values
embed_layer (nn.Module): patch embedding layer
act_layer (nn.Module): MLP activation layer
block_fn (nn.Module): transformer block class
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
"""
super().__init__()
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.num_features = self.embed_dim = (
embed_dim # num_features for consistency with other models
)
self.frozen_stages = frozen_stages
self.embed_dims = [embed_dim] * output_idx[-1]
self.num_tokens = 1
self.n_blocks = depth
self.num_heads = num_heads
self.patch_size = patch_size
self.depths = output_idx
self.checkpoint = checkpoint
self.num_register_tokens = num_register_tokens
self.interpolate_antialias = interpolate_antialias
self.interpolate_offset = interpolate_offset
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + self.num_tokens, embed_dim)
)
assert num_register_tokens >= 0
self.register_tokens = nn.Parameter(
torch.zeros(1, max(1, num_register_tokens), embed_dim)
)
if drop_path_uniform is True:
dpr = [drop_path_rate] * depth
else:
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
if ffn_layer == "mlp":
logger.info("using MLP layer as FFN")
ffn_layer = Mlp
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
logger.info("using SwiGLU layer as FFN")
ffn_layer = SwiGLUFFNFused
elif ffn_layer == "identity":
logger.info("using Identity layer as FFN")
def f(*args, **kwargs):
return nn.Identity()
ffn_layer = f
else:
raise NotImplementedError
blocks_list = [
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
ffn_layer=ffn_layer,
init_values=init_values,
)
for i in range(depth)
]
if block_chunks > 0:
self.chunked_blocks = True
chunked_blocks = []
chunksize = depth // block_chunks
for i in range(0, depth, chunksize):
# this is to keep the block index consistent if we chunk the block list
chunked_blocks.append(
[nn.Identity()] * i + blocks_list[i : i + chunksize]
)
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
else:
self.chunked_blocks = False
self.blocks = nn.ModuleList(blocks_list)
self.norm = nn.LayerNorm(embed_dim)
self.use_norm = use_norm
self.head = nn.Identity()
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
self.init_weights()
def init_weights(self):
trunc_normal_(self.pos_embed, std=0.02)
nn.init.normal_(self.cls_token, std=1e-6)
if self.num_register_tokens:
nn.init.normal_(self.register_tokens, std=1e-6)
named_apply(init_weights_vit_timm, self)
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
assert N == M * M
kwargs = {}
if self.interpolate_offset:
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
sx = float(w0 + self.interpolate_offset) / M
sy = float(h0 + self.interpolate_offset) / M
kwargs["scale_factor"] = (sx, sy)
else:
# Simply specify an output size instead of a scale factor
kwargs["size"] = (w0, h0)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
mode="bicubic",
antialias=self.interpolate_antialias,
**kwargs,
)
assert (w0, h0) == patch_pos_embed.shape[-2:]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
previous_dtype
)
def prepare_tokens_with_masks(self, x, masks=None):
B, nc, w, h = x.shape
with torch.no_grad() if self.frozen_stages > -1 else contextlib.nullcontext():
x = self.patch_embed(x)
if masks is not None:
masks = masks.bool().view(B, -1, 1)
x = torch.where(masks, self.mask_token.to(x.dtype).unsqueeze(0), x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h)
if self.num_register_tokens:
x = torch.cat(
(x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]),
dim=1,
)
return x
def forward(self, x, masks=None):
shapes = [val // self.patch_size for val in x.shape[-2:]]
batch_size = x.shape[0]
x = self.prepare_tokens_with_masks(x, masks)
outputs = []
for i, blk in enumerate(self.blocks):
with (
torch.no_grad() if i < self.frozen_stages else contextlib.nullcontext()
):
x = blk(x)
outputs.append(x)
if self.use_norm:
with (
torch.no_grad()
if self.frozen_stages >= len(self.blocks)
else contextlib.nullcontext()
):
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, :1] for out in outputs]
outputs = [out[:, self.num_register_tokens + 1 :] for out in outputs]
outputs = [out.reshape(batch_size, *shapes, -1) for out in outputs]
return (outputs, class_tokens)
def get_params(self, lr, wd, ld, *args, **kwargs):
encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
return encoder_p, encoder_lr
def freeze(self) -> None:
for module in self.modules():
module.eval()
for parameters in self.parameters():
parameters.requires_grad = False
def train(self, mode=True):
super().train(mode)
if self.frozen_stages > -1:
for p in self.patch_embed.parameters():
p.requires_grad = False
for i, blk in enumerate(self.blocks):
if i < self.frozen_stages:
blk.eval()
for p in blk.parameters():
p.requires_grad = False
for p in self.norm.parameters():
p.requires_grad = self.frozen_stages <= len(self.blocks) and self.use_norm
self.cls_token.requires_grad = self.frozen_stages < 1
self.pos_embed.requires_grad = self.frozen_stages < 1
self.mask_token.requires_grad = False
self.register_tokens.requires_grad = False
def init_weights_vit_timm(module: nn.Module, name: str = ""):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def vit_small(patch_size=16, num_register_tokens=0, export=False, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
num_register_tokens=num_register_tokens,
block_fn=partial(Block, attn_class=Attention if export else MemEffAttention),
**kwargs,
)
return model
def vit_base(patch_size=16, num_register_tokens=0, export=False, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
num_register_tokens=num_register_tokens,
block_fn=partial(Block, attn_class=Attention if export else MemEffAttention),
**kwargs,
)
return model
def vit_large(patch_size=16, num_register_tokens=0, export=False, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
num_register_tokens=num_register_tokens,
block_fn=partial(Block, attn_class=Attention if export else MemEffAttention),
**kwargs,
)
return model
def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
compact_arch_name = arch_name.replace("_", "")[:4]
return f"dinov2_{compact_arch_name}{patch_size}"
def _make_dinov2_model(
*,
arch_name: str = "vit_large",
img_size: int = 518,
patch_size: int = 14,
init_values: float = 1.0,
ffn_layer: str = "mlp",
block_chunks: int = 0,
pretrained: str = "",
output_idx: Sequence[int] = [],
num_register_tokens: int = 0,
drop_path_rate: float = 0.0,
use_norm: bool = False,
export: bool = False,
interpolate_offset: float = 0.0,
frozen_stages: int = 0,
**kwargs,
):
model_name = _make_dinov2_model_name(arch_name, patch_size)
vit_kwargs = dict(
img_size=img_size,
patch_size=patch_size,
init_values=init_values,
ffn_layer=ffn_layer,
block_chunks=block_chunks,
output_idx=output_idx,
drop_path_rate=drop_path_rate,
num_register_tokens=num_register_tokens,
use_norm=use_norm,
export=export,
interpolate_offset=interpolate_offset,
frozen_stages=frozen_stages,
)
vit_kwargs.update(**kwargs)
model = eval(arch_name)(**vit_kwargs)
if pretrained == "":
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}"
if num_register_tokens > 0:
url += "_reg4"
url += "_pretrain.pth"
state_dict = torch.hub.load_state_dict_from_url(
url, map_location="cpu", progress=False
)
info = model.load_state_dict(state_dict, strict=False)
print(info)
elif pretrained is not None:
state_dict = torch.load(pretrained, map_location="cpu")
info = model.load_state_dict(state_dict, strict=False)
print(f"loading from {pretrained} with:", info)
else:
print("Not loading pretrained weights for backbone")
return model
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .attention import Attention, MemEffAttention
from .block import NestedTensorBlock
from .dino_head import DINOHead
from .mlp import Mlp
from .patch_embed import PatchEmbed
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/attention.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
logger = logging.getLogger("dinov2")
try:
from xformers.ops import fmha, memory_efficient_attention, unbind
XFORMERS_AVAILABLE = True
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
XFORMERS_AVAILABLE = XFORMERS_AVAILABLE and torch.cuda.is_available()
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
x = F.scaled_dot_product_attention(qkv[0], qkv[1], qkv[2])
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffAttention(Attention):
def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
# new pytorch have good attn efficient, no need for xformers
if not XFORMERS_AVAILABLE or x.device.type == "cpu":
assert attn_bias is None, "xFormers is required for nested tensors usage"
return super().forward(x)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = unbind(qkv, 2)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/block.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
import logging
from typing import Any, Callable, Dict, List, Tuple
import torch
import torch.nn as nn
from .attention import Attention, MemEffAttention
from .drop_path import DropPath
from .layer_scale import LayerScale
from .mlp import Mlp
logger = logging.getLogger("dinov2")
try:
from xformers.ops import fmha, index_select_cat, scaled_index_add
XFORMERS_AVAILABLE = True
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
) -> None:
super().__init__()
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.ls1 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias,
)
self.ls2 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: torch.Tensor) -> torch.Tensor:
def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
return self.ls1(self.attn(self.norm1(x)))
def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
# the overhead is compensated only for a drop path rate larger than 0.1
x = drop_add_residual_stochastic_depth(
x,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
x = drop_add_residual_stochastic_depth(
x,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
elif self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x)
x = x + ffn_residual_func(x)
return x
def drop_add_residual_stochastic_depth(
x: torch.Tensor,
residual_func: Callable[[torch.Tensor], torch.Tensor],
sample_drop_ratio: float = 0.0,
) -> torch.Tensor:
# 1) extract subset using permutation
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
x_subset = x[brange]
# 2) apply residual_func to get residual
residual = residual_func(x_subset)
x_flat = x.flatten(1)
residual = residual.flatten(1)
residual_scale_factor = b / sample_subset_size
# 3) add the residual
x_plus_residual = torch.index_add(
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
)
return x_plus_residual.view_as(x)
def get_branges_scales(x, sample_drop_ratio=0.0):
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
residual_scale_factor = b / sample_subset_size
return brange, residual_scale_factor
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
if scaling_vector is None:
x_flat = x.flatten(1)
residual = residual.flatten(1)
x_plus_residual = torch.index_add(
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
)
else:
x_plus_residual = scaled_index_add(
x,
brange,
residual.to(dtype=x.dtype),
scaling=scaling_vector,
alpha=residual_scale_factor,
)
return x_plus_residual
attn_bias_cache: Dict[Tuple, Any] = {}
def get_attn_bias_and_cat(x_list, branges=None):
"""
this will perform the index select, cat the tensors, and provide the attn_bias from cache
"""
batch_sizes = (
[b.shape[0] for b in branges]
if branges is not None
else [x.shape[0] for x in x_list]
)
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
if all_shapes not in attn_bias_cache.keys():
seqlens = []
for b, x in zip(batch_sizes, x_list):
for _ in range(b):
seqlens.append(x.shape[1])
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
attn_bias._batch_sizes = batch_sizes
attn_bias_cache[all_shapes] = attn_bias
if branges is not None:
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
1, -1, x_list[0].shape[-1]
)
else:
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
cat_tensors = torch.cat(tensors_bs1, dim=1)
return attn_bias_cache[all_shapes], cat_tensors
def drop_add_residual_stochastic_depth_list(
x_list: List[torch.Tensor],
residual_func: Callable[[torch.Tensor, Any], torch.Tensor],
sample_drop_ratio: float = 0.0,
scaling_vector=None,
) -> torch.Tensor:
# 1) generate random set of indices for dropping samples in the batch
branges_scales = [
get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
]
branges = [s[0] for s in branges_scales]
residual_scale_factors = [s[1] for s in branges_scales]
# 2) get attention bias and index+concat the tensors
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
# 3) apply residual_func to get residual, and split the result
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
outputs = []
for x, brange, residual, residual_scale_factor in zip(
x_list, branges, residual_list, residual_scale_factors
):
outputs.append(
add_residual(
x, brange, residual, residual_scale_factor, scaling_vector
).view_as(x)
)
return outputs
class NestedTensorBlock(Block):
def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
"""
x_list contains a list of tensors to nest together and run
"""
assert isinstance(self.attn, MemEffAttention)
if self.training and self.sample_drop_ratio > 0.0:
def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
return self.attn(self.norm1(x), attn_bias=attn_bias)
def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
return self.mlp(self.norm2(x))
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=(
self.ls1.gamma if isinstance(self.ls1, LayerScale) else None
),
)
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=(
self.ls2.gamma if isinstance(self.ls1, LayerScale) else None
),
)
return x_list
else:
def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
return self.ls2(self.mlp(self.norm2(x)))
attn_bias, x = get_attn_bias_and_cat(x_list)
x = x + attn_residual_func(x, attn_bias=attn_bias)
x = x + ffn_residual_func(x)
return attn_bias.split(x)
def forward(self, x_or_x_list):
if isinstance(x_or_x_list, torch.Tensor):
return super(NestedTensorBlock, self).forward(x_or_x_list)
elif isinstance(x_or_x_list, list):
assert (
XFORMERS_AVAILABLE
), "Please install xFormers for nested tensors usage"
return self.forward_nested(x_or_x_list)
else:
raise AssertionError
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/dino_head.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from torch.nn.init import trunc_normal_
from torch.nn.utils import weight_norm
class DINOHead(nn.Module):
def __init__(
self,
in_dim,
out_dim,
use_bn=False,
nlayers=3,
hidden_dim=2048,
bottleneck_dim=256,
mlp_bias=True,
):
super().__init__()
nlayers = max(nlayers, 1)
self.mlp = _build_mlp(
nlayers,
in_dim,
bottleneck_dim,
hidden_dim=hidden_dim,
use_bn=use_bn,
bias=mlp_bias,
)
self.apply(self._init_weights)
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
x = self.last_layer(x)
return x
def _build_mlp(
nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
):
if nlayers == 1:
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
else:
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
for _ in range(nlayers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
return nn.Sequential(*layers)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/drop_path.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
import torch.nn as nn
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0:
random_tensor.div_(keep_prob)
output = x * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/layer_scale.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
from typing import Union
import torch
import torch.nn as nn
from torch import Tensor
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, Tensor] = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/mlp.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
from typing import Callable, Optional
from torch import Tensor, nn
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/patch_embed.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
from typing import Callable, Optional, Tuple, Union
import torch.nn as nn
from torch import Tensor
def make_2tuple(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
assert isinstance(x, int)
return (x, x)
class PatchEmbed(nn.Module):
"""
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
Args:
img_size: Image size.
patch_size: Patch token size.
in_chans: Number of input image channels.
embed_dim: Number of linear projection output channels.
norm_layer: Normalization layer.
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (
image_HW[0] // patch_HW[0],
image_HW[1] // patch_HW[1],
)
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
patch_H, patch_W = self.patch_size
assert (
H % patch_H == 0
), f"Input image height {H} is not a multiple of patch height {patch_H}"
assert (
W % patch_W == 0
), f"Input image width {W} is not a multiple of patch width: {patch_W}"
x = self.proj(x) # B C H W
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2) # B HW C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
return x
def flops(self) -> float:
Ho, Wo = self.patches_resolution
flops = (
Ho
* Wo
* self.embed_dim
* self.in_chans
* (self.patch_size[0] * self.patch_size[1])
)
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/backbones/metadinov2/swiglu_ffn.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional
import torch.nn.functional as F
from torch import Tensor, nn
class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
def forward(self, x: Tensor) -> Tensor:
x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
return self.w3(hidden)
try:
from xformers.ops import SwiGLU
XFORMERS_AVAILABLE = True
except ImportError:
SwiGLU = SwiGLUFFN
XFORMERS_AVAILABLE = False
class SwiGLUFFNFused(SwiGLU):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
out_features = out_features or in_features
hidden_features = hidden_features or in_features
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
super().__init__(
in_features=in_features,
hidden_features=hidden_features,
out_features=out_features,
bias=bias,
)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/encoder.py
================================================
import torch
import torch.nn as nn
from unidepth.models.backbones import ConvNeXt, ConvNeXtV2, _make_dinov2_model
class ModelWrap(nn.Module):
def __init__(self, model) -> None:
super().__init__()
self.backbone = model
def forward(self, x, *args, **kwargs):
features = []
for layer in self.backbone.features:
x = layer(x)
features.append(x)
return features
def convnextv2_base(config, **kwargs):
model = ConvNeXtV2(
depths=[3, 3, 27, 3],
dims=[128, 256, 512, 1024],
output_idx=config.get("output_idx", [3, 6, 33, 36]),
use_checkpoint=config.get("use_checkpoint", False),
**kwargs,
)
url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt"
state_dict = torch.hub.load_state_dict_from_url(
url, map_location="cpu", progress=False
)["model"]
info = model.load_state_dict(state_dict, strict=False)
print(info)
return model
def convnextv2_large(config, **kwargs):
model = ConvNeXtV2(
depths=[3, 3, 27, 3],
dims=[192, 384, 768, 1536],
output_idx=config.get("output_idx", [3, 6, 33, 36]),
use_checkpoint=config.get("use_checkpoint", False),
**kwargs,
)
url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt"
state_dict = torch.hub.load_state_dict_from_url(
url, map_location="cpu", progress=False
)["model"]
info = model.load_state_dict(state_dict, strict=False)
print(info)
return model
def convnextv2_large_mae(config, **kwargs):
model = ConvNeXtV2(
depths=[3, 3, 27, 3],
dims=[192, 384, 768, 1536],
output_idx=config.get("output_idx", [3, 6, 33, 36]),
use_checkpoint=config.get("use_checkpoint", False),
**kwargs,
)
url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt"
state_dict = torch.hub.load_state_dict_from_url(
url, map_location="cpu", progress=False
)["model"]
info = model.load_state_dict(state_dict, strict=False)
print(info)
return model
def convnextv2_huge(config, **kwargs):
model = ConvNeXtV2(
depths=[3, 3, 27, 3],
dims=[352, 704, 1408, 2816],
output_idx=config.get("output_idx", [3, 6, 33, 36]),
use_checkpoint=config.get("use_checkpoint", False),
**kwargs,
)
url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt"
state_dict = torch.hub.load_state_dict_from_url(
url, map_location="cpu", progress=False
)["model"]
info = model.load_state_dict(state_dict, strict=False)
print(info)
return model
def convnextv2_huge_mae(config, **kwargs):
model = ConvNeXtV2(
depths=[3, 3, 27, 3],
dims=[352, 704, 1408, 2816],
output_idx=config.get("output_idx", [3, 6, 33, 36]),
use_checkpoint=config.get("use_checkpoint", False),
**kwargs,
)
url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt"
state_dict = torch.hub.load_state_dict_from_url(
url, map_location="cpu", progress=False
)["model"]
info = model.load_state_dict(state_dict, strict=False)
print(info)
return model
def convnext_large_pt(config, **kwargs):
model = ConvNeXt(
depths=[3, 3, 27, 3],
dims=[192, 384, 768, 1536],
output_idx=config.get("output_idx", [3, 6, 33, 36]),
use_checkpoint=config.get("use_checkpoint", False),
**kwargs,
)
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import disable_progress_bars
from unidepth.models.backbones.convnext import HF_URL, checkpoint_filter_fn
disable_progress_bars()
repo_id, filename = HF_URL["convnext_large_pt"]
state_dict = torch.load(hf_hub_download(repo_id=repo_id, filename=filename))
state_dict = checkpoint_filter_fn(state_dict, model)
info = model.load_state_dict(state_dict, strict=False)
print(info)
return model
def convnext_large(config, **kwargs):
model = ConvNeXt(
depths=[3, 3, 27, 3],
dims=[192, 384, 768, 1536],
output_idx=config.get("output_idx", [3, 6, 33, 36]),
use_checkpoint=config.get("use_checkpoint", False),
drop_path_rate=config.get("drop_path", 0.0),
**kwargs,
)
return model
def dinov2_vits14(config, pretrained: bool = True, **kwargs):
"""
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
"""
vit = _make_dinov2_model(
arch_name="vit_small",
pretrained=config["pretrained"],
output_idx=config.get("output_idx", [3, 6, 9, 12]),
checkpoint=config.get("use_checkpoint", False),
drop_path_rate=config.get("drop_path", 0.0),
num_register_tokens=config.get("num_register_tokens", 0),
use_norm=config.get("use_norm", False),
export=config.get("export", False),
interpolate_offset=config.get("interpolate_offset", 0.0),
**kwargs,
)
return vit
def dinov2_vitb14(config, pretrained: bool = True, **kwargs):
"""
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
"""
vit = _make_dinov2_model(
arch_name="vit_base",
pretrained=config["pretrained"],
output_idx=config.get("output_idx", [3, 6, 9, 12]),
checkpoint=config.get("use_checkpoint", False),
drop_path_rate=config.get("drop_path", 0.0),
num_register_tokens=config.get("num_register_tokens", 0),
use_norm=config.get("use_norm", False),
export=config.get("export", False),
interpolate_offset=config.get("interpolate_offset", 0.0),
**kwargs,
)
return vit
def dinov2_vitl14(config, pretrained: str = "", **kwargs):
"""
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
"""
vit = _make_dinov2_model(
arch_name="vit_large",
pretrained=config["pretrained"],
output_idx=config.get("output_idx", [5, 12, 18, 24]),
checkpoint=config.get("use_checkpoint", False),
drop_path_rate=config.get("drop_path", 0.0),
num_register_tokens=config.get("num_register_tokens", 0),
use_norm=config.get("use_norm", False),
export=config.get("export", False),
interpolate_offset=config.get("interpolate_offset", 0.0),
**kwargs,
)
return vit
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv1/__init__.py
================================================
from .unidepthv1 import UniDepthV1
__all__ = [
"UniDepthV1",
]
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv1/decoder.py
================================================
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.layers import trunc_normal_
from unidepth.layers import (MLP, AttentionBlock, ConvUpsample, NystromBlock,
PositionEmbeddingSine)
from unidepth.utils.geometric import flat_interpolate, generate_rays
from unidepth.utils.misc import max_stack
from unidepth.utils.sht import rsh_cart_8
class ListAdapter(nn.Module):
def __init__(self, input_dims: List[int], hidden_dim: int):
super().__init__()
self.input_adapters = nn.ModuleList([])
self.num_chunks = len(input_dims)
for input_dim in input_dims:
self.input_adapters.append(
nn.Sequential(
nn.LayerNorm(input_dim), nn.Linear(input_dim, hidden_dim), nn.GELU()
)
)
def forward(self, x: torch.Tensor, splits: torch.Tensor) -> torch.Tensor:
xs = torch.split(x, splits.int().tolist(), dim=-1)
xs = [adapter(x) for x, adapter in zip(xs, self.input_adapters)]
return torch.cat(xs, dim=-1)
class CameraHead(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
num_heads: int = 8,
expansion: int = 4,
depth: int = 4,
dropout: float = 0.0,
layer_scale: float = 1.0,
**kwargs,
):
super().__init__()
self.aggregate = AttentionBlock(
hidden_dim,
num_heads=1,
expansion=expansion,
dropout=dropout,
layer_scale=layer_scale,
)
self.latents_pos = nn.Parameter(
torch.randn(1, 4, hidden_dim), requires_grad=True
)
self.layers = nn.ModuleList([])
self.in_features = MLP(hidden_dim, expansion=2, dropout=dropout)
for _ in range(depth):
blk = AttentionBlock(
hidden_dim,
num_heads=num_heads,
expansion=expansion,
dropout=dropout,
layer_scale=layer_scale,
)
self.layers.append(blk)
self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1)
self.cls_project = nn.Sequential(
nn.LayerNorm(input_dim),
nn.Linear(input_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, hidden_dim),
)
def forward(self, features, cls_tokens, pos_embed) -> torch.Tensor:
features = features.unbind(dim=-1)
cls_tokens = self.cls_project(cls_tokens)
features_stack = torch.cat(features, dim=1)
features_stack = features_stack + pos_embed
latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1)
features_stack = self.in_features(features_stack)
features = torch.cat((features_stack, cls_tokens), dim=1)
cls_tokens = self.aggregate(cls_tokens, context=features, pos_embed=latents_pos)
for i, layer in enumerate(self.layers):
cls_tokens = layer(cls_tokens, pos_embed=latents_pos)
# project
x = self.out(cls_tokens).squeeze(-1)
camera_intrinsics = torch.zeros(
x.shape[0], 3, 3, device=x.device, requires_grad=False
)
camera_intrinsics[:, 0, 0] = x[:, 0].exp()
camera_intrinsics[:, 1, 1] = x[:, 1].exp()
camera_intrinsics[:, 0, 2] = x[:, 2].sigmoid()
camera_intrinsics[:, 1, 2] = x[:, 3].sigmoid()
camera_intrinsics[:, 2, 2] = 1.0
return camera_intrinsics
def set_shapes(self, shapes: Tuple[int, int]):
self.shapes = shapes
class DepthHead(nn.Module):
def __init__(
self,
hidden_dim: int,
num_heads: int = 8,
expansion: int = 4,
depths: int | list[int] = 4,
camera_dim: int = 256,
num_resolutions: int = 4,
dropout: float = 0.0,
layer_scale: float = 1.0,
**kwargs,
) -> None:
super().__init__()
if isinstance(depths, int):
depths = [depths] * 3
assert len(depths) == 3
self.project_rays16 = MLP(
camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim
)
self.project_rays8 = MLP(
camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 2
)
self.project_rays4 = MLP(
camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 4
)
self.to_latents = MLP(hidden_dim, expansion=2, dropout=dropout)
self.features_channel_cat = nn.Linear(hidden_dim * num_resolutions, hidden_dim)
self.up8 = ConvUpsample(
hidden_dim, expansion=expansion, layer_scale=layer_scale
)
self.up4 = ConvUpsample(
hidden_dim // 2, expansion=expansion, layer_scale=layer_scale
)
self.up2 = ConvUpsample(
hidden_dim // 4, expansion=expansion, layer_scale=layer_scale
)
self.layers_16 = nn.ModuleList([])
self.layers_8 = nn.ModuleList([])
self.layers_4 = nn.ModuleList([])
self.aggregate_16 = AttentionBlock(
hidden_dim,
num_heads=1,
expansion=expansion,
dropout=dropout,
layer_scale=layer_scale,
context_dim=hidden_dim,
)
self.prompt_camera = AttentionBlock(
hidden_dim,
num_heads=1,
expansion=expansion,
dropout=dropout,
layer_scale=layer_scale,
context_dim=hidden_dim,
)
for i, (blk_lst, depth) in enumerate(
zip([self.layers_16, self.layers_8, self.layers_4], depths)
):
attn_cls = AttentionBlock if i == 0 else NystromBlock
for _ in range(depth):
blk_lst.append(
attn_cls(
hidden_dim // (2**i),
num_heads=num_heads // (2**i),
expansion=expansion,
dropout=dropout,
layer_scale=layer_scale,
)
)
self.out2 = nn.Conv2d(hidden_dim // 8, 1, 3, padding=1)
self.out4 = nn.Conv2d(hidden_dim // 4, 1, 3, padding=1)
self.out8 = nn.Conv2d(hidden_dim // 2, 1, 3, padding=1)
def set_original_shapes(self, shapes: Tuple[int, int]):
self.original_shapes = shapes
def set_shapes(self, shapes: Tuple[int, int]):
self.shapes = shapes
def forward(
self, features: torch.Tensor, rays_hr: torch.Tensor, pos_embed, level_embed
) -> torch.Tensor:
features = features.unbind(dim=-1)
shapes = self.shapes
rays_hr = rays_hr.detach()
# camera_embedding
rays_embedding_16 = F.normalize(
flat_interpolate(rays_hr, old=self.original_shapes, new=shapes), dim=-1
)
rays_embedding_8 = F.normalize(
flat_interpolate(
rays_hr, old=self.original_shapes, new=[x * 2 for x in shapes]
),
dim=-1,
)
rays_embedding_4 = F.normalize(
flat_interpolate(
rays_hr, old=self.original_shapes, new=[x * 4 for x in shapes]
),
dim=-1,
)
rays_embedding_16 = self.project_rays16(rsh_cart_8(rays_embedding_16))
rays_embedding_8 = self.project_rays8(rsh_cart_8(rays_embedding_8))
rays_embedding_4 = self.project_rays4(rsh_cart_8(rays_embedding_4))
features_tokens = torch.cat(features, dim=1)
features_tokens_pos = pos_embed + level_embed
# Generate latents with init as pooled features
features_channels = torch.cat(features, dim=-1)
features_16 = self.features_channel_cat(features_channels)
latents_16 = self.to_latents(
flat_interpolate(features_16, old=self.shapes, new=shapes, antialias=False)
)
# Aggregate features: F -> D
latents_16 = self.aggregate_16(
latents_16, context=features_tokens, pos_embed_context=features_tokens_pos
)
# Aggregate camera: D- > D|E
latents_16 = self.prompt_camera(latents_16, context=rays_embedding_16)
# Block 16 - Out 8
for layer in self.layers_16:
latents_16 = layer(latents_16, pos_embed=rays_embedding_16)
latents_8 = self.up8(
rearrange(
latents_16 + rays_embedding_16,
"b (h w) c -> b c h w",
h=shapes[0],
w=shapes[1],
).contiguous()
)
out8 = self.out8(
rearrange(
latents_8, "b (h w) c -> b c h w", h=shapes[0] * 2, w=shapes[1] * 2
)
)
# Block 8 - Out 4
for layer in self.layers_8:
latents_8 = layer(latents_8, pos_embed=rays_embedding_8)
latents_4 = self.up4(
rearrange(
latents_8 + rays_embedding_8,
"b (h w) c -> b c h w",
h=shapes[0] * 2,
w=shapes[1] * 2,
).contiguous()
)
out4 = self.out4(
rearrange(
latents_4, "b (h w) c -> b c h w", h=shapes[0] * 4, w=shapes[1] * 4
)
)
# Block 4 - Out 2
for layer in self.layers_4:
latents_4 = layer(latents_4, pos_embed=rays_embedding_4)
latents_2 = self.up2(
rearrange(
latents_4 + rays_embedding_4,
"b (h w) c -> b c h w",
h=shapes[0] * 4,
w=shapes[1] * 4,
).contiguous()
)
out2 = self.out2(
rearrange(
latents_2, "b (h w) c -> b c h w", h=shapes[0] * 8, w=shapes[1] * 8
)
)
# Depth features
proj_latents_16 = rearrange(
latents_16, "b (h w) c -> b c h w", h=shapes[0], w=shapes[1]
).contiguous()
# MS Outputs
out2 = out2.clamp(-10.0, 10.0).exp()
out4 = out4.clamp(-10.0, 10.0).exp()
out8 = out8.clamp(-10.0, 10.0).exp()
return out8, out4, out2, proj_latents_16
class Decoder(nn.Module):
def __init__(
self,
config,
*args,
**kwargs,
):
super().__init__()
self.build(config)
self.apply(self._init_weights)
self.test_fixed_camera = False
self.skip_camera = False
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_adapted_features(self, features_flat, splits):
features_flat_cat = torch.cat(features_flat, dim=-1)
features_projected = self.input_adapter(
features_flat_cat, splits
) # list [b hw c] shapes
features = torch.chunk(features_projected, len(splits), dim=-1)
return features
def run_camera(self, cls_tokens, features, pos_embed, original_shapes, rays):
# get cls tokens projections
cls_tokens_splits = torch.tensor(
[x.shape[-1] for x in cls_tokens],
device=features.device,
requires_grad=False,
dtype=features.dtype,
)
cls_tokens = torch.cat(cls_tokens, dim=-1)
cls_tokens = self.token_adapter(cls_tokens, cls_tokens_splits)
cls_tokens = torch.cat(
torch.chunk(cls_tokens, len(cls_tokens_splits), dim=-1), dim=1
)
# camera layer
intrinsics = self.camera_layer(
features=features, cls_tokens=cls_tokens, pos_embed=pos_embed
)
intrinsics[:, 0, 0] = max(original_shapes) / 2 * intrinsics[:, 0, 0]
intrinsics[:, 1, 1] = max(original_shapes) / 2 * intrinsics[:, 1, 1]
intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * original_shapes[1]
intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * original_shapes[0]
if not self.test_fixed_camera:
rays, _ = generate_rays(intrinsics, original_shapes, noisy=False)
return intrinsics, rays
def forward(self, inputs, image_metas) -> torch.Tensor:
B, _, H, W = inputs["image"].shape
device = inputs["image"].device
# make stride happy?
original_encoder_outputs = [x.contiguous() for x in inputs["encoder_outputs"]]
cls_tokens = [x.contiguous() for x in inputs["cls_tokens"]]
# collect features and tokens
original_encoder_outputs = [
max_stack(original_encoder_outputs[i:j])
for i, j in self.slices_encoder_range
]
# detach tokens for camera
cls_tokens = [
cls_tokens[-i - 1].detach() for i in range(len(self.slices_encoder_range))
]
# get features in b n d format
# level shapes, the shape per level, for swin like [[128, 128], [64, 64],...], for vit [[32,32]] -> mult times resolutions
resolutions = [
tuple(sorted([x.shape[1], x.shape[2]])) for x in original_encoder_outputs
]
level_shapes = sorted(list(set(resolutions)))[::-1]
if len(level_shapes) == 1:
level_shapes = level_shapes * self.num_resolutions
input_shapes = [
level_shapes[i]
for i, (start, end) in enumerate(self.slices_encoder)
for _ in range(end - start)
]
common_shape = level_shapes[-2]
# input shapes repeat shapes for each level, times the amount of the layers:
features_flat = [
flat_interpolate(
rearrange(x, "b h w c -> b (h w) c"), old=input_shape, new=common_shape
)
for x, input_shape in zip(original_encoder_outputs, input_shapes)
]
features_splits = torch.tensor(
[x.shape[-1] for x in features_flat],
device=device,
requires_grad=False,
dtype=torch.float32,
)
# input adapter, then do mean of features in same blocks
features = self.get_adapted_features(features_flat, features_splits)
features = torch.stack(features, dim=-1)
# positional embeddings, spatial and level
level_embed = torch.cat(
[
self.level_embed_layer(self.level_embeds)[i : i + 1]
.unsqueeze(0)
.repeat(B, common_shape[0] * common_shape[1], 1)
for i in range(self.num_resolutions)
],
dim=1,
)
pos_embed = self.pos_embed(
torch.zeros(
B,
1,
common_shape[0],
common_shape[1],
device=device,
requires_grad=False,
)
)
pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat(
1, self.num_resolutions, 1
)
self.camera_layer.set_shapes(common_shape)
intrinsics, rays = (
self.run_camera(
cls_tokens,
features=features,
pos_embed=pos_embed + level_embed,
original_shapes=(H, W),
rays=inputs.get("rays", None),
)
if not self.skip_camera
else (inputs["K"], inputs["rays"])
)
# run bulk of the model
self.depth_layer.set_shapes(common_shape)
self.depth_layer.set_original_shapes((H, W))
out8, out4, out2, depth_features = self.depth_layer(
features=features,
rays_hr=rays,
pos_embed=pos_embed,
level_embed=level_embed,
)
return intrinsics, [out8, out4, out2], depth_features
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {"latents_pos", "level_embeds"}
def build(self, config):
depth = config["model"]["pixel_decoder"]["depths"]
input_dims = config["model"]["pixel_encoder"]["embed_dims"]
hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"]
num_heads = config["model"]["num_heads"]
expansion = config["model"]["expansion"]
dropout = config["model"]["pixel_decoder"]["dropout"]
depths_encoder = config["model"]["pixel_encoder"]["depths"]
layer_scale = 1.0
self.depth = depth
self.dim = hidden_dim
self.downsample = 4
self.num_heads = num_heads
self.num_resolutions = len(depths_encoder)
self.depths_encoder = depths_encoder
self.slices_encoder_single = list(
zip([d - 1 for d in self.depths_encoder], self.depths_encoder)
)
self.slices_encoder_range = list(
zip([0, *self.depths_encoder[:-1]], self.depths_encoder)
)
cls_token_input_dims = [input_dims[-i - 1] for i in range(len(depths_encoder))]
input_dims = [input_dims[d - 1] for d in depths_encoder]
self.slices_encoder = self.slices_encoder_single
# adapt from encoder features, just project
self.input_adapter = ListAdapter(input_dims, hidden_dim)
self.token_adapter = ListAdapter(cls_token_input_dims, hidden_dim)
# camera layer
self.camera_layer = CameraHead(
input_dim=hidden_dim,
hidden_dim=hidden_dim,
num_heads=num_heads,
expansion=expansion,
depth=2,
dropout=dropout,
layer_scale=layer_scale,
)
self.depth_layer = DepthHead(
hidden_dim=hidden_dim,
num_heads=num_heads,
expansion=expansion,
depths=depth,
dropout=dropout,
camera_dim=81,
num_resolutions=self.num_resolutions,
layer_scale=layer_scale,
)
# transformer part
self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
self.level_embeds = nn.Parameter(
torch.randn(len(input_dims), hidden_dim), requires_grad=True
)
self.level_embed_layer = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv1/unidepthv1.py
================================================
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import importlib
from copy import deepcopy
from math import ceil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from einops import rearrange
from huggingface_hub import PyTorchModelHubMixin
from unidepth.models.unidepthv1.decoder import Decoder
from unidepth.utils.constants import (IMAGENET_DATASET_MEAN,
IMAGENET_DATASET_STD)
from unidepth.utils.distributed import is_main_process
from unidepth.utils.geometric import (generate_rays,
spherical_zbuffer_to_euclidean)
from unidepth.utils.misc import (get_params, match_gt, match_intrinsics,
profile_method)
VERBOSE = False
# inference helpers
def _paddings(image_shape, network_shape):
cur_h, cur_w = image_shape
h, w = network_shape
pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2
pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2
return pad_left, pad_right, pad_top, pad_bottom
def _shapes(image_shape, network_shape):
h, w = image_shape
input_ratio = w / h
output_ratio = network_shape[1] / network_shape[0]
if output_ratio > input_ratio:
ratio = network_shape[0] / h
elif output_ratio <= input_ratio:
ratio = network_shape[1] / w
return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio
def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes):
(pad_left, pad_right, pad_top, pad_bottom) = pads
rgbs = F.interpolate(
rgbs, size=shapes, mode="bilinear", align_corners=False, antialias=True
)
rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant")
if intrinsics is not None:
intrinsics = intrinsics.clone()
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio
intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + pad_left
intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + pad_top
return rgbs, intrinsics
return rgbs, None
def _postprocess(predictions, intrinsics, shapes, pads, ratio, original_shapes):
(pad_left, pad_right, pad_top, pad_bottom) = pads
# pred mean, trim paddings, and upsample to input dim
predictions = sum(
[
F.interpolate(
x.clone(),
size=shapes,
mode="bilinear",
align_corners=False,
antialias=True,
)
for x in predictions
]
) / len(predictions)
predictions = predictions[
..., pad_top : shapes[0] - pad_bottom, pad_left : shapes[1] - pad_right
]
predictions = F.interpolate(
predictions,
size=original_shapes,
mode="bilinear",
align_corners=False,
antialias=True,
)
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] / ratio
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] / ratio
intrinsics[:, 0, 2] = (intrinsics[:, 0, 2] - pad_left) / ratio
intrinsics[:, 1, 2] = (intrinsics[:, 1, 2] - pad_top) / ratio
return predictions, intrinsics
class UniDepthV1(
nn.Module,
PyTorchModelHubMixin,
library_name="UniDepth",
repo_url="https://github.com/lpiccinelli-eth/UniDepth",
tags=["monocular-metric-depth-estimation"],
):
def __init__(
self,
config,
eps: float = 1e-6,
**kwargs,
):
super().__init__()
self.build(config)
self.build_losses(config)
self.eps = eps
@profile_method(verbose=VERBOSE)
def forward_train(self, inputs, image_metas):
inputs, outputs = self.encode_decode(inputs, image_metas)
losses = self.compute_losses(outputs, inputs, image_metas)
return outputs, losses
@profile_method(verbose=VERBOSE)
def forward_test(self, inputs, image_metas):
inputs, outputs = self.encode_decode(inputs, image_metas)
depth_gt = inputs["depth"]
test_outputs = {}
test_outputs["depth"] = match_gt(
outputs["depth"], depth_gt, padding1=inputs["paddings"], padding2=None
)
test_outputs["points"] = match_gt(
outputs["points"], depth_gt, padding1=inputs["paddings"], padding2=None
)
test_outputs["confidence"] = match_gt(
outputs["confidence"], depth_gt, padding1=inputs["paddings"], padding2=None
)
test_outputs["rays"] = match_gt(
outputs["rays"], depth_gt, padding1=inputs["paddings"], padding2=None
)
test_outputs["rays"] = outputs["rays"] / torch.norm(
outputs["rays"], dim=1, keepdim=True
).clip(min=1e-5)
test_outputs["intrinsics"] = match_intrinsics(
outputs["intrinsics"],
inputs["image"],
depth_gt,
padding1=inputs["paddings"],
padding2=None,
)
return test_outputs
def forward(self, inputs, image_metas):
if self.training:
return self.forward_train(inputs, image_metas)
else:
return self.forward_test(inputs, image_metas)
def encode_decode(self, inputs, image_metas):
rgbs = inputs["image"]
B, _, H, W = rgbs.shape
cameras = inputs["camera"]
# shortcut eval should avoid errors
if len(image_metas) and "paddings" in image_metas[0]:
inputs["paddings"] = torch.tensor(
[image_meta["paddings"] for image_meta in image_metas],
device=self.device,
)[
..., [0, 2, 1, 3]
] # lrtb
inputs["depth_paddings"] = torch.tensor(
[image_meta["depth_paddings"] for image_meta in image_metas],
device=self.device,
)
if (
self.training
): # at inference we do not have image paddings on top of depth ones (we have not "crop" on gt in ContextCrop)
inputs["depth_paddings"] = inputs["depth_paddings"] + inputs["paddings"]
# Get camera rays for supervision, all in unit sphere
if inputs.get("camera", None) is not None:
inputs["rays"] = rearrange(
inputs["camera"].get_rays(shapes=(B, H, W)), "b c h w -> b (h w) c"
)
# Encode
encoder_outputs, cls_tokens = self.pixel_encoder(rgbs)
if "dino" in self.pixel_encoder.__class__.__name__.lower():
encoder_outputs = [
(x + y.unsqueeze(1)).contiguous()
for x, y in zip(encoder_outputs, cls_tokens)
]
inputs["encoder_outputs"] = encoder_outputs
inputs["cls_tokens"] = cls_tokens
# Decode
pred_intrinsics, predictions, depth_features = self.pixel_decoder(inputs, {})
predictions = sum(
[
F.interpolate(
x.clone(),
size=(H, W),
mode="bilinear",
align_corners=False,
antialias=True,
)
for x in predictions
]
) / len(predictions)
# Final 3D points backprojection
pred_rays, pred_angles = generate_rays(pred_intrinsics, (H, W), noisy=False)
# You may want to use inputs["angles"] if available?
pred_angles = rearrange(pred_angles, "b (h w) c -> b c h w", h=H, w=W)
points_3d = torch.cat((pred_angles, predictions), dim=1)
points_3d = spherical_zbuffer_to_euclidean(
points_3d.permute(0, 2, 3, 1)
).permute(0, 3, 1, 2)
# Output data, use for loss computation
outputs = {
"angles": pred_angles,
"rays": pred_rays,
"intrinsics": pred_intrinsics,
"points": points_3d,
"depth": predictions[:, -1:],
"cond_features": depth_features,
}
self.pixel_decoder.test_fixed_camera = False
outputs["rays"] = rearrange(outputs["rays"], "b (h w) c -> b c h w", h=H, w=W)
if "rays" in inputs:
inputs["rays"] = rearrange(inputs["rays"], "b (h w) c -> b c h w", h=H, w=W)
return inputs, outputs
def compute_losses(self, outputs, inputs, image_metas):
B, _, H, W = inputs["image"].shape
losses = {"opt": {}, "stat": {}}
if (
not self.training
): # only compute losses during training, avoid issues for mismatch size of pred and GT
return losses
losses_to_be_computed = list(self.losses.keys())
# depth loss
si = torch.tensor(
[x.get("si", False) for x in image_metas], device=self.device
).reshape(B)
loss = self.losses["depth"]
depth_losses = loss(
outputs["depth"],
target=inputs["depth"],
mask=inputs["depth_mask"].clone(),
si=si,
)
losses["opt"][loss.name] = loss.weight * depth_losses.mean()
losses_to_be_computed.remove("depth")
# camera loss, here we apply to rays for simplicity
# in the original training was on angles
# however, we saw no difference (see supplementary)
loss = self.losses["camera"]
camera_losses = loss(outputs["rays"], target=inputs["rays"])
losses["opt"][loss.name] = loss.weight * camera_losses.mean()
losses_to_be_computed.remove("camera")
# invariance loss
flips = torch.tensor(
[x.get("flip", False) for x in image_metas], device=self.device
).reshape(B)
loss = self.losses["invariance"]
invariance_losses = loss(
outputs["cond_features"],
intrinsics=inputs["camera"].K,
mask=inputs["depth_mask"],
flips=flips,
)
losses["opt"][loss.name] = loss.weight * invariance_losses.mean()
losses_to_be_computed.remove("invariance")
# remaining losses, we expect no more losses to be computed
assert (
not losses_to_be_computed
), f"Losses {losses_to_be_computed} not computed, revise `compute_loss` method"
return losses
@torch.no_grad()
def infer(self, rgbs: torch.Tensor, intrinsics=None, skip_camera=False):
if rgbs.ndim == 3:
rgbs = rgbs.unsqueeze(0)
if intrinsics is not None and intrinsics.ndim == 2:
intrinsics = intrinsics.unsqueeze(0)
B, _, H, W = rgbs.shape
rgbs = rgbs.to(self.device)
if intrinsics is not None:
intrinsics = intrinsics.to(self.device)
# process image and intrinsiscs (if any) to match network input (slow?)
if rgbs.max() > 5 or rgbs.dtype == torch.uint8:
rgbs = rgbs.to(torch.float32).div(255)
if rgbs.min() >= 0.0 and rgbs.max() <= 1.0:
rgbs = TF.normalize(
rgbs,
mean=IMAGENET_DATASET_MEAN,
std=IMAGENET_DATASET_STD,
)
(h, w), ratio = _shapes((H, W), self.image_shape)
pad_left, pad_right, pad_top, pad_bottom = _paddings((h, w), self.image_shape)
rgbs, gt_intrinsics = _preprocess(
rgbs,
intrinsics,
(h, w),
(pad_left, pad_right, pad_top, pad_bottom),
ratio,
self.image_shape,
)
# run encoder
encoder_outputs, cls_tokens = self.pixel_encoder(rgbs)
if "dino" in self.pixel_encoder.__class__.__name__.lower():
encoder_outputs = [
(x + y.unsqueeze(1)).contiguous()
for x, y in zip(encoder_outputs, cls_tokens)
]
# get data for decoder and adapt to given camera
inputs = {}
inputs["encoder_outputs"] = encoder_outputs
inputs["cls_tokens"] = cls_tokens
inputs["image"] = rgbs
if gt_intrinsics is not None:
rays, angles = generate_rays(
gt_intrinsics, self.image_shape, noisy=self.training
)
inputs["rays"] = rays
inputs["angles"] = angles
inputs["K"] = gt_intrinsics
self.pixel_decoder.test_fixed_camera = True
self.pixel_decoder.skip_camera = skip_camera
# decode all
pred_intrinsics, predictions, _ = self.pixel_decoder(inputs, {})
# undo the reshaping and get original image size (slow)
predictions, pred_intrinsics = _postprocess(
predictions,
pred_intrinsics,
self.image_shape,
(pad_left, pad_right, pad_top, pad_bottom),
ratio,
(H, W),
)
# final 3D points backprojection
intrinsics = gt_intrinsics if gt_intrinsics is not None else pred_intrinsics
angles = generate_rays(intrinsics, (H, W), noisy=False)[-1]
angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W)
points_3d = torch.cat((angles, predictions), dim=1)
points_3d = spherical_zbuffer_to_euclidean(
points_3d.permute(0, 2, 3, 1)
).permute(0, 3, 1, 2)
# output data
outputs = {
"intrinsics": pred_intrinsics,
"points": points_3d,
"depth": predictions[:, -1:],
}
self.pixel_decoder.test_fixed_camera = False
self.pixel_decoder.skip_camera = False
return outputs
def load_pretrained(self, model_file):
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
dict_model = torch.load(model_file, map_location=device)
if "model" in dict_model:
dict_model = dict_model["model"]
new_state_dict = deepcopy(
{k.replace("module.", ""): v for k, v in dict_model.items()}
)
info = self.load_state_dict(new_state_dict, strict=False)
if is_main_process():
print(
f"Loaded from {model_file} for {self.__class__.__name__} results in:",
info,
)
def get_params(self, config):
if hasattr(self.pixel_encoder, "get_params"):
encoder_p, encoder_lr = self.pixel_encoder.get_params(
config["model"]["pixel_encoder"]["lr"],
config["training"]["wd"],
config["training"]["ld"],
)
else:
encoder_p, encoder_lr = get_params(
self.pixel_encoder,
config["model"]["pixel_encoder"]["lr"],
config["training"]["wd"],
)
decoder_p, decoder_lr = get_params(
self.pixel_decoder, config["training"]["lr"], config["training"]["wd"]
)
return [*encoder_p, *decoder_p]
@property
def device(self):
return next(self.parameters()).device
def build(self, config):
mod = importlib.import_module("unidepth.models.encoder")
pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"])
pixel_encoder_config = {
**config["training"],
**config["data"],
**config["model"]["pixel_encoder"],
"interpolate_offset": 0.1,
}
pixel_encoder = pixel_encoder_factory(pixel_encoder_config)
config["model"]["pixel_encoder"]["patch_size"] = (
14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16
)
pixel_encoder_embed_dims = (
pixel_encoder.embed_dims
if hasattr(pixel_encoder, "embed_dims")
else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)]
)
config["model"]["pixel_encoder"]["embed_dim"] = getattr(
pixel_encoder, "embed_dim"
)
config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims
config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths
self.pixel_encoder = pixel_encoder
self.pixel_decoder = Decoder(config)
self.image_shape = config["data"]["image_shape"]
def build_losses(self, config):
self.losses = {}
for loss_name, loss_config in config["training"].get("losses", {}).items():
mod = importlib.import_module("unidepth.ops.losses")
loss_factory = getattr(mod, loss_config["name"])
self.losses[loss_name] = loss_factory.build(loss_config)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv2/__init__.py
================================================
from .unidepthv2 import UniDepthV2
from .unidepthv2_old import UniDepthV2old
__all__ = [
"UniDepthV2",
"UniDepthV2old",
]
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv2/decoder.py
================================================
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.layers import trunc_normal_
from unidepth.layers import (MLP, AttentionBlock, AttentionLayer,
PositionEmbeddingSine, ResUpsampleBil)
from unidepth.utils.coordinate import coords_grid
from unidepth.utils.geometric import flat_interpolate
from unidepth.utils.positional_embedding import generate_fourier_features
def orthonormal_init(num_tokens, dims):
pe = torch.randn(num_tokens, dims)
# Apply Gram-Schmidt process to make the matrix orthonormal
for i in range(num_tokens):
for j in range(i):
# Subtract the projection of current row onto previous row
pe[i] -= torch.dot(pe[i], pe[j]) * pe[j]
# Normalize the current row
pe[i] = F.normalize(pe[i], p=2, dim=0)
return pe
class ListAdapter(nn.Module):
def __init__(self, input_dims: list[int], hidden_dim: int):
super().__init__()
self.input_adapters = nn.ModuleList([])
self.num_chunks = len(input_dims)
for input_dim in input_dims:
self.input_adapters.append(nn.Linear(input_dim, hidden_dim))
def forward(self, xs: torch.Tensor) -> list[torch.Tensor]:
outs = [self.input_adapters[i](x) for i, x in enumerate(xs)]
return outs
class CameraHead(nn.Module):
def __init__(
self,
hidden_dim: int,
num_heads: int = 8,
expansion: int = 4,
dropout: float = 0.0,
layer_scale: float = 1.0,
**kwargs,
):
super().__init__()
self.num_params = 4
self.aggregate1 = AttentionBlock(
hidden_dim,
num_heads=num_heads,
expansion=expansion,
dropout=dropout,
layer_scale=layer_scale,
use_bias=False,
)
self.aggregate2 = AttentionBlock(
hidden_dim,
num_heads=num_heads,
expansion=expansion,
dropout=dropout,
layer_scale=layer_scale,
use_bias=False,
)
self.latents_pos = nn.Parameter(
torch.randn(1, self.num_params, hidden_dim), requires_grad=True
)
self.project = MLP(
hidden_dim, expansion=1, dropout=dropout, output_dim=hidden_dim
)
self.out_pinhole = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=1)
def fill_intrinsics(self, x):
fx, fy, cx, cy = x.unbind(dim=-1)
fx = torch.exp(fx)
fy = torch.exp(fy)
cx = torch.sigmoid(cx)
cy = torch.sigmoid(cy)
diagonal = (self.shapes[0] ** 2 + self.shapes[1] ** 2) ** 0.5
correction_tensor = torch.tensor(
[0.7 * diagonal, 0.7 * diagonal, self.shapes[1], self.shapes[0]],
device=x.device,
dtype=x.dtype,
)
intrinsics = torch.stack([fx, fy, cx, cy], dim=1)
intrinsics = correction_tensor.unsqueeze(0) * intrinsics
return intrinsics
def forward(self, features, cls_tokens, pos_embed) -> torch.Tensor:
features = features.unbind(dim=-1)
tokens = self.project(cls_tokens)
latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1)
tokens = self.aggregate1(tokens, pos_embed=latents_pos)
tokens = self.aggregate2(tokens, pos_embed=latents_pos)
x = self.out_pinhole(tokens.clone()).squeeze(-1)
camera_intrinsics = self.fill_intrinsics(x)
return camera_intrinsics
def set_shapes(self, shapes: tuple[int, int]):
self.shapes = shapes
class DepthHead(nn.Module):
def __init__(
self,
hidden_dim: int,
num_heads: int = 8,
expansion: int = 4,
depths: int | list[int] = 4,
camera_dim: int = 256,
dropout: float = 0.0,
kernel_size: int = 7,
layer_scale: float = 1.0,
out_dim: int = 1,
use_norm=False,
num_prompt_blocks=1,
**kwargs,
) -> None:
super().__init__()
self.camera_dim = camera_dim
self.out_dim = out_dim
self.hidden_dim = hidden_dim
self.ups = nn.ModuleList([])
self.depth_mlp = nn.ModuleList([])
self.process_features = nn.ModuleList([])
self.project_features = nn.ModuleList([])
self.prompt_camera = nn.ModuleList([])
mult = 2
self.to_latents = nn.Linear(hidden_dim, hidden_dim)
for _ in range(4):
self.prompt_camera.append(
AttentionLayer(
num_blocks=num_prompt_blocks,
dim=hidden_dim,
num_heads=num_heads,
expansion=expansion,
dropout=dropout,
layer_scale=-1.0,
context_dim=hidden_dim,
use_bias=False,
)
)
for i, depth in enumerate(depths):
current_dim = min(hidden_dim, mult * hidden_dim // int(2**i))
next_dim = mult * hidden_dim // int(2 ** (i + 1))
output_dim = max(next_dim, out_dim)
self.process_features.append(
nn.ConvTranspose2d(
hidden_dim,
current_dim,
kernel_size=max(1, 2 * i),
stride=max(1, 2 * i),
padding=0,
)
)
self.ups.append(
ResUpsampleBil(
current_dim,
output_dim=output_dim,
expansion=expansion,
layer_scale=layer_scale,
kernel_size=kernel_size,
num_layers=depth,
use_norm=use_norm,
)
)
depth_mlp = nn.Identity()
if i == len(depths) - 1:
depth_mlp = nn.Sequential(
nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim)
)
self.depth_mlp.append(depth_mlp)
self.confidence_mlp = nn.Sequential(
nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim)
)
self.to_depth_lr = nn.Conv2d(
output_dim,
output_dim // 2,
kernel_size=3,
padding=1,
padding_mode="reflect",
)
self.to_confidence_lr = nn.Conv2d(
output_dim,
output_dim // 2,
kernel_size=3,
padding=1,
padding_mode="reflect",
)
self.to_depth_hr = nn.Sequential(
nn.Conv2d(
output_dim // 2, 32, kernel_size=3, padding=1, padding_mode="reflect"
),
nn.LeakyReLU(),
nn.Conv2d(32, 1, kernel_size=1),
)
self.to_confidence_hr = nn.Sequential(
nn.Conv2d(
output_dim // 2, 32, kernel_size=3, padding=1, padding_mode="reflect"
),
nn.LeakyReLU(),
nn.Conv2d(32, 1, kernel_size=1),
)
def set_original_shapes(self, shapes: tuple[int, int]):
self.original_shapes = shapes
def set_shapes(self, shapes: tuple[int, int]):
self.shapes = shapes
def embed_rays(self, rays):
rays_embedding = flat_interpolate(
rays, old=self.original_shapes, new=self.shapes, antialias=True
)
rays_embedding = rays_embedding / torch.norm(
rays_embedding, dim=-1, keepdim=True
).clip(min=1e-4)
x, y, z = rays_embedding[..., 0], rays_embedding[..., 1], rays_embedding[..., 2]
polar = torch.acos(z)
x_clipped = x.abs().clip(min=1e-3) * (2 * (x >= 0).int() - 1)
azimuth = torch.atan2(y, x_clipped)
rays_embedding = torch.stack([polar, azimuth], dim=-1)
rays_embedding = generate_fourier_features(
rays_embedding,
dim=self.hidden_dim,
max_freq=max(self.shapes) // 2,
use_log=True,
cat_orig=False,
)
return rays_embedding
def condition(self, feat, rays_embeddings):
conditioned_features = [
prompter(rearrange(feature, "b h w c -> b (h w) c"), rays_embeddings)
for prompter, feature in zip(self.prompt_camera, feat)
]
return conditioned_features
def process(self, features_list, rays_embeddings):
conditioned_features = self.condition(features_list, rays_embeddings)
init_latents = self.to_latents(conditioned_features[0])
init_latents = rearrange(
init_latents, "b (h w) c -> b c h w", h=self.shapes[0], w=self.shapes[1]
).contiguous()
conditioned_features = [
rearrange(
x, "b (h w) c -> b c h w", h=self.shapes[0], w=self.shapes[1]
).contiguous()
for x in conditioned_features
]
latents = init_latents
out_features = []
for i, up in enumerate(self.ups):
latents = latents + self.process_features[i](conditioned_features[i + 1])
latents = up(latents)
out_features.append(latents)
return out_features, init_latents
def depth_proj(self, out_features):
h_out, w_out = out_features[-1].shape[-2:]
# aggregate output and project to depth
for i, (layer, features) in enumerate(zip(self.depth_mlp, out_features)):
out_depth_features = layer(features.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
out_depth_features = F.interpolate(
out_depth_features,
size=(h_out, w_out),
mode="bilinear",
align_corners=True,
)
if i == len(self.depth_mlp) - 1:
logdepth = out_depth_features
logdepth = self.to_depth_lr(logdepth)
logdepth = F.interpolate(
logdepth, size=self.original_shapes, mode="bilinear", align_corners=True
)
logdepth = self.to_depth_hr(logdepth)
return logdepth
def confidence_proj(self, out_features):
highres_features = out_features[-1].permute(0, 2, 3, 1)
confidence = self.confidence_mlp(highres_features).permute(0, 3, 1, 2)
confidence = self.to_confidence_lr(confidence)
confidence = F.interpolate(
confidence, size=self.original_shapes, mode="bilinear", align_corners=True
)
confidence = self.to_confidence_hr(confidence)
return confidence
def decode(self, out_features):
logdepth = self.depth_proj(out_features)
confidence = self.confidence_proj(out_features)
return logdepth, confidence
def forward(
self,
features: list[torch.Tensor],
rays_hr: torch.Tensor,
pos_embed,
level_embed,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B = features[0].shape[0]
rays_embeddings = self.embed_rays(rays_hr)
features, proj_latents_16 = self.process(features, rays_embeddings)
logdepth, logconf = self.decode(features)
return logdepth, logconf, proj_latents_16
class Decoder(nn.Module):
def __init__(
self,
config,
):
super().__init__()
self.build(config)
self.apply(self._init_weights)
self.test_gt_camera = False
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
def run_camera(self, cls_tokens, features, pos_embed, original_shapes, rays_gt):
H, W = original_shapes
# camera layer
intrinsics = self.camera_layer(
features=features, cls_tokens=cls_tokens, pos_embed=pos_embed
)
B, N = intrinsics.shape
device = intrinsics.device
dtype = intrinsics.dtype
id_coords = coords_grid(B, H, W, device=features.device, homogeneous=True)
intrinsics_matrix_inverse = torch.eye(3, device=device, dtype=dtype).repeat(
B, 1, 1
)
intrinsics_matrix_inverse[:, 0, 0] = 1.0 / intrinsics[:, 0]
intrinsics_matrix_inverse[:, 1, 1] = 1.0 / intrinsics[:, 1]
intrinsics_matrix_inverse[:, 0, 2] = -intrinsics[:, 2] / intrinsics[:, 0]
intrinsics_matrix_inverse[:, 1, 2] = -intrinsics[:, 3] / intrinsics[:, 1]
intrinsics_matrix = torch.eye(3, device=device, dtype=dtype).repeat(B, 1, 1)
intrinsics_matrix[:, 0, 0] = intrinsics[:, 0]
intrinsics_matrix[:, 1, 1] = intrinsics[:, 1]
intrinsics_matrix[:, 0, 2] = intrinsics[:, 2]
intrinsics_matrix[:, 1, 2] = intrinsics[:, 3]
rays_pred = intrinsics_matrix_inverse @ id_coords.reshape(B, 3, -1)
rays_pred = rays_pred.reshape(B, 3, H, W)
rays_pred = rays_pred / torch.norm(rays_pred, dim=1, keepdim=True).clamp(
min=1e-5
)
### LEGACY CODE FOR TRAINING
# if self.training and rays_gt is not None:
# prob = -1.0 # 0.8 * (1 - tanh(self.steps / 100000)) + 0.2
# where_use_gt_rays = torch.rand(B, 1, 1, device=device, dtype=dtype) < prob
# where_use_gt_rays = where_use_gt_rays.int()
# rays = rays_gt * where_use_gt_rays + rays_pred * (1 - where_use_gt_rays)
rays = rays_pred if rays_gt is None else rays_gt
rays = rearrange(rays, "b c h w -> b (h w) c")
return intrinsics_matrix, rays
def forward(
self,
inputs: dict[str, torch.Tensor],
image_metas: list[dict[str, torch.Tensor]],
) -> dict[str, torch.Tensor]:
B, C, H, W = inputs["image"].shape
device = inputs["image"].device
dtype = inputs["features"][0].dtype
# get features in b n d format
common_shape = inputs["features"][0].shape[1:3]
# input shapes repeat shapes for each level, times the amount of the layers:
features = self.input_adapter(inputs["features"])
# positional embeddings, spatial and level
level_embed = self.level_embeds.repeat(
B, common_shape[0] * common_shape[1], 1, 1
)
level_embed = rearrange(level_embed, "b n l d -> b (n l) d")
dummy_tensor = torch.zeros(
B, 1, common_shape[0], common_shape[1], device=device, requires_grad=False
)
pos_embed = self.pos_embed(dummy_tensor)
pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat(
1, self.num_resolutions, 1
)
# get cls tokens projections
camera_tokens = inputs["tokens"]
camera_tokens = self.camera_token_adapter(camera_tokens)
self.camera_layer.set_shapes((H, W))
intrinsics, rays = self.run_camera(
torch.cat(camera_tokens, dim=1),
features=torch.stack(features, dim=-1).detach(),
pos_embed=(pos_embed + level_embed).detach(),
original_shapes=(H, W),
rays_gt=inputs.get("rays", None),
)
# run bulk of the model
self.depth_layer.set_shapes(common_shape)
self.depth_layer.set_original_shapes((H, W))
logdepth, logconfidence, depth_features = self.depth_layer(
features=features,
rays_hr=rays,
pos_embed=pos_embed,
level_embed=level_embed,
)
return {
"radius": torch.exp(logdepth.clip(min=-8.0, max=8.0) + 2.0),
"depth_features": depth_features,
"confidence": torch.exp(logconfidence.clip(min=-8.0, max=8.0)),
"intrinsics": intrinsics,
"rays": rays,
}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {"latents_pos", "level_embeds"}
def build(self, config):
input_dims = config["model"]["pixel_encoder"]["embed_dims"]
hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"]
expansion = config["model"]["expansion"]
num_heads = config["model"]["num_heads"]
dropout = config["model"]["pixel_decoder"]["dropout"]
depths_encoder = config["model"]["pixel_encoder"]["depths"]
layer_scale = config["model"]["layer_scale"]
depth = config["model"]["pixel_decoder"]["depths"]
self.downsample = 4
depths_encoder = config["model"]["pixel_encoder"]["depths"]
self.num_resolutions = len(depths_encoder)
self.test_fixed_camera = False
out_dim = config["model"]["pixel_decoder"]["out_dim"]
kernel_size = config["model"]["pixel_decoder"].get("kernel_size", 7)
self.slices_encoder = list(zip([d - 1 for d in depths_encoder], depths_encoder))
input_dims = [input_dims[d - 1] for d in depths_encoder]
# # adapt from encoder features, just project
camera_dims = input_dims
self.input_adapter = ListAdapter(input_dims, hidden_dim)
self.camera_token_adapter = ListAdapter(camera_dims, hidden_dim)
# # camera layer
self.camera_layer = CameraHead(
hidden_dim=hidden_dim,
num_heads=num_heads,
expansion=expansion,
dropout=dropout,
layer_scale=layer_scale,
)
self.depth_layer = DepthHead(
hidden_dim=hidden_dim,
num_heads=num_heads,
expansion=expansion,
depths=depth,
dropout=dropout,
camera_dim=96,
num_resolutions=self.num_resolutions,
layer_scale=layer_scale,
out_dim=out_dim,
kernel_size=kernel_size,
num_prompt_blocks=1,
use_norm=False,
)
self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
self.level_embeds = nn.Parameter(
orthonormal_init(len(input_dims), hidden_dim).reshape(
1, 1, len(input_dims), hidden_dim
),
requires_grad=False,
)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv2/decoder_old.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.layers import trunc_normal_
from unidepth.layers import (MLP, AttentionBlock, ConvUpsampleShuffleResidual,
NystromBlock, PositionEmbeddingSine)
from unidepth.utils.geometric import flat_interpolate, generate_rays
from unidepth.utils.positional_embedding import generate_fourier_features
class ListAdapter(nn.Module):
def __init__(self, input_dims: list[int], hidden_dim: int):
super().__init__()
self.input_adapters = nn.ModuleList([])
self.num_chunks = len(input_dims)
self.checkpoint = True
for input_dim in input_dims:
self.input_adapters.append(
nn.Sequential(
nn.LayerNorm(input_dim), nn.Linear(input_dim, hidden_dim), nn.GELU()
)
)
def forward(self, x: torch.Tensor, splits: torch.Tensor) -> torch.Tensor:
xs = torch.split(x, splits.int().tolist(), dim=-1)
xs = [adapter(x) for x, adapter in zip(xs, self.input_adapters)]
return torch.cat(xs, dim=-1)
class CameraHead(nn.Module):
def __init__(
self,
hidden_dim: int,
num_heads: int = 8,
expansion: int = 4,
dropout: float = 0.0,
**kwargs,
):
super().__init__()
self.aggregate1 = AttentionBlock(
hidden_dim, num_heads=1, expansion=expansion, dropout=dropout
)
self.aggregate2 = AttentionBlock(
hidden_dim, num_heads=1, expansion=expansion, dropout=dropout
)
self.latents_pos = nn.Parameter(
torch.randn(1, 4, hidden_dim), requires_grad=True
)
self.in_features = MLP(hidden_dim, expansion=2, dropout=dropout)
self.project_cls = MLP(hidden_dim, dropout=dropout)
self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1)
def fill_intrinsics(self, x):
camera_intrinsics = torch.zeros(
x.shape[0], 3, 3, device=x.device, requires_grad=False
)
camera_intrinsics[:, 0, 0] = x[:, 0].exp()
camera_intrinsics[:, 1, 1] = x[:, 1].exp()
camera_intrinsics[:, 0, 2] = x[:, 2].sigmoid()
camera_intrinsics[:, 1, 2] = x[:, 3].sigmoid()
camera_intrinsics[:, 2, 2] = 1.0
return camera_intrinsics
def forward(self, features, cls_tokens, pos_embed) -> torch.Tensor:
features = features.unbind(dim=-1)
cls_tokens = self.project_cls(cls_tokens)
latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1)
features = self.in_features(torch.cat(features, dim=1) + pos_embed)
features = torch.cat((features, cls_tokens), dim=1)
cls_tokens = self.aggregate1(
cls_tokens, context=features, pos_embed=latents_pos
)
cls_tokens = self.aggregate2(
cls_tokens, context=features, pos_embed=latents_pos
)
# project to intrinsics
x = self.out(cls_tokens).squeeze(-1)
camera_intrinsics = self.fill_intrinsics(x)
return camera_intrinsics
def set_shapes(self, shapes: tuple[int, int]):
self.shapes = shapes
class GlobalHead(nn.Module):
def __init__(
self,
hidden_dim: int,
camera_dim: int,
expansion: int = 4,
dropout: float = 0.0,
**kwargs,
):
super().__init__()
self.camera_dim = camera_dim
self.in_features = nn.Linear(hidden_dim, hidden_dim)
self.project_rays = nn.Linear(camera_dim + 3, hidden_dim)
self.aggregate1 = AttentionBlock(
hidden_dim, num_heads=1, expansion=expansion, dropout=dropout
)
self.aggregate2 = AttentionBlock(
hidden_dim, num_heads=1, expansion=expansion, dropout=dropout
)
self.project_cls = MLP(hidden_dim, dropout=dropout)
self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1)
def embed_rays(self, rays, shapes):
rays_embedding = flat_interpolate(rays, old=self.original_shapes, new=shapes)
rays_embedding = F.normalize(rays_embedding, dim=-1)
rays_embedding = generate_fourier_features(
rays_embedding,
dim=self.camera_dim,
max_freq=max(shapes) // 2,
use_log=True,
cat_orig=True,
)
return rays_embedding
def set_original_shapes(self, shapes: tuple[int, int]):
self.original_shapes = shapes
def set_shapes(self, shapes: tuple[int, int]):
self.shapes = shapes
def get_scaleshift(self, x):
scale, shift = torch.chunk(x, 2, dim=1)
scale = scale.exp().reshape(-1, 1, 1, 1)
shift = shift.reshape(-1, 1, 1, 1)
return scale, shift
def forward(self, features, cls_tokens, rays) -> torch.Tensor:
features = features.unbind(dim=-1)
cls_tokens = self.project_cls(cls_tokens)
rays_embedding = self.project_rays(self.embed_rays(rays, self.shapes))
rays_embedding = rays_embedding.repeat(1, len(features), 1)
features = self.in_features(torch.cat(features, dim=1) + rays_embedding)
features = torch.cat((features, cls_tokens), dim=1)
cls_tokens = self.aggregate1(cls_tokens, context=features)
cls_tokens = self.aggregate2(cls_tokens, context=features)
x = self.out(cls_tokens).squeeze(-1)
scale, shift = self.get_scaleshift(x)
return scale, shift
class DepthHead(nn.Module):
def __init__(
self,
hidden_dim: int,
num_heads: int = 8,
expansion: int = 4,
depths: int | list[int] = 4,
checkpoint: bool = True,
camera_dim: int = 256,
num_resolutions: int = 4,
dropout: float = 0.0,
**kwargs,
) -> None:
super().__init__()
self.checkpoint = checkpoint
self.camera_dim = camera_dim
self.skip_depth = False
self.to_latents = MLP(hidden_dim, expansion=2, dropout=dropout)
self.features_channel_cat = nn.Linear(hidden_dim * num_resolutions, hidden_dim)
self.aggregate_16 = AttentionBlock(
hidden_dim,
num_heads=1,
expansion=expansion,
dropout=dropout,
context_dim=hidden_dim,
)
self.prompt_camera = AttentionBlock(
hidden_dim,
num_heads=1,
expansion=expansion,
dropout=dropout,
context_dim=hidden_dim,
)
self.rays_layers = nn.ModuleList([])
self.ups = nn.ModuleList([])
self.process_layers = nn.ModuleList([])
self.norms, self.out_layers = nn.ModuleList([]), nn.ModuleList([])
self.depth_mlp, self.confidence_mlp = nn.ModuleList([]), nn.ModuleList([])
for i, depth in enumerate(depths):
blk_lst = nn.ModuleList([])
for _ in range(depth):
blk_lst.append(
NystromBlock(
hidden_dim // int(2**i),
num_heads=num_heads // int(2**i),
expansion=expansion,
dropout=dropout,
)
)
self.process_layers.append(blk_lst)
self.rays_layers.append(nn.Linear(camera_dim + 3, hidden_dim // int(2**i)))
self.ups.append(
ConvUpsampleShuffleResidual(
hidden_dim // int(2**i),
expansion=expansion,
kernel_size=7,
num_layers=2,
)
)
self.depth_mlp.append(
MLP(
input_dim=hidden_dim // int(2 ** (i + 1)),
output_dim=16,
expansion=1,
)
)
self.confidence_mlp.append(
MLP(
input_dim=hidden_dim // int(2 ** (i + 1)),
output_dim=16,
expansion=1,
)
)
self.to_depth = nn.Conv2d(
16 * len(depths), 1, 7, padding=3, padding_mode="reflect"
)
self.to_confidence = nn.Conv2d(
16 * len(depths), 1, 7, padding=3, padding_mode="reflect"
)
def set_original_shapes(self, shapes: tuple[int, int]):
self.original_shapes = shapes
def set_shapes(self, shapes: tuple[int, int]):
self.shapes = shapes
def embed_rays(self, rays, shapes):
rays_embedding = flat_interpolate(rays, old=self.original_shapes, new=shapes)
rays_embedding = F.normalize(rays_embedding, dim=-1)
rays_embedding = generate_fourier_features(
rays_embedding,
dim=self.camera_dim,
max_freq=max(shapes) // 2,
use_log=True,
cat_orig=True,
)
return rays_embedding
def project_rays(self, rays, shapes):
embedded_rays = []
for i, layer in enumerate(self.rays_layers):
embedded_rays.append(
layer(self.embed_rays(rays, [(2**i) * x for x in shapes]))
)
return embedded_rays
def decode_depth(self, latents_16, rays, shapes):
latents = latents_16
out_features, depths, confidences = [], [], []
for i, (up, layers, rays_embedding) in enumerate(
zip(self.ups, self.process_layers, rays)
):
for layer in layers:
latents = layer(latents, pos_embed=rays_embedding)
latents = up(
rearrange(
latents + rays_embedding,
"b (h w) c -> b c h w",
h=shapes[0] * int(2**i),
w=shapes[1] * int(2**i),
).contiguous()
)
out = rearrange(
latents,
"b (h w) c -> b h w c",
h=shapes[0] * int(2 ** (1 + i)),
w=shapes[1] * int(2 ** (1 + i)),
)
out_features.append(out)
# aggregate output and project to depth
for i, (layer, features) in enumerate(
zip(self.depth_mlp[::-1], out_features[::-1])
):
out_depth_features = layer(features).permute(0, 3, 1, 2)
out_depth_features = F.interpolate(
out_depth_features, size=self.original_shapes, mode="bilinear"
)
depths.append(out_depth_features)
logdepth = self.to_depth(torch.cat(depths, dim=1))
# aggregate output and project to confidences
for i, (layer, features) in enumerate(
zip(self.confidence_mlp[::-1], out_features[::-1])
):
out_conf_features = layer(features).permute(0, 3, 1, 2)
out_conf_features = F.interpolate(
out_conf_features, size=self.original_shapes, mode="bilinear"
)
confidences.append(out_conf_features)
confidence = self.to_confidence(torch.cat(confidences, dim=1))
# apply sigmoid ot get conf in [0, 1]
confidence = torch.sigmoid(confidence)
return logdepth, confidence
def init_latents(self, features, shapes):
# Generate latents with init as pooled features
features_channels = torch.cat(features, dim=-1)
features_16 = self.features_channel_cat(features_channels)
latents_16 = features_16 + self.to_latents(
flat_interpolate(features_16, old=self.shapes, new=shapes, antialias=False)
)
return latents_16
def forward(
self, features: torch.Tensor, rays_hr: torch.Tensor, pos_embed, level_embed
) -> torch.Tensor:
B = features.shape[0]
features = features.unbind(dim=-1)
shapes = self.shapes
# camera_embedding
rays_embeddings = self.project_rays(rays_hr, shapes)
# Init latents
init_latents_16 = self.init_latents(features, shapes)
# Aggregate features: F -> D
latents_16 = self.aggregate_16(
init_latents_16,
context=torch.cat(features, dim=1),
pos_embed_context=pos_embed + level_embed,
)
# Aggregate camera: D -> D|E
latents_16 = self.prompt_camera(latents_16, context=rays_embeddings[0])
# Decode depth
logdepth, confidence = self.decode_depth(latents_16, rays_embeddings, shapes)
return logdepth, confidence, latents_16
class Decoder(nn.Module):
def __init__(
self,
config,
):
super().__init__()
self.build(config)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
def get_adapted_features(self, features_flat, splits):
features_flat_cat = torch.cat(features_flat, dim=-1)
features_projected = self.input_adapter(
features_flat_cat, splits
) # list [b hw c] shapes
features = torch.chunk(features_projected, splits.shape[0], dim=-1)
return features
def run_camera(self, cls_tokens, features, pos_embed, original_shapes, rays_gt):
# get cls tokens projections
cls_tokens_splits = torch.tensor(
[x.shape[-1] for x in cls_tokens],
device=features.device,
requires_grad=False,
dtype=features.dtype,
)
cls_tokens = torch.cat(cls_tokens, dim=-1)
cls_tokens = self.camera_token_adapter(cls_tokens, cls_tokens_splits)
cls_tokens = torch.cat(
torch.chunk(cls_tokens, cls_tokens_splits.shape[0], dim=-1), dim=1
)
# camera layer
intrinsics = self.camera_layer(
features=features, cls_tokens=cls_tokens, pos_embed=pos_embed
)
intrinsics[:, 0, 0] = max(original_shapes) / 2 * intrinsics[:, 0, 0]
intrinsics[:, 1, 1] = max(original_shapes) / 2 * intrinsics[:, 1, 1]
intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * original_shapes[1]
intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * original_shapes[0]
rays = (
rays_gt
if rays_gt is not None
else generate_rays(intrinsics, original_shapes)[0]
)
return intrinsics, rays
def run_global(self, cls_tokens, features, rays):
# get cls tokens projections
cls_tokens_splits = torch.tensor(
[x.shape[-1] for x in cls_tokens],
device=features.device,
requires_grad=False,
dtype=torch.float32,
)
cls_tokens = torch.cat(cls_tokens, dim=-1)
cls_tokens = self.global_token_adapter(cls_tokens, cls_tokens_splits)
cls_tokens = torch.cat(
torch.chunk(cls_tokens, cls_tokens_splits.shape[0], dim=-1), dim=1
)
scale, shift = self.global_layer(
features=features, rays=rays, cls_tokens=cls_tokens
)
return scale, shift
def forward(self, inputs, image_metas) -> torch.Tensor:
B, C, H, W = inputs["image"].shape
device = inputs["image"].device
dtype = inputs["image"].dtype
# get features in b n d format
# level shapes, the shape per level, for swin like [[128, 128], [64, 64],...], for vit [[32,32]] -> mult times resolutions
level_shapes = sorted(
list(set([tuple([x.shape[1], x.shape[2]]) for x in inputs["features"]]))
)[::-1]
if len(level_shapes) == 1:
level_shapes = level_shapes * self.num_resolutions
input_shapes = [
level_shapes[i]
for i, (start, end) in enumerate(self.slices_encoder)
for _ in range(end - start)
]
common_shape = level_shapes[-2]
# input shapes repeat shapes for each level, times the amount of the layers:
features_flat = [
flat_interpolate(
rearrange(x, "b h w c -> b (h w) c"), old=input_shape, new=common_shape
)
for x, input_shape in zip(inputs["features"], input_shapes)
]
features_splits = torch.tensor(
[x.shape[-1] for x in features_flat],
device=device,
requires_grad=False,
dtype=torch.float32,
)
features = self.get_adapted_features(features_flat, features_splits)
features = torch.stack(features, dim=-1)
# positional embeddings, spatial and level
level_embed = torch.cat(
[
self.level_embed_layer(self.level_embeds)[i : i + 1]
.unsqueeze(0)
.repeat(B, common_shape[0] * common_shape[1], 1)
for i in range(self.num_resolutions)
],
dim=1,
)
dummy_tensor = torch.zeros(
B, 1, common_shape[0], common_shape[1], device=device, requires_grad=False
)
pos_embed = self.pos_embed(dummy_tensor)
pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat(
1, self.num_resolutions, 1
)
self.camera_layer.set_shapes(common_shape)
intrinsics, rays = self.run_camera(
inputs["camera_tokens"],
features=features,
pos_embed=pos_embed + level_embed,
original_shapes=(H, W),
rays_gt=inputs.get("rays"),
)
self.global_layer.set_shapes(common_shape)
self.global_layer.set_original_shapes((H, W))
scale, shift = self.run_global(
inputs["global_tokens"], features=features, rays=rays
)
# run bulk of the model
self.depth_layer.set_shapes(common_shape)
self.depth_layer.set_original_shapes((H, W))
logdepth, confidence, depth_features = self.depth_layer(
features=features,
rays_hr=rays,
pos_embed=pos_embed,
level_embed=level_embed,
)
logdepth = logdepth.to(torch.float32, non_blocking=True)
# norm in log space, why performs better?
shapes = [int(x) for x in logdepth.shape[-2:]]
depth_normalized = F.layer_norm(logdepth, shapes).exp()
depth = (
depth_normalized + shift
) * scale # shift is scale invariant if we do (x + mu) * sigma
depth = F.softplus(depth, beta=10.0).to(dtype, non_blocking=True)
outputs = {
"depth": depth,
"confidence": confidence,
"depth_features": depth_features,
"K": intrinsics,
}
return outputs
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {"latents_pos", "level_embeds"}
def build(self, config):
input_dims = config["model"]["pixel_encoder"]["embed_dims"]
hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"]
expansion = config["model"]["expansion"]
num_heads = config["model"]["num_heads"]
dropout = config["model"]["pixel_decoder"]["dropout"]
depths_encoder = config["model"]["pixel_encoder"]["depths"]
depth = config["model"]["pixel_decoder"]["depths"]
depths_encoder = config["model"]["pixel_encoder"]["depths"]
self.downsample = 4
self.num_resolutions = len(depths_encoder)
self.slices_encoder = list(zip([d - 1 for d in depths_encoder], depths_encoder))
cls_token_input_dims = [input_dims[i] for i in [-1, -2, -3, -4]]
input_dims = [input_dims[d - 1] for d in depths_encoder]
# # camera layer
self.camera_layer = CameraHead(
hidden_dim=hidden_dim,
num_heads=num_heads,
expansion=expansion,
dropout=dropout,
)
# # scale shift layer
self.global_layer = GlobalHead(
hidden_dim=hidden_dim,
camera_dim=96,
num_heads=num_heads,
expansion=expansion,
dropout=dropout,
)
# # adapt from encoder features, just project
self.input_adapter = ListAdapter(input_dims, hidden_dim)
self.camera_token_adapter = ListAdapter(cls_token_input_dims, hidden_dim)
self.global_token_adapter = ListAdapter(cls_token_input_dims[:2], hidden_dim)
self.depth_layer = DepthHead(
hidden_dim=hidden_dim,
num_heads=num_heads,
expansion=expansion,
depths=depth,
dropout=dropout,
camera_dim=96,
num_resolutions=self.num_resolutions,
)
self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
self.level_embeds = nn.Parameter(
torch.randn(len(input_dims), hidden_dim), requires_grad=True
)
self.level_embed_layer = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv2/export.py
================================================
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import argparse
import json
import os
from math import ceil
import huggingface_hub
import torch.nn.functional as F
import torch.onnx
from unidepth.models.unidepthv2 import UniDepthV2
class UniDepthV2ONNX(UniDepthV2):
def __init__(
self,
config,
eps: float = 1e-6,
**kwargs,
):
super().__init__(config, eps)
def forward(self, rgbs):
B, _, H, W = rgbs.shape
features, tokens = self.pixel_encoder(rgbs)
inputs = {}
inputs["image"] = rgbs
inputs["features"] = [
self.stacking_fn(features[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
inputs["tokens"] = [
self.stacking_fn(tokens[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
outputs = self.pixel_decoder(inputs, [])
outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W)
pts_3d = outputs["rays"] * outputs["radius"]
return pts_3d, outputs["confidence"], outputs["intrinsics"]
class UniDepthV2ONNXcam(UniDepthV2):
def __init__(
self,
config,
eps: float = 1e-6,
**kwargs,
):
super().__init__(config, eps)
def forward(self, rgbs, rays):
B, _, H, W = rgbs.shape
features, tokens = self.pixel_encoder(rgbs)
inputs = {}
inputs["image"] = rgbs
inputs["rays"] = rays
inputs["features"] = [
self.stacking_fn(features[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
inputs["tokens"] = [
self.stacking_fn(tokens[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
outputs = self.pixel_decoder(inputs, [])
outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W)
pts_3d = outputs["rays"] * outputs["radius"]
return pts_3d, outputs["confidence"], outputs["intrinsics"]
def export(model, path, shape=(462, 630), with_camera=False):
model.eval()
image = torch.rand(1, 3, *shape)
dynamic_axes_in = {"rgbs": {0: "batch"}}
inputs = [image]
if with_camera:
rays = torch.rand(1, 3, *shape)
inputs.append(rays)
dynamic_axes_in["rays"] = {0: "batch"}
dynamic_axes_out = {
"pts_3d": {0: "batch"},
"confidence": {0: "batch"},
"intrinsics": {0: "batch"},
}
torch.onnx.export(
model,
tuple(inputs),
path,
input_names=list(dynamic_axes_in.keys()),
output_names=list(dynamic_axes_out.keys()),
opset_version=14,
dynamic_axes={**dynamic_axes_in, **dynamic_axes_out},
)
print(f"Model exported to {path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Export UniDepthV2 model to ONNX")
parser.add_argument(
"--version", type=str, default="v2", choices=["v2"], help="UniDepth version"
)
parser.add_argument(
"--backbone",
type=str,
default="vitl",
choices=["vits", "vitb", "vitl"],
help="Backbone model",
)
parser.add_argument(
"--shape",
type=int,
nargs=2,
default=(462, 630),
help="Input shape. No dyamic shape supported!",
)
parser.add_argument(
"--output-path", type=str, default="unidepthv2.onnx", help="Output ONNX file"
)
parser.add_argument(
"--with-camera",
action="store_true",
help="Export model that expects GT camera as unprojected rays at inference",
)
args = parser.parse_args()
version = args.version
backbone = args.backbone
shape = args.shape
output_path = args.output_path
with_camera = args.with_camera
# force shape to be multiple of 14
shape_rounded = [14 * ceil(x // 14 - 0.5) for x in shape]
if list(shape) != list(shape_rounded):
print(f"Shape {shape} is not multiple of 14. Rounding to {shape_rounded}")
shape = shape_rounded
# assumes command is from root of repo
with open(os.path.join("configs", f"config_{version}_{backbone}14.json")) as f:
config = json.load(f)
# tell DINO not to use efficient attention: not exportable
config["training"]["export"] = True
model = UniDepthV2ONNX(config) if not with_camera else UniDepthV2ONNXcam(config)
path = huggingface_hub.hf_hub_download(
repo_id=f"lpiccinelli/unidepth-{version}-{backbone}14",
filename=f"pytorch_model.bin",
repo_type="model",
)
info = model.load_state_dict(torch.load(path), strict=False)
print(f"UniDepth_{version}_{backbone} is loaded with:")
print(f"\t missing keys: {info.missing_keys}")
print(f"\t additional keys: {info.unexpected_keys}")
export(
model=model,
path=os.path.join(os.environ.get("TMPDIR", "."), output_path),
shape=shape,
with_camera=with_camera,
)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv2/unidepthv2.py
================================================
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import importlib
from copy import deepcopy
from math import ceil
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.v2.functional as TF
from einops import rearrange
from huggingface_hub import PyTorchModelHubMixin
from unidepth.models.unidepthv2.decoder import Decoder
from unidepth.utils.camera import BatchCamera, Camera, Pinhole
from unidepth.utils.constants import (IMAGENET_DATASET_MEAN,
IMAGENET_DATASET_STD)
from unidepth.utils.distributed import is_main_process
from unidepth.utils.misc import (first_stack, get_params, last_stack, match_gt,
match_intrinsics, max_stack, mean_stack,
softmax_stack)
STACKING_FNS = {
"max": max_stack,
"mean": mean_stack,
"first": first_stack,
"last": last_stack,
"softmax": softmax_stack,
}
def get_paddings(original_shape, aspect_ratio_range):
# Original dimensions
H_ori, W_ori = original_shape
orig_aspect_ratio = W_ori / H_ori
# Determine the closest aspect ratio within the range
min_ratio, max_ratio = aspect_ratio_range
target_aspect_ratio = min(max_ratio, max(min_ratio, orig_aspect_ratio))
if orig_aspect_ratio > target_aspect_ratio: # Too wide
W_new = W_ori
H_new = int(W_ori / target_aspect_ratio)
pad_top = (H_new - H_ori) // 2
pad_bottom = H_new - H_ori - pad_top
pad_left, pad_right = 0, 0
else: # Too tall
H_new = H_ori
W_new = int(H_ori * target_aspect_ratio)
pad_left = (W_new - W_ori) // 2
pad_right = W_new - W_ori - pad_left
pad_top, pad_bottom = 0, 0
return (pad_left, pad_right, pad_top, pad_bottom), (H_new, W_new)
def get_resize_factor(original_shape, pixels_range, shape_multiplier=14):
# Original dimensions
H_ori, W_ori = original_shape
n_pixels_ori = W_ori * H_ori
# Determine the closest number of pixels within the range
min_pixels, max_pixels = pixels_range
target_pixels = min(max_pixels, max(min_pixels, n_pixels_ori))
# Calculate the resize factor
resize_factor = (target_pixels / n_pixels_ori) ** 0.5
new_width = int(W_ori * resize_factor)
new_height = int(H_ori * resize_factor)
new_height = ceil(new_height / shape_multiplier) * shape_multiplier
new_width = ceil(new_width / shape_multiplier) * shape_multiplier
return resize_factor, (new_height, new_width)
def _postprocess(tensor, shapes, paddings, interpolation_mode="bilinear"):
# interpolate to original size
tensor = F.interpolate(
tensor, size=shapes, mode=interpolation_mode, align_corners=False
)
# remove paddings
pad1_l, pad1_r, pad1_t, pad1_b = paddings
tensor = tensor[..., pad1_t : shapes[0] - pad1_b, pad1_l : shapes[1] - pad1_r]
return tensor
def _postprocess_intrinsics(K, resize_factors, paddings):
batch_size = K.shape[0]
K_new = K.clone()
for i in range(batch_size):
scale = resize_factors[i]
pad_l, _, pad_t, _ = paddings[i]
K_new[i, 0, 0] /= scale # fx
K_new[i, 1, 1] /= scale # fy
K_new[i, 0, 2] /= scale # cx
K_new[i, 1, 2] /= scale # cy
K_new[i, 0, 2] -= pad_l # cx
K_new[i, 1, 2] -= pad_t # cy
return K_new
class UniDepthV2(
nn.Module,
PyTorchModelHubMixin,
library_name="UniDepth",
repo_url="https://github.com/lpiccinelli-eth/UniDepth",
tags=["monocular-metric-depth-estimation"],
):
def __init__(
self,
config,
eps: float = 1e-6,
**kwargs,
):
super().__init__()
self.eps = eps
self.build(config)
self.build_losses(config)
def forward_train(self, inputs, image_metas):
inputs, outputs = self.encode_decode(inputs, image_metas)
losses = self.compute_losses(outputs, inputs, image_metas)
return outputs, losses
def forward_test(self, inputs, image_metas):
inputs, outputs = self.encode_decode(inputs, image_metas)
depth_gt = inputs["depth"]
test_outputs = {}
test_outputs["depth"] = match_gt(
outputs["depth"], depth_gt, padding1=inputs["paddings"], padding2=None
)
test_outputs["points"] = match_gt(
outputs["points"], depth_gt, padding1=inputs["paddings"], padding2=None
)
test_outputs["confidence"] = match_gt(
outputs["confidence"], depth_gt, padding1=inputs["paddings"], padding2=None
)
test_outputs["rays"] = match_gt(
outputs["rays"], depth_gt, padding1=inputs["paddings"], padding2=None
)
test_outputs["rays"] = outputs["rays"] / torch.norm(
outputs["rays"], dim=1, keepdim=True
).clip(min=1e-5)
test_outputs["intrinsics"] = match_intrinsics(
outputs["intrinsics"],
inputs["image"],
depth_gt,
padding1=inputs["paddings"],
padding2=None,
)
return test_outputs
def forward(self, inputs, image_metas):
if self.training:
return self.forward_train(inputs, image_metas)
else:
return self.forward_test(inputs, image_metas)
def compute_losses(self, outputs, inputs, image_metas):
B, _, H, W = inputs["image"].shape
losses = {"opt": {}, "stat": {}}
losses_to_be_computed = list(self.losses.keys())
# depth loss
si = torch.tensor(
[x.get("si", False) for x in image_metas], device=self.device
).reshape(B)
loss = self.losses["depth"]
depth_losses = loss(
outputs["depth"],
target=inputs["depth"],
mask=inputs["depth_mask"].clone(),
si=si,
)
losses["opt"][loss.name] = loss.weight * depth_losses.mean()
losses_to_be_computed.remove("depth")
# camera loss, here we apply to rays for simplicity
# in the original training was on angles
# however, we saw no difference (see supplementary)
loss = self.losses["camera"]
camera_losses = loss(outputs["rays"], target=inputs["rays"])
losses["opt"][loss.name] = loss.weight * camera_losses.mean()
losses_to_be_computed.remove("camera")
# invariance loss on output depth
flips = torch.tensor(
[x.get("flip", False) for x in image_metas], device=self.device
).reshape(B)
loss = self.losses["invariance"]
invariance_losses = loss(
outputs["depth"],
intrinsics=inputs["camera"].K,
mask=inputs["depth_mask"],
flips=flips,
downsample_ratio=1,
)
losses["opt"][loss.name] = loss.weight * invariance_losses.mean()
losses_to_be_computed.remove("invariance")
# edge guided ssi
loss = self.losses["ssi"]
ssi_losses = loss(
outputs["depth"],
target=inputs["depth"],
mask=inputs["depth_mask"].clone(),
image=inputs["image"],
validity_mask=inputs["validity_mask"],
)
losses["opt"][loss.name] = loss.weight * ssi_losses.mean()
losses_to_be_computed.remove("ssi")
# remaining losses, we expect no more losses to be computed
loss = self.losses["confidence"]
conf_losses = loss(
outputs["confidence"].log(),
target_gt=inputs["depth"],
target_pred=outputs["depth"],
mask=inputs["depth_mask"].clone(),
)
losses["opt"][loss.name + "_conf"] = loss.weight * conf_losses.mean()
losses_to_be_computed.remove("confidence")
assert (
not losses_to_be_computed
), f"Losses {losses_to_be_computed} not computed, revise `compute_loss` method"
return losses
@torch.no_grad()
@torch.autocast(device_type="cuda", enabled=True, dtype=torch.float16)
def infer(
self,
rgb: torch.Tensor,
camera: torch.Tensor | Camera | None = None,
normalize=True,
):
ratio_bounds = self.shape_constraints["ratio_bounds"]
pixels_bounds = [
self.shape_constraints["pixels_min"],
self.shape_constraints["pixels_max"],
]
if hasattr(self, "resolution_level"):
assert (
self.resolution_level >= 0 and self.resolution_level < 10
), "resolution_level should be in [0, 10)"
pixels_range = pixels_bounds[1] - pixels_bounds[0]
interval = pixels_range / 10
new_lowbound = self.resolution_level * interval + pixels_bounds[0]
new_upbound = (self.resolution_level + 1) * interval + pixels_bounds[0]
pixels_bounds = (new_lowbound, new_upbound)
else:
warnings.warn("!! self.resolution_level not set, using default bounds !!")
# houskeeping on cpu/cuda and batchify
if rgb.ndim == 3:
rgb = rgb.unsqueeze(0)
if camera is not None:
if isinstance(camera, torch.Tensor):
assert (
camera.shape[-1] == 3 and camera.shape[-2] == 3
), "camera tensor should be of shape (..., 3, 3): assume pinhole"
camera = Pinhole(K=camera)
camera = BatchCamera.from_camera(camera)
camera = camera.to(self.device)
B, _, H, W = rgb.shape
rgb = rgb.to(self.device)
if camera is not None:
camera = camera.to(self.device)
# preprocess
paddings, (padded_H, padded_W) = get_paddings((H, W), ratio_bounds)
(pad_left, pad_right, pad_top, pad_bottom) = paddings
resize_factor, (new_H, new_W) = get_resize_factor(
(padded_H, padded_W), pixels_bounds
)
# -> rgb preprocess (input std-ized and resized)
if normalize:
rgb = TF.normalize(
rgb.float() / 255.0,
mean=IMAGENET_DATASET_MEAN,
std=IMAGENET_DATASET_STD,
)
rgb = F.pad(rgb, (pad_left, pad_right, pad_top, pad_bottom), value=0.0)
rgb = F.interpolate(
rgb, size=(new_H, new_W), mode="bilinear", align_corners=False
)
# -> camera preprocess
if camera is not None:
camera = camera.crop(
left=-pad_left, top=-pad_top, right=-pad_right, bottom=-pad_bottom
)
camera = camera.resize(resize_factor)
# run model
_, model_outputs = self.encode_decode(
inputs={"image": rgb, "camera": camera}, image_metas=[]
)
# collect outputs
out = {}
out["confidence"] = _postprocess(
model_outputs["confidence"],
(padded_H, padded_W),
paddings=paddings,
interpolation_mode=self.interpolation_mode,
)
points = _postprocess(
model_outputs["points"],
(padded_H, padded_W),
paddings=paddings,
interpolation_mode=self.interpolation_mode,
)
rays = _postprocess(
model_outputs["rays"],
(padded_H, padded_W),
paddings=paddings,
interpolation_mode=self.interpolation_mode,
)
out["intrinsics"] = _postprocess_intrinsics(
model_outputs["intrinsics"], [resize_factor] * B, [paddings] * B
)
out["radius"] = points.norm(dim=1, keepdim=True)
out["depth"] = points[:, -1:]
out["points"] = points
out["rays"] = rays / torch.norm(rays, dim=1, keepdim=True).clip(min=1e-5)
out["depth_features"] = model_outputs["depth_features"]
return out
def encode_decode(self, inputs, image_metas=[]):
B, _, H, W = inputs["image"].shape
# shortcut eval should avoid errors
if len(image_metas) and "paddings" in image_metas[0]:
inputs["paddings"] = torch.tensor(
[image_meta["paddings"] for image_meta in image_metas],
device=self.device,
)[
..., [0, 2, 1, 3]
] # lrtb
inputs["depth_paddings"] = torch.tensor(
[image_meta["depth_paddings"] for image_meta in image_metas],
device=self.device,
)
if (
self.training
): # at inference we do not have image paddings on top of depth ones (we have not "crop" on gt in ContextCrop)
inputs["depth_paddings"] = inputs["depth_paddings"] + inputs["paddings"]
if inputs.get("camera", None) is not None:
inputs["rays"] = inputs["camera"].get_rays(shapes=(B, H, W))
features, tokens = self.pixel_encoder(inputs["image"])
inputs["features"] = [
self.stacking_fn(features[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
inputs["tokens"] = [
self.stacking_fn(tokens[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
outputs = self.pixel_decoder(inputs, image_metas)
outputs["rays"] = rearrange(outputs["rays"], "b (h w) c -> b c h w", h=H, w=W)
pts_3d = outputs["rays"] * outputs["radius"]
outputs.update({"points": pts_3d, "depth": pts_3d[:, -1:]})
return inputs, outputs
def load_pretrained(self, model_file):
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
dict_model = torch.load(model_file, map_location=device, weights_only=False)
if "model" in dict_model:
dict_model = dict_model["model"]
dict_model = {k.replace("module.", ""): v for k, v in dict_model.items()}
info = self.load_state_dict(dict_model, strict=False)
if is_main_process():
print(
f"Loaded from {model_file} for {self.__class__.__name__} results in:",
info,
)
def get_params(self, config):
if hasattr(self.pixel_encoder, "get_params"):
encoder_p, encoder_lr = self.pixel_encoder.get_params(
config["model"]["pixel_encoder"]["lr"],
config["training"]["wd"],
config["training"]["ld"],
)
else:
encoder_p, encoder_lr = get_params(
self.pixel_encoder,
config["model"]["pixel_encoder"]["lr"],
config["training"]["wd"],
)
decoder_p, decoder_lr = get_params(
self.pixel_decoder, config["training"]["lr"], config["training"]["wd"]
)
return [*encoder_p, *decoder_p]
@property
def device(self):
return next(self.parameters()).device
def build(self, config):
mod = importlib.import_module("unidepth.models.encoder")
pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"])
pixel_encoder_config = {
**config["training"],
**config["model"]["pixel_encoder"],
**config["data"],
}
pixel_encoder = pixel_encoder_factory(pixel_encoder_config)
config["model"]["pixel_encoder"]["patch_size"] = (
14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16
)
pixel_encoder_embed_dims = (
pixel_encoder.embed_dims
if hasattr(pixel_encoder, "embed_dims")
else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)]
)
config["model"]["pixel_encoder"]["embed_dim"] = getattr(
pixel_encoder, "embed_dim"
)
config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims
config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths
config["model"]["pixel_encoder"]["cls_token_embed_dims"] = getattr(
pixel_encoder, "cls_token_embed_dims", pixel_encoder_embed_dims
)
pixel_decoder = Decoder(config)
self.pixel_encoder = pixel_encoder
self.pixel_decoder = pixel_decoder
self.slices_encoder_range = list(
zip([0, *self.pixel_encoder.depths[:-1]], self.pixel_encoder.depths)
)
stacking_fn = config["model"]["pixel_encoder"]["stacking_fn"]
assert (
stacking_fn in STACKING_FNS
), f"Stacking function {stacking_fn} not found in {STACKING_FNS.keys()}"
self.stacking_fn = STACKING_FNS[stacking_fn]
self.shape_constraints = config["data"]["augmentations"]["shape_constraints"]
self.interpolation_mode = "bilinear"
def build_losses(self, config):
self.losses = {}
for loss_name, loss_config in config["training"]["losses"].items():
mod = importlib.import_module("unidepth.ops.losses")
loss_factory = getattr(mod, loss_config["name"])
self.losses[loss_name] = loss_factory.build(loss_config)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/models/unidepthv2/unidepthv2_old.py
================================================
import importlib
import warnings
from copy import deepcopy
from math import ceil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from einops import rearrange
from huggingface_hub import PyTorchModelHubMixin
from unidepth.models.unidepthv2.decoder_old import Decoder
from unidepth.utils.constants import (IMAGENET_DATASET_MEAN,
IMAGENET_DATASET_STD)
from unidepth.utils.distributed import is_main_process
from unidepth.utils.geometric import (generate_rays,
spherical_zbuffer_to_euclidean)
from unidepth.utils.misc import (first_stack, last_stack, max_stack,
mean_stack, softmax_stack)
STACKING_FNS = {
"max": max_stack,
"mean": mean_stack,
"first": first_stack,
"last": last_stack,
"softmax": softmax_stack,
}
RESOLUTION_LEVELS = 10
# inference helpers
def _check_ratio(image_ratio, ratio_bounds):
ratio_bounds = sorted(ratio_bounds)
if ratio_bounds is not None and (
image_ratio < ratio_bounds[0] or image_ratio > ratio_bounds[1]
):
warnings.warn(
f"Input image ratio ({image_ratio:.3f}) is out of training "
f"distribution: {ratio_bounds}. This may lead to unexpected results. "
f"Consider resizing/padding the image to match the training distribution."
)
def _check_resolution(shape_constraints, resolution_level):
if resolution_level is None:
warnings.warn(
"Resolution level is not set. Using max resolution. "
"You can tradeoff resolution for speed by setting a number in [0,10]. "
"This can be achieved by setting model's `resolution_level` attribute."
)
resolution_level = RESOLUTION_LEVELS
pixel_bounds = sorted(shape_constraints["pixels_bounds_ori"])
pixel_range = pixel_bounds[-1] - pixel_bounds[0]
clipped_resolution_level = min(max(resolution_level, 0), RESOLUTION_LEVELS)
if clipped_resolution_level != resolution_level:
warnings.warn(
f"Resolution level {resolution_level} is out of bounds ([0,{RESOLUTION_LEVELS}]). "
f"Clipping to {clipped_resolution_level}."
)
shape_constraints["pixels_bounds"] = [
pixel_bounds[0]
+ ceil(pixel_range * clipped_resolution_level / RESOLUTION_LEVELS),
pixel_bounds[0]
+ ceil(pixel_range * clipped_resolution_level / RESOLUTION_LEVELS),
]
return shape_constraints
def _get_closes_num_pixels(image_shape, pixels_bounds):
h, w = image_shape
num_pixels = h * w
pixels_bounds = sorted(pixels_bounds)
num_pixels = max(min(num_pixels, pixels_bounds[1]), pixels_bounds[0])
return num_pixels
def _shapes(image_shape, shape_constraints):
h, w = image_shape
image_ratio = w / h
_check_ratio(image_ratio, shape_constraints["ratio_bounds"])
num_pixels = _get_closes_num_pixels(
(h / shape_constraints["patch_size"], w / shape_constraints["patch_size"]),
shape_constraints["pixels_bounds"],
)
h = ceil((num_pixels / image_ratio) ** 0.5 - 0.5)
w = ceil(h * image_ratio - 0.5)
ratio = h / image_shape[0] * shape_constraints["patch_size"]
return (
h * shape_constraints["patch_size"],
w * shape_constraints["patch_size"],
), ratio
def _preprocess(rgbs, intrinsics, shapes, ratio):
rgbs = F.interpolate(rgbs, size=shapes, mode="bilinear", antialias=True)
if intrinsics is not None:
intrinsics = intrinsics.clone()
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio
intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio
intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio
return rgbs, intrinsics
return rgbs, None
def _postprocess(outs, ratio, original_shapes, mode="nearest-exact"):
outs["depth"] = F.interpolate(outs["depth"], size=original_shapes, mode=mode)
outs["confidence"] = F.interpolate(
outs["confidence"], size=original_shapes, mode="bilinear", antialias=True
)
outs["K"][:, 0, 0] = outs["K"][:, 0, 0] / ratio
outs["K"][:, 1, 1] = outs["K"][:, 1, 1] / ratio
outs["K"][:, 0, 2] = outs["K"][:, 0, 2] / ratio
outs["K"][:, 1, 2] = outs["K"][:, 1, 2] / ratio
return outs
class UniDepthV2old(
nn.Module,
PyTorchModelHubMixin,
library_name="UniDepth",
repo_url="https://github.com/lpiccinelli-eth/UniDepth",
tags=["monocular-metric-depth-estimation"],
):
def __init__(
self,
config,
**kwargs,
):
super().__init__()
self.build(config)
def forward(self, inputs, image_metas):
H, W = inputs["depth"].shape[-2:]
if "K" in inputs:
rays, angles = generate_rays(inputs["K"], (H, W))
inputs["rays"] = rays
inputs["angles"] = angles
features, tokens = self.pixel_encoder(inputs[f"image"])
cls_tokens = [x.contiguous() for x in tokens]
features = [
self.stacking_fn(features[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
tokens = [
self.stacking_fn(tokens[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
global_tokens = [cls_tokens[i] for i in [-2, -1]]
camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]]
inputs["features"] = features
inputs["tokens"] = tokens
inputs["global_tokens"] = global_tokens
inputs["camera_tokens"] = camera_tokens
outs = self.pixel_decoder(inputs, image_metas)
angles = rearrange(
generate_rays(outs["K"], (H, W), noisy=False)[-1],
"b (h w) c -> b c h w",
h=H,
w=W,
)
predictions = F.interpolate(
outs["depth"],
size=(H, W),
mode="bilinear",
align_corners=False,
antialias=True,
)
confidence = F.interpolate(
outs["confidence"],
size=(H, W),
mode="bilinear",
align_corners=False,
antialias=True,
)
predictions_3d = torch.cat((angles, predictions), dim=1)
predictions_3d = spherical_zbuffer_to_euclidean(
predictions_3d.permute(0, 2, 3, 1)
).permute(0, 3, 1, 2)
outputs = {
"K": outs["K"],
"depth": predictions,
"confidence": confidence,
"points": predictions_3d,
"depth_features": outs["depth_features"],
}
return outputs
@torch.no_grad()
def infer(self, rgbs: torch.Tensor, intrinsics=None):
shape_constraints = self.shape_constraints
if rgbs.ndim == 3:
rgbs = rgbs.unsqueeze(0)
if intrinsics is not None and intrinsics.ndim == 2:
intrinsics = intrinsics.unsqueeze(0)
B, _, H, W = rgbs.shape
rgbs = rgbs.to(self.device)
if intrinsics is not None:
intrinsics = intrinsics.to(self.device)
# process image and intrinsiscs (if any) to match network input (slow?)
if rgbs.max() > 5 or rgbs.dtype == torch.uint8:
rgbs = rgbs.to(torch.float32).div(255)
if rgbs.min() >= 0.0 and rgbs.max() <= 1.0:
rgbs = TF.normalize(
rgbs,
mean=IMAGENET_DATASET_MEAN,
std=IMAGENET_DATASET_STD,
)
# check resolution constraints: tradeoff resolution and speed
shape_constraints = _check_resolution(shape_constraints, self.resolution_level)
# get image shape
(h, w), ratio = _shapes((H, W), shape_constraints)
rgbs, gt_intrinsics = _preprocess(
rgbs,
intrinsics,
(h, w),
ratio,
)
# run encoder
features, tokens = self.pixel_encoder(rgbs)
cls_tokens = [x.contiguous() for x in tokens]
features = [
self.stacking_fn(features[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
tokens = [
self.stacking_fn(tokens[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
global_tokens = [cls_tokens[i] for i in [-2, -1]]
camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]]
# get data fro decoder and adapt to given camera
inputs = {}
inputs["features"] = features
inputs["tokens"] = tokens
inputs["global_tokens"] = global_tokens
inputs["camera_tokens"] = camera_tokens
inputs["image"] = rgbs
if gt_intrinsics is not None:
rays, angles = generate_rays(gt_intrinsics, (h, w))
inputs["rays"] = rays
inputs["angles"] = angles
inputs["K"] = gt_intrinsics
outs = self.pixel_decoder(inputs, {})
# undo the reshaping and get original image size (slow)
outs = _postprocess(outs, ratio, (H, W), mode=self.interpolation_mode)
pred_intrinsics = outs["K"]
depth = outs["depth"]
confidence = outs["confidence"]
# final 3D points backprojection
intrinsics = intrinsics if intrinsics is not None else pred_intrinsics
angles = generate_rays(intrinsics, (H, W))[-1]
angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W)
points_3d = torch.cat((angles, depth), dim=1)
points_3d = spherical_zbuffer_to_euclidean(
points_3d.permute(0, 2, 3, 1)
).permute(0, 3, 1, 2)
outputs = {
"intrinsics": pred_intrinsics,
"points": points_3d,
"depth": depth,
"confidence": confidence,
}
return outputs
def load_pretrained(self, model_file):
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
dict_model = torch.load(model_file, map_location=device)
if "model" in dict_model:
dict_model = dict_model["model"]
dict_model = deepcopy(
{k.replace("module.", ""): v for k, v in dict_model.items()}
)
info = self.load_state_dict(dict_model, strict=False)
if is_main_process():
print(
f"Loaded from {model_file} for {self.__class__.__name__} results in:",
info,
)
@property
def device(self):
return next(self.parameters()).device
def build(self, config):
mod = importlib.import_module("unidepth.models.encoder")
pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"])
pixel_encoder_config = {
**config["training"],
**config["model"]["pixel_encoder"],
**config["data"],
}
pixel_encoder = pixel_encoder_factory(pixel_encoder_config)
config["model"]["pixel_encoder"]["patch_size"] = (
14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16
)
pixel_encoder_embed_dims = (
pixel_encoder.embed_dims
if hasattr(pixel_encoder, "embed_dims")
else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)]
)
config["model"]["pixel_encoder"]["embed_dim"] = getattr(
pixel_encoder, "embed_dim"
)
config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims
config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths
pixel_decoder = Decoder(config)
self.pixel_encoder = pixel_encoder
self.pixel_decoder = pixel_decoder
stacking_fn = config["model"]["pixel_encoder"]["stacking_fn"]
assert (
stacking_fn in STACKING_FNS
), f"Stacking function {stacking_fn} not found in {STACKING_FNS.keys()}"
self.stacking_fn = STACKING_FNS[stacking_fn]
self.slices_encoder_range = list(
zip([0, *pixel_encoder.depths[:-1]], pixel_encoder.depths)
)
self.shape_constraints = config["data"]["shape_constraints"]
self.shape_constraints["pixels_bounds_ori"] = self.shape_constraints.get(
"pixels_bounds", [1400, 2400]
)
self.interpolation_mode = "bilinear"
self.eps = 1e-6
self.resolution_level = None
def build_losses(self, config):
self.losses = {}
for loss_name, loss_config in config["training"]["losses"].items():
mod = importlib.import_module("unidepth.ops.losses")
loss_factory = getattr(mod, loss_config["name"])
self.losses[loss_name] = loss_factory.build(loss_config)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/__init__.py
================================================
from .losses import (ARel, Confidence, Dummy, EdgeGuidedLocalSSI, LocalSSI,
Regression, SelfDistill, SILog, TeacherDistill)
from .scheduler import CosineScheduler, PlainCosineScheduler
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/__init__.py
================================================
from .functions import ExtractPatchesFunction
from .modules import RandomPatchExtractor
__all__ = ["ExtractPatchesFunction", "RandomPatchExtractor"]
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/compile.sh
================================================
#!/usr/bin/env bash
if [ -z "$TORCH_CUDA_ARCH_LIST" ]; then
export TORCH_CUDA_ARCH_LIST="7.5 8.0 8.6+PTX"
fi
python setup.py build install
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/functions/__init__.py
================================================
from .extract_patches import ExtractPatchesFunction
__all__ = ["ExtractPatchesFunction"]
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/functions/extract_patches.py
================================================
import RandomPatchExtraction
import torch
from torch.autograd import Function
class ExtractPatchesFunction(Function):
@staticmethod
def forward(ctx, input, centers, h, w):
# Save variables for backward pass. inputs for shapes
ctx.save_for_backward(input, centers)
return RandomPatchExtraction.extract_patches_forward(input, centers, h, w)
@staticmethod
def backward(ctx, grad_output):
input, centers = ctx.saved_tensors
(grad_input,) = RandomPatchExtraction.extract_patches_backward(
grad_output, centers, input.shape[2], input.shape[3]
)
# breakpoint()
# Return gradients with respect to inputs only
return grad_input, None, None, None
# Test
if __name__ == "__main__":
B, C, H, W = 1, 1, 10, 10
N = 2
h, w = 3, 3
input = torch.arange(
B * C * H * W, device="cuda", dtype=torch.float32, requires_grad=True
).view(B, C, H, W)
centers = torch.tensor([[[4, 4], [6, 6]]], device="cuda", dtype=torch.int32)
output = ExtractPatchesFunction.apply(input, centers, h, w)
output.mean().backward()
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/modules/__init__.py
================================================
from .patch_extractor import RandomPatchExtractor
__all__ = ["RandomPatchExtractor"]
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/modules/patch_extractor.py
================================================
from __future__ import absolute_import, division, print_function
import torch
import torch.nn.functional as F
from torch import nn
from ..functions import ExtractPatchesFunction
class RandomPatchExtractor(nn.Module):
def __init__(
self,
):
super().__init__()
def forward(
self, tensor: torch.Tensor, centers: torch.Tensor, patch_size: tuple[int, int]
):
device = tensor.device
dtype = tensor.dtype
patch_width, patch_height = patch_size
pad_width = patch_width // 2
pad_height = patch_height // 2
dtype = tensor.dtype
# Pad input to avoid out-of-bounds
tensor_padded = F.pad(
tensor,
(pad_width, pad_width, pad_height, pad_height),
mode="constant",
value=0.0,
)
# Adjust edge coordinates to account for padding
centers_padded = centers + torch.tensor(
[pad_height, pad_width], dtype=dtype, device=device
).reshape(1, 1, 2)
output = ExtractPatchesFunction.apply(
tensor_padded.float(), centers_padded.int(), patch_height, patch_width
)
return output.to(dtype)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/setup.py
================================================
import glob
import os
import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
requirements = ["torch", "torchvision"]
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "src")
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
sources = main_file + source_cpu
extension = CppExtension
extra_compile_args = {"cxx": ["-O2"]}
define_macros = []
if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [
"-O2",
]
else:
raise NotImplementedError("Cuda is not available")
sources = list(set([os.path.join(extensions_dir, s) for s in sources]))
include_dirs = [extensions_dir]
ext_modules = [
extension(
"RandomPatchExtraction",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
return ext_modules
setup(
name="RandomPatchExtraction",
version="0.1",
author="Luigi Piccinelli",
ext_modules=get_extensions(),
packages=find_packages(
exclude=(
"configs",
"tests",
)
),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/src/cpu/extract_patches_cpu.cpp
================================================
#include
#include
#include
torch::Tensor extract_patches_cpu_forward(
const torch::Tensor &input,
const torch::Tensor ¢ers,
int h,
int w
) {
AT_ERROR("Not implement on cpu");
}
std::vector extract_patches_cpu_backward(
const torch::Tensor &grad_patches,
const torch::Tensor &coords,
int H,
int W
) {
AT_ERROR("Not implement on cpu");
}
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/src/cpu/extract_patches_cpu.h
================================================
#pragma once
#include
#include
torch::Tensor extract_patches_cpu_forward(
const torch::Tensor &input,
const torch::Tensor ¢ers,
int h,
int w
);
std::vector extract_patches_cpu_backward(
const torch::Tensor &grad_patches,
const torch::Tensor &coords,
int H,
int W
);
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/src/cuda/extract_patches_cuda.h
================================================
#ifndef EXTRACT_PATCHES_CUDA_H
#define EXTRACT_PATCHES_CUDA_H
#include
#include
#include
#include
// Function prototypes for the CUDA functions
torch::Tensor extract_patches_cuda_forward(
const torch::Tensor &input,
const torch::Tensor ¢ers,
int h,
int w
);
std::vector extract_patches_cuda_backward(
const torch::Tensor &grad_output,
const torch::Tensor ¢ers,
int H,
int W
);
#endif // EXTRACT_PATCHES_CUDA_H
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/src/cuda/extract_patches_kernel.cu
================================================
#include
#include
#include "cuda/extract_patches_kernel.cuh"
#include "cuda/extract_patches_cuda.h"
// Need to templetize these two to get fp16 working, but problems in compilation...
torch::Tensor extract_patches_cuda_forward(
const torch::Tensor &input,
const torch::Tensor ¢ers,
int h,
int w
) {
int B = input.size(0);
int C = input.size(1);
int H = input.size(2);
int W = input.size(3);
int N = centers.size(1);
auto output = torch::zeros({B, C, N, h, w}, input.options());
const int threads = C;
const dim3 blocks(B, N);
extract_patches_cuda_forward_kernel<<>>(
input.data_ptr(),
output.data_ptr(),
centers.data_ptr(),
B, C, H, W,
N, h, w);
return {output};
}
std::vector extract_patches_cuda_backward(
const torch::Tensor &grad_output,
const torch::Tensor ¢ers,
int H,
int W
) {
int B = grad_output.size(0);
int C = grad_output.size(1);
int N = centers.size(1);
int h = grad_output.size(3);
int w = grad_output.size(4);
auto grad_input = torch::zeros({B, C, H, W}, grad_output.options());
const int threads = C;
const dim3 blocks(B, N);
extract_patches_cuda_backward_kernel<<>>(
grad_output.data_ptr(),
grad_input.data_ptr(),
centers.data_ptr(),
B, C, H, W,
N, h, w);
return {grad_input};
}
template
__global__ void extract_patches_cuda_forward_kernel(
const T* __restrict__ input,
T* __restrict__ output,
const int* __restrict__ centers,
int B, int C, int H, int W,
int N, int h, int w) {
// Calculate thread indices
int batch_idx = blockIdx.x;
int patch_idx = blockIdx.y;
int channel_idx = threadIdx.x;
// Extract center coordinates
int center_y = centers[(batch_idx * N + patch_idx) * 2];
int center_x = centers[(batch_idx * N + patch_idx) * 2 + 1];
// Calculate half patch size
int half_h = h / 2;
int half_w = w / 2;
// Extract patch
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int y = center_y - half_h + i;
int x = center_x - half_w + j;
output[batch_idx * C * N * h * w + patch_idx * C * h * w + channel_idx * h * w + i * w + j] =
input[batch_idx * C * H * W + channel_idx * H * W + y * W + x];
}
}
}
template __global__ void extract_patches_cuda_forward_kernel(
const float* __restrict__ input,
float* __restrict__ output,
const int* __restrict__ centers,
int B, int C, int H, int W,
int N, int h, int w);
template __global__ void extract_patches_cuda_forward_kernel<__half>(
const __half* __restrict__ input,
__half* __restrict__ output,
const int* __restrict__ centers,
int B, int C, int H, int W,
int N, int h, int w);
template
__global__ void extract_patches_cuda_backward_kernel(
const T* __restrict__ grad_output,
T* __restrict__ grad_input,
const int* __restrict__ centers,
int B, int C, int H, int W,
int N, int h, int w) {
// Calculate thread indices
int batch_idx = blockIdx.x;
int patch_idx = blockIdx.y;
int channel_idx = threadIdx.x;
// Extract center coordinates
int center_y = centers[(batch_idx * N + patch_idx) * 2];
int center_x = centers[(batch_idx * N + patch_idx) * 2 + 1];
// Calculate half patch size
int half_h = h / 2;
int half_w = w / 2;
// Compute gradients with respect to input tensor using chain rule
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int y = center_y - half_h + i;
int x = center_x - half_w + j;
atomicAdd(
&grad_input[batch_idx * C * H * W + channel_idx * H * W + y * W + x],
grad_output[batch_idx * C * N * h * w + patch_idx * C * h * w + channel_idx * h * w + i * w + j]
);
}
}
}
template __global__ void extract_patches_cuda_backward_kernel(
const float* __restrict__ grad_output,
float* __restrict__ grad_input,
const int* __restrict__ centers,
int B, int C, int H, int W,
int N, int h, int w);
template __global__ void extract_patches_cuda_backward_kernel<__half>(
const __half* __restrict__ grad_output,
__half* __restrict__ grad_input,
const int* __restrict__ centers,
int B, int C, int H, int W,
int N, int h, int w);
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/src/cuda/extract_patches_kernel.cuh
================================================
#ifndef EXTRACT_PATCHES_KERNEL_CUH
#define EXTRACT_PATCHES_KERNEL_CUH
#include
#include
#include
#include
#include // should contain __half
// Declare the forward CUDA kernel function
template __global__ void extract_patches_cuda_forward_kernel(
const T* __restrict__ input,
T* __restrict__ output,
const int* __restrict__ centers,
int B, int C, int H, int W,
int N, int h, int w);
// Declare the backward CUDA kernel function
template __global__ void extract_patches_cuda_backward_kernel(
const T* __restrict__ grad_output,
T* __restrict__ grad_input,
const int* __restrict__ centers,
int B, int C, int H, int W,
int N, int h, int w);
#endif // EXTRACT_PATCHES_KERNEL_CUH
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/src/extract_patches.cpp
================================================
#include "extract_patches.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("extract_patches_forward", &extract_patches_forward, "Extract patches forward (CUDA)");
m.def("extract_patches_backward", &extract_patches_backward, "Extract patches backward (CUDA)");
}
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/src/extract_patches.h
================================================
#pragma once
#include "cpu/extract_patches_cpu.h"
#ifdef WITH_CUDA
#include "cuda/extract_patches_cuda.h"
#endif
#include
#include
#include
torch::Tensor extract_patches_forward(
const torch::Tensor &images,
const torch::Tensor &coords,
int patch_height,
int patch_width)
{
if (images.type().is_cuda())
{
#ifdef WITH_CUDA
return extract_patches_cuda_forward(images, coords, patch_height, patch_width);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
std::vector extract_patches_backward(
const torch::Tensor &grad_patches,
const torch::Tensor &coords,
int H,
int W)
{
if (grad_patches.type().is_cuda())
{
#ifdef WITH_CUDA
return extract_patches_cuda_backward(grad_patches, coords, H, W);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/extract_patches/test.py
================================================
import RandomPatchExtraction
import torch
def extract_patches(input, centers, patch_size):
h, w = patch_size
output = RandomPatchExtraction.extract_patches_forward(input, centers, h, w)
breakpoint()
return output
# Example usage
if __name__ == "__main__":
B, C, H, W = 1, 1, 10, 10
N = 2
h, w = 3, 3
input = torch.arange(
B * C * H * W, device="cuda", dtype=torch.float32, requires_grad=True
).view(B, C, H, W)
centers = torch.tensor([[[4, 4], [6, 6]]], device="cuda", dtype=torch.int32)
patches = extract_patches(input, centers, (h, w))
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/__init__.py
================================================
from .functions.knn import knn_gather, knn_points
__all__ = [
"knn_points",
"knn_gather",
]
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/compile.sh
================================================
#!/usr/bin/env bash
export TORCH_CUDA_ARCH_LIST="6.1 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# export FORCE_CUDA=1 #if you do not actually have cuda, workaround
python setup.py build install
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/functions/__init__.py
================================================
from .knn import knn_gather, knn_points
__all__ = [
"knn_points",
"knn_gather",
]
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/functions/knn.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from collections import namedtuple
from typing import Union
import torch
from KNN import knn_points_backward, knn_points_idx
from torch.autograd import Function
from torch.autograd.function import once_differentiable
_KNN = namedtuple("KNN", "dists idx knn")
class _knn_points(Function):
"""
Torch autograd Function wrapper for KNN C++/CUDA implementations.
"""
@staticmethod
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
def forward(
ctx,
p1,
p2,
lengths1,
lengths2,
K,
version,
norm: int = 2,
return_sorted: bool = True,
):
"""
K-Nearest neighbors on point clouds.
Args:
p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
containing up to P1 points of dimension D.
p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
containing up to P2 points of dimension D.
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
length of each pointcloud in p1. Or None to indicate that every cloud has
length P1.
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
length of each pointcloud in p2. Or None to indicate that every cloud has
length P2.
K: Integer giving the number of nearest neighbors to return.
version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of the inputs.
norm: (int) indicating the norm. Only supports 1 (for L1) and 2 (for L2).
return_sorted: (bool) whether to return the nearest neighbors sorted in
ascending order of distance.
Returns:
p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
the nearest neighbors. This is padded with zeros both where a cloud in p2
has fewer than K points and where a cloud in p1 has fewer than P1 points.
p1_idx: LongTensor of shape (N, P1, K) giving the indices of the
K nearest neighbors from points in p1 to points in p2.
Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.
"""
if not ((norm == 1) or (norm == 2)):
raise ValueError("Support for 1 or 2 norm.")
idx, dists = knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version)
# sort KNN in ascending order if K > 1
if K > 1 and return_sorted:
if lengths2.min() < K:
P1 = p1.shape[1]
mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None]
# mask has shape [N, K], true where dists irrelevant
mask = mask[:, None].expand(-1, P1, -1)
# mask has shape [N, P1, K], true where dists irrelevant
dists[mask] = float("inf")
dists, sort_idx = dists.sort(dim=2)
dists[mask] = 0
else:
dists, sort_idx = dists.sort(dim=2)
idx = idx.gather(2, sort_idx)
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
ctx.mark_non_differentiable(idx)
ctx.norm = norm
return dists, idx
@staticmethod
@once_differentiable
def backward(ctx, grad_dists, grad_idx):
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
norm = ctx.norm
# TODO(gkioxari) Change cast to floats once we add support for doubles.
if not (grad_dists.dtype == torch.float32):
grad_dists = grad_dists.float()
if not (p1.dtype == torch.float32):
p1 = p1.float()
if not (p2.dtype == torch.float32):
p2 = p2.float()
grad_p1, grad_p2 = knn_points_backward(
p1, p2, lengths1, lengths2, idx, norm, grad_dists
)
return grad_p1, grad_p2, None, None, None, None, None, None
def knn_points(
p1: torch.Tensor,
p2: torch.Tensor,
lengths1: Union[torch.Tensor, None] = None,
lengths2: Union[torch.Tensor, None] = None,
norm: int = 2,
K: int = 1,
version: int = -1,
return_nn: bool = False,
return_sorted: bool = True,
) -> _KNN:
"""
K-Nearest neighbors on point clouds.
Args:
p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
containing up to P1 points of dimension D.
p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
containing up to P2 points of dimension D.
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
length of each pointcloud in p1. Or None to indicate that every cloud has
length P1.
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
length of each pointcloud in p2. Or None to indicate that every cloud has
length P2.
norm: Integer indicating the norm of the distance. Supports only 1 for L1, 2 for L2.
K: Integer giving the number of nearest neighbors to return.
version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of the inputs.
return_nn: If set to True returns the K nearest neighbors in p2 for each point in p1.
return_sorted: (bool) whether to return the nearest neighbors sorted in
ascending order of distance.
Returns:
dists: Tensor of shape (N, P1, K) giving the squared distances to
the nearest neighbors. This is padded with zeros both where a cloud in p2
has fewer than K points and where a cloud in p1 has fewer than P1 points.
idx: LongTensor of shape (N, P1, K) giving the indices of the
K nearest neighbors from points in p1 to points in p2.
Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
in p2 has fewer than K points and where a cloud in p1 has fewer than P1
points.
nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for
each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th nearest neighbor
for `p1[n, i]`. Returned if `return_nn` is True.
The nearest neighbors are collected using `knn_gather`
.. code-block::
p2_nn = knn_gather(p2, p1_idx, lengths2)
which is a helper function that allows indexing any tensor of shape (N, P2, U) with
the indices `p1_idx` returned by `knn_points`. The output is a tensor
of shape (N, P1, K, U).
"""
if p1.shape[0] != p2.shape[0]:
raise ValueError("pts1 and pts2 must have the same batch dimension.")
if p1.shape[2] != p2.shape[2]:
raise ValueError("pts1 and pts2 must have the same point dimension.")
p1 = p1.contiguous()
p2 = p2.contiguous()
P1 = p1.shape[1]
P2 = p2.shape[1]
if lengths1 is None:
lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device)
if lengths2 is None:
lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device)
p1_dists, p1_idx = _knn_points.apply(
p1, p2, lengths1, lengths2, K, version, norm, return_sorted
)
p2_nn = None
if return_nn:
p2_nn = knn_gather(p2, p1_idx, lengths2)
return _KNN(dists=p1_dists, idx=p1_idx, knn=p2_nn if return_nn else None)
def knn_gather(
x: torch.Tensor, idx: torch.Tensor, lengths: Union[torch.Tensor, None] = None
):
"""
A helper function for knn that allows indexing a tensor x with the indices `idx`
returned by `knn_points`.
For example, if `dists, idx = knn_points(p, x, lengths_p, lengths, K)`
where p is a tensor of shape (N, L, D) and x a tensor of shape (N, M, D),
then one can compute the K nearest neighbors of p with `p_nn = knn_gather(x, idx, lengths)`.
It can also be applied for any tensor x of shape (N, M, U) where U != D.
Args:
x: Tensor of shape (N, M, U) containing U-dimensional features to
be gathered.
idx: LongTensor of shape (N, L, K) giving the indices returned by `knn_points`.
lengths: LongTensor of shape (N,) of values in the range [0, M], giving the
length of each example in the batch in x. Or None to indicate that every
example has length M.
Returns:
x_out: Tensor of shape (N, L, K, U) resulting from gathering the elements of x
with idx, s.t. `x_out[n, l, k] = x[n, idx[n, l, k]]`.
If `k > lengths[n]` then `x_out[n, l, k]` is filled with 0.0.
"""
N, M, U = x.shape
_N, L, K = idx.shape
if N != _N:
raise ValueError("x and idx must have same batch dimension.")
if lengths is None:
lengths = torch.full((x.shape[0],), M, dtype=torch.int64, device=x.device)
idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, U)
# idx_expanded has shape [N, L, K, U]
x_out = x[:, :, None].expand(-1, -1, K, -1).gather(1, idx_expanded)
# p2_nn has shape [N, L, K, U]
needs_mask = lengths.min() < K
if needs_mask:
# mask has shape [N, K], true where idx is irrelevant because
# there is less number of points in p2 than K
mask = lengths[:, None] <= torch.arange(K, device=x.device)[None]
# expand mask to shape [N, L, K, U]
mask = mask[:, None].expand(-1, L, -1)
mask = mask[:, :, :, None].expand(-1, -1, -1, U)
x_out[mask] = 0.0
return x_out
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/setup.py
================================================
import glob
import os
import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
requirements = ["torch", "torchvision"]
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "src")
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cpu = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "*.cu"))
sources = main_file + source_cpu
extension = CppExtension
extra_compile_args = {"cxx": ["-O3"]}
define_macros = []
if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [
"-O3",
]
else:
raise NotImplementedError("Cuda is not available")
sources = list(set([os.path.join(extensions_dir, s) for s in sources]))
include_dirs = [extensions_dir]
ext_modules = [
extension(
"KNN",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
return ext_modules
setup(
name="KNN",
version="0.1",
author="Luigi Piccinelli",
ext_modules=get_extensions(),
packages=find_packages(
exclude=(
"configs",
"tests",
)
),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/src/knn.cu
================================================
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include
#include
#include
#include
#include
#include
#include "utils/dispatch.cuh"
#include "utils/mink.cuh"
// A chunk of work is blocksize-many points of P1.
// The number of potential chunks to do is N*(1+(P1-1)/blocksize)
// call (1+(P1-1)/blocksize) chunks_per_cloud
// These chunks are divided among the gridSize-many blocks.
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
// In chunk i, we work on cloud i/chunks_per_cloud on points starting from
// blocksize*(i%chunks_per_cloud).
template
__global__ void KNearestNeighborKernelV0(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2,
const size_t D,
const size_t K,
const size_t norm) {
// Store both dists and indices for knn in global memory.
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud;
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t p1 = start_point + threadIdx.x;
if (p1 >= lengths1[n])
continue;
int offset = n * P1 * K + p1 * K;
int64_t length2 = lengths2[n];
MinK mink(dists + offset, idxs + offset, K);
for (int p2 = 0; p2 < length2; ++p2) {
// Find the distance between points1[n, p1] and points[n, p2]
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
scalar_t coord1 = points1[n * P1 * D + p1 * D + d];
scalar_t coord2 = points2[n * P2 * D + p2 * D + d];
scalar_t diff = coord1 - coord2;
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
}
mink.add(dist, p2);
}
}
}
template
__global__ void KNearestNeighborKernelV1(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2,
const size_t K,
const size_t norm) {
// Same idea as the previous version, but hoist D into a template argument
// so we can cache the current point in a thread-local array. We still store
// the current best K dists and indices in global memory, so this should work
// for very large K and fairly large D.
scalar_t cur_point[D];
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud;
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t p1 = start_point + threadIdx.x;
if (p1 >= lengths1[n])
continue;
for (int d = 0; d < D; ++d) {
cur_point[d] = points1[n * P1 * D + p1 * D + d];
}
int offset = n * P1 * K + p1 * K;
int64_t length2 = lengths2[n];
MinK mink(dists + offset, idxs + offset, K);
for (int p2 = 0; p2 < length2; ++p2) {
// Find the distance between cur_point and points[n, p2]
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d];
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
}
mink.add(dist, p2);
}
}
}
// This is a shim functor to allow us to dispatch using DispatchKernel1D
template
struct KNearestNeighborV1Functor {
static void run(
size_t blocks,
size_t threads,
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2,
const size_t K,
const size_t norm) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV1<<>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K, norm);
}
};
template
__global__ void KNearestNeighborKernelV2(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const int64_t N,
const int64_t P1,
const int64_t P2,
const size_t norm) {
// Same general implementation as V2, but also hoist K into a template arg.
scalar_t cur_point[D];
scalar_t min_dists[K];
int min_idxs[K];
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud;
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t p1 = start_point + threadIdx.x;
if (p1 >= lengths1[n])
continue;
for (int d = 0; d < D; ++d) {
cur_point[d] = points1[n * P1 * D + p1 * D + d];
}
int64_t length2 = lengths2[n];
MinK mink(min_dists, min_idxs, K);
for (int p2 = 0; p2 < length2; ++p2) {
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d;
scalar_t diff = cur_point[d] - points2[offset];
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
}
mink.add(dist, p2);
}
for (int k = 0; k < mink.size(); ++k) {
idxs[n * P1 * K + p1 * K + k] = min_idxs[k];
dists[n * P1 * K + p1 * K + k] = min_dists[k];
}
}
}
// This is a shim so we can dispatch using DispatchKernel2D
template
struct KNearestNeighborKernelV2Functor {
static void run(
size_t blocks,
size_t threads,
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const int64_t N,
const int64_t P1,
const int64_t P2,
const size_t norm) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV2<<>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
}
};
template
__global__ void KNearestNeighborKernelV3(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2,
const size_t norm) {
// Same idea as V2, but use register indexing for thread-local arrays.
// Enabling sorting for this version leads to huge slowdowns; I suspect
// that it forces min_dists into local memory rather than registers.
// As a result this version is always unsorted.
scalar_t cur_point[D];
scalar_t min_dists[K];
int min_idxs[K];
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
const int64_t n = chunk / chunks_per_cloud;
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
int64_t p1 = start_point + threadIdx.x;
if (p1 >= lengths1[n])
continue;
for (int d = 0; d < D; ++d) {
cur_point[d] = points1[n * P1 * D + p1 * D + d];
}
int64_t length2 = lengths2[n];
RegisterMinK mink(min_dists, min_idxs);
for (int p2 = 0; p2 < length2; ++p2) {
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d;
scalar_t diff = cur_point[d] - points2[offset];
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
}
mink.add(dist, p2);
}
for (int k = 0; k < mink.size(); ++k) {
idxs[n * P1 * K + p1 * K + k] = min_idxs[k];
dists[n * P1 * K + p1 * K + k] = min_dists[k];
}
}
}
// This is a shim so we can dispatch using DispatchKernel2D
template
struct KNearestNeighborKernelV3Functor {
static void run(
size_t blocks,
size_t threads,
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
const int64_t* __restrict__ lengths1,
const int64_t* __restrict__ lengths2,
scalar_t* __restrict__ dists,
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2,
const size_t norm) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV3<<>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
}
};
constexpr int V1_MIN_D = 1;
constexpr int V1_MAX_D = 32;
constexpr int V2_MIN_D = 1;
constexpr int V2_MAX_D = 8;
constexpr int V2_MIN_K = 1;
constexpr int V2_MAX_K = 32;
constexpr int V3_MIN_D = 1;
constexpr int V3_MAX_D = 8;
constexpr int V3_MIN_K = 1;
constexpr int V3_MAX_K = 4;
bool InBounds(const int64_t min, const int64_t x, const int64_t max) {
return min <= x && x <= max;
}
bool KnnCheckVersion(int version, const int64_t D, const int64_t K) {
if (version == 0) {
return true;
} else if (version == 1) {
return InBounds(V1_MIN_D, D, V1_MAX_D);
} else if (version == 2) {
return InBounds(V2_MIN_D, D, V2_MAX_D) && InBounds(V2_MIN_K, K, V2_MAX_K);
} else if (version == 3) {
return InBounds(V3_MIN_D, D, V3_MAX_D) && InBounds(V3_MIN_K, K, V3_MAX_K);
}
return false;
}
int ChooseVersion(const int64_t D, const int64_t K) {
for (int version = 3; version >= 1; version--) {
if (KnnCheckVersion(version, D, K)) {
return version;
}
}
return 0;
}
std::tuple KNearestNeighborIdxCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const int norm,
const int K,
int version) {
// Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
at::CheckedFrom c = "KNearestNeighborIdxCuda";
at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t});
at::checkAllSameType(c, {p1_t, p2_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(p1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto N = p1.size(0);
const auto P1 = p1.size(1);
const auto P2 = p2.size(1);
const auto D = p2.size(2);
const int64_t K_64 = K;
TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2.");
TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension");
auto long_dtype = lengths1.options().dtype(at::kLong);
auto idxs = at::zeros({N, P1, K}, long_dtype);
auto dists = at::zeros({N, P1, K}, p1.options());
if (idxs.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(idxs, dists);
}
if (version < 0) {
version = ChooseVersion(D, K);
} else if (!KnnCheckVersion(version, D, K)) {
int new_version = ChooseVersion(D, K);
std::cout << "WARNING: Requested KNN version " << version
<< " is not compatible with D = " << D << "; K = " << K
<< ". Falling back to version = " << new_version << std::endl;
version = new_version;
}
// At this point we should have a valid version no matter what data the user
// gave us. But we can check once more to be sure; however this time
// assert fail since failing at this point means we have a bug in our version
// selection or checking code.
AT_ASSERTM(KnnCheckVersion(version, D, K), "Invalid version");
const size_t threads = 256;
const size_t blocks = 256;
if (version == 0) {
AT_DISPATCH_FLOATING_TYPES(
p1.scalar_type(), "knn_kernel_cuda", ([&] {
KNearestNeighborKernelV0<<>>(
p1.contiguous().data_ptr(),
p2.contiguous().data_ptr(),
lengths1.contiguous().data_ptr(),
lengths2.contiguous().data_ptr(),
dists.data_ptr(),
idxs.data_ptr(),
N,
P1,
P2,
D,
K,
norm);
}));
} else if (version == 1) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
DispatchKernel1D<
KNearestNeighborV1Functor,
scalar_t,
V1_MIN_D,
V1_MAX_D>(
D,
blocks,
threads,
p1.contiguous().data_ptr(),
p2.contiguous().data_ptr(),
lengths1.contiguous().data_ptr(),
lengths2.contiguous().data_ptr(),
dists.data_ptr(),
idxs.data_ptr(),
N,
P1,
P2,
K,
norm);
}));
} else if (version == 2) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
DispatchKernel2D<
KNearestNeighborKernelV2Functor,
scalar_t,
V2_MIN_D,
V2_MAX_D,
V2_MIN_K,
V2_MAX_K>(
D,
K_64,
blocks,
threads,
p1.contiguous().data_ptr(),
p2.contiguous().data_ptr(),
lengths1.contiguous().data_ptr(),
lengths2.contiguous().data_ptr(),
dists.data_ptr(),
idxs.data_ptr(),
N,
P1,
P2,
norm);
}));
} else if (version == 3) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
DispatchKernel2D<
KNearestNeighborKernelV3Functor,
scalar_t,
V3_MIN_D,
V3_MAX_D,
V3_MIN_K,
V3_MAX_K>(
D,
K_64,
blocks,
threads,
p1.contiguous().data_ptr(),
p2.contiguous().data_ptr(),
lengths1.contiguous().data_ptr(),
lengths2.contiguous().data_ptr(),
dists.data_ptr(),
idxs.data_ptr(),
N,
P1,
P2,
norm);
}));
}
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(idxs, dists);
}
// ------------------------------------------------------------- //
// Backward Operators //
// ------------------------------------------------------------- //
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
// Currently, support is for floats only.
__global__ void KNearestNeighborBackwardKernel(
const float* __restrict__ p1, // (N, P1, D)
const float* __restrict__ p2, // (N, P2, D)
const int64_t* __restrict__ lengths1, // (N,)
const int64_t* __restrict__ lengths2, // (N,)
const int64_t* __restrict__ idxs, // (N, P1, K)
const float* __restrict__ grad_dists, // (N, P1, K)
float* __restrict__ grad_p1, // (N, P1, D)
float* __restrict__ grad_p2, // (N, P2, D)
const size_t N,
const size_t P1,
const size_t P2,
const size_t K,
const size_t D,
const size_t norm) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = gridDim.x * blockDim.x;
for (size_t i = tid; i < N * P1 * K * D; i += stride) {
const size_t n = i / (P1 * K * D); // batch index
size_t rem = i % (P1 * K * D);
const size_t p1_idx = rem / (K * D); // index of point in p1
rem = rem % (K * D);
const size_t k = rem / D; // k-th nearest neighbor
const size_t d = rem % D; // d-th dimension in the feature vector
const size_t num1 = lengths1[n]; // number of valid points in p1 in batch
const size_t num2 = lengths2[n]; // number of valid points in p2 in batch
if ((p1_idx < num1) && (k < num2)) {
const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
// index of point in p2 corresponding to the k-th nearest neighbor
const int64_t p2_idx = idxs[n * P1 * K + p1_idx * K + k];
// If the index is the pad value of -1 then ignore it
if (p2_idx == -1) {
continue;
}
float diff = 0.0;
if (norm == 1) {
float sign =
(p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d])
? 1.0
: -1.0;
diff = grad_dist * sign;
} else { // norm is 2
diff = 2.0 * grad_dist *
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
}
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
}
}
}
std::tuple KNearestNeighborBackwardCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
int norm,
const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4},
idxs_t{idxs, "idxs", 5}, grad_dists_t{grad_dists, "grad_dists", 6};
at::CheckedFrom c = "KNearestNeighborBackwardCuda";
at::checkAllSameGPU(
c, {p1_t, p2_t, lengths1_t, lengths2_t, idxs_t, grad_dists_t});
at::checkAllSameType(c, {p1_t, p2_t, grad_dists_t});
// This is nondeterministic because atomicAdd
at::globalContext().alertNotDeterministic("KNearestNeighborBackwardCuda");
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(p1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto N = p1.size(0);
const auto P1 = p1.size(1);
const auto P2 = p2.size(1);
const auto D = p2.size(2);
const auto K = idxs.size(2);
TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension");
TORCH_CHECK(idxs.size(0) == N, "KNN idxs must have the same batch dimension");
TORCH_CHECK(
idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1");
TORCH_CHECK(grad_dists.size(0) == N);
TORCH_CHECK(grad_dists.size(1) == P1);
TORCH_CHECK(grad_dists.size(2) == K);
auto grad_p1 = at::zeros({N, P1, D}, p1.options());
auto grad_p2 = at::zeros({N, P2, D}, p2.options());
if (grad_p1.numel() == 0 || grad_p2.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_p1, grad_p2);
}
const int blocks = 64;
const int threads = 512;
KNearestNeighborBackwardKernel<<>>(
p1.contiguous().data_ptr(),
p2.contiguous().data_ptr(),
lengths1.contiguous().data_ptr(),
lengths2.contiguous().data_ptr(),
idxs.contiguous().data_ptr(),
grad_dists.contiguous().data_ptr(),
grad_p1.data_ptr(),
grad_p2.data_ptr(),
N,
P1,
P2,
K,
D,
norm);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_p1, grad_p2);
}
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/src/knn.h
================================================
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include
#include
#include "utils/pytorch3d_cutils.h"
// Compute indices of K nearest neighbors in pointcloud p2 to points
// in pointcloud p1.
//
// Args:
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
// containing P1 points of dimension D.
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
// containing P2 points of dimension D.
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
// K: int giving the number of nearest points to return.
// version: Integer telling which implementation to use.
//
// Returns:
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
// It is padded with zeros so that it can be used easily in a later
// gather() operation.
//
// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
// distance from each point p1[n, p, :] to its K neighbors
// p2[n, p1_neighbor_idx[n, p, k], :].
// CPU implementation.
std::tuple KNearestNeighborIdxCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const int norm,
const int K);
// CUDA implementation
std::tuple KNearestNeighborIdxCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const int norm,
const int K,
const int version);
// Implementation which is exposed.
std::tuple KNearestNeighborIdx(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const int norm,
const int K,
const int version) {
if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(p1);
CHECK_CUDA(p2);
return KNearestNeighborIdxCuda(
p1, p2, lengths1, lengths2, norm, K, version);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
}
// Compute gradients with respect to p1 and p2
//
// Args:
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
// containing P1 points of dimension D.
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
// containing P2 points of dimension D.
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
// It is padded with zeros so that it can be used easily in a later
// gather() operation. This is computed from the forward pass.
// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
// gradients.
//
// Returns:
// grad_p1: FloatTensor of shape (N, P1, D) containing the output gradients
// wrt p1.
// grad_p2: FloatTensor of shape (N, P2, D) containing the output gradients
// wrt p2.
// CPU implementation.
std::tuple KNearestNeighborBackwardCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const int norm,
const at::Tensor& grad_dists);
// CUDA implementation
std::tuple KNearestNeighborBackwardCuda(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const int norm,
const at::Tensor& grad_dists);
// Implementation which is exposed.
std::tuple KNearestNeighborBackward(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const int norm,
const at::Tensor& grad_dists) {
if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(p1);
CHECK_CUDA(p2);
return KNearestNeighborBackwardCuda(
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return KNearestNeighborBackwardCpu(
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
}
// Utility to check whether a KNN version can be used.
//
// Args:
// version: Integer in the range 0 <= version <= 3 indicating one of our
// KNN implementations.
// D: Number of dimensions for the input and query point clouds
// K: Number of neighbors to be found
//
// Returns:
// Whether the indicated KNN version can be used.
bool KnnCheckVersion(int version, const int64_t D, const int64_t K);
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/src/knn_cpu.cpp
================================================
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include
#include
#include
std::tuple KNearestNeighborIdxCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const int norm,
const int K) {
const int N = p1.size(0);
const int P1 = p1.size(1);
const int D = p1.size(2);
auto long_opts = lengths1.options().dtype(torch::kInt64);
torch::Tensor idxs = torch::full({N, P1, K}, 0, long_opts);
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
auto p1_a = p1.accessor();
auto p2_a = p2.accessor();
auto lengths1_a = lengths1.accessor();
auto lengths2_a = lengths2.accessor();
auto idxs_a = idxs.accessor();
auto dists_a = dists.accessor();
for (int n = 0; n < N; ++n) {
const int64_t length1 = lengths1_a[n];
const int64_t length2 = lengths2_a[n];
for (int64_t i1 = 0; i1 < length1; ++i1) {
// Use a priority queue to store (distance, index) tuples.
std::priority_queue> q;
for (int64_t i2 = 0; i2 < length2; ++i2) {
float dist = 0;
for (int d = 0; d < D; ++d) {
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
if (norm == 1) {
dist += abs(diff);
} else { // norm is 2 (default)
dist += diff * diff;
}
}
int size = static_cast(q.size());
if (size < K || dist < std::get<0>(q.top())) {
q.emplace(dist, i2);
if (size >= K) {
q.pop();
}
}
}
while (!q.empty()) {
auto t = q.top();
q.pop();
const int k = q.size();
dists_a[n][i1][k] = std::get<0>(t);
idxs_a[n][i1][k] = std::get<1>(t);
}
}
}
return std::make_tuple(idxs, dists);
}
// ------------------------------------------------------------- //
// Backward Operators //
// ------------------------------------------------------------- //
std::tuple KNearestNeighborBackwardCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const int norm,
const at::Tensor& grad_dists) {
const int N = p1.size(0);
const int P1 = p1.size(1);
const int D = p1.size(2);
const int P2 = p2.size(1);
const int K = idxs.size(2);
torch::Tensor grad_p1 = torch::full({N, P1, D}, 0, p1.options());
torch::Tensor grad_p2 = torch::full({N, P2, D}, 0, p2.options());
auto p1_a = p1.accessor();
auto p2_a = p2.accessor();
auto lengths1_a = lengths1.accessor();
auto lengths2_a = lengths2.accessor();
auto idxs_a = idxs.accessor();
auto grad_dists_a = grad_dists.accessor();
auto grad_p1_a = grad_p1.accessor();
auto grad_p2_a = grad_p2.accessor();
for (int n = 0; n < N; ++n) {
const int64_t length1 = lengths1_a[n];
int64_t length2 = lengths2_a[n];
length2 = (length2 < K) ? length2 : K;
for (int64_t i1 = 0; i1 < length1; ++i1) {
for (int64_t k = 0; k < length2; ++k) {
const int64_t i2 = idxs_a[n][i1][k];
// If the index is the pad value of -1 then ignore it
if (i2 == -1) {
continue;
}
for (int64_t d = 0; d < D; ++d) {
float diff = 0.0;
if (norm == 1) {
float sign = (p1_a[n][i1][d] > p2_a[n][i2][d]) ? 1.0 : -1.0;
diff = grad_dists_a[n][i1][k] * sign;
} else { // norm is 2 (default)
diff = 2.0f * grad_dists_a[n][i1][k] *
(p1_a[n][i1][d] - p2_a[n][i2][d]);
}
grad_p1_a[n][i1][d] += diff;
grad_p2_a[n][i2][d] += -1.0f * diff;
}
}
}
}
return std::make_tuple(grad_p1, grad_p2);
}
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/src/knn_ext.cpp
================================================
#include
#include "knn.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef WITH_CUDA
m.def("knn_check_version", &KnnCheckVersion);
#endif
m.def("knn_points_idx", &KNearestNeighborIdx);
m.def("knn_points_backward", &KNearestNeighborBackward);
}
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/src/utils/dispatch.cuh
================================================
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
// This file provides utilities for dispatching to specialized versions of
// functions. This is especially useful for CUDA kernels, since specializing
// them to particular input sizes can often allow the compiler to unroll loops
// and place arrays into registers, which can give huge performance speedups.
//
// As an example, suppose we have the following function which is specialized
// based on a compile-time int64_t value:
//
// template
// struct SquareOffset {
// static void run(T y) {
// T val = x * x + y;
// std::cout << val << std::endl;
// }
// }
//
// This function takes one compile-time argument x, and one run-time argument y.
// We might want to compile specialized versions of this for x=0, x=1, etc and
// then dispatch to the correct one based on the runtime value of x.
// One simple way to achieve this is with a lookup table:
//
// template
// void DispatchSquareOffset(const int64_t x, T y) {
// if (x == 0) {
// SquareOffset::run(y);
// } else if (x == 1) {
// SquareOffset::run(y);
// } else if (x == 2) {
// SquareOffset::run(y);
// }
// }
//
// This function takes both x and y as run-time arguments, and dispatches to
// different specialized versions of SquareOffset based on the run-time value
// of x. This works, but it's tedious and error-prone. If we want to change the
// set of x values for which we provide compile-time specializations, then we
// will need to do a lot of tedius editing of the dispatch function. Also, if we
// want to provide compile-time specializations for another function other than
// SquareOffset, we will need to duplicate the entire lookup table.
//
// To solve these problems, we can use the DispatchKernel1D function provided by
// this file instead:
//
// template
// void DispatchSquareOffset(const int64_t x, T y) {
// constexpr int64_t xmin = 0;
// constexpr int64_t xmax = 2;
// DispatchKernel1D(x, y);
// }
//
// DispatchKernel1D uses template metaprogramming to compile specialized
// versions of SquareOffset for all values of x with xmin <= x <= xmax, and
// then dispatches to the correct one based on the run-time value of x. If we
// want to change the range of x values for which SquareOffset is specialized
// at compile-time, then all we have to do is change the values of the
// compile-time constants xmin and xmax.
//
// This file also allows us to similarly dispatch functions that depend on two
// compile-time int64_t values, using the DispatchKernel2D function like this:
//
// template
// struct Sum {
// static void run(T z, T w) {
// T val = x + y + z + w;
// std::cout << val << std::endl;
// }
// }
//
// template
// void DispatchSum(const int64_t x, const int64_t y, int z, int w) {
// constexpr int64_t xmin = 1;
// constexpr int64_t xmax = 3;
// constexpr int64_t ymin = 2;
// constexpr int64_t ymax = 5;
// DispatchKernel2D(x, y, z, w);
// }
//
// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to
// compile specialized versions of sum for all values of (x, y) with
// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct
// specialized version based on the runtime values of x and y.
// Define some helper structs in an anonymous namespace.
namespace {
// 1D dispatch: general case.
// Kernel is the function we want to dispatch to; it should take a typename and
// an int64_t as template args, and it should define a static void function
// run which takes any number of arguments of any type.
// In order to dispatch, we will take an additional template argument curN,
// and increment it via template recursion until it is equal to the run-time
// argument N.
template <
template
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t curN,
typename... Args>
struct DispatchKernelHelper1D {
static void run(const int64_t N, Args... args) {
if (curN == N) {
// The compile-time value curN is equal to the run-time value N, so we
// can dispatch to the run method of the Kernel.
Kernel::run(args...);
} else if (curN < N) {
// Increment curN via template recursion
DispatchKernelHelper1D::run(
N, args...);
}
// We shouldn't get here -- throw an error?
}
};
// 1D dispatch: Specialization when curN == maxN
// We need this base case to avoid infinite template recursion.
template <
template
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
typename... Args>
struct DispatchKernelHelper1D {
static void run(const int64_t N, Args... args) {
if (N == maxN) {
Kernel::run(args...);
}
// We shouldn't get here -- throw an error?
}
};
// 2D dispatch, general case.
// This is similar to the 1D case: we take additional template args curN and
// curM, and increment them via template recursion until they are equal to
// the run-time values of N and M, at which point we dispatch to the run
// method of the kernel.
template <
template
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t curN,
int64_t minM,
int64_t maxM,
int64_t curM,
typename... Args>
struct DispatchKernelHelper2D {
static void run(const int64_t N, const int64_t M, Args... args) {
if (curN == N && curM == M) {
Kernel::run(args...);
} else if (curN < N && curM < M) {
// Increment both curN and curM. This isn't strictly necessary; we could
// just increment one or the other at each step. But this helps to cut
// on the number of recursive calls we make.
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN + 1,
minM,
maxM,
curM + 1,
Args...>::run(N, M, args...);
} else if (curN < N) {
// Increment curN only
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN + 1,
minM,
maxM,
curM,
Args...>::run(N, M, args...);
} else if (curM < M) {
// Increment curM only
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN,
minM,
maxM,
curM + 1,
Args...>::run(N, M, args...);
}
}
};
// 2D dispatch, specialization for curN == maxN
template <
template
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t minM,
int64_t maxM,
int64_t curM,
typename... Args>
struct DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
maxN,
minM,
maxM,
curM,
Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (maxN == N && curM == M) {
Kernel::run(args...);
} else if (curM < maxM) {
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
maxN,
minM,
maxM,
curM + 1,
Args...>::run(N, M, args...);
}
// We should not get here -- throw an error?
}
};
// 2D dispatch, specialization for curM == maxM
template <
template
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t curN,
int64_t minM,
int64_t maxM,
typename... Args>
struct DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN,
minM,
maxM,
maxM,
Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (curN == N && maxM == M) {
Kernel::run(args...);
} else if (curN < maxN) {
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN + 1,
minM,
maxM,
maxM,
Args...>::run(N, M, args...);
}
// We should not get here -- throw an error?
}
};
// 2D dispatch, specialization for curN == maxN, curM == maxM
template <
template
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t minM,
int64_t maxM,
typename... Args>
struct DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
maxN,
minM,
maxM,
maxM,
Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (maxN == N && maxM == M) {
Kernel::run(args...);
}
// We should not get here -- throw an error?
}
};
} // namespace
// This is the function we expect users to call to dispatch to 1D functions
template <
template
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
typename... Args>
void DispatchKernel1D(const int64_t N, Args... args) {
if (minN <= N && N <= maxN) {
// Kick off the template recursion by calling the Helper with curN = minN
DispatchKernelHelper1D::run(
N, args...);
}
// Maybe throw an error if we tried to dispatch outside the allowed range?
}
// This is the function we expect users to call to dispatch to 2D functions
template <
template
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t minM,
int64_t maxM,
typename... Args>
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
// Kick off the template recursion by calling the Helper with curN = minN
// and curM = minM
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
minN,
minM,
maxM,
minM,
Args...>::run(N, M, args...);
}
// Maybe throw an error if we tried to dispatch outside the specified range?
}
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/src/utils/index_utils.cuh
================================================
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
// This converts dynamic array lookups into static array lookups, for small
// arrays up to size 32.
//
// Suppose we have a small thread-local array:
//
// float vals[10];
//
// Ideally we should only index this array using static indices:
//
// for (int i = 0; i < 10; ++i) vals[i] = i * i;
//
// If we do so, then the CUDA compiler may be able to place the array into
// registers, which can have a big performance improvement. However if we
// access the array dynamically, the the compiler may force the array into
// local memory, which has the same latency as global memory.
//
// These functions convert dynamic array access into static array access
// using a brute-force lookup table. It can be used like this:
//
// float vals[10];
// int idx = 3;
// float val = 3.14f;
// RegisterIndexUtils::set(vals, idx, val);
// float val2 = RegisterIndexUtils::get(vals, idx);
//
// The implementation is based on fbcuda/RegisterUtils.cuh:
// https://github.com/facebook/fbcuda/blob/master/RegisterUtils.cuh
// To avoid depending on the entire library, we just reimplement these two
// functions. The fbcuda implementation is a bit more sophisticated, and uses
// the preprocessor to generate switch statements that go up to N for each
// value of N. We are lazy and just have a giant explicit switch statement.
//
// We might be able to use a template metaprogramming approach similar to
// DispatchKernel1D for this. However DispatchKernel1D is intended to be used
// for dispatching to the correct CUDA kernel on the host, while this is
// is intended to run on the device. I was concerned that a metaprogramming
// approach for this might lead to extra function calls at runtime if the
// compiler fails to optimize them away, which could be very slow on device.
// However I didn't actually benchmark or test this.
template
struct RegisterIndexUtils {
__device__ __forceinline__ static T get(const T arr[N], int idx) {
if (idx < 0 || idx >= N)
return T();
switch (idx) {
case 0:
return arr[0];
case 1:
return arr[1];
case 2:
return arr[2];
case 3:
return arr[3];
case 4:
return arr[4];
case 5:
return arr[5];
case 6:
return arr[6];
case 7:
return arr[7];
case 8:
return arr[8];
case 9:
return arr[9];
case 10:
return arr[10];
case 11:
return arr[11];
case 12:
return arr[12];
case 13:
return arr[13];
case 14:
return arr[14];
case 15:
return arr[15];
case 16:
return arr[16];
case 17:
return arr[17];
case 18:
return arr[18];
case 19:
return arr[19];
case 20:
return arr[20];
case 21:
return arr[21];
case 22:
return arr[22];
case 23:
return arr[23];
case 24:
return arr[24];
case 25:
return arr[25];
case 26:
return arr[26];
case 27:
return arr[27];
case 28:
return arr[28];
case 29:
return arr[29];
case 30:
return arr[30];
case 31:
return arr[31];
};
return T();
}
__device__ __forceinline__ static void set(T arr[N], int idx, T val) {
if (idx < 0 || idx >= N)
return;
switch (idx) {
case 0:
arr[0] = val;
break;
case 1:
arr[1] = val;
break;
case 2:
arr[2] = val;
break;
case 3:
arr[3] = val;
break;
case 4:
arr[4] = val;
break;
case 5:
arr[5] = val;
break;
case 6:
arr[6] = val;
break;
case 7:
arr[7] = val;
break;
case 8:
arr[8] = val;
break;
case 9:
arr[9] = val;
break;
case 10:
arr[10] = val;
break;
case 11:
arr[11] = val;
break;
case 12:
arr[12] = val;
break;
case 13:
arr[13] = val;
break;
case 14:
arr[14] = val;
break;
case 15:
arr[15] = val;
break;
case 16:
arr[16] = val;
break;
case 17:
arr[17] = val;
break;
case 18:
arr[18] = val;
break;
case 19:
arr[19] = val;
break;
case 20:
arr[20] = val;
break;
case 21:
arr[21] = val;
break;
case 22:
arr[22] = val;
break;
case 23:
arr[23] = val;
break;
case 24:
arr[24] = val;
break;
case 25:
arr[25] = val;
break;
case 26:
arr[26] = val;
break;
case 27:
arr[27] = val;
break;
case 28:
arr[28] = val;
break;
case 29:
arr[29] = val;
break;
case 30:
arr[30] = val;
break;
case 31:
arr[31] = val;
break;
}
}
};
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/src/utils/mink.cuh
================================================
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#define MINK_H
#include "index_utils.cuh"
// A data structure to keep track of the smallest K keys seen so far as well
// as their associated values, intended to be used in device code.
// This data structure doesn't allocate any memory; keys and values are stored
// in arrays passed to the constructor.
//
// The implementation is generic; it can be used for any key type that supports
// the < operator, and can be used with any value type.
//
// Example usage:
//
// float keys[K];
// int values[K];
// MinK mink(keys, values, K);
// for (...) {
// // Produce some key and value from somewhere
// mink.add(key, value);
// }
// mink.sort();
//
// Now keys and values store the smallest K keys seen so far and the values
// associated to these keys:
//
// for (int k = 0; k < K; ++k) {
// float key_k = keys[k];
// int value_k = values[k];
// }
template
class MinK {
public:
// Constructor.
//
// Arguments:
// keys: Array in which to store keys
// values: Array in which to store values
// K: How many values to keep track of
__device__ MinK(key_t* keys, value_t* vals, int K)
: keys(keys), vals(vals), K(K), _size(0) {}
// Try to add a new key and associated value to the data structure. If the key
// is one of the smallest K seen so far then it will be kept; otherwise it
// it will not be kept.
//
// This takes O(1) operations if the new key is not kept, or if the structure
// currently contains fewer than K elements. Otherwise this takes O(K) time.
//
// Arguments:
// key: The key to add
// val: The value associated to the key
__device__ __forceinline__ void add(const key_t& key, const value_t& val) {
if (_size < K) {
keys[_size] = key;
vals[_size] = val;
if (_size == 0 || key > max_key) {
max_key = key;
max_idx = _size;
}
_size++;
} else if (key < max_key) {
keys[max_idx] = key;
vals[max_idx] = val;
max_key = key;
for (int k = 0; k < K; ++k) {
key_t cur_key = keys[k];
if (cur_key > max_key) {
max_key = cur_key;
max_idx = k;
}
}
}
}
// Get the number of items currently stored in the structure.
// This takes O(1) time.
__device__ __forceinline__ int size() {
return _size;
}
// Sort the items stored in the structure using bubble sort.
// This takes O(K^2) time.
__device__ __forceinline__ void sort() {
for (int i = 0; i < _size - 1; ++i) {
for (int j = 0; j < _size - i - 1; ++j) {
if (keys[j + 1] < keys[j]) {
key_t key = keys[j];
value_t val = vals[j];
keys[j] = keys[j + 1];
vals[j] = vals[j + 1];
keys[j + 1] = key;
vals[j + 1] = val;
}
}
}
}
private:
key_t* keys;
value_t* vals;
int K;
int _size;
key_t max_key;
int max_idx;
};
// This is a version of MinK that only touches the arrays using static indexing
// via RegisterIndexUtils. If the keys and values are stored in thread-local
// arrays, then this may allow the compiler to place them in registers for
// fast access.
//
// This has the same API as RegisterMinK, but doesn't support sorting.
// We found that sorting via RegisterIndexUtils gave very poor performance,
// and suspect it may have prevented the compiler from placing the arrays
// into registers.
template
class RegisterMinK {
public:
__device__ RegisterMinK(key_t* keys, value_t* vals)
: keys(keys), vals(vals), _size(0) {}
__device__ __forceinline__ void add(const key_t& key, const value_t& val) {
if (_size < K) {
RegisterIndexUtils::set(keys, _size, key);
RegisterIndexUtils::set(vals, _size, val);
if (_size == 0 || key > max_key) {
max_key = key;
max_idx = _size;
}
_size++;
} else if (key < max_key) {
RegisterIndexUtils::set(keys, max_idx, key);
RegisterIndexUtils::set(vals, max_idx, val);
max_key = key;
for (int k = 0; k < K; ++k) {
key_t cur_key = RegisterIndexUtils::get(keys, k);
if (cur_key > max_key) {
max_key = cur_key;
max_idx = k;
}
}
}
}
__device__ __forceinline__ int size() {
return _size;
}
private:
key_t* keys;
value_t* vals;
int _size;
key_t max_key;
int max_idx;
};
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/knn/src/utils/pytorch3d_cutils.h
================================================
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.")
#define CHECK_CONTIGUOUS_CUDA(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/__init__.py
================================================
from .arel import ARel
from .confidence import Confidence
from .distill import SelfDistill, TeacherDistill
from .dummy import Dummy
from .local_ssi import EdgeGuidedLocalSSI, LocalSSI
from .regression import Regression
from .silog import SILog
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/arel.py
================================================
import torch
import torch.nn as nn
from .utils import FNS, masked_mean
class ARel(nn.Module):
def __init__(
self,
weight: float,
output_fn: str = "sqrt",
input_fn: str = "linear",
eps: float = 1e-5,
):
super().__init__()
self.name: str = self.__class__.__name__
self.weight: float = weight
self.dims = [-2, -1]
self.output_fn = FNS[output_fn]
self.input_fn = FNS[input_fn]
self.eps: float = eps
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def forward(
self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, **kwargs
) -> torch.Tensor:
mask = mask.bool().clone()
input = self.input_fn(input.float())
target = self.input_fn(target.float())
error = (input - target).norm(dim=1) / target.norm(dim=1).clip(min=0.05)
mask = mask.squeeze(1)
error_image = masked_mean(data=error, mask=mask, dim=self.dims).squeeze(1, 2)
error_image = self.output_fn(error_image)
return error_image
@classmethod
def build(cls, config):
obj = cls(
weight=config["weight"],
output_fn=config["output_fn"],
input_fn=config["input_fn"],
)
return obj
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/confidence.py
================================================
import torch
import torch.nn as nn
from .utils import FNS, masked_mean
class Confidence(nn.Module):
def __init__(
self,
weight: float,
output_fn: str = "sqrt",
input_fn: str = "linear",
rescale: bool = True,
eps: float = 1e-5,
):
super(Confidence, self).__init__()
self.name: str = self.__class__.__name__
self.weight = weight
self.rescale = rescale
self.eps = eps
self.output_fn = FNS[output_fn]
self.input_fn = FNS[input_fn]
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def forward(
self,
input: torch.Tensor,
target_pred: torch.Tensor,
target_gt: torch.Tensor,
mask: torch.Tensor,
):
B, C = target_gt.shape[:2]
mask = mask.bool()
target_gt = target_gt.float().reshape(B, C, -1)
target_pred = target_pred.float().reshape(B, C, -1)
input = input.float().reshape(B, -1)
mask = mask.reshape(B, -1)
if self.rescale:
target_pred = torch.stack(
[
p * torch.median(gt[:, m]) / torch.median(p[:, m])
for p, gt, m in zip(target_pred, target_gt, mask)
]
)
error = torch.abs(
(self.input_fn(target_pred) - self.input_fn(target_gt)).norm(dim=1) - input
)
losses = masked_mean(error, dim=[-1], mask=mask).squeeze(dim=-1)
losses = self.output_fn(losses)
return losses
@classmethod
def build(cls, config):
obj = cls(
weight=config["weight"],
output_fn=config["output_fn"],
input_fn=config["input_fn"],
rescale=config.get("rescale", True),
)
return obj
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/distill.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .utils import FNS, masked_mean
class SelfDistill(nn.Module):
def __init__(self, weight: float, output_fn: str = "sqrt", eps: float = 1e-5):
super().__init__()
self.name: str = self.__class__.__name__
self.weight: float = weight
self.dims = (-2, -1)
self.output_fn = FNS[output_fn]
self.eps: float = eps
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def forward(
self,
input: torch.Tensor,
intrinsics: torch.Tensor,
mask: torch.Tensor,
flips: torch.Tensor,
downsample_ratio=14,
) -> torch.Tensor:
chunks = input.shape[0] // 2
mask = F.interpolate(mask.float(), size=input.shape[-2:], mode="nearest")
iters = zip(
input.chunk(chunks),
mask.chunk(chunks),
intrinsics.chunk(chunks),
flips.chunk(chunks),
)
inputs0, inputs1, masks = [], [], []
for i, (pair_input, pair_mask, pair_cam, pair_flip) in enumerate(iters):
mask0, mask1 = pair_mask
input0, input1 = pair_input
cam0, cam1 = pair_cam
flip0, flip1 = pair_flip
fx_0 = cam0[0, 0] / downsample_ratio
fx_1 = cam1[0, 0] / downsample_ratio
cx_0 = cam0[0, 2] / downsample_ratio
cx_1 = cam1[0, 2] / downsample_ratio
cy_0 = cam0[1, 2] / downsample_ratio
cy_1 = cam1[1, 2] / downsample_ratio
# flip image
if flip0 ^ flip1:
input0 = torch.flip(input0, dims=(2,))
mask0 = torch.flip(mask0, dims=(2,))
cx_0 = input0.shape[-1] - cx_0
# calc zoom
zoom_x = float(fx_1 / fx_0)
# apply zoom
input0 = F.interpolate(
input0.unsqueeze(0), scale_factor=zoom_x, mode="bilinear"
).squeeze(0)
mask0 = F.interpolate(
mask0.unsqueeze(0), scale_factor=zoom_x, mode="nearest"
).squeeze(0)
# calc translation
change_left = int(cx_1 - (cx_0 - 0.5) * zoom_x - 0.5)
change_top = int(cy_1 - (cy_0 - 0.5) * zoom_x - 0.5)
change_right = input1.shape[-1] - change_left - input0.shape[-1]
change_bottom = input1.shape[-2] - change_top - input0.shape[-2]
# apply translation
pad_left = max(0, change_left)
pad_right = max(0, change_right)
pad_top = max(0, change_top)
pad_bottom = max(0, change_bottom)
crop_left = max(0, -change_left)
crop_right = max(0, -change_right)
crop_top = max(0, -change_top)
crop_bottom = max(0, -change_bottom)
input0 = F.pad(
input0,
(pad_left, pad_right, pad_top, pad_bottom),
mode="constant",
value=0,
)
mask0 = F.pad(
mask0,
(pad_left, pad_right, pad_top, pad_bottom),
mode="constant",
value=0,
)
input0 = input0[
:,
crop_top : input0.shape[-2] - crop_bottom,
crop_left : input0.shape[-1] - crop_right,
]
mask0 = mask0[
:,
crop_top : mask0.shape[-2] - crop_bottom,
crop_left : mask0.shape[-1] - crop_right,
]
mask = torch.logical_and(mask0, mask1)
inputs0.append(input0)
inputs1.append(input1)
masks.append(mask)
inputs0 = torch.stack(inputs0, dim=0)
inputs1 = torch.stack(inputs1, dim=0)
masks = torch.stack(masks, dim=0)
loss1 = self.loss(inputs0, inputs1.detach(), masks)
loss2 = self.loss(inputs1, inputs0.detach(), masks)
return torch.cat([loss1, loss2], dim=0)
def loss(
self,
input: torch.Tensor,
target: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
loss = masked_mean(
(input - target).square().mean(dim=1), mask=mask, dim=[-2, -1]
)
return self.output_fn(loss + self.eps)
@classmethod
def build(cls, config):
obj = cls(
weight=config["weight"],
output_fn=config["output_fn"],
)
return obj
class TeacherDistill(nn.Module):
def __init__(
self,
weight: float,
output_fn: str = "sqrt",
cross: bool = False,
eps: float = 1e-5,
):
super().__init__()
assert output_fn in FNS
self.name: str = self.__class__.__name__
self.weight: float = weight
self.dims = (-2, -1)
self.output_fn = FNS[output_fn]
self.eps: float = eps
self.cross = cross
self.threshold = 0.05
self.head_dim = 64 # hardcoded for vit
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def forward(
self,
student_features: torch.Tensor,
teacher_features: torch.Tensor,
student_tokens: torch.Tensor,
teacher_tokens: torch.Tensor,
mask: torch.Tensor,
# metas: List[Dict[str, torch.Tensor]],
) -> torch.Tensor:
B = student_features.shape[0]
device = student_features.device
chunks = student_features.shape[0] // 2
mask = (
F.interpolate(
mask.float() + 1e-3, size=student_features.shape[-2:], mode="nearest"
)
> 0.5
)
# chunk features as self.head_dim
student_features = rearrange(
student_features, "b (n c) h w -> b c h w n", c=self.head_dim
)
teacher_features = rearrange(
teacher_features, "b (n c) h w -> b c h w n", c=self.head_dim
)
student_tokens = rearrange(
student_tokens, "b t (n c) -> b t c n", c=self.head_dim
)
teacher_tokens = rearrange(
teacher_tokens, "b t (n c) -> b t c n", c=self.head_dim
)
distance = (
(student_features - teacher_features)
.square()
.sum(dim=1, keepdim=True)
.sqrt()
.mean(dim=-1)
)
loss_features = masked_mean(distance, mask=mask, dim=[-2, -1])
loss_features = self.output_fn(loss_features.clamp(min=self.eps)).squeeze(
1, 2, 3
)
distance = (
(student_tokens - teacher_tokens).square().sum(dim=-2).sqrt().mean(dim=-1)
)
loss_tokens = self.output_fn(distance.clamp(min=self.eps)).squeeze(1)
return loss_features + 0.01 * loss_tokens
@classmethod
def build(cls, config):
obj = cls(
weight=config["weight"],
output_fn=config["output_fn"],
cross=config["cross"],
)
return obj
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/dummy.py
================================================
import torch
import torch.nn as nn
class Dummy(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.name: str = self.__class__.__name__
self.weight = 1.0
def forward(self, dummy: torch.Tensor, *args, **kwargs) -> torch.Tensor:
return torch.tensor([0.0] * dummy.shape[0], device=dummy.device)
@classmethod
def build(cls, config):
obj = cls()
return obj
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/local_ssi.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from unidepth.utils.geometric import erode
from .utils import FNS, ind2sub, masked_mean, masked_quantile, ssi
def sample_strong_edges(edges_img, quantile=0.95, reshape=8):
# flat
edges_img = F.interpolate(
edges_img, scale_factor=1 / reshape, mode="bilinear", align_corners=False
)
edges_img_flat = edges_img.flatten(1)
# Find strong edges
edges_mask = edges_img_flat > torch.quantile(
edges_img_flat, quantile, dim=-1, keepdim=True
)
num_samples = edges_mask.sum(dim=-1)
if (num_samples < 10).any():
# sample random edges where num_samples < 2
random = torch.rand_like(edges_img_flat[num_samples < 10, :]) > quantile
edges_mask[num_samples < 10, :] = torch.logical_or(
edges_mask[num_samples < 10, :], random
)
num_samples = edges_mask.sum(dim=-1)
min_samples = num_samples.min()
# Compute the coordinates of the strong edges as B, N, 2
edges_coords = torch.stack(
[torch.nonzero(x, as_tuple=False)[:min_samples].squeeze() for x in edges_mask]
)
edges_coords = (
torch.stack(ind2sub(edges_coords, edges_img.shape[-1]), dim=-1) * reshape
)
return edges_coords
@torch.jit.script
def extract_patches(tensor, sample_coords, patch_size: tuple[int, int] = (32, 32)):
N, _, H, W = tensor.shape
device = tensor.device
dtype = tensor.dtype
patch_width, patch_height = patch_size
pad_width = patch_width // 2
pad_height = patch_height // 2
# Pad the RGB images for both sheep
tensor_padded = F.pad(
tensor,
(pad_width, pad_width, pad_height, pad_height),
mode="constant",
value=0.0,
)
# Adjust edge coordinates to account for padding
sample_coords_padded = sample_coords + torch.tensor(
[pad_height, pad_width], dtype=dtype, device=device
).reshape(1, 1, 2)
# Calculate the indices for gather operation
x_centers = sample_coords_padded[:, :, 1].int()
y_centers = sample_coords_padded[:, :, 0].int()
all_patches = []
for tensor_i, x_centers_i, y_centers_i in zip(tensor_padded, x_centers, y_centers):
patches = []
for x_center, y_center in zip(x_centers_i, y_centers_i):
y_start, y_end = y_center - pad_height, y_center + pad_height + 1
x_start, x_end = x_center - pad_width, x_center + pad_width + 1
patches.append(tensor_i[..., y_start:y_end, x_start:x_end])
all_patches.append(torch.stack(patches, dim=0))
return torch.stack(all_patches, dim=0).reshape(N, -1, patch_height * patch_width)
class LocalSSI(nn.Module):
def __init__(
self,
weight: float,
output_fn: str = "sqrt",
patch_size: tuple[int, int] = (32, 32),
min_samples: int = 4,
num_levels: int = 4,
input_fn: str = "linear",
eps: float = 1e-5,
):
super(LocalSSI, self).__init__()
self.name: str = self.__class__.__name__
self.weight = weight
self.output_fn = FNS[output_fn]
self.input_fn = FNS[input_fn]
self.min_samples = min_samples
self.eps = eps
patch_logrange = np.linspace(
start=np.log2(min(patch_size)),
stop=np.log2(max(patch_size)),
endpoint=True,
num=num_levels + 1,
)
self.patch_logrange = [
(x, y) for x, y in zip(patch_logrange[:-1], patch_logrange[1:])
]
self.rescale_fn = ssi
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
mask: torch.Tensor,
*args,
**kwargs,
) -> torch.Tensor:
mask = mask.bool()
input = self.input_fn(input.float())
target = self.input_fn(target.float())
B, C, H, W = input.shape
total_errors = []
for ii, patch_logrange in enumerate(self.patch_logrange):
log_kernel = (
np.random.uniform(*patch_logrange)
if self.training
else np.mean(patch_logrange)
)
kernel_size = int(
(2**log_kernel) * min(input.shape[-2:])
) # always smaller than min_shape
kernel_size = (kernel_size, kernel_size)
stride = (int(kernel_size[0] * 0.9), int(kernel_size[1] * 0.9))
# unfold is always exceeding right/bottom, roll image only negative
# to have them back in the unfolding window
max_roll = (
(W - kernel_size[1]) % stride[1],
(H - kernel_size[0]) % stride[0],
)
roll_x, roll_y = np.random.randint(-max_roll[0], 1), np.random.randint(
-max_roll[1], 1
)
input_fold = torch.roll(input, shifts=(roll_y, roll_x), dims=(2, 3))
target_fold = torch.roll(target, shifts=(roll_y, roll_x), dims=(2, 3))
mask_fold = torch.roll(mask.float(), shifts=(roll_y, roll_x), dims=(2, 3))
# unfold in patches
input_fold = F.unfold(
input_fold, kernel_size=kernel_size, stride=stride
).permute(
0, 2, 1
) # B N C*H_p*W_p
target_fold = F.unfold(
target_fold, kernel_size=kernel_size, stride=stride
).permute(0, 2, 1)
mask_fold = (
F.unfold(mask_fold, kernel_size=kernel_size, stride=stride)
.bool()
.permute(0, 2, 1)
)
# calculate error patchwise, then mean over patch, then over image based if sample size is significant
input_fold, target_fold, _ = self.rescale_fn(
input_fold, target_fold, mask_fold, dim=[-1]
)
error = (input_fold - target_fold).abs()
# calculate elements more then 95 percentile and lower than 5percentile of error
valid_patches = mask_fold.sum(dim=-1) >= self.min_samples
error_mean_patch = masked_mean(error, mask_fold, dim=[-1]).squeeze(-1)
error_mean_image = self.output_fn(error_mean_patch.clamp(min=self.eps))
error_mean_image = masked_mean(
error_mean_image, mask=valid_patches, dim=[-1]
)
total_errors.append(error_mean_image.squeeze(-1))
# global
input_rescale = input.reshape(B, C, -1)
target_rescale = target.reshape(B, C, -1)
mask = mask.reshape(B, 1, -1).clone()
input, target, mask = self.rescale_fn(
input_rescale, target_rescale, mask, dim=[-1]
)
error = (input - target).abs().squeeze(1)
mask = mask.squeeze(1)
error_mean_image = masked_mean(error, mask, dim=[-1]).squeeze(-1)
error_mean_image = self.output_fn(error_mean_image.clamp(min=self.eps))
total_errors.append(error_mean_image)
errors = torch.stack(total_errors).mean(dim=0)
return errors
@classmethod
def build(cls, config):
obj = cls(
weight=config["weight"],
patch_size=config["patch_size"],
output_fn=config["output_fn"],
min_samples=config["min_samples"],
num_levels=config["num_levels"],
input_fn=config["input_fn"],
)
return obj
class EdgeGuidedLocalSSI(nn.Module):
def __init__(
self,
weight: float,
output_fn: str = "sqrt",
min_samples: int = 4,
input_fn: str = "linear",
use_global: bool = True,
eps: float = 1e-5,
):
super(EdgeGuidedLocalSSI, self).__init__()
self.name: str = self.__class__.__name__
self.weight = weight
self.output_fn = FNS[output_fn]
self.input_fn = FNS[input_fn]
self.min_samples = min_samples
self.eps = eps
self.use_global = use_global
self.rescale_fn = ssi
delta_x = torch.tensor(
[[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], requires_grad=False
)
delta_y = torch.tensor(
[[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]], requires_grad=False
)
self.delta_x = delta_x.reshape(1, 1, 3, 3)
self.delta_y = delta_y.reshape(1, 1, 3, 3)
try:
from unidepth.ops.extract_patches import RandomPatchExtractor
self.random_patch_extractor = RandomPatchExtractor()
except Exception as e:
self.random_patch_extractor = extract_patches
print(
"EdgeGuidedLocalSSI reverts to a non cuda-optimized operation, "
"you will experince large slowdown, "
"please install it: ",
"`cd ./unidepth/ops/extract_patches && bash compile.sh`",
)
def get_edge(self, image, mask):
channels = image.shape[1]
device = image.device
delta_x = self.delta_x.to(device).repeat(channels, 1, 1, 1)
delta_y = self.delta_y.to(device).repeat(channels, 1, 1, 1)
image_Gx = F.conv2d(image, delta_x, groups=channels, padding="same") / 8
image_Gy = F.conv2d(image, delta_y, groups=channels, padding="same") / 8
image_Gx = (
image_Gx.square().mean(dim=1, keepdim=True).sqrt()
) # RMSE over color dim
image_Gy = image_Gy.square().mean(dim=1, keepdim=True).sqrt()
edges = torch.sqrt(image_Gx**2 + image_Gy**2)
edges[:, :, :3, :] = 0
edges[:, :, -3:, :] = 0
edges[:, :, :, :3] = 0
edges[:, :, :, -3:] = 0
edges[~mask.bool()] = 0
return edges
def compute_sample_patch_error(
self, input, target, mask, sampling_coords, kernel_size, image_size
):
B, C, H, W = input.shape
patch_size = kernel_size[0] * kernel_size[1]
input = self.random_patch_extractor(
input, sampling_coords, kernel_size
).reshape(B, -1, patch_size)
target = self.random_patch_extractor(
target, sampling_coords, kernel_size
).reshape(B, -1, patch_size)
mask = (
self.random_patch_extractor(mask.float(), sampling_coords, kernel_size)
.bool()
.reshape(B, -1, patch_size)
)
input, target, mask = self.rescale_fn(input, target, mask, dim=[-1])
error = (input - target).abs().clamp(min=self.eps)
valid_patches = mask.sum(dim=-1) >= self.min_samples
error_mean_patch = masked_mean(error, mask, dim=[-1]).squeeze(-1)
error_mean_image = self.output_fn(error_mean_patch.clamp(min=self.eps))
error_mean_image = masked_mean(error_mean_image, mask=valid_patches, dim=[-1])
return error_mean_image
def compute_image_error(self, input, target, mask, image_size):
H, W = image_size
input = input.reshape(-1, 1, H * W)
target = target.reshape(-1, 1, H * W)
mask = mask.reshape(-1, 1, H * W)
input, target, mask = self.rescale_fn(input, target, mask, dim=[-1])
error = (input - target).abs().clamp(min=self.eps)
error_mean_image = masked_mean(error, mask, dim=[-1]).squeeze(-1)
error_mean_image = self.output_fn(error_mean_image.clamp(min=self.eps))
return error_mean_image
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
mask: torch.Tensor,
image: torch.Tensor | None = None,
validity_mask: torch.Tensor | None = None,
*args,
**kwargs,
) -> torch.Tensor:
mask = mask.bool()
input = self.input_fn(input.float())
target = self.input_fn(target.float())
B, _, H, W = input.shape
total_errors = []
# remove border and black border
if validity_mask is not None:
validity_mask = erode(validity_mask.float(), kernel_size=3)
edges = self.get_edge(image, validity_mask)
# quantile was 0.95?
edges_coords = sample_strong_edges(edges, quantile=0.9, reshape=14)
log_kernel = np.random.uniform(0.04, 0.08) if self.training else 0.05
kernel_size = int(
log_kernel * min(input.shape[-2:])
) # always smaller than min_shape
kernel_size = kernel_size + int(kernel_size % 2 == 0) # odd num
kernel_size = (kernel_size, kernel_size)
error_mean_image = self.compute_sample_patch_error(
input, target, mask, edges_coords, kernel_size, (H, W)
)
total_errors.append(error_mean_image.squeeze(-1))
if self.use_global:
error_mean_image = self.compute_image_error(input, target, mask, (H, W))
total_errors.append(error_mean_image.squeeze(-1))
errors = torch.stack(total_errors).mean(dim=0)
return errors
@classmethod
def build(cls, config):
obj = cls(
weight=config["weight"],
output_fn=config["output_fn"],
input_fn=config["input_fn"],
use_global=config["use_global"],
min_samples=config.get("min_samples", 6),
)
return obj
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/regression.py
================================================
import torch
import torch.nn as nn
from .utils import FNS, REGRESSION_DICT, masked_mean, masked_quantile
class Regression(nn.Module):
def __init__(
self,
weight: float,
input_fn: str,
output_fn: str,
alpha: float,
gamma: float,
fn: str,
dims: list[int] = [-1],
quantile: float = 0.0,
**kwargs,
):
super().__init__()
self.name = self.__class__.__name__
self.output_fn = FNS[output_fn]
self.input_fn = FNS[input_fn]
self.weight = weight
self.dims = dims
self.quantile = quantile
self.alpha = alpha
self.gamma = gamma
self.fn = REGRESSION_DICT[fn]
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
mask: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
if mask is not None: # usually it is just repeated
mask = mask[:, 0]
input = self.input_fn(input.float())
target = self.input_fn(target.float())
error = self.fn(input - target, gamma=self.gamma, alpha=self.alpha).mean(dim=1)
mean_error = masked_mean(data=error, mask=mask, dim=self.dims).squeeze(
self.dims
)
mean_error = self.output_fn(mean_error)
return mean_error
@classmethod
def build(cls, config):
obj = cls(
weight=config["weight"],
output_fn=config["output_fn"],
input_fn=config["input_fn"],
dims=config.get("dims", (-1,)),
alpha=config["alpha"],
gamma=config["gamma"],
fn=config["fn"],
)
return obj
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/silog.py
================================================
import torch
import torch.nn as nn
from .utils import (FNS, REGRESSION_DICT, masked_mean, masked_mean_var,
masked_quantile)
class SILog(nn.Module):
def __init__(
self,
weight: float,
input_fn: str = "linear",
output_fn: str = "sqrt",
integrated: float = 0.15,
dims: list[int] = [-3, -2, -1],
eps: float = 1e-5,
):
super().__init__()
self.name: str = self.__class__.__name__
self.weight: float = weight
self.dims = dims
self.input_fn = FNS[input_fn]
self.output_fn = FNS[output_fn]
self.eps: float = eps
self.integrated = integrated
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
mask: torch.Tensor,
si: torch.Tensor,
**kwargs,
) -> torch.Tensor:
mask = mask.bool()
error = self.input_fn(input.float()) - self.input_fn(target.float())
mean_error, var_error = masked_mean_var(
data=error, mask=mask, dim=self.dims, keepdim=False
)
if var_error.ndim > 1:
var_error = var_error.mean(dim=-1)
if self.integrated > 0.0:
scale_error = mean_error**2
var_error = var_error + self.integrated * scale_error * (1 - si.int())
out_loss = self.output_fn(var_error)
return out_loss
@classmethod
def build(cls, config):
obj = cls(
weight=config["weight"],
dims=config["dims"],
output_fn=config["output_fn"],
input_fn=config["input_fn"],
integrated=config.get("integrated", 0.15),
)
return obj
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/losses/utils.py
================================================
from math import prod
from typing import Any, Dict, List, Optional, Tuple
import torch
FNS = {
"sqrt": lambda x: torch.sqrt(x + 1e-4),
"log": lambda x: torch.log(x + 1e-4),
"log1": lambda x: torch.log(x + 1),
# if x -> 0 : log(1/x)
# if x -> inf : log(1+1/x) -> 1/x + hot
"log1i": lambda x: torch.log(1 + 50 / (1e-4 + x)),
"linear": lambda x: x,
"square": torch.square,
"disp": lambda x: 1 / (x + 1e-4),
"disp1": lambda x: 1 / (1 + x),
}
FNS_INV = {
"sqrt": torch.square,
"log": torch.exp,
"log1": lambda x: torch.exp(x) - 1,
"linear": lambda x: x,
"square": torch.sqrt,
"disp": lambda x: 1 / x,
}
def masked_mean_var(
data: torch.Tensor, mask: torch.Tensor, dim: List[int], keepdim: bool = True
):
if mask is None:
return data.mean(dim=dim, keepdim=keepdim), data.var(dim=dim, keepdim=keepdim)
mask = mask.float()
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
# data = torch.nan_to_num(data, nan=0.0)
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
mask_sum, min=1.0
)
mask_var = torch.sum(
mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
) / torch.clamp(mask_sum, min=1.0)
if not keepdim:
mask_mean, mask_var = mask_mean.squeeze(dim), mask_var.squeeze(dim)
return mask_mean, mask_var
def masked_mean(data: torch.Tensor, mask: torch.Tensor | None, dim: List[int]):
if mask is None:
return data.mean(dim=dim, keepdim=True)
mask = mask.float()
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
mask_mean = torch.sum(
torch.nan_to_num(data, nan=0.0) * mask, dim=dim, keepdim=True
) / mask_sum.clamp(min=1.0)
return mask_mean
def masked_quantile(
data: torch.Tensor, mask: torch.Tensor | None, dims: List[int], q: float
):
"""
Compute the quantile of the data only where the mask is 1 along specified dimensions.
Args:
data (torch.Tensor): The input data tensor.
mask (torch.Tensor): The mask tensor with the same shape as data, containing 1s where data should be considered.
dims (list of int): The dimensions to compute the quantile over.
q (float): The quantile to compute, must be between 0 and 1.
Returns:
torch.Tensor: The quantile computed over the specified dimensions, ignoring masked values.
"""
masked_data = data * mask if mask is not None else data
# Get a list of all dimensions
all_dims = list(range(masked_data.dim()))
# Revert negative dimensions
dims = [d % masked_data.dim() for d in dims]
# Find the dimensions to keep (not included in the `dims` list)
keep_dims = [d for d in all_dims if d not in dims]
# Permute dimensions to bring `dims` to the front
permute_order = dims + keep_dims
permuted_data = masked_data.permute(permute_order)
# Reshape into 2D: (-1, remaining_dims)
collapsed_shape = (
-1,
prod([permuted_data.size(d) for d in range(len(dims), permuted_data.dim())]),
)
reshaped_data = permuted_data.reshape(collapsed_shape)
if mask is None:
return torch.quantile(reshaped_data, q, dim=0)
permuted_mask = mask.permute(permute_order)
reshaped_mask = permuted_mask.reshape(collapsed_shape)
# Calculate quantile along the first dimension where mask is true
quantiles = []
for i in range(reshaped_data.shape[1]):
valid_data = reshaped_data[:, i][reshaped_mask[:, i]]
if valid_data.numel() == 0:
# print("Warning: No valid data found for quantile calculation.")
quantiles.append(reshaped_data[:, i].min() * 0.99)
else:
quantiles.append(torch.quantile(valid_data, q, dim=0))
# Stack back into a tensor with reduced dimensions
quantiles = torch.stack(quantiles)
quantiles = quantiles.reshape(
[permuted_data.size(d) for d in range(len(dims), permuted_data.dim())]
)
return quantiles
def masked_median(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
ndim = data.ndim
data = data.flatten(ndim - len(dim))
mask = mask.flatten(ndim - len(dim))
mask_median = torch.median(data[..., mask], dim=-1).values
return mask_median
def masked_median_mad(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
ndim = data.ndim
data = data.flatten(ndim - len(dim))
mask = mask.flatten(ndim - len(dim))
mask_median = torch.median(data[mask], dim=-1, keepdim=True).values
mask_mad = masked_mean((data - mask_median).abs(), mask, dim=[-1])
return mask_median, mask_mad
def masked_weighted_mean_var(
data: torch.Tensor, mask: torch.Tensor, weights: torch.Tensor, dim: Tuple[int, ...]
):
if mask is None:
return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
mask = mask.float()
mask_mean = torch.sum(data * mask * weights, dim=dim, keepdim=True) / torch.sum(
mask * weights, dim=dim, keepdim=True
).clamp(min=1.0)
# V1**2 - V2, V1: sum w_i, V2: sum w_i**2
denom = torch.sum(weights * mask, dim=dim, keepdim=True).square() - torch.sum(
(mask * weights).square(), dim=dim, keepdim=True
)
# correction is V1 / (V1**2 - V2), if w_i=1 => N/(N**2 - N) => 1/(N-1) (unbiased estimator of variance, cvd)
correction_factor = torch.sum(mask * weights, dim=dim, keepdim=True) / denom.clamp(
min=1.0
)
mask_var = correction_factor * torch.sum(
weights * mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
)
return mask_mean, mask_var
def ssi(
input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, dim: list[int]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# recalculate mask with points in 95% confidence interval
# the statistics are calculated on the stable points and
# are similar ot median/MAD, but median/MAD gradients
# are really weird, so this is a workaround
input_detach = input.detach()
input_mean, input_var = masked_mean_var(input_detach, mask=mask, dim=dim)
target_mean, target_var = masked_mean_var(target, mask=mask, dim=dim)
input_std = (input_var).clip(min=1e-6).sqrt()
target_std = (target_var).clip(min=1e-6).sqrt()
stable_points_input = torch.logical_and(
input_detach > input_mean - 1.96 * input_std,
input_detach < input_mean + 1.96 * input_std,
)
stable_points_target = torch.logical_and(
target > target_mean - 1.96 * target_std,
target < target_mean + 1.96 * target_std,
)
stable_mask = stable_points_target & stable_points_input & mask
input_mean, input_var = masked_mean_var(input, mask=stable_mask, dim=dim)
target_mean, target_var = masked_mean_var(target, mask=stable_mask, dim=dim)
target_normalized = (target - target_mean) / FNS["sqrt"](target_var)
input_normalized = (input - input_mean) / FNS["sqrt"](input_var)
return input_normalized, target_normalized, stable_mask
def ind2sub(idx, cols):
r = idx // cols
c = idx % cols
return r, c
def sub2ind(r, c, cols):
idx = r * cols + c
return idx
def l2(input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs) -> torch.Tensor:
return gamma * (input_tensor / gamma) ** 2
def l1(input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs) -> torch.Tensor:
return torch.abs(input_tensor)
def charbonnier(
input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs
) -> torch.Tensor:
return torch.sqrt(torch.square(input_tensor) + gamma**2) - gamma
def cauchy(
input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs
) -> torch.Tensor:
return gamma * torch.log(torch.square(input_tensor) / gamma + 1)
def geman_mcclure(
input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs
) -> torch.Tensor:
return gamma * torch.square(input_tensor) / (torch.square(input_tensor) + gamma)
def robust_loss(
input_tensor: torch.Tensor, alpha: float, gamma: float = 1.0, *args, **kwargs
) -> torch.Tensor:
coeff = abs(alpha - 2) / alpha
power = torch.square(input_tensor) / abs(alpha - 2) / (gamma**2) + 1
return (
gamma * coeff * (torch.pow(power, alpha / 2) - 1)
) # mult gamma to keep grad magnitude invariant wrt gamma
REGRESSION_DICT = {
"l2": l2,
"l1": l1,
"cauchy": cauchy,
"charbonnier": charbonnier,
"geman_mcclure": geman_mcclure,
"robust_loss": robust_loss,
}
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/ops/scheduler.py
================================================
import weakref
import numpy as np
class PlainCosineScheduler(object):
def __init__(
self,
klass,
key,
warmup_iters,
total_iters,
overwrite=False,
init_value=None,
base_value=None,
final_value=None,
step_init=-1,
):
super().__init__()
self.iter = step_init
self.overwrite = overwrite
self.base_value = base_value
self.init_value = init_value if init_value is not None else base_value
self.final_value = final_value
self.total_iters = total_iters
self.warmup_iters = warmup_iters
self.key = key
self.klass = klass
self.schedulers = [self.get_scheduler()]
def get_scheduler(self):
init_value = self.init_value
base_value = self.base_value
final_value = self.final_value
warmup_iters = self.warmup_iters
total_iters = self.total_iters
# normalize in 0,1, then apply function (power) and denormalize
normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True)
normalized_schedule = np.power(normalized_schedule, 1)
warmup_schedule = (base_value - init_value) * normalized_schedule + init_value
# main scheduling
iters = np.arange(total_iters - warmup_iters + 1)
schedule = final_value + 0.5 * (base_value - final_value) * (
1 + np.cos(np.pi * iters / (len(iters) - 1))
)
return np.concatenate((warmup_schedule, schedule))
def step(self):
self.iter = self.iter + 1
vals = self[self.iter]
for i, val in enumerate(vals):
setattr(self.klass, self.key, val)
def __getitem__(self, it):
it = min(it, self.total_iters)
return [scheduler[it] for scheduler in self.schedulers]
class CosineScheduler(object):
def __init__(
self,
optimizer,
warmup_iters,
total_iters,
key,
overwrite=False,
init_value=None,
base_value=None,
final_value=None,
step_init=-1,
):
super().__init__()
self.iter = step_init
self.overwrite = overwrite
self.optimizer = optimizer
self.base_value = base_value
self.init_value = init_value
self.final_value = final_value
self.total_iters = total_iters
self.warmup_iters = warmup_iters
self.key = key
self.schedulers = [
self.get_schedulers(group) for group in optimizer.param_groups
]
def get_schedulers(self, group):
init_value = group.get(self.key + "_init", self.init_value)
base_value = group.get(self.key + "_base", self.base_value)
final_value = group.get(self.key + "_final", self.final_value)
warmup_iters = self.warmup_iters
total_iters = self.total_iters
if self.overwrite:
final_value = self.final_value
# normalize in 0,1, then apply function (power) and denormalize
normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True)
normalized_schedule = np.power(normalized_schedule, 1)
warmup_schedule = (base_value - init_value) * normalized_schedule + init_value
# main scheduling
iters = np.arange(total_iters - warmup_iters + 1)
schedule = final_value + 0.5 * (base_value - final_value) * (
1 + np.cos(np.pi * iters / (len(iters) - 1))
)
return np.concatenate((warmup_schedule, schedule))
def step(self):
self.iter = self.iter + 1
vals = self[self.iter]
for group, val in zip(self.optimizer.param_groups, vals):
if isinstance(group[self.key], (tuple, list)):
val = (val, *group[self.key][1:])
group[self.key] = val
def __getitem__(self, it):
it = min(it, self.total_iters)
return [scheduler[it] for scheduler in self.schedulers]
def get(self):
return [group[self.key] for group in self.optimizer.param_groups]
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/__init__.py
================================================
from .camera import invert_pinhole
# from .validation import validate
from .coordinate import coords_grid, normalize_coords
from .distributed import (barrier, get_dist_info, get_rank, is_main_process,
setup_multi_processes, setup_slurm,
sync_tensor_across_gpus)
from .evaluation_depth import (DICT_METRICS, DICT_METRICS_3D, eval_3d,
eval_depth)
from .geometric import spherical_zbuffer_to_euclidean, unproject_points
from .misc import (format_seconds, get_params, identity, recursive_index,
remove_padding, to_cpu)
from .visualization import colorize, image_grid, log_train_artifacts
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/camera.py
================================================
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
from copy import deepcopy
import numpy as np
import torch
import torch.nn.functional as F
from .coordinate import coords_grid
from .misc import recursive_to, squeeze_list
def invert_pinhole(K):
fx = K[..., 0, 0]
fy = K[..., 1, 1]
cx = K[..., 0, 2]
cy = K[..., 1, 2]
K_inv = torch.zeros_like(K)
K_inv[..., 0, 0] = 1.0 / fx
K_inv[..., 1, 1] = 1.0 / fy
K_inv[..., 0, 2] = -cx / fx
K_inv[..., 1, 2] = -cy / fy
K_inv[..., 2, 2] = 1.0
return K_inv
class Camera:
"""
This is meant to be an abstract parent class, please use the others as actual cameras.
Pinhole, FIsheye624, MEI, OPENCV, EUCM, Spherical (Equirectangular).
"""
def __init__(self, params=None, K=None):
if params.ndim == 1:
params = params.unsqueeze(0)
if K is None:
K = (
torch.eye(3, device=params.device, dtype=params.dtype)
.unsqueeze(0)
.repeat(params.shape[0], 1, 1)
)
K[..., 0, 0] = params[..., 0]
K[..., 1, 1] = params[..., 1]
K[..., 0, 2] = params[..., 2]
K[..., 1, 2] = params[..., 3]
self.params = params
self.K = K
self.overlap_mask = None
self.projection_mask = None
def project(self, xyz):
raise NotImplementedError
def unproject(self, uv):
raise NotImplementedError
def get_projection_mask(self):
return self.projection_mask
def get_overlap_mask(self):
return self.overlap_mask
def reconstruct(self, depth):
id_coords = coords_grid(
1, depth.shape[-2], depth.shape[-1], device=depth.device
)
rays = self.unproject(id_coords)
return (
rays / rays[:, -1:].clamp(min=1e-4) * depth.clamp(min=1e-4)
) # assumption z>0!!!
def resize(self, factor):
self.K[..., :2, :] *= factor
self.params[..., :4] *= factor
return self
def to(self, device, non_blocking=False):
self.params = self.params.to(device, non_blocking=non_blocking)
self.K = self.K.to(device, non_blocking=non_blocking)
return self
def get_rays(self, shapes, noisy=False):
b, h, w = shapes
uv = coords_grid(1, h, w, device=self.K.device, noisy=noisy)
rays = self.unproject(uv)
return rays / torch.norm(rays, dim=1, keepdim=True).clamp(min=1e-4)
def get_pinhole_rays(self, shapes, noisy=False):
b, h, w = shapes
uv = coords_grid(b, h, w, device=self.K.device, homogeneous=True, noisy=noisy)
rays = (invert_pinhole(self.K) @ uv.reshape(b, 3, -1)).reshape(b, 3, h, w)
return rays / torch.norm(rays, dim=1, keepdim=True).clamp(min=1e-4)
def flip(self, H, W, direction="horizontal"):
new_cx = (
W - self.params[:, 2] if direction == "horizontal" else self.params[:, 2]
)
new_cy = H - self.params[:, 3] if direction == "vertical" else self.params[:, 3]
self.params = torch.stack(
[self.params[:, 0], self.params[:, 1], new_cx, new_cy], dim=1
)
self.K[..., 0, 2] = new_cx
self.K[..., 1, 2] = new_cy
return self
def clone(self):
return deepcopy(self)
def crop(self, left, top, right=None, bottom=None):
self.K[..., 0, 2] -= left
self.K[..., 1, 2] -= top
self.params[..., 2] -= left
self.params[..., 3] -= top
return self
# helper function to get how fov changes based on new original size and new size
def get_new_fov(self, new_shape, original_shape):
new_hfov = 2 * torch.atan(
self.params[..., 2] / self.params[..., 0] * new_shape[1] / original_shape[1]
)
new_vfov = 2 * torch.atan(
self.params[..., 3] / self.params[..., 1] * new_shape[0] / original_shape[0]
)
return new_hfov, new_vfov
def mask_overlap_projection(self, projected):
B, _, H, W = projected.shape
id_coords = coords_grid(B, H, W, device=projected.device)
# check for mask where flow would overlap with other part of the image
# eleemtns coming from the border are then masked out
flow = projected - id_coords
gamma = 0.1
sample_grid = gamma * flow + id_coords # sample along the flow
sample_grid[:, 0] = sample_grid[:, 0] / (W - 1) * 2 - 1
sample_grid[:, 1] = sample_grid[:, 1] / (H - 1) * 2 - 1
sampled_flow = F.grid_sample(
flow,
sample_grid.permute(0, 2, 3, 1),
mode="bilinear",
align_corners=False,
padding_mode="border",
)
mask = (
(1 - gamma) * torch.norm(flow, dim=1, keepdim=True)
< torch.norm(sampled_flow, dim=1, keepdim=True)
) | (torch.norm(flow, dim=1, keepdim=True) < 1)
return mask
def _pad_params(self):
# Ensure params are padded to length 16
if self.params.shape[1] < 16:
padding = torch.zeros(
16 - self.params.shape[1],
device=self.params.device,
dtype=self.params.dtype,
)
padding = padding.unsqueeze(0).repeat(self.params.shape[0], 1)
return torch.cat([self.params, padding], dim=1)
return self.params
@staticmethod
def flatten_cameras(cameras): # -> list[Camera]:
# Recursively flatten BatchCamera into primitive cameras
flattened_cameras = []
for camera in cameras:
if isinstance(camera, BatchCamera):
flattened_cameras.extend(BatchCamera.flatten_cameras(camera.cameras))
elif isinstance(camera, list):
flattened_cameras.extend(camera)
else:
flattened_cameras.append(camera)
return flattened_cameras
@staticmethod
def _stack_or_cat_cameras(cameras, func, **kwargs):
# Generalized method to handle stacking or concatenation
flat_cameras = BatchCamera.flatten_cameras(cameras)
K_matrices = [camera.K for camera in flat_cameras]
padded_params = [camera._pad_params() for camera in flat_cameras]
stacked_K = func(K_matrices, **kwargs)
stacked_params = func(padded_params, **kwargs)
# Keep track of the original classes
original_class = [x.__class__.__name__ for x in flat_cameras]
return BatchCamera(stacked_params, stacked_K, original_class, flat_cameras)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.cat:
return Camera._stack_or_cat_cameras(args[0], func, **kwargs)
if func is torch.stack:
return Camera._stack_or_cat_cameras(args[0], func, **kwargs)
if func is torch.flatten:
return Camera._stack_or_cat_cameras(args[0], torch.cat, **kwargs)
return super().__torch_function__(func, types, args, kwargs)
@property
def device(self):
return self.K.device
# here we assume that cx,cy are more or less H/2 and W/2
@property
def hfov(self):
return 2 * torch.atan(self.params[..., 2] / self.params[..., 0])
@property
def vfov(self):
return 2 * torch.atan(self.params[..., 3] / self.params[..., 1])
@property
def max_fov(self):
return 150.0 / 180.0 * np.pi, 150.0 / 180.0 * np.pi
class Pinhole(Camera):
def __init__(self, params=None, K=None):
assert params is not None or K is not None
if params is None:
params = torch.stack(
[K[..., 0, 0], K[..., 1, 1], K[..., 0, 2], K[..., 1, 2]], dim=-1
)
super().__init__(params=params, K=K)
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def project(self, pcd):
b, _, h, w = pcd.shape
pcd_flat = pcd.reshape(b, 3, -1) # [B, 3, H*W]
cam_coords = self.K @ pcd_flat
pcd_proj = cam_coords[:, :2] / cam_coords[:, -1:].clamp(min=0.01)
pcd_proj = pcd_proj.reshape(b, 2, h, w)
invalid = (
(pcd_proj[:, 0] >= 0)
& (pcd_proj[:, 0] < w)
& (pcd_proj[:, 1] >= 0)
& (pcd_proj[:, 1] < h)
)
self.projection_mask = (~invalid).unsqueeze(1)
return pcd_proj
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def unproject(self, uv):
b, _, h, w = uv.shape
uv_flat = uv.reshape(b, 2, -1) # [B, 2, H*W]
uv_homogeneous = torch.cat(
[uv_flat, torch.ones(b, 1, h * w, device=uv.device)], dim=1
) # [B, 3, H*W]
K_inv = torch.inverse(self.K.float())
xyz = K_inv @ uv_homogeneous
xyz = xyz / xyz[:, -1:].clip(min=1e-4)
xyz = xyz.reshape(b, 3, h, w)
self.unprojection_mask = xyz[:, -1:] > 1e-4
return xyz
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def reconstruct(self, depth):
b, _, h, w = depth.shape
uv = coords_grid(b, h, w, device=depth.device)
xyz = self.unproject(uv) * depth.clip(min=0.0)
return xyz
class EUCM(Camera):
def __init__(self, params):
super().__init__(params=params, K=None)
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def project(self, xyz):
H, W = xyz.shape[-2:]
fx, fy, cx, cy, alpha, beta = self.params[:6].unbind(dim=1)
x, y, z = xyz.unbind(dim=1)
d = torch.sqrt(beta * (x**2 + y**2) + z**2)
x = x / (alpha * d + (1 - alpha) * z).clip(min=1e-3)
y = y / (alpha * d + (1 - alpha) * z).clip(min=1e-3)
Xnorm = fx * x + cx
Ynorm = fy * y + cy
coords = torch.stack([Xnorm, Ynorm], dim=1)
invalid = (
(coords[:, 0] < 0)
| (coords[:, 0] > W)
| (coords[:, 1] < 0)
| (coords[:, 1] > H)
| (z < 0)
)
self.projection_mask = (~invalid).unsqueeze(1)
return coords
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def unproject(self, uv):
u, v = uv.unbind(dim=1)
fx, fy, cx, cy, alpha, beta = self.params.unbind(dim=1)
mx = (u - cx) / fx
my = (v - cy) / fy
r_square = mx**2 + my**2
valid_mask = r_square < torch.where(
alpha < 0.5, 1e6, 1 / (beta * (2 * alpha - 1))
)
sqrt_val = 1 - (2 * alpha - 1) * beta * r_square
mz = (1 - beta * (alpha**2) * r_square) / (
alpha * torch.sqrt(sqrt_val.clip(min=1e-5)) + (1 - alpha)
)
coeff = 1 / torch.sqrt(mx**2 + my**2 + mz**2 + 1e-5)
x = coeff * mx
y = coeff * my
z = coeff * mz
self.unprojection_mask = valid_mask & (z > 1e-3)
xnorm = torch.stack((x, y, z.clamp(1e-3)), dim=1)
return xnorm
class Spherical(Camera):
def __init__(self, params):
# Hfov and Vofv are in radians and halved!
super().__init__(params=params, K=None)
def resize(self, factor):
self.K[..., :2, :] *= factor
self.params[..., :6] *= factor
return self
def crop(self, left, top, right, bottom):
self.K[..., 0, 2] -= left
self.K[..., 1, 2] -= top
self.params[..., 2] -= left
self.params[..., 3] -= top
W, H = self.params[..., 4], self.params[..., 5]
angle_ratio_W = (W - left - right) / W
angle_ratio_H = (H - top - bottom) / H
self.params[..., 4] -= left + right
self.params[..., 5] -= top + bottom
# rescale hfov and vfov
self.params[..., 6] *= angle_ratio_W
self.params[..., 7] *= angle_ratio_H
return self
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def project(self, xyz):
width, height = self.params[..., 4], self.params[..., 5]
hfov, vfov = 2 * self.params[..., 6], 2 * self.params[..., 7]
longitude = torch.atan2(xyz[:, 0], xyz[:, 2])
latitude = torch.asin(xyz[:, 1] / torch.norm(xyz, dim=1).clamp(min=1e-5))
u = longitude / hfov * (width - 1) + (width - 1) / 2
v = latitude / vfov * (height - 1) + (height - 1) / 2
return torch.stack([u, v], dim=1)
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def unproject(self, uv):
u, v = uv.unbind(dim=1)
width, height = self.params[..., 4], self.params[..., 5]
hfov, vfov = 2 * self.params[..., 6], 2 * self.params[..., 7]
longitude = (u - (width - 1) / 2) / (width - 1) * hfov
latitude = (v - (height - 1) / 2) / (height - 1) * vfov
x = torch.cos(latitude) * torch.sin(longitude)
z = torch.cos(latitude) * torch.cos(longitude)
y = torch.sin(latitude)
unit_sphere = torch.stack([x, y, z], dim=1)
unit_sphere = unit_sphere / torch.norm(unit_sphere, dim=1, keepdim=True).clip(
min=1e-5
)
return unit_sphere
def reconstruct(self, depth):
id_coords = coords_grid(
1, depth.shape[-2], depth.shape[-1], device=depth.device
)
return self.unproject(id_coords) * depth
def get_new_fov(self, new_shape, original_shape):
new_hfov = 2 * self.params[..., 6] * new_shape[1] / original_shape[1]
new_vfov = 2 * self.params[..., 7] * new_shape[0] / original_shape[0]
return new_hfov, new_vfov
@property
def hfov(self):
return 2 * self.params[..., 6]
@property
def vfov(self):
return 2 * self.params[..., 7]
@property
def max_fov(self):
return 2 * np.pi, 0.9 * np.pi # avoid strong distortion on tops
class OPENCV(Camera):
def __init__(self, params):
super().__init__(params=params, K=None)
self.use_radial = self.params[..., 4:10].abs().sum() > 1e-6
assert (
self.params[..., 7:10].abs().sum() == 0.0
), "Do not support poly division model"
self.use_tangential = self.params[..., 10:12].abs().sum() > 1e-6
self.use_thin_prism = self.params[..., 12:].abs().sum() > 1e-6
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def project(self, xyz):
eps = 1e-9
B, _, H, W = xyz.shape
N = H * W
xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3)
# Radial correction.
z = xyz[:, :, 2].reshape(B, N, 1)
z = torch.where(torch.abs(z) < eps, eps * torch.sign(z), z)
ab = xyz[:, :, :2] / z
r = torch.norm(ab, dim=-1, p=2, keepdim=True)
th = r
# Create powers of th (th^3, th^5, ...)
th_pow = torch.cat([torch.pow(th, 2 + i * 2) for i in range(3)], dim=-1)
distortion_coeffs_num = self.params[:, 4:7].reshape(B, 1, 3)
distortion_coeffs_den = self.params[:, 7:10].reshape(B, 1, 3)
th_num = 1 + torch.sum(th_pow * distortion_coeffs_num, dim=-1, keepdim=True)
th_den = 1 + torch.sum(th_pow * distortion_coeffs_den, dim=-1, keepdim=True)
xr_yr = ab * th_num / th_den
uv_dist = xr_yr
# Tangential correction.
p0 = self.params[..., -6].reshape(B, 1)
p1 = self.params[..., -5].reshape(B, 1)
xr = xr_yr[:, :, 0].reshape(B, N)
yr = xr_yr[:, :, 1].reshape(B, N)
xr_yr_sq = torch.square(xr_yr)
xr_sq = xr_yr_sq[:, :, 0].reshape(B, N)
yr_sq = xr_yr_sq[:, :, 1].reshape(B, N)
rd_sq = xr_sq + yr_sq
uv_dist_tu = uv_dist[:, :, 0] + (
(2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1
)
uv_dist_tv = uv_dist[:, :, 1] + (
(2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0
)
uv_dist = torch.stack(
[uv_dist_tu, uv_dist_tv], dim=-1
) # Avoids in-place complaint.
# Thin Prism correction.
s0 = self.params[..., -4].reshape(B, 1)
s1 = self.params[..., -3].reshape(B, 1)
s2 = self.params[..., -2].reshape(B, 1)
s3 = self.params[..., -1].reshape(B, 1)
rd_4 = torch.square(rd_sq)
uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4)
uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4)
# Finally, apply standard terms: focal length and camera centers.
if self.params.shape[-1] == 15:
fx_fy = self.params[..., 0].reshape(B, 1, 1)
cx_cy = self.params[..., 1:3].reshape(B, 1, 2)
else:
fx_fy = self.params[..., 0:2].reshape(B, 1, 2)
cx_cy = self.params[..., 2:4].reshape(B, 1, 2)
result = uv_dist * fx_fy + cx_cy
result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2)
invalid = (
(result[:, 0] < 0)
| (result[:, 0] > W)
| (result[:, 1] < 0)
| (result[:, 1] > H)
)
self.projection_mask = (~invalid).unsqueeze(1)
self.overlap_mask = self.mask_overlap_projection(result)
return result
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def unproject(self, uv, max_iters: int = 10):
eps = 1e-3
B, _, H, W = uv.shape
N = H * W
uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2)
if self.params.shape[-1] == 15:
fx_fy = self.params[..., 0].reshape(B, 1, 1)
cx_cy = self.params[..., 1:3].reshape(B, 1, 2)
else:
fx_fy = self.params[..., 0:2].reshape(B, 1, 2)
cx_cy = self.params[..., 2:4].reshape(B, 1, 2)
uv_dist = (uv - cx_cy) / fx_fy
# Compute xr_yr using Newton's method.
xr_yr = uv_dist.clone() # Initial guess.
max_iters_tanprism = (
max_iters if self.use_thin_prism or self.use_tangential else 0
)
for _ in range(max_iters_tanprism):
uv_dist_est = xr_yr.clone()
xr = xr_yr[..., 0].reshape(B, N)
yr = xr_yr[..., 1].reshape(B, N)
xr_yr_sq = torch.square(xr_yr)
xr_sq = xr_yr_sq[..., 0].reshape(B, N)
yr_sq = xr_yr_sq[..., 1].reshape(B, N)
rd_sq = xr_sq + yr_sq
if self.use_tangential:
# Tangential terms.
p0 = self.params[..., -6].reshape(B, 1)
p1 = self.params[..., -5].reshape(B, 1)
uv_dist_est[..., 0] = uv_dist_est[..., 0] + (
(2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1
)
uv_dist_est[..., 1] = uv_dist_est[..., 1] + (
(2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0
)
if self.use_thin_prism:
# Thin Prism terms.
s0 = self.params[..., -4].reshape(B, 1)
s1 = self.params[..., -3].reshape(B, 1)
s2 = self.params[..., -2].reshape(B, 1)
s3 = self.params[..., -1].reshape(B, 1)
rd_4 = torch.square(rd_sq)
uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4)
uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4)
# Compute the derivative of uv_dist w.r.t. xr_yr.
duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2)
if self.use_tangential:
duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1
offdiag = 2.0 * (xr * p1 + yr * p0)
duv_dist_dxr_yr[..., 0, 1] = offdiag
duv_dist_dxr_yr[..., 1, 0] = offdiag
duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0
if self.use_thin_prism:
xr_yr_sq_norm = xr_sq + yr_sq
temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm)
duv_dist_dxr_yr[..., 0, 0] = duv_dist_dxr_yr[..., 0, 0] + (xr * temp1)
duv_dist_dxr_yr[..., 0, 1] = duv_dist_dxr_yr[..., 0, 1] + (yr * temp1)
temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm)
duv_dist_dxr_yr[..., 1, 0] = duv_dist_dxr_yr[..., 1, 0] + (xr * temp2)
duv_dist_dxr_yr[..., 1, 1] = duv_dist_dxr_yr[..., 1, 1] + (yr * temp2)
mat = duv_dist_dxr_yr.reshape(-1, 2, 2)
a = mat[:, 0, 0].reshape(-1, 1, 1)
b = mat[:, 0, 1].reshape(-1, 1, 1)
c = mat[:, 1, 0].reshape(-1, 1, 1)
d = mat[:, 1, 1].reshape(-1, 1, 1)
det = 1.0 / ((a * d) - (b * c))
top = torch.cat([d, -b], dim=-1)
bot = torch.cat([-c, a], dim=-1)
inv = det * torch.cat([top, bot], dim=-2)
inv = inv.reshape(B, N, 2, 2)
diff = uv_dist - uv_dist_est
a = inv[..., 0, 0]
b = inv[..., 0, 1]
c = inv[..., 1, 0]
d = inv[..., 1, 1]
e = diff[..., 0]
f = diff[..., 1]
step = torch.stack([a * e + b * f, c * e + d * f], dim=-1)
# Newton step.
xr_yr = xr_yr + step
# Compute theta using Newton's method.
xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1)
th = xr_yr_norm.clone()
max_iters_radial = max_iters if self.use_radial else 0
c = (
torch.tensor([2.0 * i + 3 for i in range(3)], device=self.device)
.reshape(1, 1, 3)
.repeat(B, 1, 1)
)
radial_params_num = self.params[..., 4:7].reshape(B, 1, 3)
# Trust region parameters
delta = torch.full((B, N, 1), 0.1, device=self.device) # Initial trust radius
delta_max = torch.tensor(1.0, device=self.device) # Maximum trust radius
eta = 0.1 # Acceptable reduction threshold
for i in range(max_iters_radial):
th_sq = th * th # th^2
# Compute powers of th^2 up to th^(12)
theta_powers = torch.cat(
[th_sq ** (i + 1) for i in range(3)], dim=-1
) # Shape: (B, N, 6)
# Compute th_radial: radial distortion model applied to th
th_radial = 1.0 + torch.sum(
theta_powers * radial_params_num, dim=-1, keepdim=True
)
th_radial = th_radial * th # Multiply by th at the end
# Compute derivative dthd_th
dthd_th = 1.0 + torch.sum(
c * radial_params_num * theta_powers, dim=-1, keepdim=True
)
dthd_th = dthd_th # Already includes derivative terms
# Compute residual
residual = th_radial - xr_yr_norm # Shape: (B, N, 1)
residual_norm = torch.norm(residual, dim=2, keepdim=True) # For each pixel
# Check for convergence
if torch.max(torch.abs(residual)) < eps:
break
# Avoid division by zero by adding a small epsilon
safe_dthd_th = dthd_th.clone()
zero_derivative_mask = dthd_th.abs() < eps
safe_dthd_th[zero_derivative_mask] = eps
# Compute Newton's step
step = -residual / safe_dthd_th
# Compute predicted reduction
predicted_reduction = -(residual * step).sum(dim=2, keepdim=True)
# Adjust step based on trust region
step_norm = torch.norm(step, dim=2, keepdim=True)
over_trust_mask = step_norm > delta
# Scale step if it exceeds trust radius
step_scaled = step.clone()
step_scaled[over_trust_mask] = step[over_trust_mask] * (
delta[over_trust_mask] / step_norm[over_trust_mask]
)
# Update theta
th_new = th + step_scaled
# Compute new residual
th_sq_new = th_new * th_new
theta_powers_new = torch.cat(
[th_sq_new ** (j + 1) for j in range(3)], dim=-1
)
th_radial_new = 1.0 + torch.sum(
theta_powers_new * radial_params_num, dim=-1, keepdim=True
)
th_radial_new = th_radial_new * th_new
residual_new = th_radial_new - xr_yr_norm
residual_new_norm = torch.norm(residual_new, dim=2, keepdim=True)
# Compute actual reduction
actual_reduction = residual_norm - residual_new_norm
# Compute ratio of actual to predicted reduction
# predicted_reduction[predicted_reduction.abs() < eps] = eps #* torch.sign(predicted_reduction[predicted_reduction.abs() < eps])
rho = actual_reduction / predicted_reduction
rho[(actual_reduction == 0) & (predicted_reduction == 0)] = 1.0
# Update trust radius delta
delta_update_mask = rho > 0.5
delta[delta_update_mask] = torch.min(
2.0 * delta[delta_update_mask], delta_max
)
delta_decrease_mask = rho < 0.2
delta[delta_decrease_mask] = 0.25 * delta[delta_decrease_mask]
# Accept or reject the step
accept_step_mask = rho > eta
th = torch.where(accept_step_mask, th_new, th)
# Compute the ray direction using theta and xr_yr.
close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps)
ray_dir = torch.where(close_to_zero, xr_yr, th / xr_yr_norm * xr_yr)
ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2)
ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2)
return ray
class Fisheye624(Camera):
def __init__(self, params):
super().__init__(params=params, K=None)
self.use_radial = self.params[..., 4:10].abs().sum() > 1e-6
self.use_tangential = self.params[..., 10:12].abs().sum() > 1e-6
self.use_thin_prism = self.params[..., 12:].abs().sum() > 1e-6
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def project(self, xyz):
eps = 1e-9
B, _, H, W = xyz.shape
N = H * W
xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3)
# Radial correction.
z = xyz[:, :, 2].reshape(B, N, 1)
z = torch.where(torch.abs(z) < eps, eps * torch.sign(z), z)
ab = xyz[:, :, :2] / z
r = torch.norm(ab, dim=-1, p=2, keepdim=True)
th = torch.atan(r)
th_divr = torch.where(r < eps, torch.ones_like(ab), ab / r)
th_pow = torch.cat(
[torch.pow(th, 3 + i * 2) for i in range(6)], dim=-1
) # Create powers of th (th^3, th^5, ...)
distortion_coeffs = self.params[:, 4:10].reshape(B, 1, 6)
th_k = th + torch.sum(th_pow * distortion_coeffs, dim=-1, keepdim=True)
xr_yr = th_k * th_divr
uv_dist = xr_yr
# Tangential correction.
p0 = self.params[..., -6].reshape(B, 1)
p1 = self.params[..., -5].reshape(B, 1)
xr = xr_yr[:, :, 0].reshape(B, N)
yr = xr_yr[:, :, 1].reshape(B, N)
xr_yr_sq = torch.square(xr_yr)
xr_sq = xr_yr_sq[:, :, 0].reshape(B, N)
yr_sq = xr_yr_sq[:, :, 1].reshape(B, N)
rd_sq = xr_sq + yr_sq
uv_dist_tu = uv_dist[:, :, 0] + (
(2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1
)
uv_dist_tv = uv_dist[:, :, 1] + (
(2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0
)
uv_dist = torch.stack(
[uv_dist_tu, uv_dist_tv], dim=-1
) # Avoids in-place complaint.
# Thin Prism correction.
s0 = self.params[..., -4].reshape(B, 1)
s1 = self.params[..., -3].reshape(B, 1)
s2 = self.params[..., -2].reshape(B, 1)
s3 = self.params[..., -1].reshape(B, 1)
rd_4 = torch.square(rd_sq)
uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4)
uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4)
# Finally, apply standard terms: focal length and camera centers.
if self.params.shape[-1] == 15:
fx_fy = self.params[..., 0].reshape(B, 1, 1)
cx_cy = self.params[..., 1:3].reshape(B, 1, 2)
else:
fx_fy = self.params[..., 0:2].reshape(B, 1, 2)
cx_cy = self.params[..., 2:4].reshape(B, 1, 2)
result = uv_dist * fx_fy + cx_cy
result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2)
invalid = (
(result[:, 0] < 0)
| (result[:, 0] > W)
| (result[:, 1] < 0)
| (result[:, 1] > H)
)
self.projection_mask = (~invalid).unsqueeze(1)
self.overlap_mask = self.mask_overlap_projection(result)
return result
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def unproject(self, uv, max_iters: int = 10):
eps = 1e-3
B, _, H, W = uv.shape
N = H * W
uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2)
if self.params.shape[-1] == 15:
fx_fy = self.params[..., 0].reshape(B, 1, 1)
cx_cy = self.params[..., 1:3].reshape(B, 1, 2)
else:
fx_fy = self.params[..., 0:2].reshape(B, 1, 2)
cx_cy = self.params[..., 2:4].reshape(B, 1, 2)
uv_dist = (uv - cx_cy) / fx_fy
# Compute xr_yr using Trust-region method.
xr_yr = uv_dist.clone()
max_iters_tanprism = (
max_iters if self.use_thin_prism or self.use_tangential else 0
)
for _ in range(max_iters_tanprism):
uv_dist_est = xr_yr.clone()
xr = xr_yr[..., 0].reshape(B, N)
yr = xr_yr[..., 1].reshape(B, N)
xr_yr_sq = torch.square(xr_yr)
xr_sq = xr_yr_sq[..., 0].reshape(B, N)
yr_sq = xr_yr_sq[..., 1].reshape(B, N)
rd_sq = xr_sq + yr_sq
if self.use_tangential:
# Tangential terms.
p0 = self.params[..., -6].reshape(B, 1)
p1 = self.params[..., -5].reshape(B, 1)
uv_dist_est[..., 0] = uv_dist_est[..., 0] + (
(2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1
)
uv_dist_est[..., 1] = uv_dist_est[..., 1] + (
(2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0
)
if self.use_thin_prism:
# Thin Prism terms.
s0 = self.params[..., -4].reshape(B, 1)
s1 = self.params[..., -3].reshape(B, 1)
s2 = self.params[..., -2].reshape(B, 1)
s3 = self.params[..., -1].reshape(B, 1)
rd_4 = torch.square(rd_sq)
uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4)
uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4)
# Compute the derivative of uv_dist w.r.t. xr_yr.
duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2)
if self.use_tangential:
duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1
offdiag = 2.0 * (xr * p1 + yr * p0)
duv_dist_dxr_yr[..., 0, 1] = offdiag
duv_dist_dxr_yr[..., 1, 0] = offdiag
duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0
if self.use_thin_prism:
xr_yr_sq_norm = xr_sq + yr_sq
temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm)
duv_dist_dxr_yr[..., 0, 0] = duv_dist_dxr_yr[..., 0, 0] + (xr * temp1)
duv_dist_dxr_yr[..., 0, 1] = duv_dist_dxr_yr[..., 0, 1] + (yr * temp1)
temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm)
duv_dist_dxr_yr[..., 1, 0] = duv_dist_dxr_yr[..., 1, 0] + (xr * temp2)
duv_dist_dxr_yr[..., 1, 1] = duv_dist_dxr_yr[..., 1, 1] + (yr * temp2)
mat = duv_dist_dxr_yr.reshape(-1, 2, 2)
a = mat[:, 0, 0].reshape(-1, 1, 1)
b = mat[:, 0, 1].reshape(-1, 1, 1)
c = mat[:, 1, 0].reshape(-1, 1, 1)
d = mat[:, 1, 1].reshape(-1, 1, 1)
det = 1.0 / ((a * d) - (b * c))
top = torch.cat([d, -b], dim=-1)
bot = torch.cat([-c, a], dim=-1)
inv = det * torch.cat([top, bot], dim=-2)
inv = inv.reshape(B, N, 2, 2)
diff = uv_dist - uv_dist_est
a = inv[..., 0, 0]
b = inv[..., 0, 1]
c = inv[..., 1, 0]
d = inv[..., 1, 1]
e = diff[..., 0]
f = diff[..., 1]
step = torch.stack([a * e + b * f, c * e + d * f], dim=-1)
# Newton step.
xr_yr = xr_yr + step
# Compute theta using Newton's method.
xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1)
th = xr_yr_norm.clone()
max_iters_radial = max_iters if self.use_radial else 0
c = (
torch.tensor([2.0 * i + 3 for i in range(6)], device=self.device)
.reshape(1, 1, 6)
.repeat(B, 1, 1)
)
radial_params = self.params[..., 4:10].reshape(B, 1, 6)
# Trust region parameters
delta = torch.full((B, N, 1), 0.1, device=self.device) # Initial trust radius
delta_max = torch.tensor(1.0, device=self.device) # Maximum trust radius
eta = 0.1 # Acceptable reduction threshold
for i in range(max_iters_radial):
th_sq = th * th
# Compute powers of th^2 up to th^(12)
theta_powers = torch.cat(
[th_sq ** (i + 1) for i in range(6)], dim=-1
) # Shape: (B, N, 6)
# Compute th_radial: radial distortion model applied to th
th_radial = 1.0 + torch.sum(
theta_powers * radial_params, dim=-1, keepdim=True
)
th_radial = th_radial * th
# Compute derivative dthd_th
dthd_th = 1.0 + torch.sum(
c * radial_params * theta_powers, dim=-1, keepdim=True
)
# Compute residual
residual = th_radial - xr_yr_norm # Shape: (B, N, 1)
residual_norm = torch.norm(residual, dim=2, keepdim=True)
# Check for convergence
if torch.max(torch.abs(residual)) < eps:
break
# Avoid division by zero by adding a small epsilon
safe_dthd_th = dthd_th.clone()
zero_derivative_mask = dthd_th.abs() < eps
safe_dthd_th[zero_derivative_mask] = eps
# Compute Newton's step
step = -residual / safe_dthd_th
# Compute predicted reduction
predicted_reduction = -(residual * step).sum(dim=2, keepdim=True)
# Adjust step based on trust region
step_norm = torch.norm(step, dim=2, keepdim=True)
over_trust_mask = step_norm > delta
# Scale step if it exceeds trust radius
step_scaled = step.clone()
step_scaled[over_trust_mask] = step[over_trust_mask] * (
delta[over_trust_mask] / step_norm[over_trust_mask]
)
# Update theta
th_new = th + step_scaled
# Compute new residual
th_sq_new = th_new * th_new
theta_powers_new = torch.cat(
[th_sq_new ** (j + 1) for j in range(6)], dim=-1
)
th_radial_new = 1.0 + torch.sum(
theta_powers_new * radial_params, dim=-1, keepdim=True
)
th_radial_new = th_radial_new * th_new
residual_new = th_radial_new - xr_yr_norm
residual_new_norm = torch.norm(residual_new, dim=2, keepdim=True)
# Compute actual reduction
actual_reduction = residual_norm - residual_new_norm
# Compute ratio of actual to predicted reduction
rho = actual_reduction / predicted_reduction
rho[(actual_reduction == 0) & (predicted_reduction == 0)] = 1.0
# Update trust radius delta
delta_update_mask = rho > 0.5
delta[delta_update_mask] = torch.min(
2.0 * delta[delta_update_mask], delta_max
)
delta_decrease_mask = rho < 0.2
delta[delta_decrease_mask] = 0.25 * delta[delta_decrease_mask]
# Accept or reject the step
accept_step_mask = rho > eta
th = torch.where(accept_step_mask, th_new, th)
# Compute the ray direction using theta and xr_yr.
close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps)
ray_dir = torch.where(close_to_zero, xr_yr, torch.tan(th) / xr_yr_norm * xr_yr)
ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2)
ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2)
return ray
class MEI(Camera):
def __init__(self, params):
super().__init__(params=params, K=None)
# fx fy cx cy k1 k2 p1 p2 xi
self.use_radial = self.params[..., 4:6].abs().sum() > 1e-6
self.use_tangential = self.params[..., 6:8].abs().sum() > 1e-6
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def unproject(self, uv, max_iters: int = 20):
eps = 1e-6
B, _, H, W = uv.shape
N = H * W
uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2)
k1, k2, p0, p1, xi = self.params[..., 4:9].unbind(dim=1)
fx_fy = self.params[..., 0:2].reshape(B, 1, 2)
cx_cy = self.params[..., 2:4].reshape(B, 1, 2)
uv_dist = (uv - cx_cy) / fx_fy
# Compute xr_yr using Newton's method.
xr_yr = uv_dist.clone() # Initial guess.
max_iters_tangential = max_iters if self.use_tangential else 0
for _ in range(max_iters_tangential):
uv_dist_est = xr_yr.clone()
# Tangential terms.
xr = xr_yr[..., 0]
yr = xr_yr[..., 1]
xr_yr_sq = xr_yr**2
xr_sq = xr_yr_sq[..., 0]
yr_sq = xr_yr_sq[..., 1]
rd_sq = xr_sq + yr_sq
uv_dist_est[..., 0] = uv_dist_est[..., 0] + (
(2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1
)
uv_dist_est[..., 1] = uv_dist_est[..., 1] + (
(2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0
)
# Compute the derivative of uv_dist w.r.t. xr_yr.
duv_dist_dxr_yr = torch.ones((B, N, 2, 2), device=uv.device)
duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1
offdiag = 2.0 * (xr * p1 + yr * p0)
duv_dist_dxr_yr[..., 0, 1] = offdiag
duv_dist_dxr_yr[..., 1, 0] = offdiag
duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0
mat = duv_dist_dxr_yr.reshape(-1, 2, 2)
a = mat[:, 0, 0].reshape(-1, 1, 1)
b = mat[:, 0, 1].reshape(-1, 1, 1)
c = mat[:, 1, 0].reshape(-1, 1, 1)
d = mat[:, 1, 1].reshape(-1, 1, 1)
det = 1.0 / ((a * d) - (b * c))
top = torch.cat([d, -b], dim=-1)
bot = torch.cat([-c, a], dim=-1)
inv = det * torch.cat([top, bot], dim=-2)
inv = inv.reshape(B, N, 2, 2)
diff = uv_dist - uv_dist_est
a = inv[..., 0, 0]
b = inv[..., 0, 1]
c = inv[..., 1, 0]
d = inv[..., 1, 1]
e = diff[..., 0]
f = diff[..., 1]
step = torch.stack([a * e + b * f, c * e + d * f], dim=-1)
# Newton step.
xr_yr = xr_yr + step
# Compute theta using Newton's method.
xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1)
th = xr_yr_norm.clone()
max_iters_radial = max_iters if self.use_radial else 0
for _ in range(max_iters_radial):
th_radial = 1.0 + k1 * torch.pow(th, 2) + k2 * torch.pow(th, 4)
dthd_th = 1.0 + 3.0 * k1 * torch.pow(th, 2) + 5.0 * k2 * torch.pow(th, 4)
th_radial = th_radial * th
step = (xr_yr_norm - th_radial) / dthd_th
# handle dthd_th close to 0.
step = torch.where(
torch.abs(dthd_th) > eps, step, torch.sign(step) * eps * 10.0
)
th = th + step
# Compute the ray direction using theta and xr_yr.
close_to_zero = (torch.abs(th) < eps) & (torch.abs(xr_yr_norm) < eps)
ray_dir = torch.where(close_to_zero, xr_yr, th * xr_yr / xr_yr_norm)
# Compute the 3D projective ray
rho2_u = (
ray_dir.norm(p=2, dim=2, keepdim=True) ** 2
) # B N 1 # x_c * x_c + y_c * y_c
xi = xi.reshape(B, 1, 1)
sqrt_term = torch.sqrt(1.0 + (1.0 - xi * xi) * rho2_u)
P_z = 1.0 - xi * (rho2_u + 1.0) / (xi + sqrt_term)
# Special case when xi is 1.0 (unit sphere projection ??)
P_z = torch.where(xi == 1.0, (1.0 - rho2_u) / 2.0, P_z)
ray = torch.cat([ray_dir, P_z], dim=-1)
ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2)
return ray
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def project(self, xyz):
is_flat = xyz.ndim == 3
B, N = xyz.shape[:2]
if not is_flat:
B, _, H, W = xyz.shape
N = H * W
xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3)
k1, k2, p0, p1, xi = self.params[..., 4:].unbind(dim=1)
fx_fy = self.params[..., 0:2].reshape(B, 1, 2)
cx_cy = self.params[..., 2:4].reshape(B, 1, 2)
norm = xyz.norm(p=2, dim=-1, keepdim=True)
ab = xyz[..., :-1] / (xyz[..., -1:] + xi.reshape(B, 1, 1) * norm)
# radial correction
r = ab.norm(dim=-1, p=2, keepdim=True)
k1 = self.params[..., 4].reshape(B, 1, 1)
k2 = self.params[..., 5].reshape(B, 1, 1)
# ab / r * th * (1 + k1 * (th ** 2) + k2 * (th**4))
# but here r = th, no spherical distortion
xr_yr = ab * (1 + k1 * (r**2) + k2 * (r**4))
# Tangential correction.
uv_dist = xr_yr
p0 = self.params[:, -3].reshape(B, 1)
p1 = self.params[:, -2].reshape(B, 1)
xr = xr_yr[..., 0].reshape(B, N)
yr = xr_yr[..., 1].reshape(B, N)
xr_yr_sq = torch.square(xr_yr)
xr_sq = xr_yr_sq[:, :, 0].reshape(B, N)
yr_sq = xr_yr_sq[:, :, 1].reshape(B, N)
rd_sq = xr_sq + yr_sq
uv_dist_tu = uv_dist[:, :, 0] + (
(2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1
)
uv_dist_tv = uv_dist[:, :, 1] + (
(2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0
)
uv_dist = torch.stack(
[uv_dist_tu, uv_dist_tv], dim=-1
) # Avoids in-place complaint.
result = uv_dist * fx_fy + cx_cy
if not is_flat:
result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2)
invalid = (
(result[:, 0] < 0)
| (result[:, 0] > W)
| (result[:, 1] < 0)
| (result[:, 1] > H)
)
self.projection_mask = (~invalid).unsqueeze(1)
# creates hole in the middle... ??
# self.overlap_mask = self.mask_overlap_projection(result)
return result
class BatchCamera(Camera):
"""
This is not to be used directly, but to be used as a wrapper around multiple cameras.
It should expose only the `from_camera` method as it the only way to create a BatchCamera.
"""
def __init__(self, params, K, original_class, cameras):
super().__init__(params, K)
self.original_class = original_class
self.cameras = cameras
# Delegate these methods to original camera
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def project(self, points_3d):
return torch.cat(
[
camera.project(points_3d[i : i + 1])
for i, camera in enumerate(self.cameras)
]
)
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def unproject(self, points_2d):
val = torch.cat(
[camera.unproject(points_2d) for i, camera in enumerate(self.cameras)]
)
return val
def crop(self, left, top, right=None, bottom=None):
val = torch.cat(
[
camera.crop(left, top, right, bottom)
for i, camera in enumerate(self.cameras)
]
)
return val
def resize(self, ratio):
val = torch.cat([camera.resize(ratio) for i, camera in enumerate(self.cameras)])
return val
def reconstruct(self, depth):
val = torch.cat(
[
camera.reconstruct(depth[i : i + 1])
for i, camera in enumerate(self.cameras)
]
)
return val
def get_projection_mask(self):
return torch.cat(
[camera.projection_mask for i, camera in enumerate(self.cameras)]
)
def to(self, device, non_blocking=False):
self = super().to(device, non_blocking=non_blocking)
self.cameras = recursive_to(
self.cameras, device, non_blocking=non_blocking, cls=Camera
)
return self
def reshape(self, *shape):
# Reshape the intrinsic matrix (K) and params
# we know that the shape of K is (..., 3, 3) and params is (..., 16)
reshaped_K = self.K.reshape(*shape, 3, 3)
reshaped_params = self.params.reshape(*shape, self.params.shape[-1])
self.cameras = np.array(self.cameras, dtype=object).reshape(shape).tolist()
self.original_class = (
np.array(self.original_class, dtype=object).reshape(shape).tolist()
)
# Create a new BatchCamera with reshaped K and params
return BatchCamera(
reshaped_params, reshaped_K, self.original_class, self.cameras
)
def get_new_fov(self, new_shape, original_shape):
return [
camera.get_new_fov(new_shape, original_shape)
for i, camera in enumerate(self.cameras)
]
def squeeze(self, dim):
return BatchCamera(
self.params.squeeze(dim),
self.K.squeeze(dim),
squeeze_list(self.original_class, dim=dim),
squeeze_list(self.cameras, dim=dim),
)
def __getitem__(self, idx):
if isinstance(idx, int):
return self.cameras[idx]
elif isinstance(idx, slice):
return BatchCamera(
self.params[idx],
self.K[idx],
self.original_class[idx],
self.cameras[idx],
)
raise TypeError(f"Invalid index type: {type(idx)}")
def __setitem__(self, idx, value):
# If it's an integer index, return a single camera
if isinstance(idx, int):
self.cameras[idx] = value
self.params[idx, :] = 0.0
self.params[idx, : value.params.shape[1]] = value.params[0]
self.K[idx] = value.K[0]
self.original_class[idx] = getattr(
value, "original_class", value.__class__.__name__
)
# If it's a slice, return a new BatchCamera with sliced cameras
elif isinstance(idx, slice):
# Update each internal attribute using the slice
self.params[idx] = value.params
self.K[idx] = value.K
self.original_class[idx] = value.original_class
self.cameras[idx] = value.cameras
def __len__(self):
return len(self.cameras)
@classmethod
def from_camera(cls, camera):
return cls(camera.params, camera.K, [camera.__class__.__name__], [camera])
@property
def is_perspective(self):
return [isinstance(camera, Pinhole) for camera in self.cameras]
@property
def is_spherical(self):
return [isinstance(camera, Spherical) for camera in self.cameras]
@property
def is_eucm(self):
return [isinstance(camera, EUCM) for camera in self.cameras]
@property
def is_fisheye(self):
return [isinstance(camera, Fisheye624) for camera in self.cameras]
@property
def is_pinhole(self):
return [isinstance(camera, Pinhole) for camera in self.cameras]
@property
def hfov(self):
return [camera.hfov for camera in self.cameras]
@property
def vfov(self):
return [camera.vfov for camera in self.cameras]
@property
def max_fov(self):
return [camera.max_fov for camera in self.cameras]
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/chamfer_distance.py
================================================
import warnings
from typing import Union
import torch
try:
from unidepth.ops.knn import knn_points
except ImportError as e:
warnings.warn(
"!! To run evaluation you need KNN. Please compile KNN: "
"`cd unidepth/ops/knn with && bash compile.sh`."
)
knn_points = lambda x : x
def _validate_chamfer_reduction_inputs(
batch_reduction: Union[str, None], point_reduction: str
):
"""Check the requested reductions are valid.
Args:
batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["mean", "sum"] or None.
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["mean", "sum"].
"""
if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
if point_reduction not in ["mean", "sum"]:
raise ValueError('point_reduction must be one of ["mean", "sum"]')
def _handle_pointcloud_input(
points: torch.Tensor,
lengths: Union[torch.Tensor, None],
normals: Union[torch.Tensor, None],
):
"""
If points is an instance of Pointclouds, retrieve the padded points tensor
along with the number of points per batch and the padded normals.
Otherwise, return the input points (and normals) with the number of points per cloud
set to the size of the second dimension of `points`.
"""
if points.ndim != 3:
raise ValueError("Expected points to be of shape (N, P, D)")
X = points
if lengths is not None and (lengths.ndim != 1 or lengths.shape[0] != X.shape[0]):
raise ValueError("Expected lengths to be of shape (N,)")
if lengths is None:
lengths = torch.full(
(X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device
)
if normals is not None and normals.ndim != 3:
raise ValueError("Expected normals to be of shape (N, P, 3")
return X, lengths, normals
class ChamferDistance(torch.nn.Module):
def forward(
self,
x,
y,
x_lengths=None,
y_lengths=None,
x_normals=None,
y_normals=None,
weights=None,
batch_reduction: Union[str, None] = "mean",
point_reduction: str = "mean",
):
"""
Chamfer distance between two pointclouds x and y.
Args:
x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
a batch of point clouds with at most P1 points in each batch element,
batch size N and feature dimension D.
y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
a batch of point clouds with at most P2 points in each batch element,
batch size N and feature dimension D.
x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
cloud in x.
y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
cloud in x.
x_normals: Optional FloatTensor of shape (N, P1, D).
y_normals: Optional FloatTensor of shape (N, P2, D).
weights: Optional FloatTensor of shape (N,) giving weights for
batch elements for reduction operation.
batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["mean", "sum"] or None.
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["mean", "sum"].
Returns:
2-element tuple containing
- **loss**: Tensor giving the reduced distance between the pointclouds
in x and the pointclouds in y.
- **loss_normals**: Tensor giving the reduced cosine distance of normals
between pointclouds in x and pointclouds in y. Returns None if
x_normals and y_normals are None.
"""
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
return_normals = x_normals is not None and y_normals is not None
N, P1, D = x.shape
P2 = y.shape[1]
# Check if inputs are heterogeneous and create a lengths mask.
is_x_heterogeneous = (x_lengths != P1).any()
is_y_heterogeneous = (y_lengths != P2).any()
x_mask = (
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
) # shape [N, P1]
y_mask = (
torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
) # shape [N, P2]
if y.shape[0] != N or y.shape[2] != D:
raise ValueError("y does not have the correct shape.")
if weights is not None:
if weights.size(0) != N:
raise ValueError("weights must be of shape (N,).")
if not (weights >= 0).all():
raise ValueError("weights cannot be negative.")
if weights.sum() == 0.0:
weights = weights.view(N, 1)
if batch_reduction in ["mean", "sum"]:
return (
(x.sum((1, 2)) * weights).sum() * 0.0,
(x.sum((1, 2)) * weights).sum() * 0.0,
)
return (
(x.sum((1, 2)) * weights) * 0.0,
(x.sum((1, 2)) * weights) * 0.0,
)
x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1)
y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1)
cham_x = x_nn.dists[..., 0] # (N, P1)
cham_y = y_nn.dists[..., 0] # (N, P2)
if is_x_heterogeneous:
cham_x[x_mask] = 0.0
if is_y_heterogeneous:
cham_y[y_mask] = 0.0
if weights is not None:
cham_x *= weights.view(N, 1)
cham_y *= weights.view(N, 1)
return cham_x, cham_y, x_nn.idx[..., -1], y_nn.idx[..., -1]
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/constants.py
================================================
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import math
import torch
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DATASET_STD = (0.229, 0.224, 0.225)
DEPTH_BINS = torch.cat(
(
torch.logspace(math.log10(0.1), math.log10(180.0), steps=512),
torch.tensor([260.0]),
),
dim=0,
)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/coordinate.py
================================================
import torch
def coords_grid(b, h, w, homogeneous=False, device=None, noisy=False):
pixel_coords_x = torch.linspace(0.5, w - 0.5, w, device=device)
pixel_coords_y = torch.linspace(0.5, h - 0.5, h, device=device)
if noisy: # \pm 0.5px noise
pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5
pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5
stacks = [pixel_coords_x.repeat(h, 1), pixel_coords_y.repeat(w, 1).t()]
if homogeneous:
ones = torch.ones_like(stacks[0]) # [H, W]
stacks.append(ones)
grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
if device is not None:
grid = grid.to(device)
return grid
def normalize_coords(coords, h, w):
c = torch.tensor([(w - 1) / 2.0, (h - 1) / 2.0], device=coords.device).view(
1, 2, 1, 1
)
return (coords - c) / c
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/distributed.py
================================================
import os
import pickle
import platform
import subprocess
import warnings
import cv2
import torch
import torch.utils.data.distributed
from torch import distributed as dist
from torch import multiprocessing as mp
_LOCAL_PROCESS_GROUP = None
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def get_local_rank() -> int:
"""
Returns:
The rank of the current process within the local (per-machine) process group.
"""
if not is_dist_avail_and_initialized():
return 0
assert _LOCAL_PROCESS_GROUP is not None
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
def get_local_size() -> int:
"""
Returns:
The size of the per-machine process group,
i.e. the number of processes per machine.
"""
if not is_dist_avail_and_initialized():
return 1
assert _LOCAL_PROCESS_GROUP is not None
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def barrier():
if not is_dist_avail_and_initialized():
return
dist.barrier()
def is_main_process():
return get_rank() == 0
def is_rank_zero(args):
return args.rank == 0
def get_dist_info():
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def setup_multi_processes(cfg):
"""Setup multi-processing environment variables."""
# set multi-process start method as `fork` to speed up the training
if platform.system() != "Windows":
mp_start_method = cfg.get("mp_start_method", "fork")
current_method = mp.get_start_method(allow_none=True)
if current_method is not None and current_method != mp_start_method:
warnings.warn(
f"Multi-processing start method `{mp_start_method}` is "
f"different from the previous setting `{current_method}`."
f"It will be force set to `{mp_start_method}`. You can change "
f"this behavior by changing `mp_start_method` in your config."
)
mp.set_start_method(mp_start_method, force=True)
# disable opencv multithreading to avoid system being overloaded
# opencv_num_threads = cfg.get('opencv_num_threads', 0)
# cv2.setNumThreads(opencv_num_threads)
# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
# workers_per_gpu = cfg.get('workers_per_gpu', 4)
# if 'OMP_NUM_THREADS' not in os.environ and workers_per_gpu > 1:
# omp_num_threads = 1
# warnings.warn(
# f'Setting OMP_NUM_THREADS environment variable for each process '
# f'to be {omp_num_threads} in default, to avoid your system being '
# f'overloaded, please further tune the variable for optimal '
# f'performance in your application as needed.')
# os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
# setup MKL threads
# if 'MKL_NUM_THREADS' not in os.environ and workers_per_gpu > 1:
# mkl_num_threads = os.environ.get('OMP_NUM_THREADS', 1)
# warnings.warn(
# f'Setting MKL_NUM_THREADS environment variable for each process '
# f'to be {mkl_num_threads} in default, to avoid your system being '
# f'overloaded, please further tune the variable for optimal '
# f'performance in your application as needed.')
# os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
def setup_slurm(backend: str, port: str) -> None:
proc_id = int(os.environ["SLURM_PROCID"])
ntasks = int(os.environ["SLURM_NTASKS"])
node_list = os.environ["SLURM_NODELIST"]
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
os.environ["MASTER_PORT"] = str(port)
os.environ["MASTER_ADDR"] = addr
os.environ["WORLD_SIZE"] = str(ntasks)
os.environ["LOCAL_RANK"] = str(proc_id % num_gpus)
os.environ["RANK"] = str(proc_id)
print(
proc_id,
ntasks,
num_gpus,
proc_id % num_gpus,
node_list,
addr,
os.environ["MASTER_PORT"],
os.system("nvidia-smi -L"),
)
dist.init_process_group(backend, rank=proc_id, world_size=ntasks)
def sync_tensor_across_gpus(t, dim=0, cat=True):
if t is None or not (dist.is_available() and dist.is_initialized()):
return t
t = torch.atleast_1d(t)
group = dist.group.WORLD
group_size = torch.distributed.get_world_size(group)
local_size = torch.tensor(t.size(dim), device=t.device)
all_sizes = [torch.zeros_like(local_size) for _ in range(group_size)]
dist.all_gather(all_sizes, local_size)
max_size = max(all_sizes)
size_diff = max_size.item() - local_size.item()
if size_diff:
padding = torch.zeros(size_diff, device=t.device, dtype=t.dtype)
t = torch.cat((t, padding))
gather_t_tensor = [torch.zeros_like(t) for _ in range(group_size)]
dist.all_gather(gather_t_tensor, t)
all_ts = []
for t, size in zip(gather_t_tensor, all_sizes):
all_ts.append(t[:size])
if cat:
return torch.cat(all_ts, dim=0)
return all_ts
def sync_string_across_gpus(keys: list[str], device, dim=0):
keys_serialized = pickle.dumps(keys, protocol=pickle.HIGHEST_PROTOCOL)
keys_serialized_tensor = (
torch.frombuffer(keys_serialized, dtype=torch.uint8).clone().to(device)
)
keys_serialized_tensor = sync_tensor_across_gpus(
keys_serialized_tensor, dim=0, cat=False
)
keys = [
key
for keys in keys_serialized_tensor
for key in pickle.loads(bytes(keys.cpu().tolist()))
]
return keys
def create_local_process_group() -> None:
num_workers_per_machine = torch.cuda.device_count()
global _LOCAL_PROCESS_GROUP
assert _LOCAL_PROCESS_GROUP is None
assert get_world_size() % num_workers_per_machine == 0
num_machines = get_world_size() // num_workers_per_machine
machine_rank = get_rank() // num_workers_per_machine
for i in range(num_machines):
ranks_on_i = list(
range(i * num_workers_per_machine, (i + 1) * num_workers_per_machine)
)
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
_LOCAL_PROCESS_GROUP = pg
def _get_global_gloo_group():
if dist.get_backend() == "nccl":
return dist.new_group(backend="gloo")
else:
return dist.group.WORLD
def all_gather(data, group=None):
if get_world_size() == 1:
return [data]
if group is None:
group = (
_get_global_gloo_group()
) # use CPU group by default, to reduce GPU RAM usage.
world_size = dist.get_world_size(group)
if world_size == 1:
return [data]
output = [None for _ in range(world_size)]
dist.all_gather_object(output, data, group=group)
return output
def local_broadcast_process_authkey():
if get_local_size() == 1:
return
local_rank = get_local_rank()
authkey = bytes(mp.current_process().authkey)
all_keys = all_gather(authkey)
local_leader_key = all_keys[get_rank() - local_rank]
if authkey != local_leader_key:
# print("Process authkey is different from the key of local leader! workers are launched independently ??")
# print("Overwriting local authkey ...")
mp.current_process().authkey = local_leader_key
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/ema_torch.py
================================================
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
from __future__ import division, unicode_literals
import contextlib
import copy
import weakref
from math import tanh
from typing import Iterable, Optional
import torch
class DummyExponentialMovingAverage:
def __init__(self, *args, **kwargs):
pass
def _get_parameters(self, *args, **kwargs):
pass
def get_current_decay(self, *args, **kwargs):
pass
def update(self, *args, **kwargs):
pass
def copy_to(self, *args, **kwargs):
pass
def store(self, *args, **kwargs):
return
def restore(self, *args, **kwargs):
return
@contextlib.contextmanager
def average_parameters(self, *args, **kwargs):
try:
yield
finally:
pass
def to(self, *args, **kwargs):
pass
def state_dict(self, *args, **kwargs):
pass
def load_state_dict(self, *args, **kwargs):
pass
class ExponentialMovingAverage:
"""
Maintains (exponential) moving average of a set of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter` (typically from
`model.parameters()`).
Note that EMA is computed on *all* provided parameters,
regardless of whether or not they have `requires_grad = True`;
this allows a single EMA object to be consistantly used even
if which parameters are trainable changes step to step.
If you want to some parameters in the EMA, do not pass them
to the object in the first place. For example:
ExponentialMovingAverage(
parameters=[p for p in model.parameters() if p.requires_grad],
decay=0.9
)
will ignore parameters that do not require grad.
decay: The exponential decay.
use_num_updates: Whether to use number of updates when computing
averages.
"""
def __init__(
self,
parameters: Iterable[torch.nn.Parameter],
decay: float,
use_num_updates: bool = True,
update_after_step: int = 10000,
tau: int = 20000,
switch: bool = False,
):
if decay < 0.0 or decay > 1.0:
raise ValueError("Decay must be between 0 and 1")
self.decay = decay
self.switch = switch # fi keeping EMA params in model after epochs
self.num_updates = 0 if use_num_updates else None
parameters = list(parameters)
self.shadow_params = [p.clone().detach() for p in parameters]
self.collected_params = None
# By maintaining only a weakref to each parameter,
# we maintain the old GC behaviour of ExponentialMovingAverage:
# if the model goes out of scope but the ExponentialMovingAverage
# is kept, no references to the model or its parameters will be
# maintained, and the model will be cleaned up.
self._params_refs = [weakref.ref(p) for p in parameters]
self.update_after_step = update_after_step
self.tau = tau
def _get_parameters(
self, parameters: Optional[Iterable[torch.nn.Parameter]]
) -> Iterable[torch.nn.Parameter]:
if parameters is None:
parameters = [p() for p in self._params_refs]
if any(p is None for p in parameters):
raise ValueError(
"(One of) the parameters with which this ExponentialMovingAverage was initialized no longer exists (was garbage collected);"
" please either provide `parameters` explicitly or keep the model to which they belong from being garbage collected."
)
return parameters
else:
parameters = list(parameters)
if len(parameters) != len(self.shadow_params):
raise ValueError(
"Number of parameters passed as argument is different "
"from number of shadow parameters maintained by this "
"ExponentialMovingAverage"
)
return parameters
def get_current_decay(self):
epoch = max(self.num_updates - self.update_after_step - 1, 0.0)
if epoch <= 0:
return 0.0
value = tanh(epoch / self.tau) * self.decay
return value
def update(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None:
"""
Update currently maintained parameters.
Call this every time the parameters are updated, such as the result of
the `optimizer.step()` call.
Args:
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
parameters used to initialize this object. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = self._get_parameters(parameters)
decay = self.get_current_decay()
if self.num_updates is not None:
self.num_updates += 1
one_minus_decay = 1.0 - decay
with torch.no_grad():
for s_param, param in zip(self.shadow_params, parameters):
tmp = s_param - param
# tmp will be a new tensor so we can do in-place
tmp.mul_(one_minus_decay)
s_param.sub_(tmp)
def copy_to(
self, parameters: Optional[Iterable[torch.nn.Parameter]] = None
) -> None:
"""
Copy current averaged parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = self._get_parameters(parameters)
for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.data)
def store(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None:
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored. If `None`, the parameters of with which this
`ExponentialMovingAverage` was initialized will be used.
"""
parameters = self._get_parameters(parameters)
self.collected_params = [param.detach().clone() for param in parameters]
def restore(
self, parameters: Optional[Iterable[torch.nn.Parameter]] = None
) -> None:
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
if self.collected_params is None:
raise RuntimeError(
"This ExponentialMovingAverage has no `store()`ed weights "
"to `restore()`"
)
parameters = self._get_parameters(parameters)
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
@contextlib.contextmanager
def average_parameters(
self, parameters: Optional[Iterable[torch.nn.Parameter]] = None
):
r"""
Context manager for validation/inference with averaged parameters.
Equivalent to:
ema.store()
ema.copy_to()
try:
...
finally:
ema.restore()
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = self._get_parameters(parameters)
self.store(parameters)
self.copy_to(parameters)
try:
yield
finally:
if not self.switch:
self.restore(parameters)
def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
Args:
device: like `device` argument to `torch.Tensor.to`
"""
# .to() on the tensors handles None correctly
self.shadow_params = [
(
p.to(device=device, dtype=dtype)
if p.is_floating_point()
else p.to(device=device)
)
for p in self.shadow_params
]
if self.collected_params is not None:
self.collected_params = [
(
p.to(device=device, dtype=dtype)
if p.is_floating_point()
else p.to(device=device)
)
for p in self.collected_params
]
return
def state_dict(self) -> dict:
r"""Returns the state of the ExponentialMovingAverage as a dict."""
# Following PyTorch conventions, references to tensors are returned:
# "returns a reference to the state and not its copy!" -
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
return {
"decay": self.decay,
"num_updates": self.num_updates,
"shadow_params": self.shadow_params,
"collected_params": self.collected_params,
}
def load_state_dict(self, state_dict: dict) -> None:
r"""Loads the ExponentialMovingAverage state.
Args:
state_dict (dict): EMA state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = copy.deepcopy(state_dict)
self.decay = state_dict["decay"]
if self.decay < 0.0 or self.decay > 1.0:
raise ValueError("Decay must be between 0 and 1")
self.num_updates = state_dict["num_updates"]
assert self.num_updates is None or isinstance(
self.num_updates, int
), "Invalid num_updates"
self.shadow_params = state_dict["shadow_params"]
assert isinstance(self.shadow_params, list), "shadow_params must be a list"
assert all(
isinstance(p, torch.Tensor) for p in self.shadow_params
), "shadow_params must all be Tensors"
self.collected_params = state_dict["collected_params"]
if self.collected_params is not None:
assert isinstance(
self.collected_params, list
), "collected_params must be a list"
assert all(
isinstance(p, torch.Tensor) for p in self.collected_params
), "collected_params must all be Tensors"
assert len(self.collected_params) == len(
self.shadow_params
), "collected_params and shadow_params had different lengths"
if len(self.shadow_params) == len(self._params_refs):
# Consistant with torch.optim.Optimizer, cast things to consistant
# device and dtype with the parameters
params = [p() for p in self._params_refs]
# If parameters have been garbage collected, just load the state
# we were given without change.
if not any(p is None for p in params):
# ^ parameter references are still good
for i, p in enumerate(params):
self.shadow_params[i] = self.shadow_params[i].to(
device=p.device, dtype=p.dtype
)
if self.collected_params is not None:
self.collected_params[i] = self.collected_params[i].to(
device=p.device, dtype=p.dtype
)
else:
raise ValueError(
"Tried to `load_state_dict()` with the wrong number of "
"parameters in the saved state."
)
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/evaluation_depth.py
================================================
from collections import defaultdict
from functools import partial
import torch
import torch.nn.functional as F
from unidepth.utils.chamfer_distance import ChamferDistance
chamfer_cls = ChamferDistance()
def chamfer_dist(tensor1, tensor2):
x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device)
y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device)
dist1, dist2, idx1, idx2 = chamfer_cls(
tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths
)
return (torch.sqrt(dist1) + torch.sqrt(dist2)) / 2
def auc(tensor1, tensor2, thresholds):
x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device)
y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device)
dist1, dist2, idx1, idx2 = chamfer_cls(
tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths
)
# compute precision recall
precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds]
recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds]
auc_value = torch.trapz(
torch.tensor(precisions, device=tensor1.device),
torch.tensor(recalls, device=tensor1.device),
)
return auc_value
def delta(tensor1, tensor2, exponent):
inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1))
return (inlier < 1.25**exponent).to(torch.float32).mean()
def tau(tensor1, tensor2, perc):
inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1))
return (inlier < (1.0 + perc)).to(torch.float32).mean()
def ssi(tensor1, tensor2):
stability_mat = 1e-9 * torch.eye(2, device=tensor1.device)
tensor2_one = torch.stack(
[tensor2.detach(), torch.ones_like(tensor2).detach()], dim=1
)
scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (
tensor2_one.T @ tensor1.unsqueeze(1)
)
scale, shift = scale_shift.squeeze().chunk(2, dim=0)
return tensor2 * scale + shift
def si(tensor1, tensor2):
return tensor2 * torch.median(tensor1) / torch.median(tensor2)
def arel(tensor1, tensor2):
tensor2 = tensor2 * torch.median(tensor1) / torch.median(tensor2)
return (torch.abs(tensor1 - tensor2) / tensor1).mean()
def d_auc(tensor1, tensor2):
exponents = torch.linspace(0.01, 5.0, steps=100, device=tensor1.device)
deltas = [delta(tensor1, tensor2, exponent) for exponent in exponents]
return torch.trapz(torch.tensor(deltas, device=tensor1.device), exponents) / 5.0
def f1_score(tensor1, tensor2, thresholds):
x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device)
y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device)
dist1, dist2, idx1, idx2 = chamfer_cls(
tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths
)
# compute precision recall
precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds]
recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds]
precisions = torch.tensor(precisions, device=tensor1.device)
recalls = torch.tensor(recalls, device=tensor1.device)
f1_thresholds = 2 * precisions * recalls / (precisions + recalls)
f1_thresholds = torch.where(
torch.isnan(f1_thresholds), torch.zeros_like(f1_thresholds), f1_thresholds
)
f1_value = torch.trapz(f1_thresholds) / len(thresholds)
return f1_value
DICT_METRICS = {
"d1": partial(delta, exponent=1.0),
"d2": partial(delta, exponent=2.0),
"d3": partial(delta, exponent=3.0),
"rmse": lambda gt, pred: torch.sqrt(((gt - pred) ** 2).mean()),
"rmselog": lambda gt, pred: torch.sqrt(
((torch.log(gt) - torch.log(pred)) ** 2).mean()
),
"arel": lambda gt, pred: (torch.abs(gt - pred) / gt).mean(),
"sqrel": lambda gt, pred: (((gt - pred) ** 2) / gt).mean(),
"log10": lambda gt, pred: torch.abs(torch.log10(pred) - torch.log10(gt)).mean(),
"silog": lambda gt, pred: 100 * torch.std(torch.log(pred) - torch.log(gt)).mean(),
"medianlog": lambda gt, pred: 100
* (torch.log(pred) - torch.log(gt)).median().abs(),
"d_auc": d_auc,
"tau": partial(tau, perc=0.03),
}
DICT_METRICS_3D = {
"MSE_3d": lambda gt, pred, thresholds: torch.norm(gt - pred, dim=0, p=2),
"chamfer": lambda gt, pred, thresholds: chamfer_dist(
gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1)
),
"F1": lambda gt, pred, thresholds: f1_score(
gt.unsqueeze(0).permute(0, 2, 1),
pred.unsqueeze(0).permute(0, 2, 1),
thresholds=thresholds,
),
}
DICT_METRICS_D = {
"a1": lambda gt, pred: (torch.maximum((gt / pred), (pred / gt)) > 1.25**1.0).to(
torch.float32
),
"abs_rel": lambda gt, pred: (torch.abs(gt - pred) / gt),
}
def eval_depth(
gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, max_depth=None
):
summary_metrics = defaultdict(list)
preds = F.interpolate(preds, gts.shape[-2:], mode="bilinear")
for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)):
if max_depth is not None:
mask = mask & (gt <= max_depth)
for name, fn in DICT_METRICS.items():
if name in ["tau", "d1", "arel"]:
for rescale_fn in ["ssi", "si"]:
summary_metrics[f"{name}_{rescale_fn}"].append(
fn(gt[mask], eval(rescale_fn)(gt[mask], pred[mask]))
)
summary_metrics[name].append(fn(gt[mask], pred[mask]).mean())
return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()}
def eval_3d(
gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, thresholds=None
):
summary_metrics = defaultdict(list)
ratio = min(
1.0, (240 * 320 / masks.sum()) ** 0.5
) # rescale to avoid OOM during eval, FIXME
h_max, w_max = int(gts.shape[-2] * ratio), int(gts.shape[-1] * ratio)
gts = F.interpolate(gts, size=(h_max, w_max), mode="nearest-exact")
preds = F.interpolate(preds, size=(h_max, w_max), mode="nearest-exact")
masks = F.interpolate(
masks.float(), size=(h_max, w_max), mode="nearest-exact"
).bool()
for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)):
if not torch.any(mask):
continue
for name, fn in DICT_METRICS_3D.items():
summary_metrics[name].append(
fn(gt[:, mask.squeeze()], pred[:, mask.squeeze()], thresholds).mean()
)
return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()}
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/geometric.py
================================================
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
from typing import Tuple
import torch
from torch.nn import functional as F
@torch.jit.script
def generate_rays(
camera_intrinsics: torch.Tensor, image_shape: Tuple[int, int], noisy: bool = False
):
batch_size, device, dtype = (
camera_intrinsics.shape[0],
camera_intrinsics.device,
camera_intrinsics.dtype,
)
height, width = image_shape
# Generate grid of pixel coordinates
pixel_coords_x = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
pixel_coords_y = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
if noisy:
pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5
pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5
pixel_coords = torch.stack(
[pixel_coords_x.repeat(height, 1), pixel_coords_y.repeat(width, 1).t()], dim=2
) # (H, W, 2)
pixel_coords = pixel_coords + 0.5
# Calculate ray directions
intrinsics_inv = torch.eye(3, device=device).unsqueeze(0).repeat(batch_size, 1, 1)
intrinsics_inv[:, 0, 0] = 1.0 / camera_intrinsics[:, 0, 0]
intrinsics_inv[:, 1, 1] = 1.0 / camera_intrinsics[:, 1, 1]
intrinsics_inv[:, 0, 2] = -camera_intrinsics[:, 0, 2] / camera_intrinsics[:, 0, 0]
intrinsics_inv[:, 1, 2] = -camera_intrinsics[:, 1, 2] / camera_intrinsics[:, 1, 1]
homogeneous_coords = torch.cat(
[pixel_coords, torch.ones_like(pixel_coords[:, :, :1])], dim=2
) # (H, W, 3)
ray_directions = torch.matmul(
intrinsics_inv, homogeneous_coords.permute(2, 0, 1).flatten(1)
) # (3, H*W)
ray_directions = F.normalize(ray_directions, dim=1) # (B, 3, H*W)
ray_directions = ray_directions.permute(0, 2, 1) # (B, H*W, 3)
theta = torch.atan2(ray_directions[..., 0], ray_directions[..., -1])
phi = torch.acos(ray_directions[..., 1])
# pitch = torch.asin(ray_directions[..., 1])
# roll = torch.atan2(ray_directions[..., 0], - ray_directions[..., 1])
angles = torch.stack([theta, phi], dim=-1)
return ray_directions, angles
@torch.jit.script
def spherical_zbuffer_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
theta = spherical_tensor[..., 0] # Extract polar angle
phi = spherical_tensor[..., 1] # Extract azimuthal angle
z = spherical_tensor[..., 2] # Extract zbuffer depth
# y = r * cos(phi)
# x = r * sin(phi) * sin(theta)
# z = r * sin(phi) * cos(theta)
# =>
# r = z / sin(phi) / cos(theta)
# y = z / (sin(phi) / cos(phi)) / cos(theta)
# x = z * sin(theta) / cos(theta)
x = z * torch.tan(theta)
y = z / torch.tan(phi) / torch.cos(theta)
euclidean_tensor = torch.stack((x, y, z), dim=-1)
return euclidean_tensor
@torch.jit.script
def spherical_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
theta = spherical_tensor[..., 0] # Extract polar angle
phi = spherical_tensor[..., 1] # Extract azimuthal angle
r = spherical_tensor[..., 2] # Extract radius
# y = r * cos(phi)
# x = r * sin(phi) * sin(theta)
# z = r * sin(phi) * cos(theta)
x = r * torch.sin(phi) * torch.sin(theta)
y = r * torch.cos(phi)
z = r * torch.cos(theta) * torch.sin(phi)
euclidean_tensor = torch.stack((x, y, z), dim=-1)
return euclidean_tensor
@torch.jit.script
def euclidean_to_spherical(spherical_tensor: torch.Tensor) -> torch.Tensor:
x = spherical_tensor[..., 0] # Extract polar angle
y = spherical_tensor[..., 1] # Extract azimuthal angle
z = spherical_tensor[..., 2] # Extract radius
# y = r * cos(phi)
# x = r * sin(phi) * sin(theta)
# z = r * sin(phi) * cos(theta)
r = torch.sqrt(x**2 + y**2 + z**2)
theta = torch.atan2(x / r, z / r)
phi = torch.acos(y / r)
euclidean_tensor = torch.stack((theta, phi, r), dim=-1)
return euclidean_tensor
@torch.jit.script
def euclidean_to_spherical_zbuffer(euclidean_tensor: torch.Tensor) -> torch.Tensor:
pitch = torch.asin(euclidean_tensor[..., 1])
yaw = torch.atan2(euclidean_tensor[..., 0], euclidean_tensor[..., -1])
z = euclidean_tensor[..., 2] # Extract zbuffer depth
euclidean_tensor = torch.stack((pitch, yaw, z), dim=-1)
return euclidean_tensor
@torch.jit.script
def unproject_points(
depth: torch.Tensor, camera_intrinsics: torch.Tensor
) -> torch.Tensor:
"""
Unprojects a batch of depth maps to 3D point clouds using camera intrinsics.
Args:
depth (torch.Tensor): Batch of depth maps of shape (B, 1, H, W).
camera_intrinsics (torch.Tensor): Camera intrinsic matrix of shape (B, 3, 3).
Returns:
torch.Tensor: Batch of 3D point clouds of shape (B, 3, H, W).
"""
batch_size, _, height, width = depth.shape
device = depth.device
# Create pixel grid
y_coords, x_coords = torch.meshgrid(
torch.arange(height, device=device),
torch.arange(width, device=device),
indexing="ij",
)
pixel_coords = torch.stack((x_coords, y_coords), dim=-1) # (H, W, 2)
# Get homogeneous coords (u v 1)
pixel_coords_homogeneous = torch.cat(
(pixel_coords, torch.ones((height, width, 1), device=device)), dim=-1
)
pixel_coords_homogeneous = pixel_coords_homogeneous.permute(2, 0, 1).flatten(
1
) # (3, H*W)
# Apply K^-1 @ (u v 1): [B, 3, 3] @ [3, H*W] -> [B, 3, H*W]
unprojected_points = torch.matmul(
torch.inverse(camera_intrinsics), pixel_coords_homogeneous
) # (B, 3, H*W)
unprojected_points = unprojected_points.view(
batch_size, 3, height, width
) # (B, 3, H, W)
unprojected_points = unprojected_points * depth # (B, 3, H, W)
return unprojected_points
@torch.jit.script
def project_points(
points_3d: torch.Tensor,
intrinsic_matrix: torch.Tensor,
image_shape: Tuple[int, int],
) -> torch.Tensor:
# Project 3D points onto the image plane via intrinsics (u v w) = (x y z) @ K^T
points_2d = torch.matmul(points_3d, intrinsic_matrix.transpose(1, 2))
# Normalize projected points: (u v w) -> (u / w, v / w, 1)
points_2d = points_2d[..., :2] / points_2d[..., 2:]
points_2d = points_2d.int()
# points need to be inside the image (can it diverge onto all points out???)
valid_mask = (
(points_2d[..., 0] >= 0)
& (points_2d[..., 0] < image_shape[1])
& (points_2d[..., 1] >= 0)
& (points_2d[..., 1] < image_shape[0])
)
# Calculate the flat indices of the valid pixels
flat_points_2d = points_2d[..., 0] + points_2d[..., 1] * image_shape[1]
flat_indices = flat_points_2d.long()
# Create depth maps and counts using scatter_add, (B, H, W)
depth_maps = torch.zeros(
[points_3d.shape[0], *image_shape], device=points_3d.device
)
counts = torch.zeros([points_3d.shape[0], *image_shape], device=points_3d.device)
# Loop over batches to apply masks and accumulate depth/count values
for i in range(points_3d.shape[0]):
valid_indices = flat_indices[i, valid_mask[i]]
depth_maps[i].view(-1).scatter_add_(
0, valid_indices, points_3d[i, valid_mask[i], 2]
)
counts[i].view(-1).scatter_add_(
0, valid_indices, torch.ones_like(points_3d[i, valid_mask[i], 2])
)
# Calculate mean depth for each pixel in each batch
mean_depth_maps = depth_maps / counts.clamp(min=1.0)
return mean_depth_maps.reshape(-1, 1, *image_shape) # (B, 1, H, W)
@torch.jit.script
def downsample(data: torch.Tensor, downsample_factor: int = 2):
N, _, H, W = data.shape
data = data.view(
N,
H // downsample_factor,
downsample_factor,
W // downsample_factor,
downsample_factor,
1,
)
data = data.permute(0, 1, 3, 5, 2, 4).contiguous()
data = data.view(-1, downsample_factor * downsample_factor)
data_tmp = torch.where(data == 0.0, 1e5 * torch.ones_like(data), data)
data = torch.min(data_tmp, dim=-1).values
data = data.view(N, 1, H // downsample_factor, W // downsample_factor)
data = torch.where(data > 1000, torch.zeros_like(data), data)
return data
@torch.jit.script
def flat_interpolate(
flat_tensor: torch.Tensor,
old: Tuple[int, int],
new: Tuple[int, int],
antialias: bool = True,
mode: str = "bilinear",
) -> torch.Tensor:
if old[0] == new[0] and old[1] == new[1]:
return flat_tensor
tensor = flat_tensor.view(flat_tensor.shape[0], old[0], old[1], -1).permute(
0, 3, 1, 2
) # b c h w
tensor_interp = F.interpolate(
tensor,
size=(new[0], new[1]),
mode=mode,
align_corners=False,
antialias=antialias,
)
flat_tensor_interp = tensor_interp.view(
flat_tensor.shape[0], -1, new[0] * new[1]
).permute(
0, 2, 1
) # b (h w) c
return flat_tensor_interp.contiguous()
@torch.jit.script
def dilate(image, kernel_size: int | tuple[int, int]):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
device, dtype = image.device, image.dtype
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
kernel = torch.ones((1, 1, *kernel_size), dtype=torch.float32, device=image.device)
dilated_image = F.conv2d(image.float(), kernel, padding=padding, stride=1)
dilated_image = torch.where(
dilated_image > 0,
torch.tensor(1.0, device=device),
torch.tensor(0.0, device=device),
)
return dilated_image.to(dtype)
@torch.jit.script
def erode(image, kernel_size: int | tuple[int, int]):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
device, dtype = image.device, image.dtype
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
kernel = torch.ones((1, 1, *kernel_size), dtype=torch.float32, device=image.device)
eroded_image = F.conv2d(image.float(), kernel, padding=padding, stride=1)
eroded_image = torch.where(
eroded_image == (kernel_size[0] * kernel_size[1]),
torch.tensor(1.0, device=device),
torch.tensor(0.0, device=device),
)
return eroded_image.to(dtype)
@torch.jit.script
def iou(mask1: torch.Tensor, mask2: torch.Tensor) -> torch.Tensor:
device = mask1.device
# Ensure the masks are binary (0 or 1)
mask1 = mask1.to(torch.bool)
mask2 = mask2.to(torch.bool)
# Compute intersection and union
intersection = torch.sum(mask1 & mask2).to(torch.float32)
union = torch.sum(mask1 | mask2).to(torch.float32)
# Compute IoU
iou = intersection / union.clip(min=1.0)
return iou
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/misc.py
================================================
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
from functools import wraps
from time import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from scipy import interpolate
@torch.jit.script
def max_stack(tensors: list[torch.Tensor]) -> torch.Tensor:
if len(tensors) == 1:
return tensors[0]
return torch.stack(tensors, dim=-1).max(dim=-1).values
def last_stack(tensors: list[torch.Tensor]) -> torch.Tensor:
return tensors[-1]
def first_stack(tensors: list[torch.Tensor]) -> torch.Tensor:
return tensors[0]
@torch.jit.script
def softmax_stack(
tensors: list[torch.Tensor], temperature: float = 1.0
) -> torch.Tensor:
if len(tensors) == 1:
return tensors[0]
return F.softmax(torch.stack(tensors, dim=-1) / temperature, dim=-1).sum(dim=-1)
@torch.jit.script
def mean_stack(tensors: list[torch.Tensor]) -> torch.Tensor:
if len(tensors) == 1:
return tensors[0]
return torch.stack(tensors, dim=-1).mean(dim=-1)
@torch.jit.script
def sum_stack(tensors: list[torch.Tensor]) -> torch.Tensor:
if len(tensors) == 1:
return tensors[0]
return torch.stack(tensors, dim=-1).sum(dim=-1)
def convert_module_to_f16(l):
"""
Convert primitive modules to float16.
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
def convert_module_to_f32(l):
"""
Convert primitive modules to float32, undoing convert_module_to_f16().
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.float()
if l.bias is not None:
l.bias.data = l.bias.data.float()
def format_seconds(seconds):
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
return f"{hours:d}:{minutes:02d}:{seconds:02d}"
def get_params(module, lr, wd):
skip_list = {}
skip_keywords = {}
if hasattr(module, "no_weight_decay"):
skip_list = module.no_weight_decay()
if hasattr(module, "no_weight_decay_keywords"):
skip_keywords = module.no_weight_decay_keywords()
has_decay = []
no_decay = []
for name, param in module.named_parameters():
if not param.requires_grad:
continue # frozen weights
if (
(name in skip_list)
or any((kw in name for kw in skip_keywords))
or len(param.shape) == 1
or name.endswith(".gamma")
or name.endswith(".beta")
or name.endswith(".bias")
):
# if (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1:
no_decay.append(param)
else:
has_decay.append(param)
group1 = {
"params": has_decay,
"weight_decay": wd,
"lr": lr,
"weight_decay_init": wd,
"weight_decay_base": wd,
# "lr_init": lr,
"lr_base": lr,
}
group2 = {
"params": no_decay,
"weight_decay": 0.0,
"lr": lr,
"weight_decay_init": 0.0,
"weight_decay_base": 0.0,
"weight_decay_final": 0.0,
# "lr_init": lr,
"lr_base": lr,
}
return [group1, group2], [lr, lr]
def get_num_layer_for_swin(var_name, num_max_layer, layers_per_stage):
if var_name in ("cls_token", "mask_token", "pos_embed", "absolute_pos_embed"):
return 0
elif var_name.startswith("patch_embed"):
return 0
elif var_name.startswith("layers"):
if var_name.split(".")[2] == "blocks":
stage_id = int(var_name.split(".")[1])
layer_id = int(var_name.split(".")[3]) + sum(layers_per_stage[:stage_id])
return layer_id + 1
elif var_name.split(".")[2] == "downsample":
stage_id = int(var_name.split(".")[1])
layer_id = sum(layers_per_stage[: stage_id + 1])
return layer_id
else:
return num_max_layer - 1
def get_params_layerdecayswin(module, lr, wd, ld):
skip_list = {}
skip_keywords = {}
if hasattr(module, "no_weight_decay"):
skip_list = module.no_weight_decay()
if hasattr(module, "no_weight_decay_keywords"):
skip_keywords = module.no_weight_decay_keywords()
layers_per_stage = module.depths
num_layers = sum(layers_per_stage) + 1
lrs = []
params = []
for name, param in module.named_parameters():
if not param.requires_grad:
print(f"{name} frozen")
continue # frozen weights
layer_id = get_num_layer_for_swin(name, num_layers, layers_per_stage)
lr_cur = lr * ld ** (num_layers - layer_id - 1)
# if (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1 or name.endswith(".bias"):
if (name in skip_list) or any((kw in name for kw in skip_keywords)):
wd_cur = 0.0
else:
wd_cur = wd
params.append({"params": param, "weight_decay": wd_cur, "lr": lr_cur})
lrs.append(lr_cur)
return params, lrs
def log(t, eps: float = 1e-5):
return torch.log(t.clamp(min=eps))
def l2norm(t):
return F.normalize(t, dim=-1)
def exists(val):
return val is not None
def identity(t, *args, **kwargs):
return t
def divisible_by(numer, denom):
return (numer % denom) == 0
def first(arr, d=None):
if len(arr) == 0:
return d
return arr[0]
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
def _many(fn):
@wraps(fn)
def inner(tensors, pattern, **kwargs):
return (fn(tensor, pattern, **kwargs) for tensor in tensors)
return inner
rearrange_many = _many(rearrange)
repeat_many = _many(repeat)
reduce_many = _many(reduce)
def load_pretrained(state_dict, checkpoint):
checkpoint_model = checkpoint["model"]
if any([True if "encoder." in k else False for k in checkpoint_model.keys()]):
checkpoint_model = {
k.replace("encoder.", ""): v
for k, v in checkpoint_model.items()
if k.startswith("encoder.")
}
print("Detect pre-trained model, remove [encoder.] prefix.")
else:
print("Detect non-pre-trained model, pass without doing anything.")
print(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........")
checkpoint = load_checkpoint_swin(state_dict, checkpoint_model)
def load_checkpoint_swin(model, checkpoint_model):
state_dict = model.state_dict()
# Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size
all_keys = list(checkpoint_model.keys())
for key in all_keys:
if "relative_position_bias_table" in key:
relative_position_bias_table_pretrained = checkpoint_model[key]
relative_position_bias_table_current = state_dict[key]
L1, nH1 = relative_position_bias_table_pretrained.size()
L2, nH2 = relative_position_bias_table_current.size()
if nH1 != nH2:
print(f"Error in loading {key}, passing......")
else:
if L1 != L2:
print(f"{key}: Interpolate relative_position_bias_table using geo.")
src_size = int(L1**0.5)
dst_size = int(L2**0.5)
def geometric_progression(a, r, n):
return a * (1.0 - r**n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
# if q > 1.090307:
# q = 1.090307
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
print("Original positions = %s" % str(x))
print("Target positions = %s" % str(dx))
all_rel_pos_bias = []
for i in range(nH1):
z = (
relative_position_bias_table_pretrained[:, i]
.view(src_size, src_size)
.float()
.numpy()
)
f_cubic = interpolate.interp2d(x, y, z, kind="cubic")
all_rel_pos_bias.append(
torch.Tensor(f_cubic(dx, dy))
.contiguous()
.view(-1, 1)
.to(relative_position_bias_table_pretrained.device)
)
new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
checkpoint_model[key] = new_rel_pos_bias
# delete relative_position_index since we always re-init it
relative_position_index_keys = [
k for k in checkpoint_model.keys() if "relative_position_index" in k
]
for k in relative_position_index_keys:
del checkpoint_model[k]
# delete relative_coords_table since we always re-init it
relative_coords_table_keys = [
k for k in checkpoint_model.keys() if "relative_coords_table" in k
]
for k in relative_coords_table_keys:
del checkpoint_model[k]
# # re-map keys due to name change
rpe_mlp_keys = [k for k in checkpoint_model.keys() if "cpb_mlp" in k]
for k in rpe_mlp_keys:
checkpoint_model[k.replace("cpb_mlp", "rpe_mlp")] = checkpoint_model.pop(k)
# delete attn_mask since we always re-init it
attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k]
for k in attn_mask_keys:
del checkpoint_model[k]
encoder_keys = [k for k in checkpoint_model.keys() if k.startswith("encoder.")]
for k in encoder_keys:
checkpoint_model[k.replace("encoder.", "")] = checkpoint_model.pop(k)
return checkpoint_model
def add_padding_metas(out, image_metas):
device = out.device
# left, right, top, bottom
paddings = [img_meta.get("paddings", [0] * 4) for img_meta in image_metas]
paddings = torch.stack(paddings).to(device)
outs = [F.pad(o, padding, value=0.0) for padding, o in zip(paddings, out)]
return torch.stack(outs)
# left, right, top, bottom
def remove_padding(out, paddings):
H, W = out.shape[-2:]
outs = [
o[..., padding[2] : H - padding[3], padding[0] : W - padding[1]]
for padding, o in zip(paddings, out)
]
return torch.stack(outs)
def remove_padding_metas(out, image_metas):
B, C, H, W = out.shape
device = out.device
# left, right, top, bottom
paddings = [
torch.tensor(img_meta.get("paddings", [0] * 4)) for img_meta in image_metas
]
return remove_padding(out, paddings)
def ssi_helper(tensor1, tensor2):
stability_mat = 1e-4 * torch.eye(2, device=tensor1.device)
tensor2_one = torch.stack([tensor2, torch.ones_like(tensor2)], dim=1)
scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (
tensor2_one.T @ tensor1.unsqueeze(1)
)
scale, shift = scale_shift.squeeze().chunk(2, dim=0)
return scale, shift
def calculate_mean_values(names, values):
# Create a defaultdict to store sum and count for each name
name_values = {name: {} for name in names}
# Iterate through the lists and accumulate values for each name
for name, value in zip(names, values):
name_values[name]["sum"] = name_values[name].get("sum", 0.0) + value
name_values[name]["count"] = name_values[name].get("count", 0.0) + 1
# Calculate mean values and create the output dictionary
output_dict = {
name: name_values[name]["sum"] / name_values[name]["count"]
for name in name_values
}
return output_dict
def remove_leading_dim(infos):
if isinstance(infos, dict):
return {k: remove_leading_dim(v) for k, v in infos.items()}
elif isinstance(infos, torch.Tensor):
return infos.squeeze(0)
else:
return infos
def recursive_index(infos, index):
if isinstance(infos, dict):
return {k: recursive_index(v, index) for k, v in infos.items()}
elif isinstance(infos, torch.Tensor):
return infos[index]
else:
return infos
def to_cpu(infos):
if isinstance(infos, dict):
return {k: to_cpu(v) for k, v in infos.items()}
elif isinstance(infos, torch.Tensor):
return infos.detach()
else:
return infos
def recursive_to(infos, device, non_blocking, cls):
if isinstance(infos, dict):
return {k: recursive_to(v, device, non_blocking, cls) for k, v in infos.items()}
elif isinstance(infos, list):
return [recursive_to(v, device, non_blocking, cls) for v in infos]
elif isinstance(infos, cls):
return infos.to(device, non_blocking=non_blocking)
else:
return infos
def masked_mean(
data: torch.Tensor,
mask: torch.Tensor | None = None,
dim: list[int] | None = None,
keepdim: bool = False,
) -> torch.Tensor:
dim = dim if dim is not None else list(range(data.dim()))
if mask is None:
return data.mean(dim=dim, keepdim=keepdim)
mask = mask.float()
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
mask_sum, min=1.0
)
return mask_mean.squeeze(dim) if not keepdim else mask_mean
class ProfileMethod:
def __init__(self, model, func_name, track_statistics=True, verbose=False):
self.model = model
self.func_name = func_name
self.verbose = verbose
self.track_statistics = track_statistics
self.timings = []
def __enter__(self):
# Start timing
if self.verbose:
if torch.cuda.is_available():
torch.cuda.synchronize()
self.start_time = time()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.verbose:
if torch.cuda.is_available():
torch.cuda.synchronize()
self.end_time = time()
elapsed_time = self.end_time - self.start_time
self.timings.append(elapsed_time)
if self.track_statistics and len(self.timings) > 25:
# Compute statistics if tracking
timings_array = np.array(self.timings)
mean_time = np.mean(timings_array)
std_time = np.std(timings_array)
quantiles = np.percentile(timings_array, [0, 25, 50, 75, 100])
print(
f"{self.model.__class__.__name__}.{self.func_name} took {elapsed_time:.4f} seconds"
)
print(f"Mean Time: {mean_time:.4f} seconds")
print(f"Std Time: {std_time:.4f} seconds")
print(
f"Quantiles: Min={quantiles[0]:.4f}, 25%={quantiles[1]:.4f}, Median={quantiles[2]:.4f}, 75%={quantiles[3]:.4f}, Max={quantiles[4]:.4f}"
)
else:
print(
f"{self.model.__class__.__name__}.{self.func_name} took {elapsed_time:.4f} seconds"
)
def profile_method(track_statistics=True, verbose=False):
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
with ProfileMethod(self, func.__name__, track_statistics, verbose):
return func(self, *args, **kwargs)
return wrapper
return decorator
class ProfileFunction:
def __init__(self, func_name, track_statistics=True, verbose=False):
self.func_name = func_name
self.verbose = verbose
self.track_statistics = track_statistics
self.timings = []
def __enter__(self):
# Start timing
if self.verbose:
if torch.cuda.is_available():
torch.cuda.synchronize()
self.start_time = time()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.verbose:
if torch.cuda.is_available():
torch.cuda.synchronize()
self.end_time = time()
elapsed_time = self.end_time - self.start_time
self.timings.append(elapsed_time)
if self.track_statistics and len(self.timings) > 25:
# Compute statistics if tracking
timings_array = np.array(self.timings)
mean_time = np.mean(timings_array)
std_time = np.std(timings_array)
quantiles = np.percentile(timings_array, [0, 25, 50, 75, 100])
print(f"{self.func_name} took {elapsed_time:.4f} seconds")
print(f"Mean Time: {mean_time:.4f} seconds")
print(f"Std Time: {std_time:.4f} seconds")
print(
f"Quantiles: Min={quantiles[0]:.4f}, 25%={quantiles[1]:.4f}, Median={quantiles[2]:.4f}, 75%={quantiles[3]:.4f}, Max={quantiles[4]:.4f}"
)
else:
print(f"{self.func_name} took {elapsed_time:.4f} seconds")
def profile_function(track_statistics=True, verbose=False):
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
with ProfileFunction(func.__name__, track_statistics, verbose):
return func(self, *args, **kwargs)
return wrapper
return decorator
def squeeze_list(nested_list, dim, current_dim=0):
# If the current dimension is in the list of indices to squeeze
if isinstance(nested_list, list) and len(nested_list) == 1 and current_dim == dim:
return squeeze_list(nested_list[0], dim, current_dim + 1)
elif isinstance(nested_list, list):
return [squeeze_list(item, dim, current_dim + 1) for item in nested_list]
else:
return nested_list
def match_gt(tensor1, tensor2, padding1, padding2, mode: str = "bilinear"):
"""
Transform each item in tensor1 batch to match tensor2's dimensions and padding.
Args:
tensor1 (torch.Tensor): The input tensor to transform, with shape (batch_size, channels, height, width).
tensor2 (torch.Tensor): The target tensor to match, with shape (batch_size, channels, height, width).
padding1 (tuple): Padding applied to tensor1 (pad_left, pad_right, pad_top, pad_bottom).
padding2 (tuple): Desired padding to be applied to match tensor2 (pad_left, pad_right, pad_top, pad_bottom).
Returns:
torch.Tensor: The batch of transformed tensors matching tensor2's size and padding.
"""
# Get batch size
batch_size = len(tensor1)
src_dtype = tensor1[0].dtype
tgt_dtype = tensor2[0].dtype
# List to store transformed tensors
transformed_tensors = []
for i in range(batch_size):
item1 = tensor1[i]
item2 = tensor2[i]
h1, w1 = item1.shape[1], item1.shape[2]
pad1_l, pad1_r, pad1_t, pad1_b = (
padding1[i] if padding1 is not None else (0, 0, 0, 0)
)
pad2_l, pad2_r, pad2_t, pad2_b = (
padding2[i] if padding2 is not None else (0, 0, 0, 0)
)
item1_unpadded = item1[:, pad1_t : h1 - pad1_b, pad1_l : w1 - pad1_r]
h2, w2 = (
item2.shape[1] - pad2_t - pad2_b,
item2.shape[2] - pad2_l - pad2_r,
)
item1_resized = F.interpolate(
item1_unpadded.unsqueeze(0).to(tgt_dtype), size=(h2, w2), mode=mode
)
item1_padded = F.pad(item1_resized, (pad2_l, pad2_r, pad2_t, pad2_b))
transformed_tensors.append(item1_padded)
transformed_batch = torch.cat(transformed_tensors)
return transformed_batch.to(src_dtype)
def match_intrinsics(K1, tensor1, tensor2, padding1, padding2):
"""
Adjust camera intrinsics K1 to match the size and padding transformation applied to tensor1
so that it corresponds correctly to tensor2.
Args:
K1 (torch.Tensor): The camera intrinsics matrix for tensor1, shape (batch_size, 3, 3).
tensor1 (torch.Tensor): The original image tensor, shape (batch_size, C, H1, W1).
tensor2 (torch.Tensor): The target image tensor, shape (batch_size, C, H2, W2).
padding1 (list of tuples): List of padding applied to tensor1 (pad_left, pad_right, pad_top, pad_bottom).
padding2 (list of tuples): Desired padding to be applied to match tensor2 (pad_left, pad_right, pad_top, pad_bottom).
Returns:
torch.Tensor: The adjusted intrinsics matrix of shape (batch_size, 3, 3).
"""
batch_size = K1.shape[0]
K1_new = K1.clone()
for i in range(batch_size):
h1, w1 = tensor1.shape[2], tensor1.shape[3]
h2, w2 = tensor2.shape[2], tensor2.shape[3]
# Remove original padding
pad1_l, pad1_r, pad1_t, pad1_b = (
padding1[i] if padding1 is not None else (0, 0, 0, 0)
)
w1_unpadded, h1_unpadded = w1 - (pad1_l + pad1_r), h1 - (pad1_t + pad1_b)
# Compute new image size after removing original padding
pad2_l, pad2_r, pad2_t, pad2_b = (
padding2[i] if padding2 is not None else (0, 0, 0, 0)
)
w2_unpadded, h2_unpadded = w2 - (pad2_l + pad2_r), h2 - (pad2_t + pad2_b)
# Compute scaling factors
scale_x = w2_unpadded / w1_unpadded
scale_y = h2_unpadded / h1_unpadded
# Update focal length (fx, fy) and principal point (cx, cy)
K1_new[i, 0, 0] *= scale_x # fx
K1_new[i, 1, 1] *= scale_y # fy
K1_new[i, 0, 2] = (K1[i, 0, 2] - pad1_l) * scale_x + pad2_l # cx
K1_new[i, 1, 2] = (K1[i, 1, 2] - pad1_t) * scale_y + pad2_t # cy
return K1_new
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/positional_embedding.py
================================================
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
from math import pi
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange, repeat
class PositionEmbeddingSine(nn.Module):
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * pi
self.scale = scale
def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
if mask is None:
mask = torch.zeros(
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (
2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats
)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def __repr__(self, _repr_indent=4):
head = "Positional encoding " + self.__class__.__name__
body = [
"num_pos_feats: {}".format(self.num_pos_feats),
"temperature: {}".format(self.temperature),
"normalize: {}".format(self.normalize),
"scale: {}".format(self.scale),
]
# _repr_indent = 4
lines = [head] + [" " * _repr_indent + line for line in body]
return "\n".join(lines)
class LearnedSinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x):
x = rearrange(x, "b -> b 1")
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1)
return fouriered
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class VisionRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
pt_seq_len,
ft_seq_len=None,
custom_freqs=None,
freqs_for="lang",
theta=10000,
max_freq=10,
num_freqs=1,
):
super().__init__()
if custom_freqs:
freqs = custom_freqs
elif freqs_for == "lang":
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant":
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f"unknown modality {freqs_for}")
if ft_seq_len is None:
ft_seq_len = pt_seq_len
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
freqs_h = torch.einsum("..., f -> ... f", t, freqs)
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
freqs_w = torch.einsum("..., f -> ... f", t, freqs)
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
self.register_buffer("freqs_cos", freqs.cos())
self.register_buffer("freqs_sin", freqs.sin())
print("======== shape of rope freq", self.freqs_cos.shape, "========")
def forward(self, t, start_index=0):
rot_dim = self.freqs_cos.shape[-1]
end_index = start_index + rot_dim
assert (
rot_dim <= t.shape[-1]
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
t_left, t, t_right = (
t[..., :start_index],
t[..., start_index:end_index],
t[..., end_index:],
)
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
return torch.cat((t_left, t, t_right), dim=-1)
class VisionRotaryEmbeddingFast(nn.Module):
def __init__(
self,
dim,
pt_seq_len,
ft_seq_len=None,
custom_freqs=None,
freqs_for="lang",
theta=10000,
max_freq=10,
num_freqs=1,
):
super().__init__()
if custom_freqs:
freqs = custom_freqs
elif freqs_for == "lang":
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant":
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f"unknown modality {freqs_for}")
if ft_seq_len is None:
ft_seq_len = pt_seq_len
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
freqs = torch.einsum("..., f -> ... f", t, freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
self.register_buffer("freqs_cos", freqs_cos)
self.register_buffer("freqs_sin", freqs_sin)
def forward(self, t):
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
from math import log2
def generate_fourier_features(
x: torch.Tensor,
dim: int = 512,
max_freq: int = 64,
use_cos: bool = False,
use_log: bool = False,
cat_orig: bool = False,
):
x_orig = x
device, dtype, input_dim = x.device, x.dtype, x.shape[-1]
num_bands = dim // (2 * input_dim) if use_cos else dim // input_dim
if use_log:
scales = 2.0 ** torch.linspace(
0.0, log2(max_freq), steps=num_bands, device=device, dtype=dtype
)
else:
scales = torch.linspace(
1.0, max_freq / 2, num_bands, device=device, dtype=dtype
)
x = x.unsqueeze(-1)
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
x = x * scales * pi
x = torch.cat(
(
[x.sin(), x.cos()]
if use_cos
else [
x.sin(),
]
),
dim=-1,
)
x = x.flatten(-2)
if cat_orig:
return torch.cat((x, x_orig), dim=-1)
return x
# from PIL import Image
# from unidepth.utils import image_grid, colorize
# if __name__ == "__main__":
# H, W = 512, 512
# resolution = 128
# mesh = torch.meshgrid(torch.linspace(-1, 1, H), torch.linspace(-1, 1, W))
# mesh = torch.stack(mesh, dim=0).unsqueeze(0)
# mesh = mesh.view(1, 2, -1).permute(0, 2, 1)
# features = generate_fourier_features(mesh, dim=32, max_freq=resolution, use_log=True)
# channels = features.shape[-1]
# print(features.shape)
# features = features[0].view(H, W, channels).permute(2, 0, 1).numpy()
# Image.fromarray(image_grid([colorize(1+x, 0.0, 2.0, "viridis") for x in features], rows=8, cols=4)).save(f"tmp_{resolution}.png")
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/sht.py
================================================
"""Real spherical harmonics in Cartesian form for PyTorch.
This is an autogenerated file. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
"""
import torch
def rsh_cart_0(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 0.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,1) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
return torch.stack(
[
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
],
-1,
)
def rsh_cart_1(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 1.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,4) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
x = xyz[..., 0]
y = xyz[..., 1]
z = xyz[..., 2]
return torch.stack(
[
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
-0.48860251190292 * y,
0.48860251190292 * z,
-0.48860251190292 * x,
],
-1,
)
def rsh_cart_2(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 2.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,9) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
x = xyz[..., 0]
y = xyz[..., 1]
z = xyz[..., 2]
x2 = x**2
y2 = y**2
z2 = z**2
xy = x * y
xz = x * z
yz = y * z
return torch.stack(
[
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
-0.48860251190292 * y,
0.48860251190292 * z,
-0.48860251190292 * x,
1.09254843059208 * xy,
-1.09254843059208 * yz,
0.94617469575756 * z2 - 0.31539156525252,
-1.09254843059208 * xz,
0.54627421529604 * x2 - 0.54627421529604 * y2,
],
-1,
)
def rsh_cart_3(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 3.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,16) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
x = xyz[..., 0]
y = xyz[..., 1]
z = xyz[..., 2]
x2 = x**2
y2 = y**2
z2 = z**2
xy = x * y
xz = x * z
yz = y * z
return torch.stack(
[
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
-0.48860251190292 * y,
0.48860251190292 * z,
-0.48860251190292 * x,
1.09254843059208 * xy,
-1.09254843059208 * yz,
0.94617469575756 * z2 - 0.31539156525252,
-1.09254843059208 * xz,
0.54627421529604 * x2 - 0.54627421529604 * y2,
-0.590043589926644 * y * (3.0 * x2 - y2),
2.89061144264055 * xy * z,
0.304697199642977 * y * (1.5 - 7.5 * z2),
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
0.304697199642977 * x * (1.5 - 7.5 * z2),
1.44530572132028 * z * (x2 - y2),
-0.590043589926644 * x * (x2 - 3.0 * y2),
],
-1,
)
def rsh_cart_4(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 4.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,25) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
x = xyz[..., 0]
y = xyz[..., 1]
z = xyz[..., 2]
x2 = x**2
y2 = y**2
z2 = z**2
xy = x * y
xz = x * z
yz = y * z
x4 = x2**2
y4 = y2**2
z4 = z2**2
return torch.stack(
[
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
-0.48860251190292 * y,
0.48860251190292 * z,
-0.48860251190292 * x,
1.09254843059208 * xy,
-1.09254843059208 * yz,
0.94617469575756 * z2 - 0.31539156525252,
-1.09254843059208 * xz,
0.54627421529604 * x2 - 0.54627421529604 * y2,
-0.590043589926644 * y * (3.0 * x2 - y2),
2.89061144264055 * xy * z,
0.304697199642977 * y * (1.5 - 7.5 * z2),
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
0.304697199642977 * x * (1.5 - 7.5 * z2),
1.44530572132028 * z * (x2 - y2),
-0.590043589926644 * x * (x2 - 3.0 * y2),
2.5033429417967 * xy * (x2 - y2),
-1.77013076977993 * yz * (3.0 * x2 - y2),
0.126156626101008 * xy * (52.5 * z2 - 7.5),
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
1.48099765681286
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 0.952069922236839 * z2
+ 0.317356640745613,
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
-1.77013076977993 * xz * (x2 - 3.0 * y2),
-3.75501441269506 * x2 * y2
+ 0.625835735449176 * x4
+ 0.625835735449176 * y4,
],
-1,
)
def rsh_cart_5(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 5.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,36) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
x = xyz[..., 0]
y = xyz[..., 1]
z = xyz[..., 2]
x2 = x**2
y2 = y**2
z2 = z**2
xy = x * y
xz = x * z
yz = y * z
x4 = x2**2
y4 = y2**2
z4 = z2**2
return torch.stack(
[
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
-0.48860251190292 * y,
0.48860251190292 * z,
-0.48860251190292 * x,
1.09254843059208 * xy,
-1.09254843059208 * yz,
0.94617469575756 * z2 - 0.31539156525252,
-1.09254843059208 * xz,
0.54627421529604 * x2 - 0.54627421529604 * y2,
-0.590043589926644 * y * (3.0 * x2 - y2),
2.89061144264055 * xy * z,
0.304697199642977 * y * (1.5 - 7.5 * z2),
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
0.304697199642977 * x * (1.5 - 7.5 * z2),
1.44530572132028 * z * (x2 - y2),
-0.590043589926644 * x * (x2 - 3.0 * y2),
2.5033429417967 * xy * (x2 - y2),
-1.77013076977993 * yz * (3.0 * x2 - y2),
0.126156626101008 * xy * (52.5 * z2 - 7.5),
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
1.48099765681286
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 0.952069922236839 * z2
+ 0.317356640745613,
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
-1.77013076977993 * xz * (x2 - 3.0 * y2),
-3.75501441269506 * x2 * y2
+ 0.625835735449176 * x4
+ 0.625835735449176 * y4,
-0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
8.30264925952416 * xy * z * (x2 - y2),
0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
0.241571547304372
* y
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
),
-1.24747010616985 * z * (1.5 * z2 - 0.5)
+ 1.6840846433293
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 0.498988042467941 * z,
0.241571547304372
* x
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
),
0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
-0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
],
-1,
)
def rsh_cart_6(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 6.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,49) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
x = xyz[..., 0]
y = xyz[..., 1]
z = xyz[..., 2]
x2 = x**2
y2 = y**2
z2 = z**2
xy = x * y
xz = x * z
yz = y * z
x4 = x2**2
y4 = y2**2
z4 = z2**2
return torch.stack(
[
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
-0.48860251190292 * y,
0.48860251190292 * z,
-0.48860251190292 * x,
1.09254843059208 * xy,
-1.09254843059208 * yz,
0.94617469575756 * z2 - 0.31539156525252,
-1.09254843059208 * xz,
0.54627421529604 * x2 - 0.54627421529604 * y2,
-0.590043589926644 * y * (3.0 * x2 - y2),
2.89061144264055 * xy * z,
0.304697199642977 * y * (1.5 - 7.5 * z2),
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
0.304697199642977 * x * (1.5 - 7.5 * z2),
1.44530572132028 * z * (x2 - y2),
-0.590043589926644 * x * (x2 - 3.0 * y2),
2.5033429417967 * xy * (x2 - y2),
-1.77013076977993 * yz * (3.0 * x2 - y2),
0.126156626101008 * xy * (52.5 * z2 - 7.5),
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
1.48099765681286
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 0.952069922236839 * z2
+ 0.317356640745613,
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
-1.77013076977993 * xz * (x2 - 3.0 * y2),
-3.75501441269506 * x2 * y2
+ 0.625835735449176 * x4
+ 0.625835735449176 * y4,
-0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
8.30264925952416 * xy * z * (x2 - y2),
0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
0.241571547304372
* y
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
),
-1.24747010616985 * z * (1.5 * z2 - 0.5)
+ 1.6840846433293
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 0.498988042467941 * z,
0.241571547304372
* x
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
),
0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
-0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
4.09910463115149 * x**4 * xy
- 13.6636821038383 * xy**3
+ 4.09910463115149 * xy * y**4,
-2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
0.00584892228263444
* y
* (3.0 * x2 - y2)
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
0.0701870673916132
* xy
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
),
0.221950995245231
* y
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
),
-1.48328138624466
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ 1.86469659985043
* z
* (
-1.33333333333333 * z * (1.5 * z2 - 0.5)
+ 1.8
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 0.533333333333333 * z
)
+ 0.953538034014426 * z2
- 0.317846011338142,
0.221950995245231
* x
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
),
0.0350935336958066
* (x2 - y2)
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
),
0.00584892228263444
* x
* (x2 - 3.0 * y2)
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
-2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
0.683184105191914 * x2**3
+ 10.2477615778787 * x2 * y4
- 10.2477615778787 * x4 * y2
- 0.683184105191914 * y2**3,
],
-1,
)
def rsh_cart_7(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 7.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,64) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
x = xyz[..., 0]
y = xyz[..., 1]
z = xyz[..., 2]
x2 = x**2
y2 = y**2
z2 = z**2
xy = x * y
xz = x * z
yz = y * z
x4 = x2**2
y4 = y2**2
z4 = z2**2
return torch.stack(
[
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
-0.48860251190292 * y,
0.48860251190292 * z,
-0.48860251190292 * x,
1.09254843059208 * xy,
-1.09254843059208 * yz,
0.94617469575756 * z2 - 0.31539156525252,
-1.09254843059208 * xz,
0.54627421529604 * x2 - 0.54627421529604 * y2,
-0.590043589926644 * y * (3.0 * x2 - y2),
2.89061144264055 * xy * z,
0.304697199642977 * y * (1.5 - 7.5 * z2),
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
0.304697199642977 * x * (1.5 - 7.5 * z2),
1.44530572132028 * z * (x2 - y2),
-0.590043589926644 * x * (x2 - 3.0 * y2),
2.5033429417967 * xy * (x2 - y2),
-1.77013076977993 * yz * (3.0 * x2 - y2),
0.126156626101008 * xy * (52.5 * z2 - 7.5),
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
1.48099765681286
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 0.952069922236839 * z2
+ 0.317356640745613,
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
-1.77013076977993 * xz * (x2 - 3.0 * y2),
-3.75501441269506 * x2 * y2
+ 0.625835735449176 * x4
+ 0.625835735449176 * y4,
-0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
8.30264925952416 * xy * z * (x2 - y2),
0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
0.241571547304372
* y
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
),
-1.24747010616985 * z * (1.5 * z2 - 0.5)
+ 1.6840846433293
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 0.498988042467941 * z,
0.241571547304372
* x
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
),
0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
-0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
4.09910463115149 * x**4 * xy
- 13.6636821038383 * xy**3
+ 4.09910463115149 * xy * y**4,
-2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
0.00584892228263444
* y
* (3.0 * x2 - y2)
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
0.0701870673916132
* xy
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
),
0.221950995245231
* y
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
),
-1.48328138624466
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ 1.86469659985043
* z
* (
-1.33333333333333 * z * (1.5 * z2 - 0.5)
+ 1.8
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 0.533333333333333 * z
)
+ 0.953538034014426 * z2
- 0.317846011338142,
0.221950995245231
* x
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
),
0.0350935336958066
* (x2 - y2)
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
),
0.00584892228263444
* x
* (x2 - 3.0 * y2)
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
-2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
0.683184105191914 * x2**3
+ 10.2477615778787 * x2 * y4
- 10.2477615778787 * x4 * y2
- 0.683184105191914 * y2**3,
-0.707162732524596
* y
* (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
9.98394571852353e-5
* y
* (5197.5 - 67567.5 * z2)
* (-10.0 * x2 * y2 + 5.0 * x4 + y4),
0.00239614697244565
* xy
* (x2 - y2)
* (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
0.00397356022507413
* y
* (3.0 * x2 - y2)
* (
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ 1063.125 * z2
- 118.125
),
0.0561946276120613
* xy
* (
-4.8 * z * (52.5 * z2 - 7.5)
+ 2.6
* z
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
)
+ 48.0 * z
),
0.206472245902897
* y
* (
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 2.16666666666667
* z
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
)
- 10.9375 * z2
+ 2.1875
),
1.24862677781952 * z * (1.5 * z2 - 0.5)
- 1.68564615005635
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 2.02901851395672
* z
* (
-1.45833333333333
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ 1.83333333333333
* z
* (
-1.33333333333333 * z * (1.5 * z2 - 0.5)
+ 1.8
* z
* (
1.75
* z
* (
1.66666666666667 * z * (1.5 * z2 - 0.5)
- 0.666666666666667 * z
)
- 1.125 * z2
+ 0.375
)
+ 0.533333333333333 * z
)
+ 0.9375 * z2
- 0.3125
)
- 0.499450711127808 * z,
0.206472245902897
* x
* (
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 2.16666666666667
* z
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
)
- 10.9375 * z2
+ 2.1875
),
0.0280973138060306
* (x2 - y2)
* (
-4.8 * z * (52.5 * z2 - 7.5)
+ 2.6
* z
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
)
+ 48.0 * z
),
0.00397356022507413
* x
* (x2 - 3.0 * y2)
* (
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ 1063.125 * z2
- 118.125
),
0.000599036743111412
* (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
* (-6.0 * x2 * y2 + x4 + y4),
9.98394571852353e-5
* x
* (5197.5 - 67567.5 * z2)
* (-10.0 * x2 * y2 + x4 + 5.0 * y4),
2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
-0.707162732524596
* x
* (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
],
-1,
)
# @torch.jit.script
def rsh_cart_8(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 8.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,81) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
x = xyz[..., 0]
y = xyz[..., 1]
z = xyz[..., 2]
x2 = x**2
y2 = y**2
z2 = z**2
xy = x * y
xz = x * z
yz = y * z
x4 = x2**2
y4 = y2**2
# z4 = z2**2
return torch.stack(
[
0.282094791773878 * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]),
-0.48860251190292 * y,
0.48860251190292 * z,
-0.48860251190292 * x,
1.09254843059208 * xy,
-1.09254843059208 * yz,
0.94617469575756 * z2 - 0.31539156525252,
-1.09254843059208 * xz,
0.54627421529604 * x2 - 0.54627421529604 * y2,
-0.590043589926644 * y * (3.0 * x2 - y2),
2.89061144264055 * xy * z,
0.304697199642977 * y * (1.5 - 7.5 * z2),
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
0.304697199642977 * x * (1.5 - 7.5 * z2),
1.44530572132028 * z * (x2 - y2),
-0.590043589926644 * x * (x2 - 3.0 * y2),
2.5033429417967 * xy * (x2 - y2),
-1.77013076977993 * yz * (3.0 * x2 - y2),
0.126156626101008 * xy * (52.5 * z2 - 7.5),
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
1.48099765681286
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 0.952069922236839 * z2
+ 0.317356640745613,
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
-1.77013076977993 * xz * (x2 - 3.0 * y2),
-3.75501441269506 * x2 * y2
+ 0.625835735449176 * x4
+ 0.625835735449176 * y4,
-0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
8.30264925952416 * xy * z * (x2 - y2),
0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
0.241571547304372
* y
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
),
-1.24747010616985 * z * (1.5 * z2 - 0.5)
+ 1.6840846433293
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 0.498988042467941 * z,
0.241571547304372
* x
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
),
0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
-0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
4.09910463115149 * x**4 * xy
- 13.6636821038383 * xy**3
+ 4.09910463115149 * xy * y**4,
-2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
0.00584892228263444
* y
* (3.0 * x2 - y2)
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
0.0701870673916132
* xy
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
),
0.221950995245231
* y
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
),
-1.48328138624466
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ 1.86469659985043
* z
* (
-1.33333333333333 * z * (1.5 * z2 - 0.5)
+ 1.8
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 0.533333333333333 * z
)
+ 0.953538034014426 * z2
- 0.317846011338142,
0.221950995245231
* x
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
),
0.0350935336958066
* (x2 - y2)
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
),
0.00584892228263444
* x
* (x2 - 3.0 * y2)
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
-2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
0.683184105191914 * x2**3
+ 10.2477615778787 * x2 * y4
- 10.2477615778787 * x4 * y2
- 0.683184105191914 * y2**3,
-0.707162732524596
* y
* (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
9.98394571852353e-5
* y
* (5197.5 - 67567.5 * z2)
* (-10.0 * x2 * y2 + 5.0 * x4 + y4),
0.00239614697244565
* xy
* (x2 - y2)
* (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
0.00397356022507413
* y
* (3.0 * x2 - y2)
* (
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ 1063.125 * z2
- 118.125
),
0.0561946276120613
* xy
* (
-4.8 * z * (52.5 * z2 - 7.5)
+ 2.6
* z
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
)
+ 48.0 * z
),
0.206472245902897
* y
* (
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 2.16666666666667
* z
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
)
- 10.9375 * z2
+ 2.1875
),
1.24862677781952 * z * (1.5 * z2 - 0.5)
- 1.68564615005635
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 2.02901851395672
* z
* (
-1.45833333333333
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ 1.83333333333333
* z
* (
-1.33333333333333 * z * (1.5 * z2 - 0.5)
+ 1.8
* z
* (
1.75
* z
* (
1.66666666666667 * z * (1.5 * z2 - 0.5)
- 0.666666666666667 * z
)
- 1.125 * z2
+ 0.375
)
+ 0.533333333333333 * z
)
+ 0.9375 * z2
- 0.3125
)
- 0.499450711127808 * z,
0.206472245902897
* x
* (
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 2.16666666666667
* z
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
)
- 10.9375 * z2
+ 2.1875
),
0.0280973138060306
* (x2 - y2)
* (
-4.8 * z * (52.5 * z2 - 7.5)
+ 2.6
* z
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
)
+ 48.0 * z
),
0.00397356022507413
* x
* (x2 - 3.0 * y2)
* (
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ 1063.125 * z2
- 118.125
),
0.000599036743111412
* (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
* (-6.0 * x2 * y2 + x4 + y4),
9.98394571852353e-5
* x
* (5197.5 - 67567.5 * z2)
* (-10.0 * x2 * y2 + x4 + 5.0 * y4),
2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
-0.707162732524596
* x
* (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
5.83141328139864 * xy * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3),
-2.91570664069932
* yz
* (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
7.87853281621404e-6
* (1013512.5 * z2 - 67567.5)
* (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
5.10587282657803e-5
* y
* (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
* (-10.0 * x2 * y2 + 5.0 * x4 + y4),
0.00147275890257803
* xy
* (x2 - y2)
* (
3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
- 14293.125 * z2
+ 1299.375
),
0.0028519853513317
* y
* (3.0 * x2 - y2)
* (
-7.33333333333333 * z * (52.5 - 472.5 * z2)
+ 3.0
* z
* (
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ 1063.125 * z2
- 118.125
)
- 560.0 * z
),
0.0463392770473559
* xy
* (
-4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ 2.5
* z
* (
-4.8 * z * (52.5 * z2 - 7.5)
+ 2.6
* z
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
)
+ 48.0 * z
)
+ 137.8125 * z2
- 19.6875
),
0.193851103820053
* y
* (
3.2 * z * (1.5 - 7.5 * z2)
- 2.51428571428571
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
+ 2.14285714285714
* z
* (
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 2.16666666666667
* z
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25
* z
* (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
)
- 10.9375 * z2
+ 2.1875
)
+ 5.48571428571429 * z
),
1.48417251362228
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.86581687426801
* z
* (
-1.33333333333333 * z * (1.5 * z2 - 0.5)
+ 1.8
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 0.533333333333333 * z
)
+ 2.1808249179756
* z
* (
1.14285714285714 * z * (1.5 * z2 - 0.5)
- 1.54285714285714
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 1.85714285714286
* z
* (
-1.45833333333333
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ 1.83333333333333
* z
* (
-1.33333333333333 * z * (1.5 * z2 - 0.5)
+ 1.8
* z
* (
1.75
* z
* (
1.66666666666667 * z * (1.5 * z2 - 0.5)
- 0.666666666666667 * z
)
- 1.125 * z2
+ 0.375
)
+ 0.533333333333333 * z
)
+ 0.9375 * z2
- 0.3125
)
- 0.457142857142857 * z
)
- 0.954110901614325 * z2
+ 0.318036967204775,
0.193851103820053
* x
* (
3.2 * z * (1.5 - 7.5 * z2)
- 2.51428571428571
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
+ 2.14285714285714
* z
* (
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 2.16666666666667
* z
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25
* z
* (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
)
- 10.9375 * z2
+ 2.1875
)
+ 5.48571428571429 * z
),
0.0231696385236779
* (x2 - y2)
* (
-4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ 2.5
* z
* (
-4.8 * z * (52.5 * z2 - 7.5)
+ 2.6
* z
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
)
+ 48.0 * z
)
+ 137.8125 * z2
- 19.6875
),
0.0028519853513317
* x
* (x2 - 3.0 * y2)
* (
-7.33333333333333 * z * (52.5 - 472.5 * z2)
+ 3.0
* z
* (
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ 1063.125 * z2
- 118.125
)
- 560.0 * z
),
0.000368189725644507
* (-6.0 * x2 * y2 + x4 + y4)
* (
3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
- 14293.125 * z2
+ 1299.375
),
5.10587282657803e-5
* x
* (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
* (-10.0 * x2 * y2 + x4 + 5.0 * y4),
7.87853281621404e-6
* (1013512.5 * z2 - 67567.5)
* (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
-2.91570664069932
* xz
* (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
-20.4099464848952 * x2**3 * y2
- 20.4099464848952 * x2 * y2**3
+ 0.72892666017483 * x4**2
+ 51.0248662122381 * x4 * y4
+ 0.72892666017483 * y4**2,
],
-1,
)
__all__ = [
"rsh_cart_0",
"rsh_cart_1",
"rsh_cart_2",
"rsh_cart_3",
"rsh_cart_4",
"rsh_cart_5",
"rsh_cart_6",
"rsh_cart_7",
"rsh_cart_8",
]
from typing import Optional
import torch
class SphHarm(torch.nn.Module):
def __init__(self, m, n, dtype=torch.float32) -> None:
super().__init__()
self.dtype = dtype
m = torch.tensor(list(range(-m + 1, m)))
n = torch.tensor(list(range(n)))
self.is_normalized = False
vals = torch.cartesian_prod(m, n).T
vals = vals[:, vals[0] <= vals[1]]
m, n = vals.unbind(0)
self.register_buffer("m", tensor=m)
self.register_buffer("n", tensor=n)
self.register_buffer("l_max", tensor=torch.max(self.n))
f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d = self._init_legendre()
self.register_buffer("f_a", tensor=f_a)
self.register_buffer("f_b", tensor=f_b)
self.register_buffer("d0_mask_3d", tensor=d0_mask_3d)
self.register_buffer("d1_mask_3d", tensor=d1_mask_3d)
self.register_buffer("initial_value", tensor=initial_value)
@property
def device(self):
return next(self.buffers()).device
def forward(self, points: torch.Tensor) -> torch.Tensor:
"""Computes the spherical harmonics."""
# Y_l^m = (-1) ^ m c_l^m P_l^m(cos(theta)) exp(i m phi)
B, N, D = points.shape
dtype = points.dtype
theta, phi = points.view(-1, D).to(self.dtype).unbind(-1)
cos_colatitude = torch.cos(phi)
legendre = self._gen_associated_legendre(cos_colatitude)
vals = torch.stack([self.m.abs(), self.n], dim=0)
vals = torch.cat(
[
vals.repeat(1, theta.shape[0]),
torch.arange(theta.shape[0], device=theta.device)
.unsqueeze(0)
.repeat_interleave(vals.shape[1], dim=1),
],
dim=0,
)
legendre_vals = legendre[vals[0], vals[1], vals[2]]
legendre_vals = legendre_vals.reshape(-1, theta.shape[0])
angle = torch.outer(self.m.abs(), theta)
vandermonde = torch.complex(torch.cos(angle), torch.sin(angle))
harmonics = torch.complex(
legendre_vals * torch.real(vandermonde),
legendre_vals * torch.imag(vandermonde),
)
# Negative order.
m = self.m.unsqueeze(-1)
harmonics = torch.where(
m < 0, (-1.0) ** m.abs() * torch.conj(harmonics), harmonics
)
harmonics = harmonics.permute(1, 0).reshape(B, N, -1).to(dtype)
return harmonics
def _gen_recurrence_mask(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Generates mask for recurrence relation on the remaining entries.
The remaining entries are with respect to the diagonal and offdiagonal
entries.
Args:
l_max: see `gen_normalized_legendre`.
Returns:
torch.Tensors representing the mask used by the recurrence relations.
"""
# Computes all coefficients.
m_mat, l_mat = torch.meshgrid(
torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
indexing="ij",
)
if self.is_normalized:
c0 = l_mat * l_mat
c1 = m_mat * m_mat
c2 = 2.0 * l_mat
c3 = (l_mat - 1.0) * (l_mat - 1.0)
d0 = torch.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
d1 = torch.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
else:
d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)
d0_mask_indices = torch.triu_indices(self.l_max + 1, 1)
d1_mask_indices = torch.triu_indices(self.l_max + 1, 2)
d_zeros = torch.zeros(
(self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
)
d_zeros[d0_mask_indices] = d0[d0_mask_indices]
d0_mask = d_zeros
d_zeros = torch.zeros(
(self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
)
d_zeros[d1_mask_indices] = d1[d1_mask_indices]
d1_mask = d_zeros
# Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
i = torch.arange(self.l_max + 1, device=self.device)[:, None, None]
j = torch.arange(self.l_max + 1, device=self.device)[None, :, None]
k = torch.arange(self.l_max + 1, device=self.device)[None, None, :]
mask = (i + j - k == 0).to(self.dtype)
d0_mask_3d = torch.einsum("jk,ijk->ijk", d0_mask, mask)
d1_mask_3d = torch.einsum("jk,ijk->ijk", d1_mask, mask)
return (d0_mask_3d, d1_mask_3d)
def _recursive(self, i: int, p_val: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
coeff_0 = self.d0_mask_3d[i]
coeff_1 = self.d1_mask_3d[i]
h = torch.einsum(
"ij,ijk->ijk",
coeff_0,
torch.einsum("ijk,k->ijk", torch.roll(p_val, shifts=1, dims=1), x),
) - torch.einsum("ij,ijk->ijk", coeff_1, torch.roll(p_val, shifts=2, dims=1))
p_val = p_val + h
return p_val
def _init_legendre(self):
a_idx = torch.arange(1, self.l_max + 1, dtype=self.dtype, device=self.device)
b_idx = torch.arange(self.l_max, dtype=self.dtype, device=self.device)
if self.is_normalized:
# The initial value p(0,0).
initial_value: torch.Tensor = torch.tensor(
0.5 / (torch.pi**0.5), device=self.device
)
f_a = torch.cumprod(-1 * torch.sqrt(1.0 + 0.5 / a_idx), dim=0)
f_b = torch.sqrt(2.0 * b_idx + 3.0)
else:
# The initial value p(0,0).
initial_value = torch.tensor(1.0, device=self.device)
f_a = torch.cumprod(1.0 - 2.0 * a_idx, dim=0)
f_b = 2.0 * b_idx + 1.0
d0_mask_3d, d1_mask_3d = self._gen_recurrence_mask()
return f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d
def _gen_associated_legendre(self, x: torch.Tensor) -> torch.Tensor:
r"""Computes associated Legendre functions (ALFs) of the first kind.
The ALFs of the first kind are used in spherical harmonics. The spherical
harmonic of degree `l` and order `m` can be written as
`Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the
normalization factor and θ and φ are the colatitude and longitude,
repectively. `N_l^m` is chosen in the way that the spherical harmonics form
a set of orthonormal basis function of L^2(S^2). For the computational
efficiency of spherical harmonics transform, the normalization factor is
used in the computation of the ALFs. In addition, normalizing `P_l^m`
avoids overflow/underflow and achieves better numerical stability. Three
recurrence relations are used in the computation.
Args:
l_max: The maximum degree of the associated Legendre function. Both the
degrees and orders are `[0, 1, 2, ..., l_max]`.
x: A vector of type `float32`, `float64` containing the sampled points in
spherical coordinates, at which the ALFs are computed; `x` is essentially
`cos(θ)`. For the numerical integration used by the spherical harmonics
transforms, `x` contains the quadrature points in the interval of
`[-1, 1]`. There are several approaches to provide the quadrature points:
Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
method (`scipy.special.roots_chebyu`), and Driscoll & Healy
method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
transforms and convolutions on the 2-sphere." Advances in applied
mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
points are nearly equal-spaced along θ and provide exact discrete
orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
operation, `W` is a diagonal matrix containing the quadrature weights,
and `I` is the identity matrix. The Gauss-Chebyshev points are equally
spaced, which only provide approximate discrete orthogonality. The
Driscoll & Healy qudarture points are equally spaced and provide the
exact discrete orthogonality. The number of sampling points is required to
be twice as the number of frequency points (modes) in the Driscoll & Healy
approach, which enables FFT and achieves a fast spherical harmonics
transform.
is_normalized: True if the associated Legendre functions are normalized.
With normalization, `N_l^m` is applied such that the spherical harmonics
form a set of orthonormal basis functions of L^2(S^2).
Returns:
The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
of the ALFs at `x`; the dimensions in the sequence of order, degree, and
evalution points.
"""
p = torch.zeros(
(self.l_max + 1, self.l_max + 1, x.shape[0]), dtype=x.dtype, device=x.device
)
p[0, 0] = self.initial_value
# Compute the diagonal entries p(l,l) with recurrence.
y = torch.cumprod(
torch.broadcast_to(torch.sqrt(1.0 - x * x), (self.l_max, x.shape[0])), dim=0
)
p_diag = self.initial_value * torch.einsum("i,ij->ij", self.f_a, y)
# torch.diag_indices(l_max + 1)
diag_indices = torch.stack(
[torch.arange(0, self.l_max + 1, device=x.device)] * 2, dim=0
)
p[(diag_indices[0][1:], diag_indices[1][1:])] = p_diag
diag_indices = torch.stack(
[torch.arange(0, self.l_max, device=x.device)] * 2, dim=0
)
# Compute the off-diagonal entries with recurrence.
p_offdiag = torch.einsum(
"ij,ij->ij",
torch.einsum("i,j->ij", self.f_b, x),
p[(diag_indices[0], diag_indices[1])],
) # p[torch.diag_indices(l_max)])
p[(diag_indices[0][: self.l_max], diag_indices[1][: self.l_max] + 1)] = (
p_offdiag
)
# Compute the remaining entries with recurrence.
if self.l_max > 1:
for i in range(2, self.l_max + 1):
p = self._recursive(i, p, x)
return p
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/validation.py
================================================
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import torch
import torch.utils.data.distributed
import wandb
from torch.nn import functional as F
from unidepth.utils import barrier, is_main_process
from unidepth.utils.misc import remove_padding
def original_image(batch, preds=None):
paddings = [
torch.tensor(pads)
for img_meta in batch["img_metas"]
for pads in img_meta.get("paddings", [[0] * 4])
]
paddings = torch.stack(paddings).to(batch["data"]["image"].device)[
..., [0, 2, 1, 3]
] # lrtb
T, _, H, W = batch["data"]["depth"].shape
batch["data"]["image"] = F.interpolate(
batch["data"]["image"],
(H + paddings[2] + paddings[3], W + paddings[1] + paddings[2]),
mode="bilinear",
align_corners=False,
antialias=True,
)
batch["data"]["image"] = remove_padding(
batch["data"]["image"], paddings.repeat(T, 1)
)
if preds is not None:
for key in ["depth"]:
if key in preds:
preds[key] = F.interpolate(
preds[key],
(H + paddings[2] + paddings[3], W + paddings[1] + paddings[2]),
mode="bilinear",
align_corners=False,
antialias=True,
)
preds[key] = remove_padding(preds[key], paddings.repeat(T, 1))
return batch, preds
def log_metrics(metrics_all, step):
for name_ds, metrics in metrics_all.items():
for metrics_name, metrics_value in metrics.items():
try:
print(f"Metrics/{name_ds}/{metrics_name} {round(metrics_value, 4)}")
wandb.log(
{f"Metrics/{name_ds}/{metrics_name}": metrics_value}, step=step
)
except:
pass
def validate(model, test_loaders, step, context):
metrics_all = {}
for name_ds, test_loader in test_loaders.items():
for i, batch in enumerate(test_loader):
with context:
batch["data"] = {
k: v.to(model.device) for k, v in batch["data"].items()
}
# remove temporal dimension of the dataloder, here is always 1!
batch["data"] = {k: v.squeeze(1) for k, v in batch["data"].items()}
batch["img_metas"] = [
{k: v[0] for k, v in meta.items() if isinstance(v, list)}
for meta in batch["img_metas"]
]
preds = model(batch["data"], batch["img_metas"])
batch, _ = original_image(batch, preds=None)
test_loader.dataset.accumulate_metrics(
inputs=batch["data"],
preds=preds,
keyframe_idx=batch["img_metas"][0].get("keyframe_idx"),
)
barrier()
metrics_all[name_ds] = test_loader.dataset.get_evaluation()
barrier()
if is_main_process():
log_metrics(metrics_all=metrics_all, step=step)
return metrics_all
================================================
FILE: camera_pose_annotation/depth_estimation/UniDepth/unidepth/utils/visualization.py
================================================
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import wandb
from PIL import Image
from unidepth.utils.misc import ssi_helper
def colorize(
value: np.ndarray, vmin: float = None, vmax: float = None, cmap: str = "magma_r"
):
# if already RGB, do nothing
if value.ndim > 2:
if value.shape[-1] > 1:
return value
value = value[..., 0]
invalid_mask = value < 0.0001
# normalize
vmin = value.min() if vmin is None else vmin
vmax = value.max() if vmax is None else vmax
value = (value - vmin) / (vmax - vmin) # vmin..vmax
# set color
cmapper = plt.get_cmap(cmap)
value = cmapper(value, bytes=True) # (nxmx4)
value[invalid_mask] = 0
img = value[..., :3]
return img
def image_grid(imgs: list[np.ndarray], rows: int, cols: int) -> np.ndarray:
if not len(imgs):
return None
assert len(imgs) == rows * cols
h, w = imgs[0].shape[:2]
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(
Image.fromarray(img.astype(np.uint8)).resize(
(w, h), resample=Image.BILINEAR
),
box=(i % cols * w, i // cols * h),
)
return np.array(grid)
def get_pointcloud_from_rgbd(
image: np.array,
depth: np.array,
mask: np.ndarray,
intrinsic_matrix: np.array,
extrinsic_matrix: np.array = None,
):
depth = np.array(depth).squeeze()
mask = np.array(mask).squeeze()
# Mask the depth array
masked_depth = np.ma.masked_where(mask == False, depth)
# masked_depth = np.ma.masked_greater(masked_depth, 8000)
# Create idx array
idxs = np.indices(masked_depth.shape)
u_idxs = idxs[1]
v_idxs = idxs[0]
# Get only non-masked depth and idxs
z = masked_depth[~masked_depth.mask]
compressed_u_idxs = u_idxs[~masked_depth.mask]
compressed_v_idxs = v_idxs[~masked_depth.mask]
image = np.stack(
[image[..., i][~masked_depth.mask] for i in range(image.shape[-1])], axis=-1
)
# Calculate local position of each point
# Apply vectorized math to depth using compressed arrays
cx = intrinsic_matrix[0, 2]
fx = intrinsic_matrix[0, 0]
x = (compressed_u_idxs - cx) * z / fx
cy = intrinsic_matrix[1, 2]
fy = intrinsic_matrix[1, 1]
# Flip y as we want +y pointing up not down
y = -((compressed_v_idxs - cy) * z / fy)
# # Apply camera_matrix to pointcloud as to get the pointcloud in world coords
# if extrinsic_matrix is not None:
# # Calculate camera pose from extrinsic matrix
# camera_matrix = np.linalg.inv(extrinsic_matrix)
# # Create homogenous array of vectors by adding 4th entry of 1
# # At the same time flip z as for eye space the camera is looking down the -z axis
# w = np.ones(z.shape)
# x_y_z_eye_hom = np.vstack((x, y, -z, w))
# # Transform the points from eye space to world space
# x_y_z_world = np.dot(camera_matrix, x_y_z_eye_hom)[:3]
# return x_y_z_world.T
# else:
x_y_z_local = np.stack((x, y, z), axis=-1)
return np.concatenate([x_y_z_local, image], axis=-1)
def save_file_ply(xyz, rgb, pc_file):
if rgb.max() < 1.001:
rgb = rgb * 255.0
rgb = rgb.astype(np.uint8)
# print(rgb)
with open(pc_file, "w") as f:
# headers
f.writelines(
[
"ply\n" "format ascii 1.0\n",
"element vertex {}\n".format(xyz.shape[0]),
"property float x\n",
"property float y\n",
"property float z\n",
"property uchar red\n",
"property uchar green\n",
"property uchar blue\n",
"end_header\n",
]
)
for i in range(xyz.shape[0]):
str_v = "{:10.6f} {:10.6f} {:10.6f} {:d} {:d} {:d}\n".format(
xyz[i, 0], xyz[i, 1], xyz[i, 2], rgb[i, 0], rgb[i, 1], rgb[i, 2]
)
f.write(str_v)
# really awful fct... FIXME
def log_train_artifacts(rgbs, gts, preds, ds_name, step, infos={}):
rgbs = [
(127.5 * (rgb + 1))
.clip(0, 255)
.to(torch.uint8)
.cpu()
.detach()
.permute(1, 2, 0)
.numpy()
for rgb in rgbs
]
new_gts, new_preds = [], []
if len(gts) > 0:
for i, gt in enumerate(gts):
scale, shift = ssi_helper(
gts[i][gts[i] > 0].cpu().detach(), preds[i][gts[i] > 0].cpu().detach()
)
gt = gts[i].cpu().detach().squeeze().numpy()
pred = (preds[i].cpu().detach() * scale + shift).squeeze().numpy()
vmin = gt[gt > 0].min() if (gt > 0).any() else 0.0
vmax = gt.max() if (gt > 0).any() else 0.1
new_gts.append(colorize(gt, vmin=vmin, vmax=vmax))
new_preds.append(colorize(pred, vmin=vmin, vmax=vmax))
gts, preds = new_gts, new_preds
else:
preds = [
colorize(pred.cpu().detach().squeeze().numpy(), 0.0, 80.0)
for i, pred in enumerate(preds)
]
num_additional, additionals = 0, []
for name, info in infos.items():
num_additional += 1
if info.shape[1] == 3:
additionals.extend(
[
(127.5 * (x + 1))
.clip(0, 255)
.to(torch.uint8)
.cpu()
.detach()
.permute(1, 2, 0)
.numpy()
for x in info[:4]
]
)
else:
additionals.extend(
[
colorize(x.cpu().detach().squeeze().numpy())
for i, x in enumerate(info[:4])
]
)
num_rows = 2 + int(len(gts) > 0) + num_additional
artifacts_grid = image_grid(
[*rgbs, *gts, *preds, *additionals], num_rows, len(rgbs)
)
try:
wandb.log({f"{ds_name}_training": [wandb.Image(artifacts_grid)]}, step=step)
except:
Image.fromarray(artifacts_grid).save(
os.path.join(os.environ["HOME"], "Workspace", f"art_grid{step}.png")
)
print("Logging training images failed")
================================================
FILE: camera_pose_annotation/depth_estimation/__init__.py
================================================
================================================
FILE: camera_pose_annotation/dynamic_mask/__init__.py
================================================
================================================
FILE: camera_pose_annotation/dynamic_mask/inference_batch.py
================================================
"""
Batch inference script for dynamic mask generation using SAM2.
Processes video frames to generate dynamic object masks based on motion probabilities.
"""
import os
import numpy as np
import torch
import torch.nn.functional as F
from glob import glob
import cv2
from scipy import ndimage
from scipy.sparse import csr_matrix
import argparse
import pandas as pd
import subprocess
from multiprocessing import Manager
import queue
import concurrent.futures
from tqdm import tqdm
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
def compress(dyn_masks, save_path=None):
"""Compress dynamic masks using sparse matrix representation."""
assert save_path.endswith(".npz")
# Transform to sparse matrices
sparse_matrices_list = [csr_matrix(dyn_mask) for dyn_mask in dyn_masks]
sparse_matrices = {}
for i, dyn_mask in enumerate(sparse_matrices_list):
sparse_matrices[f"f_{i}_data"] = dyn_mask.data
sparse_matrices[f"f_{i}_indices"] = dyn_mask.indices
sparse_matrices[f"f_{i}_indptr"] = dyn_mask.indptr
if i == 0:
sparse_matrices["shape"] = dyn_mask.shape
np.savez_compressed(save_path, **sparse_matrices)
def segment_sky(image):
"""Segment sky regions from image using HSV color space and morphological operations."""
# Convert RGB to HSV
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
# Define range for blue color and create mask
lower_blue = np.array([0, 0, 100])
upper_blue = np.array([30, 255, 255])
mask = cv2.inRange(hsv, lower_blue, upper_blue).view(bool)
# Add luminous gray regions (likely sky)
mask |= (hsv[:, :, 1] < 10) & (hsv[:, :, 2] > 150)
mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180)
mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220)
# Morphological operations to clean up mask
kernel = np.ones((5, 5), np.uint8)
mask2 = ndimage.binary_opening(mask, structure=kernel)
# Keep only largest connected components
_, labels, stats, _ = cv2.connectedComponentsWithStats(
mask2.view(np.uint8), connectivity=8
)
cc_sizes = stats[1:, cv2.CC_STAT_AREA]
order = cc_sizes.argsort()[::-1] # Bigger first
i = 0
selection = []
while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2:
selection.append(1 + order[i])
i += 1
mask3 = np.isin(labels, selection).reshape(labels.shape)
# Return as tensor
return torch.from_numpy(mask3)
def predict_mask(predictor, row, args, device):
"""Generate dynamic masks for a video using SAM2 and motion probabilities."""
dir_path = os.path.join(args.dir_path, str(row["id"]))
if not os.path.exists(dir_path):
print(f"Directory '{dir_path}' not found. Exit.")
return
img_dir = os.path.join(args.dir_path, str(row["id"]), "img")
if not os.path.exists(img_dir):
print(f"Image directory '{img_dir}' not found. Exit.")
return
rec_dir = os.path.join(dir_path, "reconstructions")
if not os.path.exists(rec_dir):
print(f"Reconstructions directory '{rec_dir}' not found. Exit.")
return
prob_file = os.path.join(rec_dir, "motion_prob.npy")
if not os.path.exists(prob_file):
print(f"Motion probability file '{prob_file}' not found. Exit.")
return
compress_file = os.path.join(rec_dir, "dyn_masks.npz")
if os.path.exists(compress_file):
return
# Load motion probabilities
motion_probs = torch.from_numpy(np.load(prob_file)).to(device)
# Load images
images_list = list(sorted(glob(os.path.join(img_dir, "*.jpg"))))
images = [cv2.imread(img_path) for img_path in images_list]
if len(images) == 0 or len(images) != len(motion_probs):
print(
f"{row['video_path']},Number of frames ({len(images)}) does not match number of motion probabilities ({len(motion_probs)}). Exit."
)
return
width, height = images[0].shape[1], images[0].shape[0]
area = width * height
masks = []
# Process each frame
for i in range(len(images)):
motion_prob = motion_probs[i].to(device)
# Segment sky to avoid false detections
sky_mask = segment_sky(images[i])
predictor.set_image(images[i])
# Adaptive thresholding based on motion probability distribution
# We use an adaptive thresholding based on motion probability distribution to create initial masks. Then
prob_min, prob_max = motion_prob.min(), motion_prob.max()
threshold = (prob_max - prob_min) * 0.4 + prob_min
if threshold > prob_max - 0.1:
masks.append(np.zeros((height, width), dtype=np.uint8))
continue
# Create initial mask from motion probabilities
mask = (motion_prob < threshold).float()
mask = F.interpolate(
mask.unsqueeze(0).unsqueeze(0),
size=(height, width),
mode="bilinear",
align_corners=False,
).squeeze()
# Find contours and use them as SAM2 prompts
mask_np = mask.cpu().numpy().astype(np.uint8)
contours, _ = cv2.findContours(
mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
merged_mask = np.zeros_like(mask_np)
for c in contours:
points = []
for point in c:
points.append(point[0])
points = np.array(points)
# Sample points from contour as prompts
interval = max(1, len(points) // 3)
input_points = points[::interval].astype(np.float32)
# Skip if points are in sky region
if sky_mask[input_points[:, 1], input_points[:, 0]].any():
continue
input_labels = np.ones(input_points.shape[0], dtype=np.int64)
# Use SAM2 to refine mask
mask, score, _ = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False,
)
# Skip if mask area is too large (likely background)
if mask[0].sum() > area * 0.3:
continue
merged_mask = np.logical_or(merged_mask, mask[0])
masks.append(merged_mask)
# Save compressed masks
masks = np.stack(masks, axis=0)
compress(masks, compress_file)
def worker(task_queue, progress_queue, args, id):
"""Worker function for parallel dynamic mask generation."""
gpu_id = id % args.gpu_num
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) # Bind to specific GPU
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
sam2_model = None
predictor = None
while True:
try:
index, row = task_queue.get_nowait()
except queue.Empty:
break
# Initialize SAM2 model and predictor lazily
if sam2_model is None:
sam2_model = build_sam2(
args.model_cfg, args.checkpoints_path, device=device
)
if predictor is None:
predictor = SAM2ImagePredictor(sam2_model)
predictor.reset_predictor()
predict_mask(predictor, row, args, device)
progress_queue.put(index)
def parse_args():
"""Parse command line arguments for dynamic mask generation."""
parser = argparse.ArgumentParser(description="SAM2 Image Predictor")
parser.add_argument("--csv_path", type=str, help="Path to the csv file")
parser.add_argument(
"--dir_path",
type=str,
required=True,
help="Path to the directory containing images and masks",
)
parser.add_argument(
"--num_workers", type=int, default=16, help="#workers for concurrent.futures"
)
parser.add_argument("--gpu_num", type=int, default=1, help="gpu number")
parser.add_argument(
"--checkpoints_path",
type=str,
default="checkpoints",
help="Path to the model checkpoint",
)
parser.add_argument(
"--model_cfg",
type=str,
default="configs/sam2.1/sam2.1_hiera_l.yaml",
help="Path to the model configuration file",
)
return parser.parse_args()
def main():
args = parse_args()
if not os.path.exists(args.csv_path):
print(f"Meta file '{args.csv_path}' not found. Exit.")
return
# Set SAM2 checkpoint path
args.checkpoints_path = os.path.join(
args.checkpoints_path, "SAM2/sam2.1_hiera_large.pt"
)
df = pd.read_csv(args.csv_path)
# Setup multiprocessing
manager = Manager()
task_queue = manager.Queue()
progress_queue = manager.Queue()
for index, row in df.iterrows():
task_queue.put((index, row))
# Process tasks with multiple workers
with concurrent.futures.ProcessPoolExecutor(
max_workers=args.num_workers
) as executor:
futures = []
for id in range(args.num_workers):
futures.append(executor.submit(worker, task_queue, progress_queue, args, id))
processed = 0
total_tasks = len(df)
with tqdm(total=total_tasks, desc="Processing rows") as pbar:
while processed < total_tasks:
try:
progress_queue.get(timeout=1)
processed += 1
pbar.update(1)
except queue.Empty:
if all(f.done() for f in futures) and progress_queue.empty():
break
for future in futures:
future.result()
if __name__ == "__main__":
main()
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from hydra import initialize_config_module
from hydra.core.global_hydra import GlobalHydra
if not GlobalHydra.instance().is_initialized():
initialize_config_module("sam2", version_base="1.2")
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/automatic_mask_generator.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
from sam2.modeling.sam2_base import SAM2Base
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.utils.amg import (
area_from_rle,
batch_iterator,
batched_mask_to_box,
box_xyxy_to_xywh,
build_all_layer_point_grids,
calculate_stability_score,
coco_encode_rle,
generate_crop_boxes,
is_box_near_crop_edge,
mask_to_rle_pytorch,
MaskData,
remove_small_regions,
rle_to_mask,
uncrop_boxes_xyxy,
uncrop_masks,
uncrop_points,
)
class SAM2AutomaticMaskGenerator:
def __init__(
self,
model: SAM2Base,
points_per_side: Optional[int] = 32,
points_per_batch: int = 64,
pred_iou_thresh: float = 0.8,
stability_score_thresh: float = 0.95,
stability_score_offset: float = 1.0,
mask_threshold: float = 0.0,
box_nms_thresh: float = 0.7,
crop_n_layers: int = 0,
crop_nms_thresh: float = 0.7,
crop_overlap_ratio: float = 512 / 1500,
crop_n_points_downscale_factor: int = 1,
point_grids: Optional[List[np.ndarray]] = None,
min_mask_region_area: int = 0,
output_mode: str = "binary_mask",
use_m2m: bool = False,
multimask_output: bool = True,
**kwargs,
) -> None:
"""
Using a SAM 2 model, generates masks for the entire image.
Generates a grid of point prompts over the image, then filters
low quality and duplicate masks. The default settings are chosen
for SAM 2 with a HieraL backbone.
Arguments:
model (Sam): The SAM 2 model to use for mask prediction.
points_per_side (int or None): The number of points to be sampled
along one side of the image. The total number of points is
points_per_side**2. If None, 'point_grids' must provide explicit
point sampling.
points_per_batch (int): Sets the number of points run simultaneously
by the model. Higher numbers may be faster but use more GPU memory.
pred_iou_thresh (float): A filtering threshold in [0,1], using the
model's predicted mask quality.
stability_score_thresh (float): A filtering threshold in [0,1], using
the stability of the mask under changes to the cutoff used to binarize
the model's mask predictions.
stability_score_offset (float): The amount to shift the cutoff when
calculated the stability score.
mask_threshold (float): Threshold for binarizing the mask logits
box_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks.
crop_n_layers (int): If >0, mask prediction will be run again on
crops of the image. Sets the number of layers to run, where each
layer has 2**i_layer number of image crops.
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks between different crops.
crop_overlap_ratio (float): Sets the degree to which crops overlap.
In the first crop layer, crops will overlap by this fraction of
the image length. Later layers with more crops scale down this overlap.
crop_n_points_downscale_factor (int): The number of points-per-side
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
point_grids (list(np.ndarray) or None): A list over explicit grids
of points used for sampling, normalized to [0,1]. The nth grid in the
list is used in the nth crop layer. Exclusive with points_per_side.
min_mask_region_area (int): If >0, postprocessing will be applied
to remove disconnected regions and holes in masks with area smaller
than min_mask_region_area. Requires opencv.
output_mode (str): The form masks are returned in. Can be 'binary_mask',
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
For large resolutions, 'binary_mask' may consume large amounts of
memory.
use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
multimask_output (bool): Whether to output multimask at each point of the grid.
"""
assert (points_per_side is None) != (
point_grids is None
), "Exactly one of points_per_side or point_grid must be provided."
if points_per_side is not None:
self.point_grids = build_all_layer_point_grids(
points_per_side,
crop_n_layers,
crop_n_points_downscale_factor,
)
elif point_grids is not None:
self.point_grids = point_grids
else:
raise ValueError("Can't have both points_per_side and point_grid be None.")
assert output_mode in [
"binary_mask",
"uncompressed_rle",
"coco_rle",
], f"Unknown output_mode {output_mode}."
if output_mode == "coco_rle":
try:
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
except ImportError as e:
print("Please install pycocotools")
raise e
self.predictor = SAM2ImagePredictor(
model,
max_hole_area=min_mask_region_area,
max_sprinkle_area=min_mask_region_area,
)
self.points_per_batch = points_per_batch
self.pred_iou_thresh = pred_iou_thresh
self.stability_score_thresh = stability_score_thresh
self.stability_score_offset = stability_score_offset
self.mask_threshold = mask_threshold
self.box_nms_thresh = box_nms_thresh
self.crop_n_layers = crop_n_layers
self.crop_nms_thresh = crop_nms_thresh
self.crop_overlap_ratio = crop_overlap_ratio
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
self.min_mask_region_area = min_mask_region_area
self.output_mode = output_mode
self.use_m2m = use_m2m
self.multimask_output = multimask_output
@classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
"""
Load a pretrained model from the Hugging Face hub.
Arguments:
model_id (str): The Hugging Face repository ID.
**kwargs: Additional arguments to pass to the model constructor.
Returns:
(SAM2AutomaticMaskGenerator): The loaded model.
"""
from sam2.build_sam import build_sam2_hf
sam_model = build_sam2_hf(model_id, **kwargs)
return cls(sam_model, **kwargs)
@torch.no_grad()
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
"""
Generates masks for the given image.
Arguments:
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
Returns:
list(dict(str, any)): A list over records for masks. Each record is
a dict containing the following keys:
segmentation (dict(str, any) or np.ndarray): The mask. If
output_mode='binary_mask', is an array of shape HW. Otherwise,
is a dictionary containing the RLE.
bbox (list(float)): The box around the mask, in XYWH format.
area (int): The area in pixels of the mask.
predicted_iou (float): The model's own prediction of the mask's
quality. This is filtered by the pred_iou_thresh parameter.
point_coords (list(list(float))): The point coordinates input
to the model to generate this mask.
stability_score (float): A measure of the mask's quality. This
is filtered on using the stability_score_thresh parameter.
crop_box (list(float)): The crop of the image used to generate
the mask, given in XYWH format.
"""
# Generate masks
mask_data = self._generate_masks(image)
# Encode masks
if self.output_mode == "coco_rle":
mask_data["segmentations"] = [
coco_encode_rle(rle) for rle in mask_data["rles"]
]
elif self.output_mode == "binary_mask":
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
else:
mask_data["segmentations"] = mask_data["rles"]
# Write mask records
curr_anns = []
for idx in range(len(mask_data["segmentations"])):
ann = {
"segmentation": mask_data["segmentations"][idx],
"area": area_from_rle(mask_data["rles"][idx]),
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
"predicted_iou": mask_data["iou_preds"][idx].item(),
"point_coords": [mask_data["points"][idx].tolist()],
"stability_score": mask_data["stability_score"][idx].item(),
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
}
curr_anns.append(ann)
return curr_anns
def _generate_masks(self, image: np.ndarray) -> MaskData:
orig_size = image.shape[:2]
crop_boxes, layer_idxs = generate_crop_boxes(
orig_size, self.crop_n_layers, self.crop_overlap_ratio
)
# Iterate over image crops
data = MaskData()
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
data.cat(crop_data)
# Remove duplicate masks between crops
if len(crop_boxes) > 1:
# Prefer masks from smaller crops
scores = 1 / box_area(data["crop_boxes"])
scores = scores.to(data["boxes"].device)
keep_by_nms = batched_nms(
data["boxes"].float(),
scores,
torch.zeros_like(data["boxes"][:, 0]), # categories
iou_threshold=self.crop_nms_thresh,
)
data.filter(keep_by_nms)
data.to_numpy()
return data
def _process_crop(
self,
image: np.ndarray,
crop_box: List[int],
crop_layer_idx: int,
orig_size: Tuple[int, ...],
) -> MaskData:
# Crop the image and calculate embeddings
x0, y0, x1, y1 = crop_box
cropped_im = image[y0:y1, x0:x1, :]
cropped_im_size = cropped_im.shape[:2]
self.predictor.set_image(cropped_im)
# Get points for this crop
points_scale = np.array(cropped_im_size)[None, ::-1]
points_for_image = self.point_grids[crop_layer_idx] * points_scale
# Generate masks for this crop in batches
data = MaskData()
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
batch_data = self._process_batch(
points, cropped_im_size, crop_box, orig_size, normalize=True
)
data.cat(batch_data)
del batch_data
self.predictor.reset_predictor()
# Remove duplicates within this crop.
keep_by_nms = batched_nms(
data["boxes"].float(),
data["iou_preds"],
torch.zeros_like(data["boxes"][:, 0]), # categories
iou_threshold=self.box_nms_thresh,
)
data.filter(keep_by_nms)
# Return to the original image frame
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
data["points"] = uncrop_points(data["points"], crop_box)
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
return data
def _process_batch(
self,
points: np.ndarray,
im_size: Tuple[int, ...],
crop_box: List[int],
orig_size: Tuple[int, ...],
normalize=False,
) -> MaskData:
orig_h, orig_w = orig_size
# Run model on this batch
points = torch.as_tensor(
points, dtype=torch.float32, device=self.predictor.device
)
in_points = self.predictor._transforms.transform_coords(
points, normalize=normalize, orig_hw=im_size
)
in_labels = torch.ones(
in_points.shape[0], dtype=torch.int, device=in_points.device
)
masks, iou_preds, low_res_masks = self.predictor._predict(
in_points[:, None, :],
in_labels[:, None],
multimask_output=self.multimask_output,
return_logits=True,
)
# Serialize predictions and store in MaskData
data = MaskData(
masks=masks.flatten(0, 1),
iou_preds=iou_preds.flatten(0, 1),
points=points.repeat_interleave(masks.shape[1], dim=0),
low_res_masks=low_res_masks.flatten(0, 1),
)
del masks
if not self.use_m2m:
# Filter by predicted IoU
if self.pred_iou_thresh > 0.0:
keep_mask = data["iou_preds"] > self.pred_iou_thresh
data.filter(keep_mask)
# Calculate and filter by stability score
data["stability_score"] = calculate_stability_score(
data["masks"], self.mask_threshold, self.stability_score_offset
)
if self.stability_score_thresh > 0.0:
keep_mask = data["stability_score"] >= self.stability_score_thresh
data.filter(keep_mask)
else:
# One step refinement using previous mask predictions
in_points = self.predictor._transforms.transform_coords(
data["points"], normalize=normalize, orig_hw=im_size
)
labels = torch.ones(
in_points.shape[0], dtype=torch.int, device=in_points.device
)
masks, ious = self.refine_with_m2m(
in_points, labels, data["low_res_masks"], self.points_per_batch
)
data["masks"] = masks.squeeze(1)
data["iou_preds"] = ious.squeeze(1)
if self.pred_iou_thresh > 0.0:
keep_mask = data["iou_preds"] > self.pred_iou_thresh
data.filter(keep_mask)
data["stability_score"] = calculate_stability_score(
data["masks"], self.mask_threshold, self.stability_score_offset
)
if self.stability_score_thresh > 0.0:
keep_mask = data["stability_score"] >= self.stability_score_thresh
data.filter(keep_mask)
# Threshold masks and calculate boxes
data["masks"] = data["masks"] > self.mask_threshold
data["boxes"] = batched_mask_to_box(data["masks"])
# Filter boxes that touch crop boundaries
keep_mask = ~is_box_near_crop_edge(
data["boxes"], crop_box, [0, 0, orig_w, orig_h]
)
if not torch.all(keep_mask):
data.filter(keep_mask)
# Compress to RLE
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
data["rles"] = mask_to_rle_pytorch(data["masks"])
del data["masks"]
return data
@staticmethod
def postprocess_small_regions(
mask_data: MaskData, min_area: int, nms_thresh: float
) -> MaskData:
"""
Removes small disconnected regions and holes in masks, then reruns
box NMS to remove any new duplicates.
Edits mask_data in place.
Requires open-cv as a dependency.
"""
if len(mask_data["rles"]) == 0:
return mask_data
# Filter small disconnected regions and holes
new_masks = []
scores = []
for rle in mask_data["rles"]:
mask = rle_to_mask(rle)
mask, changed = remove_small_regions(mask, min_area, mode="holes")
unchanged = not changed
mask, changed = remove_small_regions(mask, min_area, mode="islands")
unchanged = unchanged and not changed
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
# Give score=0 to changed masks and score=1 to unchanged masks
# so NMS will prefer ones that didn't need postprocessing
scores.append(float(unchanged))
# Recalculate boxes and remove any new duplicates
masks = torch.cat(new_masks, dim=0)
boxes = batched_mask_to_box(masks)
keep_by_nms = batched_nms(
boxes.float(),
torch.as_tensor(scores),
torch.zeros_like(boxes[:, 0]), # categories
iou_threshold=nms_thresh,
)
# Only recalculate RLEs for masks that have changed
for i_mask in keep_by_nms:
if scores[i_mask] == 0.0:
mask_torch = masks[i_mask].unsqueeze(0)
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
mask_data.filter(keep_by_nms)
return mask_data
def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
new_masks = []
new_iou_preds = []
for cur_points, cur_point_labels, low_res_mask in batch_iterator(
points_per_batch, points, point_labels, low_res_masks
):
best_masks, best_iou_preds, _ = self.predictor._predict(
cur_points[:, None, :],
cur_point_labels[:, None],
mask_input=low_res_mask[:, None, :],
multimask_output=False,
return_logits=True,
)
new_masks.append(best_masks)
new_iou_preds.append(best_iou_preds)
masks = torch.cat(new_masks, dim=0)
return masks, torch.cat(new_iou_preds, dim=0)
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/benchmark.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import time
import numpy as np
import torch
from tqdm import tqdm
from sam2.build_sam import build_sam2_video_predictor
# Only cuda supported
assert torch.cuda.is_available()
device = torch.device("cuda")
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Config and checkpoint
sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
# Build video predictor with vos_optimized=True setting
predictor = build_sam2_video_predictor(
model_cfg, sam2_checkpoint, device=device, vos_optimized=True
)
# Initialize with video
video_dir = "notebooks/videos/bedroom"
# scan all the JPEG frame names in this directory
frame_names = [
p
for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(video_path=video_dir)
# Number of runs, warmup etc
warm_up, runs = 5, 25
verbose = True
num_frames = len(frame_names)
total, count = 0, 0
torch.cuda.empty_cache()
# We will select an object with a click.
# See video_predictor_example.ipynb for more detailed explanation
ann_frame_idx, ann_obj_id = 0, 1
# Add a positive click at (x, y) = (210, 350)
# For labels, `1` means positive click
points = np.array([[210, 350]], dtype=np.float32)
labels = np.array([1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# Warmup and then average FPS over several runs
with torch.autocast("cuda", torch.bfloat16):
with torch.inference_mode():
for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
start = time.time()
# Start tracking
for (
out_frame_idx,
out_obj_ids,
out_mask_logits,
) in predictor.propagate_in_video(inference_state):
pass
end = time.time()
total += end - start
count += 1
if i == warm_up - 1:
print("Warmup FPS: ", count * num_frames / total)
total = 0
count = 0
print("FPS: ", count * num_frames / total)
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/build_sam.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import torch
from hydra import compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
import sam2
# Check if the user is running Python from the parent directory of the sam2 repo
# (i.e. the directory where this repo is cloned into) -- this is not supported since
# it could shadow the sam2 package and cause issues.
if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
# If the user has "sam2/sam2" in their path, they are likey importing the repo itself
# as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
# This typically happens because the user is running Python from the parent directory
# that contains the sam2 repo they cloned.
raise RuntimeError(
"You're likely running Python from the parent directory of the sam2 repository "
"(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
"This is not supported since the `sam2` Python package could be shadowed by the "
"repository name (the repository is also named `sam2` and contains the Python package "
"in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
"rather than its parent dir, or from your home directory) after installing SAM 2."
)
HF_MODEL_ID_TO_FILENAMES = {
"facebook/sam2-hiera-tiny": (
"configs/sam2/sam2_hiera_t.yaml",
"sam2_hiera_tiny.pt",
),
"facebook/sam2-hiera-small": (
"configs/sam2/sam2_hiera_s.yaml",
"sam2_hiera_small.pt",
),
"facebook/sam2-hiera-base-plus": (
"configs/sam2/sam2_hiera_b+.yaml",
"sam2_hiera_base_plus.pt",
),
"facebook/sam2-hiera-large": (
"configs/sam2/sam2_hiera_l.yaml",
"sam2_hiera_large.pt",
),
"facebook/sam2.1-hiera-tiny": (
"configs/sam2.1/sam2.1_hiera_t.yaml",
"sam2.1_hiera_tiny.pt",
),
"facebook/sam2.1-hiera-small": (
"configs/sam2.1/sam2.1_hiera_s.yaml",
"sam2.1_hiera_small.pt",
),
"facebook/sam2.1-hiera-base-plus": (
"configs/sam2.1/sam2.1_hiera_b+.yaml",
"sam2.1_hiera_base_plus.pt",
),
"facebook/sam2.1-hiera-large": (
"configs/sam2.1/sam2.1_hiera_l.yaml",
"sam2.1_hiera_large.pt",
),
}
def build_sam2(
config_file,
ckpt_path=None,
device="cuda",
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
**kwargs,
):
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
]
# Read config and init model
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
OmegaConf.resolve(cfg)
model = instantiate(cfg.model, _recursive_=True)
_load_checkpoint(model, ckpt_path)
model = model.to(device)
if mode == "eval":
model.eval()
return model
def build_sam2_video_predictor(
config_file,
ckpt_path=None,
device="cuda",
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
vos_optimized=False,
**kwargs,
):
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
]
if vos_optimized:
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
"++model.compile_image_encoder=True", # Let sam2_base handle this
]
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
"++model.binarize_mask_from_pts_for_mem_enc=true",
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
"++model.fill_hole_area=8",
]
hydra_overrides.extend(hydra_overrides_extra)
# Read config and init model
cfg = compose(config_name=config_file, overrides=hydra_overrides)
OmegaConf.resolve(cfg)
model = instantiate(cfg.model, _recursive_=True)
_load_checkpoint(model, ckpt_path)
model = model.to(device)
if mode == "eval":
model.eval()
return model
def _hf_download(model_id):
from huggingface_hub import hf_hub_download
config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
return config_name, ckpt_path
def build_sam2_hf(model_id, **kwargs):
config_name, ckpt_path = _hf_download(model_id)
return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
def build_sam2_video_predictor_hf(model_id, **kwargs):
config_name, ckpt_path = _hf_download(model_id)
return build_sam2_video_predictor(
config_file=config_name, ckpt_path=ckpt_path, **kwargs
)
def _load_checkpoint(model, ckpt_path):
if ckpt_path is not None:
sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
missing_keys, unexpected_keys = model.load_state_dict(sd)
if missing_keys:
logging.error(missing_keys)
raise RuntimeError()
if unexpected_keys:
logging.error(unexpected_keys)
raise RuntimeError()
logging.info("Loaded checkpoint sucessfully")
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2/sam2_hiera_b+.yaml
================================================
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 112
num_heads: 2
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [896, 448, 224, 112]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: false
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2/sam2_hiera_l.yaml
================================================
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 144
num_heads: 2
stages: [2, 6, 36, 4]
global_att_blocks: [23, 33, 43]
window_pos_embed_bkg_spatial_size: [7, 7]
window_spec: [8, 4, 16, 8]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [1152, 576, 288, 144]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: false
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2/sam2_hiera_s.yaml
================================================
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 96
num_heads: 1
stages: [1, 2, 11, 2]
global_att_blocks: [7, 10, 13]
window_pos_embed_bkg_spatial_size: [7, 7]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [768, 384, 192, 96]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: false
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2/sam2_hiera_t.yaml
================================================
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 96
num_heads: 1
stages: [1, 2, 7, 2]
global_att_blocks: [5, 7, 9]
window_pos_embed_bkg_spatial_size: [7, 7]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [768, 384, 192, 96]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
# SAM decoder
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: false
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
# HieraT does not currently support compilation, should always be set to False
compile_image_encoder: False
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml
================================================
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 112
num_heads: 2
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [896, 448, 224, 112]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2.1/sam2.1_hiera_l.yaml
================================================
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 144
num_heads: 2
stages: [2, 6, 36, 4]
global_att_blocks: [23, 33, 43]
window_pos_embed_bkg_spatial_size: [7, 7]
window_spec: [8, 4, 16, 8]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [1152, 576, 288, 144]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2.1/sam2.1_hiera_s.yaml
================================================
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 96
num_heads: 1
stages: [1, 2, 11, 2]
global_att_blocks: [7, 10, 13]
window_pos_embed_bkg_spatial_size: [7, 7]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [768, 384, 192, 96]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2.1/sam2.1_hiera_t.yaml
================================================
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 96
num_heads: 1
stages: [1, 2, 7, 2]
global_att_blocks: [5, 7, 9]
window_pos_embed_bkg_spatial_size: [7, 7]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [768, 384, 192, 96]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
# SAM decoder
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
# HieraT does not currently support compilation, should always be set to False
compile_image_encoder: False
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml
================================================
# @package _global_
scratch:
resolution: 1024
train_batch_size: 1
num_train_workers: 10
num_frames: 8
max_num_objects: 3
base_lr: 5.0e-6
vision_lr: 3.0e-06
phases_per_epoch: 1
num_epochs: 40
dataset:
# PATHS to Dataset
img_folder: null # PATH to MOSE JPEGImages folder
gt_folder: null # PATH to MOSE Annotations folder
file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
multiplier: 2
# Video transforms
vos:
train_transforms:
- _target_: training.dataset.transforms.ComposeAPI
transforms:
- _target_: training.dataset.transforms.RandomHorizontalFlip
consistent_transform: True
- _target_: training.dataset.transforms.RandomAffine
degrees: 25
shear: 20
image_interpolation: bilinear
consistent_transform: True
- _target_: training.dataset.transforms.RandomResizeAPI
sizes: ${scratch.resolution}
square: true
consistent_transform: True
- _target_: training.dataset.transforms.ColorJitter
consistent_transform: True
brightness: 0.1
contrast: 0.03
saturation: 0.03
hue: null
- _target_: training.dataset.transforms.RandomGrayscale
p: 0.05
consistent_transform: True
- _target_: training.dataset.transforms.ColorJitter
consistent_transform: False
brightness: 0.1
contrast: 0.05
saturation: 0.05
hue: null
- _target_: training.dataset.transforms.ToTensorAPI
- _target_: training.dataset.transforms.NormalizeAPI
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
trainer:
_target_: training.trainer.Trainer
mode: train_only
max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
accelerator: cuda
seed_value: 123
model:
_target_: training.model.sam2.SAM2Train
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 112
num_heads: 2
drop_path_rate: 0.1
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [896, 448, 224, 112]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: ${scratch.resolution}
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
# compile_image_encoder: False
####### Training specific params #######
# box/point input and corrections
prob_to_use_pt_input_for_train: 0.5
prob_to_use_pt_input_for_eval: 0.0
prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
prob_to_use_box_input_for_eval: 0.0
prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
# maximum 2 initial conditioning frames
num_init_cond_frames_for_train: 2
rand_init_cond_frames_for_train: True # random 1~2
num_correction_pt_per_frame: 7
use_act_ckpt_iterative_pt_sampling: false
num_init_cond_frames_for_eval: 1 # only mask on the first frame
forward_backbone_per_frame_for_eval: True
data:
train:
_target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
phases_per_epoch: ${scratch.phases_per_epoch}
batch_sizes:
- ${scratch.train_batch_size}
datasets:
- _target_: training.dataset.utils.RepeatFactorWrapper
dataset:
_target_: training.dataset.utils.ConcatDataset
datasets:
- _target_: training.dataset.vos_dataset.VOSDataset
transforms: ${vos.train_transforms}
training: true
video_dataset:
_target_: training.dataset.vos_raw_dataset.PNGRawDataset
img_folder: ${dataset.img_folder}
gt_folder: ${dataset.gt_folder}
file_list_txt: ${dataset.file_list_txt}
sampler:
_target_: training.dataset.vos_sampler.RandomUniformSampler
num_frames: ${scratch.num_frames}
max_num_objects: ${scratch.max_num_objects}
multiplier: ${dataset.multiplier}
shuffle: True
num_workers: ${scratch.num_train_workers}
pin_memory: True
drop_last: True
collate_fn:
_target_: training.utils.data_utils.collate_fn
_partial_: true
dict_key: all
optim:
amp:
enabled: True
amp_dtype: bfloat16
optimizer:
_target_: torch.optim.AdamW
gradient_clip:
_target_: training.optimizer.GradientClipper
max_norm: 0.1
norm_type: 2
param_group_modifiers:
- _target_: training.optimizer.layer_decay_param_modifier
_partial_: True
layer_decay_value: 0.9
apply_to: 'image_encoder.trunk'
overrides:
- pattern: '*pos_embed*'
value: 1.0
options:
lr:
- scheduler:
_target_: fvcore.common.param_scheduler.CosineParamScheduler
start_value: ${scratch.base_lr}
end_value: ${divide:${scratch.base_lr},10}
- scheduler:
_target_: fvcore.common.param_scheduler.CosineParamScheduler
start_value: ${scratch.vision_lr}
end_value: ${divide:${scratch.vision_lr},10}
param_names:
- 'image_encoder.*'
weight_decay:
- scheduler:
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
value: 0.1
- scheduler:
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
value: 0.0
param_names:
- '*bias*'
module_cls_names: ['torch.nn.LayerNorm']
loss:
all:
_target_: training.loss_fns.MultiStepMultiMasksAndIous
weight_dict:
loss_mask: 20
loss_dice: 1
loss_iou: 1
loss_class: 1
supervise_all_iou: true
iou_use_l1_loss: true
pred_obj_scores: true
focal_gamma_obj_score: 0.0
focal_alpha_obj_score: -1.0
distributed:
backend: nccl
find_unused_parameters: True
logging:
tensorboard_writer:
_target_: training.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
log_dir: ${launcher.experiment_log_dir}/logs
log_freq: 10
# initialize from a SAM 2 checkpoint
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
model_weight_initializer:
_partial_: True
_target_: training.utils.checkpoint_utils.load_state_dict_into_model
strict: True
ignore_unexpected_keys: null
ignore_missing_keys: null
state_dict:
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
ckpt_state_dict_keys: ['model']
launcher:
num_nodes: 1
gpus_per_node: 8
experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
# SLURM args if running on a cluster
submitit:
partition: null
account: null
qos: null
cpus_per_task: 10
use_cluster: false
timeout_hour: 24
name: null
port_range: [10000, 65000]
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/csrc/connected_components.cu
================================================
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
// adapted from https://github.com/zsef123/Connected_components_PyTorch
// with license found in the LICENSE_cctorch file in the root directory.
#include
#include
#include
#include
#include
#include
// 2d
#define BLOCK_ROWS 16
#define BLOCK_COLS 16
namespace cc2d {
template
__device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) {
return (bitmap >> pos) & 1;
}
__device__ int32_t find(const int32_t* s_buf, int32_t n) {
while (s_buf[n] != n)
n = s_buf[n];
return n;
}
__device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) {
const int32_t id = n;
while (s_buf[n] != n) {
n = s_buf[n];
s_buf[id] = n;
}
return n;
}
__device__ void union_(int32_t* s_buf, int32_t a, int32_t b) {
bool done;
do {
a = find(s_buf, a);
b = find(s_buf, b);
if (a < b) {
int32_t old = atomicMin(s_buf + b, a);
done = (old == b);
b = old;
} else if (b < a) {
int32_t old = atomicMin(s_buf + a, b);
done = (old == a);
a = old;
} else
done = true;
} while (!done);
}
__global__ void
init_labeling(int32_t* label, const uint32_t W, const uint32_t H) {
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
const uint32_t idx = row * W + col;
if (row < H && col < W)
label[idx] = idx;
}
__global__ void
merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) {
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
const uint32_t idx = row * W + col;
if (row >= H || col >= W)
return;
uint32_t P = 0;
if (img[idx])
P |= 0x777;
if (row + 1 < H && img[idx + W])
P |= 0x777 << 4;
if (col + 1 < W && img[idx + 1])
P |= 0x777 << 1;
if (col == 0)
P &= 0xEEEE;
if (col + 1 >= W)
P &= 0x3333;
else if (col + 2 >= W)
P &= 0x7777;
if (row == 0)
P &= 0xFFF0;
if (row + 1 >= H)
P &= 0xFF;
if (P > 0) {
// If need check about top-left pixel(if flag the first bit) and hit the
// top-left pixel
if (hasBit(P, 0) && img[idx - W - 1]) {
union_(label, idx, idx - 2 * W - 2); // top left block
}
if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1]))
union_(label, idx, idx - 2 * W); // top bottom block
if (hasBit(P, 3) && img[idx + 2 - W])
union_(label, idx, idx - 2 * W + 2); // top right block
if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1]))
union_(label, idx, idx - 2); // just left block
}
}
__global__ void compression(int32_t* label, const int32_t W, const int32_t H) {
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
const uint32_t idx = row * W + col;
if (row < H && col < W)
find_n_compress(label, idx);
}
__global__ void final_labeling(
const uint8_t* img,
int32_t* label,
const int32_t W,
const int32_t H) {
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
const uint32_t idx = row * W + col;
if (row >= H || col >= W)
return;
int32_t y = label[idx] + 1;
if (img[idx])
label[idx] = y;
else
label[idx] = 0;
if (col + 1 < W) {
if (img[idx + 1])
label[idx + 1] = y;
else
label[idx + 1] = 0;
if (row + 1 < H) {
if (img[idx + W + 1])
label[idx + W + 1] = y;
else
label[idx + W + 1] = 0;
}
}
if (row + 1 < H) {
if (img[idx + W])
label[idx + W] = y;
else
label[idx + W] = 0;
}
}
__global__ void init_counting(
const int32_t* label,
int32_t* count_init,
const int32_t W,
const int32_t H) {
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
const uint32_t idx = row * W + col;
if (row >= H || col >= W)
return;
int32_t y = label[idx];
if (y > 0) {
int32_t count_idx = y - 1;
atomicAdd(count_init + count_idx, 1);
}
}
__global__ void final_counting(
const int32_t* label,
const int32_t* count_init,
int32_t* count_final,
const int32_t W,
const int32_t H) {
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
const uint32_t idx = row * W + col;
if (row >= H || col >= W)
return;
int32_t y = label[idx];
if (y > 0) {
int32_t count_idx = y - 1;
count_final[idx] = count_init[count_idx];
} else {
count_final[idx] = 0;
}
}
} // namespace cc2d
std::vector get_connected_componnets(
const torch::Tensor& inputs) {
AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor");
AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape");
AT_ASSERTM(
inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type");
const uint32_t N = inputs.size(0);
const uint32_t C = inputs.size(1);
const uint32_t H = inputs.size(2);
const uint32_t W = inputs.size(3);
AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
AT_ASSERTM((H % 2) == 0, "height must be an even number");
AT_ASSERTM((W % 2) == 0, "width must be an even number");
// label must be uint32_t
auto label_options =
torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());
torch::Tensor labels = torch::zeros({N, C, H, W}, label_options);
torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options);
torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options);
dim3 grid = dim3(
((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS,
((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS);
dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS);
dim3 grid_count =
dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS);
dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
for (int n = 0; n < N; n++) {
uint32_t offset = n * H * W;
cc2d::init_labeling<<>>(
labels.data_ptr() + offset, W, H);
cc2d::merge<<>>(
inputs.data_ptr() + offset,
labels.data_ptr() + offset,
W,
H);
cc2d::compression<<>>(
labels.data_ptr() + offset, W, H);
cc2d::final_labeling<<>>(
inputs.data_ptr() + offset,
labels.data_ptr() + offset,
W,
H);
// get the counting of each pixel
cc2d::init_counting<<>>(
labels.data_ptr() + offset,
counts_init.data_ptr() + offset,
W,
H);
cc2d::final_counting<<>>(
labels.data_ptr() + offset,
counts_init.data_ptr() + offset,
counts_final.data_ptr() + offset,
W,
H);
}
// returned values are [labels, counts]
std::vector outputs;
outputs.push_back(labels);
outputs.push_back(counts_final);
return outputs;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"get_connected_componnets",
&get_connected_componnets,
"get_connected_componnets");
}
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/backbones/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/backbones/hieradet.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from functools import partial
from typing import List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from iopath.common.file_io import g_pathmgr
from sam2.modeling.backbones.utils import (
PatchEmbed,
window_partition,
window_unpartition,
)
from sam2.modeling.sam2_utils import DropPath, MLP
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
if pool is None:
return x
# (B, H, W, C) -> (B, C, H, W)
x = x.permute(0, 3, 1, 2)
x = pool(x)
# (B, C, H', W') -> (B, H', W', C)
x = x.permute(0, 2, 3, 1)
if norm:
x = norm(x)
return x
class MultiScaleAttention(nn.Module):
def __init__(
self,
dim: int,
dim_out: int,
num_heads: int,
q_pool: nn.Module = None,
):
super().__init__()
self.dim = dim
self.dim_out = dim_out
self.num_heads = num_heads
self.q_pool = q_pool
self.qkv = nn.Linear(dim, dim_out * 3)
self.proj = nn.Linear(dim_out, dim_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (B, H * W, 3, nHead, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
# q, k, v with shape (B, H * W, nheads, C)
q, k, v = torch.unbind(qkv, 2)
# Q pooling (for downsample at stage changes)
if self.q_pool:
q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
H, W = q.shape[1:3] # downsampled shape
q = q.reshape(B, H * W, self.num_heads, -1)
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
x = F.scaled_dot_product_attention(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
)
# Transpose back
x = x.transpose(1, 2)
x = x.reshape(B, H, W, -1)
x = self.proj(x)
return x
class MultiScaleBlock(nn.Module):
def __init__(
self,
dim: int,
dim_out: int,
num_heads: int,
mlp_ratio: float = 4.0,
drop_path: float = 0.0,
norm_layer: Union[nn.Module, str] = "LayerNorm",
q_stride: Tuple[int, int] = None,
act_layer: nn.Module = nn.GELU,
window_size: int = 0,
):
super().__init__()
if isinstance(norm_layer, str):
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
self.dim = dim
self.dim_out = dim_out
self.norm1 = norm_layer(dim)
self.window_size = window_size
self.pool, self.q_stride = None, q_stride
if self.q_stride:
self.pool = nn.MaxPool2d(
kernel_size=q_stride, stride=q_stride, ceil_mode=False
)
self.attn = MultiScaleAttention(
dim,
dim_out,
num_heads=num_heads,
q_pool=self.pool,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim_out)
self.mlp = MLP(
dim_out,
int(dim_out * mlp_ratio),
dim_out,
num_layers=2,
activation=act_layer,
)
if dim != dim_out:
self.proj = nn.Linear(dim, dim_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x # B, H, W, C
x = self.norm1(x)
# Skip connection
if self.dim != self.dim_out:
shortcut = do_pool(self.proj(x), self.pool)
# Window partition
window_size = self.window_size
if window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, window_size)
# Window Attention + Q Pooling (if stage change)
x = self.attn(x)
if self.q_stride:
# Shapes have changed due to Q pooling
window_size = self.window_size // self.q_stride[0]
H, W = shortcut.shape[1:3]
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
pad_hw = (H + pad_h, W + pad_w)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, window_size, pad_hw, (H, W))
x = shortcut + self.drop_path(x)
# MLP
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Hiera(nn.Module):
"""
Reference: https://arxiv.org/abs/2306.00989
"""
def __init__(
self,
embed_dim: int = 96, # initial embed dim
num_heads: int = 1, # initial number of heads
drop_path_rate: float = 0.0, # stochastic depth
q_pool: int = 3, # number of q_pool stages
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
dim_mul: float = 2.0, # dim_mul factor at stage shift
head_mul: float = 2.0, # head_mul factor at stage shift
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
# window size per stage, when not using global att.
window_spec: Tuple[int, ...] = (
8,
4,
14,
7,
),
# global attn in these blocks
global_att_blocks: Tuple[int, ...] = (
12,
16,
20,
),
weights_path=None,
return_interm_layers=True, # return feats from every stage
):
super().__init__()
assert len(stages) == len(window_spec)
self.window_spec = window_spec
depth = sum(stages)
self.q_stride = q_stride
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
assert 0 <= q_pool <= len(self.stage_ends[:-1])
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
self.return_interm_layers = return_interm_layers
self.patch_embed = PatchEmbed(
embed_dim=embed_dim,
)
# Which blocks have global att?
self.global_att_blocks = global_att_blocks
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
self.pos_embed = nn.Parameter(
torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
)
self.pos_embed_window = nn.Parameter(
torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
cur_stage = 1
self.blocks = nn.ModuleList()
for i in range(depth):
dim_out = embed_dim
# lags by a block, so first block of
# next stage uses an initial window size
# of previous stage and final window size of current stage
window_size = self.window_spec[cur_stage - 1]
if self.global_att_blocks is not None:
window_size = 0 if i in self.global_att_blocks else window_size
if i - 1 in self.stage_ends:
dim_out = int(embed_dim * dim_mul)
num_heads = int(num_heads * head_mul)
cur_stage += 1
block = MultiScaleBlock(
dim=embed_dim,
dim_out=dim_out,
num_heads=num_heads,
drop_path=dpr[i],
q_stride=self.q_stride if i in self.q_pool_blocks else None,
window_size=window_size,
)
embed_dim = dim_out
self.blocks.append(block)
self.channel_list = (
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
if return_interm_layers
else [self.blocks[-1].dim_out]
)
if weights_path is not None:
with g_pathmgr.open(weights_path, "rb") as f:
chkpt = torch.load(f, map_location="cpu")
logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
h, w = hw
window_embed = self.pos_embed_window
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
pos_embed = pos_embed + window_embed.tile(
[x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
)
pos_embed = pos_embed.permute(0, 2, 3, 1)
return pos_embed
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.patch_embed(x)
# x: (B, H, W, C)
# Add pos embed
x = x + self._get_pos_embed(x.shape[1:3])
outputs = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if (i == self.stage_ends[-1]) or (
i in self.stage_ends and self.return_interm_layers
):
feats = x.permute(0, 3, 1, 2)
outputs.append(feats)
return outputs
def get_layer_id(self, layer_name):
# https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
num_layers = self.get_num_layers()
if layer_name.find("rel_pos") != -1:
return num_layers + 1
elif layer_name.find("pos_embed") != -1:
return 0
elif layer_name.find("patch_embed") != -1:
return 0
elif layer_name.find("blocks") != -1:
return int(layer_name.split("blocks")[1].split(".")[1]) + 1
else:
return num_layers + 1
def get_num_layers(self) -> int:
return len(self.blocks)
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/backbones/image_encoder.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class ImageEncoder(nn.Module):
def __init__(
self,
trunk: nn.Module,
neck: nn.Module,
scalp: int = 0,
):
super().__init__()
self.trunk = trunk
self.neck = neck
self.scalp = scalp
assert (
self.trunk.channel_list == self.neck.backbone_channel_list
), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
def forward(self, sample: torch.Tensor):
# Forward through backbone
features, pos = self.neck(self.trunk(sample))
if self.scalp > 0:
# Discard the lowest resolution features
features, pos = features[: -self.scalp], pos[: -self.scalp]
src = features[-1]
output = {
"vision_features": src,
"vision_pos_enc": pos,
"backbone_fpn": features,
}
return output
class FpnNeck(nn.Module):
"""
A modified variant of Feature Pyramid Network (FPN) neck
(we remove output conv and also do bicubic interpolation similar to ViT
pos embed interpolation)
"""
def __init__(
self,
position_encoding: nn.Module,
d_model: int,
backbone_channel_list: List[int],
kernel_size: int = 1,
stride: int = 1,
padding: int = 0,
fpn_interp_model: str = "bilinear",
fuse_type: str = "sum",
fpn_top_down_levels: Optional[List[int]] = None,
):
"""Initialize the neck
:param trunk: the backbone
:param position_encoding: the positional encoding to use
:param d_model: the dimension of the model
:param neck_norm: the normalization to use
"""
super().__init__()
self.position_encoding = position_encoding
self.convs = nn.ModuleList()
self.backbone_channel_list = backbone_channel_list
self.d_model = d_model
for dim in backbone_channel_list:
current = nn.Sequential()
current.add_module(
"conv",
nn.Conv2d(
in_channels=dim,
out_channels=d_model,
kernel_size=kernel_size,
stride=stride,
padding=padding,
),
)
self.convs.append(current)
self.fpn_interp_model = fpn_interp_model
assert fuse_type in ["sum", "avg"]
self.fuse_type = fuse_type
# levels to have top-down features in its outputs
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
# have top-down propagation, while outputs of level 0 and level 1 have only
# lateral features from the same backbone level.
if fpn_top_down_levels is None:
# default is to have top-down features on all levels
fpn_top_down_levels = range(len(self.convs))
self.fpn_top_down_levels = list(fpn_top_down_levels)
def forward(self, xs: List[torch.Tensor]):
out = [None] * len(self.convs)
pos = [None] * len(self.convs)
assert len(xs) == len(self.convs)
# fpn forward pass
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
prev_features = None
# forward in top-down order (from low to high resolution)
n = len(self.convs) - 1
for i in range(n, -1, -1):
x = xs[i]
lateral_features = self.convs[n - i](x)
if i in self.fpn_top_down_levels and prev_features is not None:
top_down_features = F.interpolate(
prev_features.to(dtype=torch.float32),
scale_factor=2.0,
mode=self.fpn_interp_model,
align_corners=(
None if self.fpn_interp_model == "nearest" else False
),
antialias=False,
)
prev_features = lateral_features + top_down_features
if self.fuse_type == "avg":
prev_features /= 2
else:
prev_features = lateral_features
x_out = prev_features
out[i] = x_out
pos[i] = self.position_encoding(x_out).to(x_out.dtype)
return out, pos
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/backbones/utils.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Some utilities for backbones, in particular for windowing"""
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
def window_partition(x, window_size):
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(windows, window_size, pad_hw, hw):
"""
Window unpartition into original sequences and removing padding.
Args:
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.reshape(
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :]
return x
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, ...] = (7, 7),
stride: Tuple[int, ...] = (4, 4),
padding: Tuple[int, ...] = (3, 3),
in_chans: int = 3,
embed_dim: int = 768,
):
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/memory_attention.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
from torch import nn, Tensor
from sam2.modeling.sam.transformer import RoPEAttention
from sam2.modeling.sam2_utils import get_activation_fn, get_clones
class MemoryAttentionLayer(nn.Module):
def __init__(
self,
activation: str,
cross_attention: nn.Module,
d_model: int,
dim_feedforward: int,
dropout: float,
pos_enc_at_attn: bool,
pos_enc_at_cross_attn_keys: bool,
pos_enc_at_cross_attn_queries: bool,
self_attention: nn.Module,
):
super().__init__()
self.d_model = d_model
self.dim_feedforward = dim_feedforward
self.dropout_value = dropout
self.self_attn = self_attention
self.cross_attn_image = cross_attention
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation_str = activation
self.activation = get_activation_fn(activation)
# Where to add pos enc
self.pos_enc_at_attn = pos_enc_at_attn
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
def _forward_sa(self, tgt, query_pos):
# Self-Attention
tgt2 = self.norm1(tgt)
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
tgt2 = self.self_attn(q, k, v=tgt2)
tgt = tgt + self.dropout1(tgt2)
return tgt
def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
kwds = {}
if num_k_exclude_rope > 0:
assert isinstance(self.cross_attn_image, RoPEAttention)
kwds = {"num_k_exclude_rope": num_k_exclude_rope}
# Cross-Attention
tgt2 = self.norm2(tgt)
tgt2 = self.cross_attn_image(
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
v=memory,
**kwds,
)
tgt = tgt + self.dropout2(tgt2)
return tgt
def forward(
self,
tgt,
memory,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
num_k_exclude_rope: int = 0,
) -> torch.Tensor:
# Self-Attn, Cross-Attn
tgt = self._forward_sa(tgt, query_pos)
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
# MLP
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
class MemoryAttention(nn.Module):
def __init__(
self,
d_model: int,
pos_enc_at_input: bool,
layer: nn.Module,
num_layers: int,
batch_first: bool = True, # Do layers expect batch first input?
):
super().__init__()
self.d_model = d_model
self.layers = get_clones(layer, num_layers)
self.num_layers = num_layers
self.norm = nn.LayerNorm(d_model)
self.pos_enc_at_input = pos_enc_at_input
self.batch_first = batch_first
def forward(
self,
curr: torch.Tensor, # self-attention inputs
memory: torch.Tensor, # cross-attention inputs
curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
):
if isinstance(curr, list):
assert isinstance(curr_pos, list)
assert len(curr) == len(curr_pos) == 1
curr, curr_pos = (
curr[0],
curr_pos[0],
)
assert (
curr.shape[1] == memory.shape[1]
), "Batch size must be the same for curr and memory"
output = curr
if self.pos_enc_at_input and curr_pos is not None:
output = output + 0.1 * curr_pos
if self.batch_first:
# Convert to batch first
output = output.transpose(0, 1)
curr_pos = curr_pos.transpose(0, 1)
memory = memory.transpose(0, 1)
memory_pos = memory_pos.transpose(0, 1)
for layer in self.layers:
kwds = {}
if isinstance(layer.cross_attn_image, RoPEAttention):
kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
output = layer(
tgt=output,
memory=memory,
pos=memory_pos,
query_pos=curr_pos,
**kwds,
)
normed_output = self.norm(output)
if self.batch_first:
# Convert back to seq first
normed_output = normed_output.transpose(0, 1)
curr_pos = curr_pos.transpose(0, 1)
return normed_output
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/memory_encoder.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
class MaskDownSampler(nn.Module):
"""
Progressively downsample a mask by total_stride, each time by stride.
Note that LayerNorm is applied per *token*, like in ViT.
With each downsample (by a factor stride**2), channel capacity increases by the same factor.
In the end, we linearly project to embed_dim channels.
"""
def __init__(
self,
embed_dim=256,
kernel_size=4,
stride=4,
padding=0,
total_stride=16,
activation=nn.GELU,
):
super().__init__()
num_layers = int(math.log2(total_stride) // math.log2(stride))
assert stride**num_layers == total_stride
self.encoder = nn.Sequential()
mask_in_chans, mask_out_chans = 1, 1
for _ in range(num_layers):
mask_out_chans = mask_in_chans * (stride**2)
self.encoder.append(
nn.Conv2d(
mask_in_chans,
mask_out_chans,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
)
self.encoder.append(LayerNorm2d(mask_out_chans))
self.encoder.append(activation())
mask_in_chans = mask_out_chans
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
def forward(self, x):
return self.encoder(x)
# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
class CXBlock(nn.Module):
r"""ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(
self,
dim,
kernel_size=7,
padding=3,
drop_path=0.0,
layer_scale_init_value=1e-6,
use_dwconv=True,
):
super().__init__()
self.dwconv = nn.Conv2d(
dim,
dim,
kernel_size=kernel_size,
padding=padding,
groups=dim if use_dwconv else 1,
) # depthwise conv
self.norm = LayerNorm2d(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, 4 * dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
if layer_scale_init_value > 0
else None
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
input = x
x = self.dwconv(x)
x = self.norm(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
class Fuser(nn.Module):
def __init__(self, layer, num_layers, dim=None, input_projection=False):
super().__init__()
self.proj = nn.Identity()
self.layers = get_clones(layer, num_layers)
if input_projection:
assert dim is not None
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
def forward(self, x):
# normally x: (N, C, H, W)
x = self.proj(x)
for layer in self.layers:
x = layer(x)
return x
class MemoryEncoder(nn.Module):
def __init__(
self,
out_dim,
mask_downsampler,
fuser,
position_encoding,
in_dim=256, # in_dim of pix_feats
):
super().__init__()
self.mask_downsampler = mask_downsampler
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
self.fuser = fuser
self.position_encoding = position_encoding
self.out_proj = nn.Identity()
if out_dim != in_dim:
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
def forward(
self,
pix_feat: torch.Tensor,
masks: torch.Tensor,
skip_mask_sigmoid: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
## Process masks
# sigmoid, so that less domain shift from gt masks which are bool
if not skip_mask_sigmoid:
masks = F.sigmoid(masks)
masks = self.mask_downsampler(masks)
## Fuse pix_feats and downsampled masks
# in case the visual features are on CPU, cast them to CUDA
pix_feat = pix_feat.to(masks.device)
x = self.pix_feat_proj(pix_feat)
x = x + masks
x = self.fuser(x)
x = self.out_proj(x)
pos = self.position_encoding(x).to(x.dtype)
return {"vision_features": x, "vision_pos_enc": [pos]}
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/position_encoding.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Any, Optional, Tuple
import numpy as np
import torch
from torch import nn
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention Is All You Need paper, generalized to work on images.
"""
def __init__(
self,
num_pos_feats,
temperature: int = 10000,
normalize: bool = True,
scale: Optional[float] = None,
# Following settings only relevant
# for warmping up cache for compilation
warmup_cache: bool = True,
image_size: int = 1024,
strides: Tuple[int] = (4, 8, 16, 32),
):
super().__init__()
assert num_pos_feats % 2 == 0, "Expecting even model width"
self.num_pos_feats = num_pos_feats // 2
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
self.cache = {}
if warmup_cache and torch.cuda.is_available():
# Warmup cache for cuda, to help with compilation
device = torch.device("cuda")
for stride in strides:
cache_key = (image_size // stride, image_size // stride)
self._pe(1, device, *cache_key)
def _encode_xy(self, x, y):
# The positions are expected to be normalized
assert len(x) == len(y) and x.ndim == y.ndim == 1
x_embed = x * self.scale
y_embed = y * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, None] / dim_t
pos_y = y_embed[:, None] / dim_t
pos_x = torch.stack(
(pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
).flatten(1)
pos_y = torch.stack(
(pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
).flatten(1)
return pos_x, pos_y
@torch.no_grad()
def encode_boxes(self, x, y, w, h):
pos_x, pos_y = self._encode_xy(x, y)
pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
return pos
encode = encode_boxes # Backwards compatibility
@torch.no_grad()
def encode_points(self, x, y, labels):
(bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
assert bx == by and nx == ny and bx == bl and nx == nl
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
return pos
@torch.no_grad()
def _pe(self, B, device, *cache_key):
H, W = cache_key
if cache_key in self.cache:
return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
y_embed = (
torch.arange(1, H + 1, dtype=torch.float32, device=device)
.view(1, -1, 1)
.repeat(B, 1, W)
)
x_embed = (
torch.arange(1, W + 1, dtype=torch.float32, device=device)
.view(1, 1, -1)
.repeat(B, H, 1)
)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
self.cache[cache_key] = pos[0]
return pos
@torch.no_grad()
def forward(self, x: torch.Tensor):
B = x.shape[0]
cache_key = (x.shape[-2], x.shape[-1])
return self._pe(B, x.device, *cache_key)
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.register_buffer(
"positional_encoding_gaussian_matrix",
scale * torch.randn((2, num_pos_feats)),
)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
"""Generate positional encoding for a grid of the specified size."""
h, w = size
device: Any = self.positional_encoding_gaussian_matrix.device
grid = torch.ones((h, w), device=device, dtype=torch.float32)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / h
x_embed = x_embed / w
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
return pe.permute(2, 0, 1) # C x H x W
def forward_with_coords(
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
) -> torch.Tensor:
"""Positionally encode points that are not normalized to [0,1]."""
coords = coords_input.clone()
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
return self._pe_encoding(coords.to(torch.float)) # B x N x C
# Rotary Positional Encoding, adapted from:
# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
# 2. https://github.com/naver-ai/rope-vit
# 3. https://github.com/lucidrains/rotary-embedding-torch
def init_t_xy(end_x: int, end_y: int):
t = torch.arange(end_x * end_y, dtype=torch.float32)
t_x = (t % end_x).float()
t_y = torch.div(t, end_x, rounding_mode="floor").float()
return t_x, t_y
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
t_x, t_y = init_t_xy(end_x, end_y)
freqs_x = torch.outer(t_x, freqs_x)
freqs_y = torch.outer(t_y, freqs_y)
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_enc(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
repeat_freqs_k: bool = False,
):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = (
torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
if xk.shape[-2] != 0
else None
)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
if xk_ is None:
# no keys to rotate, due to dropout
return xq_out.type_as(xq).to(xq.device), xk
# repeat freqs along seq_len dim to match k seq_len
if repeat_freqs_k:
r = xk_.shape[-2] // xq_.shape[-2]
if freqs_cis.is_cuda:
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
else:
# torch.repeat on complex numbers may not be supported on non-CUDA devices
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/sam/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/sam/mask_decoder.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional, Tuple, Type
import torch
from torch import nn
from sam2.modeling.sam2_utils import LayerNorm2d, MLP
class MaskDecoder(nn.Module):
def __init__(
self,
*,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int = 3,
activation: Type[nn.Module] = nn.GELU,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
use_high_res_features: bool = False,
iou_prediction_use_sigmoid=False,
dynamic_multimask_via_stability=False,
dynamic_multimask_stability_delta=0.05,
dynamic_multimask_stability_thresh=0.98,
pred_obj_scores: bool = False,
pred_obj_scores_mlp: bool = False,
use_multimask_token_for_obj_ptr: bool = False,
) -> None:
"""
Predicts masks given an image and prompt embeddings, using a
transformer architecture.
Arguments:
transformer_dim (int): the channel dimension of the transformer
transformer (nn.Module): the transformer used to predict masks
num_multimask_outputs (int): the number of masks to predict
when disambiguating masks
activation (nn.Module): the type of activation to use when
upscaling masks
iou_head_depth (int): the depth of the MLP used to predict
mask quality
iou_head_hidden_dim (int): the hidden dimension of the MLP
used to predict mask quality
"""
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer
self.num_multimask_outputs = num_multimask_outputs
self.iou_token = nn.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
self.pred_obj_scores = pred_obj_scores
if self.pred_obj_scores:
self.obj_score_token = nn.Embedding(1, transformer_dim)
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
),
activation(),
)
self.use_high_res_features = use_high_res_features
if use_high_res_features:
self.conv_s0 = nn.Conv2d(
transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
)
self.conv_s1 = nn.Conv2d(
transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
)
self.output_hypernetworks_mlps = nn.ModuleList(
[
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
for i in range(self.num_mask_tokens)
]
)
self.iou_prediction_head = MLP(
transformer_dim,
iou_head_hidden_dim,
self.num_mask_tokens,
iou_head_depth,
sigmoid_output=iou_prediction_use_sigmoid,
)
if self.pred_obj_scores:
self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
if pred_obj_scores_mlp:
self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
# When outputting a single mask, optionally we can dynamically fall back to the best
# multimask output token if the single mask output token gives low stability scores.
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
repeat_image: bool,
high_res_features: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single
mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
torch.Tensor: batched SAM token for mask output
"""
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
repeat_image=repeat_image,
high_res_features=high_res_features,
)
# Select the correct mask or masks for output
if multimask_output:
masks = masks[:, 1:, :, :]
iou_pred = iou_pred[:, 1:]
elif self.dynamic_multimask_via_stability and not self.training:
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
else:
masks = masks[:, 0:1, :, :]
iou_pred = iou_pred[:, 0:1]
if multimask_output and self.use_multimask_token_for_obj_ptr:
sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
else:
# Take the mask output token. Here we *always* use the token for single mask output.
# At test time, even if we track after 1-click (and using multimask_output=True),
# we still take the single mask token here. The rationale is that we always track
# after multiple clicks during training, so the past tokens seen during training
# are always the single mask token (and we'll let it be the object-memory token).
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
# Prepare output
return masks, iou_pred, sam_tokens_out, object_score_logits
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
repeat_image: bool,
high_res_features: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
s = 0
if self.pred_obj_scores:
output_tokens = torch.cat(
[
self.obj_score_token.weight,
self.iou_token.weight,
self.mask_tokens.weight,
],
dim=0,
)
s = 1
else:
output_tokens = torch.cat(
[self.iou_token.weight, self.mask_tokens.weight], dim=0
)
output_tokens = output_tokens.unsqueeze(0).expand(
sparse_prompt_embeddings.size(0), -1, -1
)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
if repeat_image:
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
else:
assert image_embeddings.shape[0] == tokens.shape[0]
src = image_embeddings
src = src + dense_prompt_embeddings
assert (
image_pe.size(0) == 1
), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, s, :]
mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
if not self.use_high_res_features:
upscaled_embedding = self.output_upscaling(src)
else:
dc1, ln1, act1, dc2, act2 = self.output_upscaling
feat_s0, feat_s1 = high_res_features
upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
)
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
if self.pred_obj_scores:
assert s == 1
object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
else:
# Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
return masks, iou_pred, mask_tokens_out, object_score_logits
def _get_stability_scores(self, mask_logits):
"""
Compute stability scores of the mask logits based on the IoU between upper and
lower thresholds.
"""
mask_logits = mask_logits.flatten(-2)
stability_delta = self.dynamic_multimask_stability_delta
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
return stability_scores
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
"""
When outputting a single mask, if the stability score from the current single-mask
output (based on output token 0) falls below a threshold, we instead select from
multi-mask outputs (based on output token 1~3) the mask with the highest predicted
IoU score. This is intended to ensure a valid mask for both clicking and tracking.
"""
# The best mask from multimask output tokens (1~3)
multimask_logits = all_mask_logits[:, 1:, :, :]
multimask_iou_scores = all_iou_scores[:, 1:]
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
batch_inds = torch.arange(
multimask_iou_scores.size(0), device=all_iou_scores.device
)
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
best_multimask_logits = best_multimask_logits.unsqueeze(1)
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
# The mask from singlemask output token 0 and its stability score
singlemask_logits = all_mask_logits[:, 0:1, :, :]
singlemask_iou_scores = all_iou_scores[:, 0:1]
stability_scores = self._get_stability_scores(singlemask_logits)
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
# Dynamically fall back to best multimask output upon low stability scores.
mask_logits_out = torch.where(
is_stable[..., None, None].expand_as(singlemask_logits),
singlemask_logits,
best_multimask_logits,
)
iou_scores_out = torch.where(
is_stable.expand_as(singlemask_iou_scores),
singlemask_iou_scores,
best_multimask_iou_scores,
)
return mask_logits_out, iou_scores_out
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/sam/prompt_encoder.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple, Type
import torch
from torch import nn
from sam2.modeling.position_encoding import PositionEmbeddingRandom
from sam2.modeling.sam2_utils import LayerNorm2d
class PromptEncoder(nn.Module):
def __init__(
self,
embed_dim: int,
image_embedding_size: Tuple[int, int],
input_image_size: Tuple[int, int],
mask_in_chans: int,
activation: Type[nn.Module] = nn.GELU,
) -> None:
"""
Encodes prompts for input to SAM's mask decoder.
Arguments:
embed_dim (int): The prompts' embedding dimension
image_embedding_size (tuple(int, int)): The spatial size of the
image embedding, as (H, W).
input_image_size (int): The padded size of the image as input
to the image encoder, as (H, W).
mask_in_chans (int): The number of hidden channels used for
encoding input masks.
activation (nn.Module): The activation to use when encoding
input masks.
"""
super().__init__()
self.embed_dim = embed_dim
self.input_image_size = input_image_size
self.image_embedding_size = image_embedding_size
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
point_embeddings = [
nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
]
self.point_embeddings = nn.ModuleList(point_embeddings)
self.not_a_point_embed = nn.Embedding(1, embed_dim)
self.mask_input_size = (
4 * image_embedding_size[0],
4 * image_embedding_size[1],
)
self.mask_downscaling = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4),
activation(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans),
activation(),
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)
self.no_mask_embed = nn.Embedding(1, embed_dim)
def get_dense_pe(self) -> torch.Tensor:
"""
Returns the positional encoding used to encode point prompts,
applied to a dense set of points the shape of the image encoding.
Returns:
torch.Tensor: Positional encoding with shape
1x(embed_dim)x(embedding_h)x(embedding_w)
"""
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
def _embed_points(
self,
points: torch.Tensor,
labels: torch.Tensor,
pad: bool,
) -> torch.Tensor:
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
point_embedding = self.pe_layer.forward_with_coords(
points, self.input_image_size
)
point_embedding = torch.where(
(labels == -1).unsqueeze(-1),
torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 0).unsqueeze(-1),
point_embedding + self.point_embeddings[0].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 1).unsqueeze(-1),
point_embedding + self.point_embeddings[1].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 2).unsqueeze(-1),
point_embedding + self.point_embeddings[2].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 3).unsqueeze(-1),
point_embedding + self.point_embeddings[3].weight,
point_embedding,
)
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(
coords, self.input_image_size
)
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
return corner_embedding
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
"""Embeds mask inputs."""
mask_embedding = self.mask_downscaling(masks)
return mask_embedding
def _get_batch_size(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> int:
"""
Gets the batch size of the output given the batch size of the input prompts.
"""
if points is not None:
return points[0].shape[0]
elif boxes is not None:
return boxes.shape[0]
elif masks is not None:
return masks.shape[0]
else:
return 1
def _get_device(self) -> torch.device:
return self.point_embeddings[0].weight.device
def forward(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.
Arguments:
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
and labels to embed.
boxes (torch.Tensor or none): boxes to embed
masks (torch.Tensor or none): masks to embed
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty(
(bs, 0, self.embed_dim), device=self._get_device()
)
if points is not None:
coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
if boxes is not None:
box_embeddings = self._embed_boxes(boxes)
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
return sparse_embeddings, dense_embeddings
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/sam/transformer.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from functools import partial
from typing import Tuple, Type
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
from sam2.modeling.sam2_utils import MLP
class TwoWayTransformer(nn.Module):
def __init__(
self,
depth: int,
embedding_dim: int,
num_heads: int,
mlp_dim: int,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
) -> None:
"""
A transformer decoder that attends to an input image using
queries whose positional embedding is supplied.
Args:
depth (int): number of layers in the transformer
embedding_dim (int): the channel dimension for the input embeddings
num_heads (int): the number of heads for multihead attention. Must
divide embedding_dim
mlp_dim (int): the channel dimension internal to the MLP block
activation (nn.Module): the activation to use in the MLP block
"""
super().__init__()
self.depth = depth
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.mlp_dim = mlp_dim
self.layers = nn.ModuleList()
for i in range(depth):
self.layers.append(
TwoWayAttentionBlock(
embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
activation=activation,
attention_downsample_rate=attention_downsample_rate,
skip_first_layer_pe=(i == 0),
)
)
self.final_attn_token_to_image = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.norm_final_attn = nn.LayerNorm(embedding_dim)
def forward(
self,
image_embedding: Tensor,
image_pe: Tensor,
point_embedding: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape
B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must
have the same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
Returns:
torch.Tensor: the processed point_embedding
torch.Tensor: the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
# Prepare queries
queries = point_embedding
keys = image_embedding
# Apply transformer blocks and final layernorm
for layer in self.layers:
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
# Apply the final attention layer from the points to the image
q = queries + point_embedding
k = keys + image_pe
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm_final_attn(queries)
return queries, keys
class TwoWayAttentionBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
num_heads: int,
mlp_dim: int = 2048,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
skip_first_layer_pe: bool = False,
) -> None:
"""
A transformer block with four layers: (1) self-attention of sparse
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
block on sparse inputs, and (4) cross attention of dense inputs to sparse
inputs.
Arguments:
embedding_dim (int): the channel dimension of the embeddings
num_heads (int): the number of heads in the attention layers
mlp_dim (int): the hidden dimension of the mlp block
activation (nn.Module): the activation of the mlp block
skip_first_layer_pe (bool): skip the PE on the first layer
"""
super().__init__()
self.self_attn = Attention(embedding_dim, num_heads)
self.norm1 = nn.LayerNorm(embedding_dim)
self.cross_attn_token_to_image = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.norm2 = nn.LayerNorm(embedding_dim)
self.mlp = MLP(
embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
)
self.norm3 = nn.LayerNorm(embedding_dim)
self.norm4 = nn.LayerNorm(embedding_dim)
self.cross_attn_image_to_token = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
) -> Tuple[Tensor, Tensor]:
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(q=queries, k=queries, v=queries)
else:
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)
# Cross attention block, tokens attending to image embedding
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
keys = self.norm4(keys)
return queries, keys
class Attention(nn.Module):
"""
An attention layer that allows for downscaling the size of the embedding
after projection to queries, keys, and values.
"""
def __init__(
self,
embedding_dim: int,
num_heads: int,
downsample_rate: int = 1,
dropout: float = 0.0,
kv_in_dim: int = None,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert (
self.internal_dim % num_heads == 0
), "num_heads must divide embedding_dim."
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
self.dropout_p = dropout
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
b, n, c = x.shape
x = x.reshape(b, n, num_heads, c // num_heads)
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
def _recombine_heads(self, x: Tensor) -> Tensor:
b, n_heads, n_tokens, c_per_head = x.shape
x = x.transpose(1, 2)
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
# Input projections
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# Separate into heads
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)
dropout_p = self.dropout_p if self.training else 0.0
# Attention
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out)
out = self.out_proj(out)
return out
class RoPEAttention(Attention):
"""Attention with rotary position encoding."""
def __init__(
self,
*args,
rope_theta=10000.0,
# whether to repeat q rope to match k length
# this is needed for cross-attention to memories
rope_k_repeat=False,
feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
**kwargs,
):
super().__init__(*args, **kwargs)
self.compute_cis = partial(
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
)
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
self.freqs_cis = (
freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
)
self.rope_k_repeat = rope_k_repeat
def forward(
self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
) -> Tensor:
# Input projections
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# Separate into heads
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)
# Apply rotary position encoding
w = h = math.sqrt(q.shape[-2])
self.freqs_cis = self.freqs_cis.to(q.device)
if self.freqs_cis.shape[0] != q.shape[-2]:
self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
if q.shape[-2] != k.shape[-2]:
assert self.rope_k_repeat
num_k_rope = k.size(-2) - num_k_exclude_rope
q, k[:, :, :num_k_rope] = apply_rotary_enc(
q,
k[:, :, :num_k_rope],
freqs_cis=self.freqs_cis,
repeat_freqs_k=self.rope_k_repeat,
)
dropout_p = self.dropout_p if self.training else 0.0
# Attention
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out)
out = self.out_proj(out)
return out
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/sam2_base.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.distributed
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from sam2.modeling.sam.mask_decoder import MaskDecoder
from sam2.modeling.sam.prompt_encoder import PromptEncoder
from sam2.modeling.sam.transformer import TwoWayTransformer
from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames
# a large negative value as a placeholder score for missing objects
NO_OBJ_SCORE = -1024.0
class SAM2Base(torch.nn.Module):
def __init__(
self,
image_encoder,
memory_attention,
memory_encoder,
num_maskmem=7, # default 1 input frame + 6 previous frames
image_size=512,
backbone_stride=16, # stride of the image backbone output
sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob
sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob
# During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
binarize_mask_from_pts_for_mem_enc=False,
use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
# The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
# we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
# a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
max_cond_frames_in_attn=-1,
# on the first frame, whether to directly add the no-memory embedding to the image feature
# (instead of using the transformer encoder)
directly_add_no_mem_embed=False,
# whether to use high-resolution feature maps in the SAM mask decoder
use_high_res_features_in_sam=False,
# whether to output multiple (3) masks for the first click on initial conditioning frames
multimask_output_in_sam=False,
# the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
# default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
multimask_min_pt_num=1,
multimask_max_pt_num=1,
# whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
multimask_output_for_tracking=False,
# Whether to use multimask tokens for obj ptr; Only relevant when both
# use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
use_multimask_token_for_obj_ptr: bool = False,
# whether to use sigmoid to restrict ious prediction to [0-1]
iou_prediction_use_sigmoid=False,
# The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
# For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
# (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
memory_temporal_stride_for_eval=1,
# whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
non_overlap_masks_for_mem_enc=False,
# whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder=False,
# the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
max_obj_ptrs_in_encoder=16,
# whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
add_tpos_enc_to_obj_ptrs=True,
# whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
# with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
proj_tpos_enc_in_obj_ptrs=False,
# whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers
# (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
use_signed_tpos_enc_to_obj_ptrs=False,
# whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
# (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
only_obj_ptrs_in_the_past_for_eval=False,
# Whether to predict if there is an object in the frame
pred_obj_scores: bool = False,
# Whether to use an MLP to predict object scores
pred_obj_scores_mlp: bool = False,
# Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
# Whether to have a fixed no obj pointer when there is no object present
# or to use it as an additive embedding with obj_ptr produced by decoder
fixed_no_obj_ptr: bool = False,
# Soft no object, i.e. mix in no_obj_ptr softly,
# hope to make recovery easier if there is a mistake and mitigate accumulation of errors
soft_no_obj_ptr: bool = False,
use_mlp_for_obj_ptr_proj: bool = False,
# add no obj embedding to spatial frames
no_obj_embed_spatial: bool = False,
# extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
sam_mask_decoder_extra_args=None,
compile_image_encoder: bool = False,
):
super().__init__()
# Part 1: the image backbone
self.image_encoder = image_encoder
# Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
self.use_high_res_features_in_sam = use_high_res_features_in_sam
self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
if use_obj_ptrs_in_encoder:
# A conv layer to downsample the mask prompt to stride 4 (the same stride as
# low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
# so that it can be fed into the SAM mask decoder to generate a pointer.
self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
if proj_tpos_enc_in_obj_ptrs:
assert add_tpos_enc_to_obj_ptrs # these options need to be used together
self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
# Part 2: memory attention to condition current frame's visual features
# with memories (and obj ptrs) from past frames
self.memory_attention = memory_attention
self.hidden_dim = image_encoder.neck.d_model
# Part 3: memory encoder for the previous frame's outputs
self.memory_encoder = memory_encoder
self.mem_dim = self.hidden_dim
if hasattr(self.memory_encoder, "out_proj") and hasattr(
self.memory_encoder.out_proj, "weight"
):
# if there is compression of memories along channel dim
self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
self.num_maskmem = num_maskmem # Number of memories accessible
# Temporal encoding of the memories
self.maskmem_tpos_enc = torch.nn.Parameter(
torch.zeros(num_maskmem, 1, 1, self.mem_dim)
)
trunc_normal_(self.maskmem_tpos_enc, std=0.02)
# a single token to indicate no memory embedding from previous frames
self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
trunc_normal_(self.no_mem_embed, std=0.02)
trunc_normal_(self.no_mem_pos_enc, std=0.02)
self.directly_add_no_mem_embed = directly_add_no_mem_embed
# Apply sigmoid to the output raw mask logits (to turn them from
# range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
# On frames with mask input, whether to directly output the input mask without
# using a SAM prompt encoder + mask decoder
self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
self.multimask_output_in_sam = multimask_output_in_sam
self.multimask_min_pt_num = multimask_min_pt_num
self.multimask_max_pt_num = multimask_max_pt_num
self.multimask_output_for_tracking = multimask_output_for_tracking
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
# Part 4: SAM-style prompt encoder (for both mask and point inputs)
# and SAM-style mask decoder for the final mask output
self.image_size = image_size
self.backbone_stride = backbone_stride
self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
self.pred_obj_scores = pred_obj_scores
self.pred_obj_scores_mlp = pred_obj_scores_mlp
self.fixed_no_obj_ptr = fixed_no_obj_ptr
self.soft_no_obj_ptr = soft_no_obj_ptr
if self.fixed_no_obj_ptr:
assert self.pred_obj_scores
assert self.use_obj_ptrs_in_encoder
if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
trunc_normal_(self.no_obj_ptr, std=0.02)
self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
self.no_obj_embed_spatial = None
if no_obj_embed_spatial:
self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
trunc_normal_(self.no_obj_embed_spatial, std=0.02)
self._build_sam_heads()
self.max_cond_frames_in_attn = max_cond_frames_in_attn
# Model compilation
if compile_image_encoder:
# Compile the forward function (not the full module) to allow loading checkpoints.
print(
"Image encoder compilation is enabled. First forward pass will be slow."
)
self.image_encoder.forward = torch.compile(
self.image_encoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)
@property
def device(self):
return next(self.parameters()).device
def forward(self, *args, **kwargs):
raise NotImplementedError(
"Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning"
"See notebooks/video_predictor_example.ipynb for an inference example."
)
def _build_sam_heads(self):
"""Build SAM-style prompt encoder and mask decoder."""
self.sam_prompt_embed_dim = self.hidden_dim
self.sam_image_embedding_size = self.image_size // self.backbone_stride
# build PromptEncoder and MaskDecoder from SAM
# (their hyperparameters like `mask_in_chans=16` are from SAM code)
self.sam_prompt_encoder = PromptEncoder(
embed_dim=self.sam_prompt_embed_dim,
image_embedding_size=(
self.sam_image_embedding_size,
self.sam_image_embedding_size,
),
input_image_size=(self.image_size, self.image_size),
mask_in_chans=16,
)
self.sam_mask_decoder = MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=self.sam_prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=self.sam_prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
use_high_res_features=self.use_high_res_features_in_sam,
iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
pred_obj_scores=self.pred_obj_scores,
pred_obj_scores_mlp=self.pred_obj_scores_mlp,
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
**(self.sam_mask_decoder_extra_args or {}),
)
if self.use_obj_ptrs_in_encoder:
# a linear projection on SAM output tokens to turn them into object pointers
self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
if self.use_mlp_for_obj_ptr_proj:
self.obj_ptr_proj = MLP(
self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
)
else:
self.obj_ptr_proj = torch.nn.Identity()
if self.proj_tpos_enc_in_obj_ptrs:
# a linear projection on temporal positional encoding in object pointers to
# avoid potential interference with spatial positional encoding
self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
else:
self.obj_ptr_tpos_proj = torch.nn.Identity()
def _forward_sam_heads(
self,
backbone_features,
point_inputs=None,
mask_inputs=None,
high_res_features=None,
multimask_output=False,
):
"""
Forward SAM prompt encoders and mask heads.
Inputs:
- backbone_features: image features of [B, C, H, W] shape
- point_inputs: a dictionary with "point_coords" and "point_labels", where
1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
absolute pixel-unit coordinate in (x, y) format of the P input points
2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
positive clicks, 0 means negative clicks, and -1 means padding
- mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
same spatial size as the image.
- high_res_features: either 1) None or 2) or a list of length 2 containing
two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
which will be used as high-resolution feature maps for SAM decoder.
- multimask_output: if it's True, we output 3 candidate masks and their 3
corresponding IoU estimates, and if it's False, we output only 1 mask and
its corresponding IoU estimate.
Outputs:
- low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
`multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
output mask logits (before sigmoid) for the low-resolution masks, with 4x
the resolution (1/4 stride) of the input backbone_features.
- high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
if `multimask_output=True` and M = 1 if `multimask_output=False`),
upsampled from the low-resolution masks, with shape size as the image
(stride is 1 pixel).
- ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
if `multimask_output=False`), the estimated IoU of each output mask.
- low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
If `multimask_output=True`, it's the mask with the highest IoU estimate.
If `multimask_output=False`, it's the same as `low_res_multimasks`.
- high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
If `multimask_output=True`, it's the mask with the highest IoU estimate.
If `multimask_output=False`, it's the same as `high_res_multimasks`.
- obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
based on the output token from the SAM mask decoder.
"""
B = backbone_features.size(0)
device = backbone_features.device
assert backbone_features.size(1) == self.sam_prompt_embed_dim
assert backbone_features.size(2) == self.sam_image_embedding_size
assert backbone_features.size(3) == self.sam_image_embedding_size
# a) Handle point prompts
if point_inputs is not None:
sam_point_coords = point_inputs["point_coords"]
sam_point_labels = point_inputs["point_labels"]
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
else:
# If no points are provide, pad with an empty point (with label -1)
sam_point_coords = torch.zeros(B, 1, 2, device=device)
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
# b) Handle mask prompts
if mask_inputs is not None:
# If mask_inputs is provided, downsize it into low-res mask input if needed
# and feed it as a dense mask prompt into the SAM mask encoder
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
sam_mask_prompt = F.interpolate(
mask_inputs.float(),
size=self.sam_prompt_encoder.mask_input_size,
align_corners=False,
mode="bilinear",
antialias=True, # use antialias for downsampling
)
else:
sam_mask_prompt = mask_inputs
else:
# Otherwise, simply feed None (and SAM's prompt encoder will add
# a learned `no_mask_embed` to indicate no mask input in this case).
sam_mask_prompt = None
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
points=(sam_point_coords, sam_point_labels),
boxes=None,
masks=sam_mask_prompt,
)
(
low_res_multimasks,
ious,
sam_output_tokens,
object_score_logits,
) = self.sam_mask_decoder(
image_embeddings=backbone_features,
image_pe=self.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
repeat_image=False, # the image is already batched
high_res_features=high_res_features,
)
if self.pred_obj_scores:
is_obj_appearing = object_score_logits > 0
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
# consistent with the actual mask prediction
low_res_multimasks = torch.where(
is_obj_appearing[:, None, None],
low_res_multimasks,
NO_OBJ_SCORE,
)
# convert masks from possibly bfloat16 (or float16) to float32
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
low_res_multimasks = low_res_multimasks.float()
high_res_multimasks = F.interpolate(
low_res_multimasks,
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
sam_output_token = sam_output_tokens[:, 0]
if multimask_output:
# take the best mask prediction (with the highest IoU estimation)
best_iou_inds = torch.argmax(ious, dim=-1)
batch_inds = torch.arange(B, device=device)
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
if sam_output_tokens.size(1) > 1:
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
else:
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
# Extract object pointer from the SAM output token (with occlusion handling)
obj_ptr = self.obj_ptr_proj(sam_output_token)
if self.pred_obj_scores:
# Allow *soft* no obj ptr, unlike for masks
if self.soft_no_obj_ptr:
lambda_is_obj_appearing = object_score_logits.sigmoid()
else:
lambda_is_obj_appearing = is_obj_appearing.float()
if self.fixed_no_obj_ptr:
obj_ptr = lambda_is_obj_appearing * obj_ptr
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
return (
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
)
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
"""
Directly turn binary `mask_inputs` into a output mask logits without using SAM.
(same input and output shapes as in _forward_sam_heads above).
"""
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
mask_inputs_float = mask_inputs.float()
high_res_masks = mask_inputs_float * out_scale + out_bias
low_res_masks = F.interpolate(
high_res_masks,
size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
align_corners=False,
mode="bilinear",
antialias=True, # use antialias for downsampling
)
# a dummy IoU prediction of all 1's under mask input
ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
if not self.use_obj_ptrs_in_encoder:
# all zeros as a dummy object pointer (of shape [B, C])
obj_ptr = torch.zeros(
mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
)
else:
# produce an object pointer using the SAM decoder from the mask input
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
backbone_features=backbone_features,
mask_inputs=self.mask_downsample(mask_inputs_float),
high_res_features=high_res_features,
)
# In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
# Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
# on the object_scores from the SAM decoder.
is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
is_obj_appearing = is_obj_appearing[..., None]
lambda_is_obj_appearing = is_obj_appearing.float()
object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
if self.pred_obj_scores:
if self.fixed_no_obj_ptr:
obj_ptr = lambda_is_obj_appearing * obj_ptr
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
return (
low_res_masks,
high_res_masks,
ious,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
)
def forward_image(self, img_batch: torch.Tensor):
"""Get the image feature on the input batch."""
backbone_out = self.image_encoder(img_batch)
if self.use_high_res_features_in_sam:
# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
backbone_out["backbone_fpn"][0]
)
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
backbone_out["backbone_fpn"][1]
)
return backbone_out
def _prepare_backbone_features(self, backbone_out):
"""Prepare and flatten visual features."""
backbone_out = backbone_out.copy()
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
# flatten NxCxHxW to HWxNxC
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
def _prepare_memory_conditioned_features(
self,
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
output_dict,
num_frames,
track_in_reverse=False, # tracking in reverse time order (for demo usage)
):
"""Fuse the current frame's visual feature map with previous memory."""
B = current_vision_feats[-1].size(1) # batch size on this frame
C = self.hidden_dim
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
device = current_vision_feats[-1].device
# The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
# In this case, we skip the fusion with any memory.
if self.num_maskmem == 0: # Disable memory and skip fusion
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
return pix_feat
num_obj_ptr_tokens = 0
tpos_sign_mul = -1 if track_in_reverse else 1
# Step 1: condition the visual features of the current frame on previous memories
if not is_init_cond_frame:
# Retrieve the memories encoded with the maskmem backbone
to_cat_memory, to_cat_memory_pos_embed = [], []
# Add conditioning frames's output first (all cond frames have t_pos=0 for
# when getting temporal positional embedding below)
assert len(output_dict["cond_frame_outputs"]) > 0
# Select a maximum number of temporally closest cond frames for cross attention
cond_outputs = output_dict["cond_frame_outputs"]
selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
frame_idx, cond_outputs, self.max_cond_frames_in_attn
)
t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
# Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
# the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
# We also allow taking the memory frame non-consecutively (with stride>1), in which case
# we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
stride = 1 if self.training else self.memory_temporal_stride_for_eval
for t_pos in range(1, self.num_maskmem):
t_rel = self.num_maskmem - t_pos # how many frames before current frame
if t_rel == 1:
# for t_rel == 1, we take the last frame (regardless of r)
if not track_in_reverse:
# the frame immediately before this frame (i.e. frame_idx - 1)
prev_frame_idx = frame_idx - t_rel
else:
# the frame immediately after this frame (i.e. frame_idx + 1)
prev_frame_idx = frame_idx + t_rel
else:
# for t_rel >= 2, we take the memory frame from every r-th frames
if not track_in_reverse:
# first find the nearest frame among every r-th frames before this frame
# for r=1, this would be (frame_idx - 2)
prev_frame_idx = ((frame_idx - 2) // stride) * stride
# then seek further among every r-th frames
prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride
else:
# first find the nearest frame among every r-th frames after this frame
# for r=1, this would be (frame_idx + 2)
prev_frame_idx = -(-(frame_idx + 2) // stride) * stride
# then seek further among every r-th frames
prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
if out is None:
# If an unselected conditioning frame is among the last (self.num_maskmem - 1)
# frames, we still attend to it as if it's a non-conditioning frame.
out = unselected_cond_outputs.get(prev_frame_idx, None)
t_pos_and_prevs.append((t_pos, out))
for t_pos, prev in t_pos_and_prevs:
if prev is None:
continue # skip padding frames
# "maskmem_features" might have been offloaded to CPU in demo use cases,
# so we load it back to GPU (it's a no-op if it's already on GPU).
feats = prev["maskmem_features"].to(device, non_blocking=True)
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
# Spatial positional encoding (it might have been offloaded to CPU in eval)
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
# Temporal positional encoding
maskmem_enc = (
maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
)
to_cat_memory_pos_embed.append(maskmem_enc)
# Construct the list of past object pointers
if self.use_obj_ptrs_in_encoder:
max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
# First add those object pointers from selected conditioning frames
# (optionally, only include object pointers in the past during evaluation)
if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
ptr_cond_outputs = {
t: out
for t, out in selected_cond_outputs.items()
if (t >= frame_idx if track_in_reverse else t <= frame_idx)
}
else:
ptr_cond_outputs = selected_cond_outputs
pos_and_ptrs = [
# Temporal pos encoding contains how far away each pointer is from current frame
(
(
(frame_idx - t) * tpos_sign_mul
if self.use_signed_tpos_enc_to_obj_ptrs
else abs(frame_idx - t)
),
out["obj_ptr"],
)
for t, out in ptr_cond_outputs.items()
]
# Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
for t_diff in range(1, max_obj_ptrs_in_encoder):
t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
if t < 0 or (num_frames is not None and t >= num_frames):
break
out = output_dict["non_cond_frame_outputs"].get(
t, unselected_cond_outputs.get(t, None)
)
if out is not None:
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
# If we have at least one object pointer, add them to the across attention
if len(pos_and_ptrs) > 0:
pos_list, ptrs_list = zip(*pos_and_ptrs)
# stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
obj_ptrs = torch.stack(ptrs_list, dim=0)
# a temporal positional embedding based on how far each object pointer is from
# the current frame (sine embedding normalized by the max pointer num).
if self.add_tpos_enc_to_obj_ptrs:
t_diff_max = max_obj_ptrs_in_encoder - 1
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
obj_pos = torch.tensor(pos_list).to(
device=device, non_blocking=True
)
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
else:
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
if self.mem_dim < C:
# split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
obj_ptrs = obj_ptrs.reshape(
-1, B, C // self.mem_dim, self.mem_dim
)
obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
to_cat_memory.append(obj_ptrs)
to_cat_memory_pos_embed.append(obj_pos)
num_obj_ptr_tokens = obj_ptrs.shape[0]
else:
num_obj_ptr_tokens = 0
else:
# for initial conditioning frames, encode them without using any previous memory
if self.directly_add_no_mem_embed:
# directly add no-mem embedding (instead of using the transformer encoder)
pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
return pix_feat_with_mem
# Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
# Step 2: Concatenate the memories and forward through the transformer encoder
memory = torch.cat(to_cat_memory, dim=0)
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
pix_feat_with_mem = self.memory_attention(
curr=current_vision_feats,
curr_pos=current_vision_pos_embeds,
memory=memory,
memory_pos=memory_pos_embed,
num_obj_ptr_tokens=num_obj_ptr_tokens,
)
# reshape the output (HW)BC => BCHW
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
return pix_feat_with_mem
def _encode_new_memory(
self,
current_vision_feats,
feat_sizes,
pred_masks_high_res,
object_score_logits,
is_mask_from_pts,
):
"""Encode the current image and its prediction into a memory feature."""
B = current_vision_feats[-1].size(1) # batch size on this frame
C = self.hidden_dim
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
# top-level feature, (HW)BC => BCHW
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
if self.non_overlap_masks_for_mem_enc and not self.training:
# optionally, apply non-overlapping constraints to the masks (it's applied
# in the batch dimension and should only be used during eval, where all
# the objects come from the same video under batch size 1).
pred_masks_high_res = self._apply_non_overlapping_constraints(
pred_masks_high_res
)
# scale the raw mask logits with a temperature before applying sigmoid
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
if binarize and not self.training:
mask_for_mem = (pred_masks_high_res > 0).float()
else:
# apply sigmoid on the raw mask logits to turn them into range (0, 1)
mask_for_mem = torch.sigmoid(pred_masks_high_res)
# apply scale and bias terms to the sigmoid probabilities
if self.sigmoid_scale_for_mem_enc != 1.0:
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
if self.sigmoid_bias_for_mem_enc != 0.0:
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
maskmem_out = self.memory_encoder(
pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
)
maskmem_features = maskmem_out["vision_features"]
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
# add a no-object embedding to the spatial memory to indicate that the frame
# is predicted to be occluded (i.e. no object is appearing in the frame)
if self.no_obj_embed_spatial is not None:
is_obj_appearing = (object_score_logits > 0).float()
maskmem_features += (
1 - is_obj_appearing[..., None, None]
) * self.no_obj_embed_spatial[..., None, None].expand(
*maskmem_features.shape
)
return maskmem_features, maskmem_pos_enc
def _track_step(
self,
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
point_inputs,
mask_inputs,
output_dict,
num_frames,
track_in_reverse,
prev_sam_mask_logits,
):
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
if len(current_vision_feats) > 1:
high_res_features = [
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
]
else:
high_res_features = None
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
sam_outputs = self._use_mask_as_output(
pix_feat, high_res_features, mask_inputs
)
else:
# fused the visual feature with previous memory features in the memory bank
pix_feat = self._prepare_memory_conditioned_features(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats[-1:],
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
feat_sizes=feat_sizes[-1:],
output_dict=output_dict,
num_frames=num_frames,
track_in_reverse=track_in_reverse,
)
# apply SAM-style segmentation head
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
if prev_sam_mask_logits is not None:
assert point_inputs is not None and mask_inputs is None
mask_inputs = prev_sam_mask_logits
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
)
return current_out, sam_outputs, high_res_features, pix_feat
def _encode_memory_in_output(
self,
current_vision_feats,
feat_sizes,
point_inputs,
run_mem_encoder,
high_res_masks,
object_score_logits,
current_out,
):
if run_mem_encoder and self.num_maskmem > 0:
high_res_masks_for_mem_enc = high_res_masks
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
current_vision_feats=current_vision_feats,
feat_sizes=feat_sizes,
pred_masks_high_res=high_res_masks_for_mem_enc,
object_score_logits=object_score_logits,
is_mask_from_pts=(point_inputs is not None),
)
current_out["maskmem_features"] = maskmem_features
current_out["maskmem_pos_enc"] = maskmem_pos_enc
else:
current_out["maskmem_features"] = None
current_out["maskmem_pos_enc"] = None
def track_step(
self,
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
point_inputs,
mask_inputs,
output_dict,
num_frames,
track_in_reverse=False, # tracking in reverse time order (for demo usage)
# Whether to run the memory encoder on the predicted masks. Sometimes we might want
# to skip the memory encoder with `run_mem_encoder=False`. For example,
# in demo we might call `track_step` multiple times for each user click,
# and only encode the memory when the user finalizes their clicks. And in ablation
# settings like SAM training on static images, we don't need the memory encoder.
run_mem_encoder=True,
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
prev_sam_mask_logits=None,
):
current_out, sam_outputs, _, _ = self._track_step(
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
point_inputs,
mask_inputs,
output_dict,
num_frames,
track_in_reverse,
prev_sam_mask_logits,
)
(
_,
_,
_,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
) = sam_outputs
current_out["pred_masks"] = low_res_masks
current_out["pred_masks_high_res"] = high_res_masks
current_out["obj_ptr"] = obj_ptr
if not self.training:
# Only add this in inference (to avoid unused param in activation checkpointing;
# it's mainly used in the demo to encode spatial memories w/ consolidated masks)
current_out["object_score_logits"] = object_score_logits
# Finally run the memory encoder on the predicted mask to encode
# it into a new memory feature (that can be used in future frames)
self._encode_memory_in_output(
current_vision_feats,
feat_sizes,
point_inputs,
run_mem_encoder,
high_res_masks,
object_score_logits,
current_out,
)
return current_out
def _use_multimask(self, is_init_cond_frame, point_inputs):
"""Whether to use multimask output in the SAM head."""
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
multimask_output = (
self.multimask_output_in_sam
and (is_init_cond_frame or self.multimask_output_for_tracking)
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
)
return multimask_output
def _apply_non_overlapping_constraints(self, pred_masks):
"""
Apply non-overlapping constraints to the object scores in pred_masks. Here we
keep only the highest scoring object at each spatial location in pred_masks.
"""
batch_size = pred_masks.size(0)
if batch_size == 1:
return pred_masks
device = pred_masks.device
# "max_obj_inds": object index of the object with the highest score at each location
max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
# "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
keep = max_obj_inds == batch_obj_inds
# suppress overlapping regions' scores below -10.0 so that the foreground regions
# don't overlap (here sigmoid(-10.0)=4.5398e-05)
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
return pred_masks
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/modeling/sam2_utils.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sam2.utils.misc import mask_to_box
def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
"""
Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
that are temporally closest to the current frame at `frame_idx`. Here, we take
- a) the closest conditioning frame before `frame_idx` (if any);
- b) the closest conditioning frame after `frame_idx` (if any);
- c) any other temporally closest conditioning frames until reaching a total
of `max_cond_frame_num` conditioning frames.
Outputs:
- selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
- unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
"""
if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
selected_outputs = cond_frame_outputs
unselected_outputs = {}
else:
assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
selected_outputs = {}
# the closest conditioning frame before `frame_idx` (if any)
idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
if idx_before is not None:
selected_outputs[idx_before] = cond_frame_outputs[idx_before]
# the closest conditioning frame after `frame_idx` (if any)
idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
if idx_after is not None:
selected_outputs[idx_after] = cond_frame_outputs[idx_after]
# add other temporally closest conditioning frames until reaching a total
# of `max_cond_frame_num` conditioning frames.
num_remain = max_cond_frame_num - len(selected_outputs)
inds_remain = sorted(
(t for t in cond_frame_outputs if t not in selected_outputs),
key=lambda x: abs(x - frame_idx),
)[:num_remain]
selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
unselected_outputs = {
t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
}
return selected_outputs, unselected_outputs
def get_1d_sine_pe(pos_inds, dim, temperature=10000):
"""
Get 1D sine positional embedding as in the original Transformer paper.
"""
pe_dim = dim // 2
dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
pos_embed = pos_inds.unsqueeze(-1) / dim_t
pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
return pos_embed
def get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class DropPath(nn.Module):
# adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
def __init__(self, drop_prob=0.0, scale_by_keep=True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and self.scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
activation: nn.Module = nn.ReLU,
sigmoid_output: bool = False,
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
self.sigmoid_output = sigmoid_output
self.act = activation()
def forward(self, x):
for i, layer in enumerate(self.layers):
x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
if self.sigmoid_output:
x = F.sigmoid(x)
return x
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
def sample_box_points(
masks: torch.Tensor,
noise: float = 0.1, # SAM default
noise_bound: int = 20, # SAM default
top_left_label: int = 2,
bottom_right_label: int = 3,
) -> Tuple[np.array, np.array]:
"""
Sample a noised version of the top left and bottom right corners of a given `bbox`
Inputs:
- masks: [B, 1, H,W] boxes, dtype=torch.Tensor
- noise: noise as a fraction of box width and height, dtype=float
- noise_bound: maximum amount of noise (in pure pixesl), dtype=int
Returns:
- box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float
- box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32
"""
device = masks.device
box_coords = mask_to_box(masks)
B, _, H, W = masks.shape
box_labels = torch.tensor(
[top_left_label, bottom_right_label], dtype=torch.int, device=device
).repeat(B)
if noise > 0.0:
if not isinstance(noise_bound, torch.Tensor):
noise_bound = torch.tensor(noise_bound, device=device)
bbox_w = box_coords[..., 2] - box_coords[..., 0]
bbox_h = box_coords[..., 3] - box_coords[..., 1]
max_dx = torch.min(bbox_w * noise, noise_bound)
max_dy = torch.min(bbox_h * noise, noise_bound)
box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1
box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
box_coords = box_coords + box_noise
img_bounds = (
torch.tensor([W, H, W, H], device=device) - 1
) # uncentered pixel coords
box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
box_labels = box_labels.reshape(-1, 2)
return box_coords, box_labels
def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
"""
Sample `num_pt` random points (along with their labels) independently from the error regions.
Inputs:
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
- num_pt: int, number of points to sample independently for each of the B error maps
Outputs:
- points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
- labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means
negative clicks
"""
if pred_masks is None: # if pred_masks is not provided, treat it as empty
pred_masks = torch.zeros_like(gt_masks)
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
assert num_pt >= 0
B, _, H_im, W_im = gt_masks.shape
device = gt_masks.device
# false positive region, a new point sampled in this region should have
# negative label to correct the FP error
fp_masks = ~gt_masks & pred_masks
# false negative region, a new point sampled in this region should have
# positive label to correct the FN error
fn_masks = gt_masks & ~pred_masks
# whether the prediction completely match the ground-truth on each mask
all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2)
all_correct = all_correct[..., None, None]
# channel 0 is FP map, while channel 1 is FN map
pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device)
# sample a negative new click from FP region or a positive new click
# from FN region, depend on where the maximum falls,
# and in case the predictions are all correct (no FP or FN), we just
# sample a negative click from the background region
pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks)
pts_noise[..., 1] *= fn_masks
pts_idx = pts_noise.flatten(2).argmax(dim=2)
labels = (pts_idx % 2).to(torch.int32)
pts_idx = pts_idx // 2
pts_x = pts_idx % W_im
pts_y = pts_idx // W_im
points = torch.stack([pts_x, pts_y], dim=2).to(torch.float)
return points, labels
def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
"""
Sample 1 random point (along with its label) from the center of each error region,
that is, the point with the largest distance to the boundary of each error region.
This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
Inputs:
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
- padding: if True, pad with boundary of 1 px for distance transform
Outputs:
- points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
- labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
"""
import cv2
if pred_masks is None:
pred_masks = torch.zeros_like(gt_masks)
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
B, _, _, W_im = gt_masks.shape
device = gt_masks.device
# false positive region, a new point sampled in this region should have
# negative label to correct the FP error
fp_masks = ~gt_masks & pred_masks
# false negative region, a new point sampled in this region should have
# positive label to correct the FN error
fn_masks = gt_masks & ~pred_masks
fp_masks = fp_masks.cpu().numpy()
fn_masks = fn_masks.cpu().numpy()
points = torch.zeros(B, 1, 2, dtype=torch.float)
labels = torch.ones(B, 1, dtype=torch.int32)
for b in range(B):
fn_mask = fn_masks[b, 0]
fp_mask = fp_masks[b, 0]
if padding:
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
# compute the distance of each point in FN/FP region to its boundary
fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
if padding:
fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
# take the point in FN/FP region with the largest distance to its boundary
fn_mask_dt_flat = fn_mask_dt.reshape(-1)
fp_mask_dt_flat = fp_mask_dt.reshape(-1)
fn_argmax = np.argmax(fn_mask_dt_flat)
fp_argmax = np.argmax(fp_mask_dt_flat)
is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
pt_idx = fn_argmax if is_positive else fp_argmax
points[b, 0, 0] = pt_idx % W_im # x
points[b, 0, 1] = pt_idx // W_im # y
labels[b, 0] = int(is_positive)
points = points.to(device)
labels = labels.to(device)
return points, labels
def get_next_point(gt_masks, pred_masks, method):
if method == "uniform":
return sample_random_points_from_errors(gt_masks, pred_masks)
elif method == "center":
return sample_one_point_from_error_center(gt_masks, pred_masks)
else:
raise ValueError(f"unknown sampling method {method}")
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/sam2_hiera_b+.yaml
================================================
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 112
num_heads: 2
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [896, 448, 224, 112]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: false
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/sam2_hiera_l.yaml
================================================
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 144
num_heads: 2
stages: [2, 6, 36, 4]
global_att_blocks: [23, 33, 43]
window_pos_embed_bkg_spatial_size: [7, 7]
window_spec: [8, 4, 16, 8]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [1152, 576, 288, 144]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: false
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/sam2_hiera_s.yaml
================================================
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 96
num_heads: 1
stages: [1, 2, 11, 2]
global_att_blocks: [7, 10, 13]
window_pos_embed_bkg_spatial_size: [7, 7]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [768, 384, 192, 96]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: false
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/sam2_hiera_t.yaml
================================================
# @package _global_
# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 96
num_heads: 1
stages: [1, 2, 7, 2]
global_att_blocks: [5, 7, 9]
window_pos_embed_bkg_spatial_size: [7, 7]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [768, 384, 192, 96]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
# SAM decoder
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: false
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
# HieraT does not currently support compilation, should always be set to False
compile_image_encoder: False
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/sam2_image_predictor.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from PIL.Image import Image
from sam2.modeling.sam2_base import SAM2Base
from sam2.utils.transforms import SAM2Transforms
class SAM2ImagePredictor:
def __init__(
self,
sam_model: SAM2Base,
mask_threshold=0.0,
max_hole_area=0.0,
max_sprinkle_area=0.0,
**kwargs,
) -> None:
"""
Uses SAM-2 to calculate the image embedding for an image, and then
allow repeated, efficient mask prediction given prompts.
Arguments:
sam_model (Sam-2): The model to use for mask prediction.
mask_threshold (float): The threshold to use when converting mask logits
to binary masks. Masks are thresholded at 0 by default.
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
the maximum area of max_hole_area in low_res_masks.
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
the maximum area of max_sprinkle_area in low_res_masks.
"""
super().__init__()
self.model = sam_model
self._transforms = SAM2Transforms(
resolution=self.model.image_size,
mask_threshold=mask_threshold,
max_hole_area=max_hole_area,
max_sprinkle_area=max_sprinkle_area,
)
# Predictor state
self._is_image_set = False
self._features = None
self._orig_hw = None
# Whether the predictor is set for single image or a batch of images
self._is_batch = False
# Predictor config
self.mask_threshold = mask_threshold
# Spatial dim for backbone feature maps
self._bb_feat_sizes = [
(256, 256),
(128, 128),
(64, 64),
]
@classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
"""
Load a pretrained model from the Hugging Face hub.
Arguments:
model_id (str): The Hugging Face repository ID.
**kwargs: Additional arguments to pass to the model constructor.
Returns:
(SAM2ImagePredictor): The loaded model.
"""
from sam2.build_sam import build_sam2_hf
sam_model = build_sam2_hf(model_id, **kwargs)
return cls(sam_model, **kwargs)
@torch.no_grad()
def set_image(
self,
image: Union[np.ndarray, Image],
) -> None:
"""
Calculates the image embeddings for the provided image, allowing
masks to be predicted with the 'predict' method.
Arguments:
image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
with pixel values in [0, 255].
image_format (str): The color format of the image, in ['RGB', 'BGR'].
"""
self.reset_predictor()
# Transform the image to the form expected by the model
if isinstance(image, np.ndarray):
logging.info("For numpy array image, we assume (HxWxC) format")
self._orig_hw = [image.shape[:2]]
elif isinstance(image, Image):
w, h = image.size
self._orig_hw = [(h, w)]
else:
raise NotImplementedError("Image format not supported")
input_image = self._transforms(image)
input_image = input_image[None, ...].to(self.device)
assert (
len(input_image.shape) == 4 and input_image.shape[1] == 3
), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
logging.info("Computing image embeddings for the provided image...")
backbone_out = self.model.forward_image(input_image)
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
if self.model.directly_add_no_mem_embed:
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
feats = [
feat.permute(1, 2, 0).view(1, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
][::-1]
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
self._is_image_set = True
logging.info("Image embeddings computed.")
@torch.no_grad()
def set_image_batch(
self,
image_list: List[Union[np.ndarray]],
) -> None:
"""
Calculates the image embeddings for the provided image batch, allowing
masks to be predicted with the 'predict_batch' method.
Arguments:
image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
with pixel values in [0, 255].
"""
self.reset_predictor()
assert isinstance(image_list, list)
self._orig_hw = []
for image in image_list:
assert isinstance(
image, np.ndarray
), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
self._orig_hw.append(image.shape[:2])
# Transform the image to the form expected by the model
img_batch = self._transforms.forward_batch(image_list)
img_batch = img_batch.to(self.device)
batch_size = img_batch.shape[0]
assert (
len(img_batch.shape) == 4 and img_batch.shape[1] == 3
), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
logging.info("Computing image embeddings for the provided images...")
backbone_out = self.model.forward_image(img_batch)
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
if self.model.directly_add_no_mem_embed:
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
feats = [
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
][::-1]
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
self._is_image_set = True
self._is_batch = True
logging.info("Image embeddings computed.")
def predict_batch(
self,
point_coords_batch: List[np.ndarray] = None,
point_labels_batch: List[np.ndarray] = None,
box_batch: List[np.ndarray] = None,
mask_input_batch: List[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
normalize_coords=True,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
It returns a tuple of lists of masks, ious, and low_res_masks_logits.
"""
assert self._is_batch, "This function should only be used when in batched mode"
if not self._is_image_set:
raise RuntimeError(
"An image must be set with .set_image_batch(...) before mask prediction."
)
num_images = len(self._features["image_embed"])
all_masks = []
all_ious = []
all_low_res_masks = []
for img_idx in range(num_images):
# Transform input prompts
point_coords = (
point_coords_batch[img_idx] if point_coords_batch is not None else None
)
point_labels = (
point_labels_batch[img_idx] if point_labels_batch is not None else None
)
box = box_batch[img_idx] if box_batch is not None else None
mask_input = (
mask_input_batch[img_idx] if mask_input_batch is not None else None
)
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
point_coords,
point_labels,
box,
mask_input,
normalize_coords,
img_idx=img_idx,
)
masks, iou_predictions, low_res_masks = self._predict(
unnorm_coords,
labels,
unnorm_box,
mask_input,
multimask_output,
return_logits=return_logits,
img_idx=img_idx,
)
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
iou_predictions_np = (
iou_predictions.squeeze(0).float().detach().cpu().numpy()
)
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
all_masks.append(masks_np)
all_ious.append(iou_predictions_np)
all_low_res_masks.append(low_res_masks_np)
return all_masks, all_ious, all_low_res_masks
def predict(
self,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
box: Optional[np.ndarray] = None,
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
normalize_coords=True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Predict masks for the given input prompts, using the currently set image.
Arguments:
point_coords (np.ndarray or None): A Nx2 array of point prompts to the
model. Each point is in (X,Y) in pixels.
point_labels (np.ndarray or None): A length N array of labels for the
point prompts. 1 indicates a foreground point and 0 indicates a
background point.
box (np.ndarray or None): A length 4 array given a box prompt to the
model, in XYXY format.
mask_input (np.ndarray): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form 1xHxW, where
for SAM, H=W=256.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
return_logits (bool): If true, returns un-thresholded masks logits
instead of a binary mask.
normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
Returns:
(np.ndarray): The output masks in CxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(np.ndarray): An array of length C containing the model's
predictions for the quality of each mask.
(np.ndarray): An array of shape CxHxW, where C is the number
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
"""
if not self._is_image_set:
raise RuntimeError(
"An image must be set with .set_image(...) before mask prediction."
)
# Transform input prompts
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
point_coords, point_labels, box, mask_input, normalize_coords
)
masks, iou_predictions, low_res_masks = self._predict(
unnorm_coords,
labels,
unnorm_box,
mask_input,
multimask_output,
return_logits=return_logits,
)
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
return masks_np, iou_predictions_np, low_res_masks_np
def _prep_prompts(
self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
):
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
if point_coords is not None:
assert (
point_labels is not None
), "point_labels must be supplied if point_coords is supplied."
point_coords = torch.as_tensor(
point_coords, dtype=torch.float, device=self.device
)
unnorm_coords = self._transforms.transform_coords(
point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
)
labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
if len(unnorm_coords.shape) == 2:
unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
if box is not None:
box = torch.as_tensor(box, dtype=torch.float, device=self.device)
unnorm_box = self._transforms.transform_boxes(
box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
) # Bx2x2
if mask_logits is not None:
mask_input = torch.as_tensor(
mask_logits, dtype=torch.float, device=self.device
)
if len(mask_input.shape) == 3:
mask_input = mask_input[None, :, :, :]
return mask_input, unnorm_coords, labels, unnorm_box
@torch.no_grad()
def _predict(
self,
point_coords: Optional[torch.Tensor],
point_labels: Optional[torch.Tensor],
boxes: Optional[torch.Tensor] = None,
mask_input: Optional[torch.Tensor] = None,
multimask_output: bool = True,
return_logits: bool = False,
img_idx: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Predict masks for the given input prompts, using the currently set image.
Input prompts are batched torch tensors and are expected to already be
transformed to the input frame using SAM2Transforms.
Arguments:
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
model. Each point is in (X,Y) in pixels.
point_labels (torch.Tensor or None): A BxN array of labels for the
point prompts. 1 indicates a foreground point and 0 indicates a
background point.
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
model, in XYXY format.
mask_input (np.ndarray): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form Bx1xHxW, where
for SAM, H=W=256. Masks returned by a previous iteration of the
predict method do not need further transformation.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
return_logits (bool): If true, returns un-thresholded masks logits
instead of a binary mask.
Returns:
(torch.Tensor): The output masks in BxCxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(torch.Tensor): An array of shape BxC containing the model's
predictions for the quality of each mask.
(torch.Tensor): An array of shape BxCxHxW, where C is the number
of masks and H=W=256. These low res logits can be passed to
a subsequent iteration as mask input.
"""
if not self._is_image_set:
raise RuntimeError(
"An image must be set with .set_image(...) before mask prediction."
)
if point_coords is not None:
concat_points = (point_coords, point_labels)
else:
concat_points = None
# Embed prompts
if boxes is not None:
box_coords = boxes.reshape(-1, 2, 2)
box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
box_labels = box_labels.repeat(boxes.size(0), 1)
# we merge "boxes" and "points" into a single "concat_points" input (where
# boxes are added at the beginning) to sam_prompt_encoder
if concat_points is not None:
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
concat_points = (concat_coords, concat_labels)
else:
concat_points = (box_coords, box_labels)
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
points=concat_points,
boxes=None,
masks=mask_input,
)
# Predict masks
batched_mode = (
concat_points is not None and concat_points[0].shape[0] > 1
) # multi object prediction
high_res_features = [
feat_level[img_idx].unsqueeze(0)
for feat_level in self._features["high_res_feats"]
]
low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
repeat_image=batched_mode,
high_res_features=high_res_features,
)
# Upscale the masks to the original image resolution
masks = self._transforms.postprocess_masks(
low_res_masks, self._orig_hw[img_idx]
)
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
if not return_logits:
masks = masks > self.mask_threshold
return masks, iou_predictions, low_res_masks
def get_image_embedding(self) -> torch.Tensor:
"""
Returns the image embeddings for the currently set image, with
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
the embedding spatial dimension of SAM (typically C=256, H=W=64).
"""
if not self._is_image_set:
raise RuntimeError(
"An image must be set with .set_image(...) to generate an embedding."
)
assert (
self._features is not None
), "Features must exist if an image has been set."
return self._features["image_embed"]
@property
def device(self) -> torch.device:
return self.model.device
def reset_predictor(self) -> None:
"""
Resets the image embeddings and other state variables.
"""
self._is_image_set = False
self._features = None
self._orig_hw = None
self._is_batch = False
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/sam2_video_predictor.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from collections import OrderedDict
import torch
import torch.nn.functional as F
from tqdm import tqdm
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
class SAM2VideoPredictor(SAM2Base):
"""The predictor class to handle user interactions and manage inference states."""
def __init__(
self,
fill_hole_area=0,
# whether to apply non-overlapping constraints on the output object masks
non_overlap_masks=False,
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
clear_non_cond_mem_around_input=False,
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
add_all_frames_to_correct_as_cond=False,
**kwargs,
):
super().__init__(**kwargs)
self.fill_hole_area = fill_hole_area
self.non_overlap_masks = non_overlap_masks
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
@torch.inference_mode()
def init_state(
self,
video_path,
offload_video_to_cpu=False,
offload_state_to_cpu=False,
async_loading_frames=False,
):
"""Initialize an inference state."""
compute_device = self.device # device of the model
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
compute_device=compute_device,
)
inference_state = {}
inference_state["images"] = images
inference_state["num_frames"] = len(images)
# whether to offload the video frames to CPU memory
# turning on this option saves the GPU memory with only a very small overhead
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
# whether to offload the inference state to CPU memory
# turning on this option saves the GPU memory at the cost of a lower tracking fps
# (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
# and from 24 to 21 when tracking two objects)
inference_state["offload_state_to_cpu"] = offload_state_to_cpu
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height
inference_state["video_width"] = video_width
inference_state["device"] = compute_device
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
else:
inference_state["storage_device"] = compute_device
# inputs on each frame
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
# visual features on a small number of recently visited frames for quick interactions
inference_state["cached_features"] = {}
# values that don't change across frames (so we only need to hold one copy of them)
inference_state["constants"] = {}
# mapping between client-side object id and model-side object index
inference_state["obj_id_to_idx"] = OrderedDict()
inference_state["obj_idx_to_id"] = OrderedDict()
inference_state["obj_ids"] = []
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
inference_state["output_dict_per_obj"] = {}
# A temporary storage to hold new outputs when user interact with a frame
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
inference_state["temp_output_dict_per_obj"] = {}
# Frames that already holds consolidated outputs from click or mask inputs
# (we directly use their consolidated outputs during tracking)
# metadata for each tracking frame (e.g. which direction it's tracked)
inference_state["frames_tracked_per_obj"] = {}
# Warm up the visual backbone and cache the image feature on frame 0
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
return inference_state
@classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
"""
Load a pretrained model from the Hugging Face hub.
Arguments:
model_id (str): The Hugging Face repository ID.
**kwargs: Additional arguments to pass to the model constructor.
Returns:
(SAM2VideoPredictor): The loaded model.
"""
from sam2.build_sam import build_sam2_video_predictor_hf
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
return sam_model
def _obj_id_to_idx(self, inference_state, obj_id):
"""Map client-side object id to model-side object index."""
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
if obj_idx is not None:
return obj_idx
# We always allow adding new objects (including after tracking starts).
allow_new_object = True
if allow_new_object:
# get the next object slot
obj_idx = len(inference_state["obj_id_to_idx"])
inference_state["obj_id_to_idx"][obj_id] = obj_idx
inference_state["obj_idx_to_id"][obj_idx] = obj_id
inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
# set up input and output structures for this object
inference_state["point_inputs_per_obj"][obj_idx] = {}
inference_state["mask_inputs_per_obj"][obj_idx] = {}
inference_state["output_dict_per_obj"][obj_idx] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: }
"non_cond_frame_outputs": {}, # dict containing {frame_idx: }
}
inference_state["temp_output_dict_per_obj"][obj_idx] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: }
"non_cond_frame_outputs": {}, # dict containing {frame_idx: }
}
inference_state["frames_tracked_per_obj"][obj_idx] = {}
return obj_idx
else:
raise RuntimeError(
f"Cannot add new object id {obj_id} after tracking starts. "
f"All existing object ids: {inference_state['obj_ids']}. "
f"Please call 'reset_state' to restart from scratch."
)
def _obj_idx_to_id(self, inference_state, obj_idx):
"""Map model-side object index to client-side object id."""
return inference_state["obj_idx_to_id"][obj_idx]
def _get_obj_num(self, inference_state):
"""Get the total number of unique object ids received so far in this session."""
return len(inference_state["obj_idx_to_id"])
@torch.inference_mode()
def add_new_points_or_box(
self,
inference_state,
frame_idx,
obj_id,
points=None,
labels=None,
clear_old_points=True,
normalize_coords=True,
box=None,
):
"""Add new points to a frame."""
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
if (points is not None) != (labels is not None):
raise ValueError("points and labels must be provided together")
if points is None and box is None:
raise ValueError("at least one of points or box must be provided as input")
if points is None:
points = torch.zeros(0, 2, dtype=torch.float32)
elif not isinstance(points, torch.Tensor):
points = torch.tensor(points, dtype=torch.float32)
if labels is None:
labels = torch.zeros(0, dtype=torch.int32)
elif not isinstance(labels, torch.Tensor):
labels = torch.tensor(labels, dtype=torch.int32)
if points.dim() == 2:
points = points.unsqueeze(0) # add batch dimension
if labels.dim() == 1:
labels = labels.unsqueeze(0) # add batch dimension
# If `box` is provided, we add it as the first two points with labels 2 and 3
# along with the user-provided points (consistent with how SAM 2 is trained).
if box is not None:
if not clear_old_points:
raise ValueError(
"cannot add box without clearing old points, since "
"box prompt must be provided before any point prompt "
"(please use clear_old_points=True instead)"
)
if not isinstance(box, torch.Tensor):
box = torch.tensor(box, dtype=torch.float32, device=points.device)
box_coords = box.reshape(1, 2, 2)
box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
box_labels = box_labels.reshape(1, 2)
points = torch.cat([box_coords, points], dim=1)
labels = torch.cat([box_labels, labels], dim=1)
if normalize_coords:
video_H = inference_state["video_height"]
video_W = inference_state["video_width"]
points = points / torch.tensor([video_W, video_H]).to(points.device)
# scale the (normalized) coordinates by the model's internal image size
points = points * self.image_size
points = points.to(inference_state["device"])
labels = labels.to(inference_state["device"])
if not clear_old_points:
point_inputs = point_inputs_per_frame.get(frame_idx, None)
else:
point_inputs = None
point_inputs = concat_points(point_inputs, points, labels)
point_inputs_per_frame[frame_idx] = point_inputs
mask_inputs_per_frame.pop(frame_idx, None)
# If this frame hasn't been tracked before, we treat it as an initial conditioning
# frame, meaning that the inputs points are to generate segments on this frame without
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
# the input points will be used to correct the already tracked masks.
obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
is_init_cond_frame = frame_idx not in obj_frames_tracked
# whether to track in reverse time order
if is_init_cond_frame:
reverse = False
else:
reverse = obj_frames_tracked[frame_idx]["reverse"]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
# Add a frame to conditioning output if it's an initial conditioning frame or
# if the model sees all frames receiving clicks/mask as conditioning frames.
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Get any previously predicted mask logits on this object and feed it along with
# the new clicks into the SAM mask decoder.
prev_sam_mask_logits = None
# lookup temporary output dict first, which contains the most recent output
# (if not found, then lookup conditioning and non-conditioning frame output)
prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
if prev_out is None:
prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
if prev_out is None:
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
if prev_out is not None and prev_out["pred_masks"] is not None:
device = inference_state["device"]
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
current_out, _ = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=obj_output_dict, # run on the slice of a single object
frame_idx=frame_idx,
batch_size=1, # run on the slice of a single object
is_init_cond_frame=is_init_cond_frame,
point_inputs=point_inputs,
mask_inputs=None,
reverse=reverse,
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
# allows us to enforce non-overlapping constraints on all objects before encoding
# them into memory.
run_mem_encoder=False,
prev_sam_mask_logits=prev_sam_mask_logits,
)
# Add the output to the output dict (to be used as future memory)
obj_temp_output_dict[storage_key][frame_idx] = current_out
# Resize the output mask to the original video resolution
obj_ids = inference_state["obj_ids"]
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state,
frame_idx,
is_cond=is_cond,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, consolidated_out["pred_masks_video_res"]
)
return frame_idx, obj_ids, video_res_masks
def add_new_points(self, *args, **kwargs):
"""Deprecated method. Please use `add_new_points_or_box` instead."""
return self.add_new_points_or_box(*args, **kwargs)
@torch.inference_mode()
def add_new_mask(
self,
inference_state,
frame_idx,
obj_id,
mask,
):
"""Add new mask to a frame."""
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
if not isinstance(mask, torch.Tensor):
mask = torch.tensor(mask, dtype=torch.bool)
assert mask.dim() == 2
mask_H, mask_W = mask.shape
mask_inputs_orig = mask[None, None] # add batch and channel dimension
mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
# resize the mask if it doesn't match the model's image size
if mask_H != self.image_size or mask_W != self.image_size:
mask_inputs = torch.nn.functional.interpolate(
mask_inputs_orig,
size=(self.image_size, self.image_size),
align_corners=False,
mode="bilinear",
antialias=True, # use antialias for downsampling
)
mask_inputs = (mask_inputs >= 0.5).float()
else:
mask_inputs = mask_inputs_orig
mask_inputs_per_frame[frame_idx] = mask_inputs
point_inputs_per_frame.pop(frame_idx, None)
# If this frame hasn't been tracked before, we treat it as an initial conditioning
# frame, meaning that the inputs points are to generate segments on this frame without
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
# the input points will be used to correct the already tracked masks.
obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
is_init_cond_frame = frame_idx not in obj_frames_tracked
# whether to track in reverse time order
if is_init_cond_frame:
reverse = False
else:
reverse = obj_frames_tracked[frame_idx]["reverse"]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
# Add a frame to conditioning output if it's an initial conditioning frame or
# if the model sees all frames receiving clicks/mask as conditioning frames.
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
current_out, _ = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=obj_output_dict, # run on the slice of a single object
frame_idx=frame_idx,
batch_size=1, # run on the slice of a single object
is_init_cond_frame=is_init_cond_frame,
point_inputs=None,
mask_inputs=mask_inputs,
reverse=reverse,
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
# allows us to enforce non-overlapping constraints on all objects before encoding
# them into memory.
run_mem_encoder=False,
)
# Add the output to the output dict (to be used as future memory)
obj_temp_output_dict[storage_key][frame_idx] = current_out
# Resize the output mask to the original video resolution
obj_ids = inference_state["obj_ids"]
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state,
frame_idx,
is_cond=is_cond,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, consolidated_out["pred_masks_video_res"]
)
return frame_idx, obj_ids, video_res_masks
def _get_orig_video_res_output(self, inference_state, any_res_masks):
"""
Resize the object scores to the original video resolution (video_res_masks)
and apply non-overlapping constraints for final output.
"""
device = inference_state["device"]
video_H = inference_state["video_height"]
video_W = inference_state["video_width"]
any_res_masks = any_res_masks.to(device, non_blocking=True)
if any_res_masks.shape[-2:] == (video_H, video_W):
video_res_masks = any_res_masks
else:
video_res_masks = torch.nn.functional.interpolate(
any_res_masks,
size=(video_H, video_W),
mode="bilinear",
align_corners=False,
)
if self.non_overlap_masks:
video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
return any_res_masks, video_res_masks
def _consolidate_temp_output_across_obj(
self,
inference_state,
frame_idx,
is_cond,
consolidate_at_video_res=False,
):
"""
Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
a frame into a single output for all objects, including
1) fill any missing objects either from `output_dict_per_obj` (if they exist in
`output_dict_per_obj` for this frame) or leave them as placeholder values
(if they don't exist in `output_dict_per_obj` for this frame);
2) if specified, rerun memory encoder after apply non-overlapping constraints
on the object scores.
"""
batch_size = self._get_obj_num(inference_state)
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Optionally, we allow consolidating the temporary outputs at the original
# video resolution (to provide a better editing experience for mask prompts).
if consolidate_at_video_res:
consolidated_H = inference_state["video_height"]
consolidated_W = inference_state["video_width"]
consolidated_mask_key = "pred_masks_video_res"
else:
consolidated_H = consolidated_W = self.image_size // 4
consolidated_mask_key = "pred_masks"
# Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
# will be added when rerunning the memory encoder after applying non-overlapping
# constraints to object scores. Its "pred_masks" are prefilled with a large
# negative value (NO_OBJ_SCORE) to represent missing objects.
consolidated_out = {
consolidated_mask_key: torch.full(
size=(batch_size, 1, consolidated_H, consolidated_W),
fill_value=NO_OBJ_SCORE,
dtype=torch.float32,
device=inference_state["storage_device"],
),
}
for obj_idx in range(batch_size):
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
out = obj_temp_output_dict[storage_key].get(frame_idx, None)
# If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
# we fall back and look up its previous output in "output_dict_per_obj".
# We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
# "output_dict_per_obj" to find a previous output for this object.
if out is None:
out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
if out is None:
out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
# If the object doesn't appear in "output_dict_per_obj" either, we skip it
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
# placeholder above) and set its object pointer to be a dummy pointer.
if out is None:
continue
# Add the temporary object output mask to consolidated output mask
obj_mask = out["pred_masks"]
consolidated_pred_masks = consolidated_out[consolidated_mask_key]
if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
else:
# Resize first if temporary object mask has a different resolution
resized_obj_mask = torch.nn.functional.interpolate(
obj_mask,
size=consolidated_pred_masks.shape[-2:],
mode="bilinear",
align_corners=False,
)
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
return consolidated_out
@torch.inference_mode()
def propagate_in_video_preflight(self, inference_state):
"""Prepare inference_state and consolidate temporary outputs before tracking."""
# Check and make sure that every object has received input points or masks.
batch_size = self._get_obj_num(inference_state)
if batch_size == 0:
raise RuntimeError(
"No input points or masks are provided for any object; please add inputs first."
)
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
# add them into "output_dict".
for obj_idx in range(batch_size):
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
for is_cond in [False, True]:
# Separately consolidate conditioning and non-conditioning temp outputs
storage_key = (
"cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
)
# Find all the frames that contain temporary outputs for any objects
# (these should be the frames that have just received clicks for mask inputs
# via `add_new_points_or_box` or `add_new_mask`)
for frame_idx, out in obj_temp_output_dict[storage_key].items():
# Run memory encoder on the temporary outputs (if the memory feature is missing)
if out["maskmem_features"] is None:
high_res_masks = torch.nn.functional.interpolate(
out["pred_masks"].to(inference_state["device"]),
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
inference_state=inference_state,
frame_idx=frame_idx,
batch_size=1, # run on the slice of a single object
high_res_masks=high_res_masks,
object_score_logits=out["object_score_logits"],
# these frames are what the user interacted with
is_mask_from_pts=True,
)
out["maskmem_features"] = maskmem_features
out["maskmem_pos_enc"] = maskmem_pos_enc
obj_output_dict[storage_key][frame_idx] = out
if self.clear_non_cond_mem_around_input:
# clear non-conditioning memory of the surrounding frames
self._clear_obj_non_cond_mem_around_input(
inference_state, frame_idx, obj_idx
)
# clear temporary outputs in `temp_output_dict_per_obj`
obj_temp_output_dict[storage_key].clear()
# check and make sure that every object has received input points or masks
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
if len(obj_output_dict["cond_frame_outputs"]) == 0:
obj_id = self._obj_idx_to_id(inference_state, obj_idx)
raise RuntimeError(
f"No input points or masks are provided for object id {obj_id}; please add inputs first."
)
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
# output on the same frame in "non_cond_frame_outputs"
for frame_idx in obj_output_dict["cond_frame_outputs"]:
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
@torch.inference_mode()
def propagate_in_video(
self,
inference_state,
start_frame_idx=None,
max_frame_num_to_track=None,
reverse=False,
):
"""Propagate the input points across frames to track in the entire video."""
self.propagate_in_video_preflight(inference_state)
obj_ids = inference_state["obj_ids"]
num_frames = inference_state["num_frames"]
batch_size = self._get_obj_num(inference_state)
# set start index, end index, and processing order
if start_frame_idx is None:
# default: start from the earliest frame with input points
start_frame_idx = min(
t
for obj_output_dict in inference_state["output_dict_per_obj"].values()
for t in obj_output_dict["cond_frame_outputs"]
)
if max_frame_num_to_track is None:
# default: track all the frames in the video
max_frame_num_to_track = num_frames
if reverse:
end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
if start_frame_idx > 0:
processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
else:
processing_order = [] # skip reverse tracking if starting from frame 0
else:
end_frame_idx = min(
start_frame_idx + max_frame_num_to_track, num_frames - 1
)
processing_order = range(start_frame_idx, end_frame_idx + 1)
for frame_idx in tqdm(processing_order, desc="propagate in video"):
pred_masks_per_obj = [None] * batch_size
for obj_idx in range(batch_size):
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
# We skip those frames already in consolidated outputs (these are frames
# that received input clicks or mask). Note that we cannot directly run
# batched forward on them via `_run_single_frame_inference` because the
# number of clicks on each object might be different.
if frame_idx in obj_output_dict["cond_frame_outputs"]:
storage_key = "cond_frame_outputs"
current_out = obj_output_dict[storage_key][frame_idx]
device = inference_state["device"]
pred_masks = current_out["pred_masks"].to(device, non_blocking=True)
if self.clear_non_cond_mem_around_input:
# clear non-conditioning memory of the surrounding frames
self._clear_obj_non_cond_mem_around_input(
inference_state, frame_idx, obj_idx
)
else:
storage_key = "non_cond_frame_outputs"
current_out, pred_masks = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=obj_output_dict,
frame_idx=frame_idx,
batch_size=1, # run on the slice of a single object
is_init_cond_frame=False,
point_inputs=None,
mask_inputs=None,
reverse=reverse,
run_mem_encoder=True,
)
obj_output_dict[storage_key][frame_idx] = current_out
inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {
"reverse": reverse
}
pred_masks_per_obj[obj_idx] = pred_masks
# Resize the output mask to the original video resolution (we directly use
# the mask scores on GPU for output to avoid any CPU conversion in between)
if len(pred_masks_per_obj) > 1:
all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
else:
all_pred_masks = pred_masks_per_obj[0]
_, video_res_masks = self._get_orig_video_res_output(
inference_state, all_pred_masks
)
yield frame_idx, obj_ids, video_res_masks
@torch.inference_mode()
def clear_all_prompts_in_frame(
self, inference_state, frame_idx, obj_id, need_output=True
):
"""Remove all input points or mask in a specific frame for a given object."""
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
# Clear the conditioning information on the given frame
inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
# Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
if out is not None:
# The frame is not a conditioning frame anymore since it's not receiving inputs,
# so we "downgrade" its output (if exists) to a non-conditioning frame output.
obj_output_dict["non_cond_frame_outputs"][frame_idx] = out
inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None)
if not need_output:
return
# Finally, output updated masks per object (after removing the inputs above)
obj_ids = inference_state["obj_ids"]
is_cond = any(
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
for obj_temp_output_dict in temp_output_dict_per_obj.values()
)
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state,
frame_idx,
is_cond=is_cond,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, consolidated_out["pred_masks_video_res"]
)
return frame_idx, obj_ids, video_res_masks
@torch.inference_mode()
def reset_state(self, inference_state):
"""Remove all input points or mask in all frames throughout the video."""
self._reset_tracking_results(inference_state)
# Remove all object ids
inference_state["obj_id_to_idx"].clear()
inference_state["obj_idx_to_id"].clear()
inference_state["obj_ids"].clear()
inference_state["point_inputs_per_obj"].clear()
inference_state["mask_inputs_per_obj"].clear()
inference_state["output_dict_per_obj"].clear()
inference_state["temp_output_dict_per_obj"].clear()
inference_state["frames_tracked_per_obj"].clear()
def _reset_tracking_results(self, inference_state):
"""Reset all tracking inputs and results across the videos."""
for v in inference_state["point_inputs_per_obj"].values():
v.clear()
for v in inference_state["mask_inputs_per_obj"].values():
v.clear()
for v in inference_state["output_dict_per_obj"].values():
v["cond_frame_outputs"].clear()
v["non_cond_frame_outputs"].clear()
for v in inference_state["temp_output_dict_per_obj"].values():
v["cond_frame_outputs"].clear()
v["non_cond_frame_outputs"].clear()
for v in inference_state["frames_tracked_per_obj"].values():
v.clear()
def _get_image_feature(self, inference_state, frame_idx, batch_size):
"""Compute the image features on a given frame."""
# Look up in the cache first
image, backbone_out = inference_state["cached_features"].get(
frame_idx, (None, None)
)
if backbone_out is None:
# Cache miss -- we will run inference on a single image
device = inference_state["device"]
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).
inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
# expand the features to have the same dimension as the number of objects
expanded_image = image.expand(batch_size, -1, -1, -1)
expanded_backbone_out = {
"backbone_fpn": backbone_out["backbone_fpn"].copy(),
"vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
}
for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
expanded_backbone_out["backbone_fpn"][i] = feat.expand(
batch_size, -1, -1, -1
)
for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
pos = pos.expand(batch_size, -1, -1, -1)
expanded_backbone_out["vision_pos_enc"][i] = pos
features = self._prepare_backbone_features(expanded_backbone_out)
features = (expanded_image,) + features
return features
def _run_single_frame_inference(
self,
inference_state,
output_dict,
frame_idx,
batch_size,
is_init_cond_frame,
point_inputs,
mask_inputs,
reverse,
run_mem_encoder,
prev_sam_mask_logits=None,
):
"""Run tracking on a single frame based on current inputs and previous memory."""
# Retrieve correct image features
(
_,
_,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
) = self._get_image_feature(inference_state, frame_idx, batch_size)
# point and mask should not appear as input simultaneously on the same frame
assert point_inputs is None or mask_inputs is None
current_out = self.track_step(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats,
current_vision_pos_embeds=current_vision_pos_embeds,
feat_sizes=feat_sizes,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
output_dict=output_dict,
num_frames=inference_state["num_frames"],
track_in_reverse=reverse,
run_mem_encoder=run_mem_encoder,
prev_sam_mask_logits=prev_sam_mask_logits,
)
# optionally offload the output to CPU memory to save GPU space
storage_device = inference_state["storage_device"]
maskmem_features = current_out["maskmem_features"]
if maskmem_features is not None:
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
pred_masks_gpu = current_out["pred_masks"]
# potentially fill holes in the predicted masks
if self.fill_hole_area > 0:
pred_masks_gpu = fill_holes_in_mask_scores(
pred_masks_gpu, self.fill_hole_area
)
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
obj_ptr = current_out["obj_ptr"]
object_score_logits = current_out["object_score_logits"]
# make a compact version of this frame's output to reduce the state size
compact_current_out = {
"maskmem_features": maskmem_features,
"maskmem_pos_enc": maskmem_pos_enc,
"pred_masks": pred_masks,
"obj_ptr": obj_ptr,
"object_score_logits": object_score_logits,
}
return compact_current_out, pred_masks_gpu
def _run_memory_encoder(
self,
inference_state,
frame_idx,
batch_size,
high_res_masks,
object_score_logits,
is_mask_from_pts,
):
"""
Run the memory encoder on `high_res_masks`. This is usually after applying
non-overlapping constraints to object scores. Since their scores changed, their
memory also need to be computed again with the memory encoder.
"""
# Retrieve correct image features
_, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
inference_state, frame_idx, batch_size
)
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
current_vision_feats=current_vision_feats,
feat_sizes=feat_sizes,
pred_masks_high_res=high_res_masks,
object_score_logits=object_score_logits,
is_mask_from_pts=is_mask_from_pts,
)
# optionally offload the output to CPU memory to save GPU space
storage_device = inference_state["storage_device"]
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
maskmem_pos_enc = self._get_maskmem_pos_enc(
inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
)
return maskmem_features, maskmem_pos_enc
def _get_maskmem_pos_enc(self, inference_state, current_out):
"""
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
a constant in the inference session to reduce session storage size.
"""
model_constants = inference_state["constants"]
# "out_maskmem_pos_enc" should be either a list of tensors or None
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
if out_maskmem_pos_enc is not None:
if "maskmem_pos_enc" not in model_constants:
assert isinstance(out_maskmem_pos_enc, list)
# only take the slice for one object, since it's same across objects
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
else:
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
# expand the cached maskmem_pos_enc to the actual batch size
batch_size = out_maskmem_pos_enc[0].size(0)
expanded_maskmem_pos_enc = [
x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
]
else:
expanded_maskmem_pos_enc = None
return expanded_maskmem_pos_enc
@torch.inference_mode()
def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
"""
Remove an object id from the tracking state. If strict is True, we check whether
the object id actually exists and raise an error if it doesn't exist.
"""
old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
updated_frames = []
# Check whether this object_id to remove actually exists and possibly raise an error.
if old_obj_idx_to_rm is None:
if not strict:
return inference_state["obj_ids"], updated_frames
raise RuntimeError(
f"Cannot remove object id {obj_id} as it doesn't exist. "
f"All existing object ids: {inference_state['obj_ids']}."
)
# If this is the only remaining object id, we simply reset the state.
if len(inference_state["obj_id_to_idx"]) == 1:
self.reset_state(inference_state)
return inference_state["obj_ids"], updated_frames
# There are still remaining objects after removing this object id. In this case,
# we need to delete the object storage from inference state tensors.
# Step 0: clear the input on those frames where this object id has point or mask input
# (note that this step is required as it might downgrade conditioning frames to
# non-conditioning ones)
obj_input_frames_inds = set()
obj_input_frames_inds.update(
inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
)
obj_input_frames_inds.update(
inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
)
for frame_idx in obj_input_frames_inds:
self.clear_all_prompts_in_frame(
inference_state, frame_idx, obj_id, need_output=False
)
# Step 1: Update the object id mapping (note that it must be done after Step 0,
# since Step 0 still requires the old object id mappings in inference_state)
old_obj_ids = inference_state["obj_ids"]
old_obj_inds = list(range(len(old_obj_ids)))
remain_old_obj_inds = old_obj_inds.copy()
remain_old_obj_inds.remove(old_obj_idx_to_rm)
new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
new_obj_inds = list(range(len(new_obj_ids)))
# build new mappings
old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
inference_state["obj_ids"] = new_obj_ids
# Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
def _map_keys(container):
new_kvs = []
for k in old_obj_inds:
v = container.pop(k)
if k in old_idx_to_new_idx:
new_kvs.append((old_idx_to_new_idx[k], v))
container.update(new_kvs)
_map_keys(inference_state["point_inputs_per_obj"])
_map_keys(inference_state["mask_inputs_per_obj"])
_map_keys(inference_state["output_dict_per_obj"])
_map_keys(inference_state["temp_output_dict_per_obj"])
_map_keys(inference_state["frames_tracked_per_obj"])
# Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which
# could show an updated mask for objects previously occluded by the object being removed
if need_output:
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
for frame_idx in obj_input_frames_inds:
is_cond = any(
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
for obj_temp_output_dict in temp_output_dict_per_obj.values()
)
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state,
frame_idx,
is_cond=is_cond,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, consolidated_out["pred_masks_video_res"]
)
updated_frames.append((frame_idx, video_res_masks))
return inference_state["obj_ids"], updated_frames
def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
"""
Remove the non-conditioning memory around the input frame. When users provide
correction clicks, the surrounding frames' non-conditioning memories can still
contain outdated object appearance information and could confuse the model.
This method clears those non-conditioning memories surrounding the interacted
frame to avoid giving the model both old and new information about the object.
"""
r = self.memory_temporal_stride_for_eval
frame_idx_begin = frame_idx - r * self.num_maskmem
frame_idx_end = frame_idx + r * self.num_maskmem
batch_size = self._get_obj_num(inference_state)
for obj_idx in range(batch_size):
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"]
for t in range(frame_idx_begin, frame_idx_end + 1):
non_cond_frame_outputs.pop(t, None)
class SAM2VideoPredictorVOS(SAM2VideoPredictor):
"""Optimized for the VOS setting"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._compile_all_components()
def _compile_all_components(self):
print("Compiling all components for VOS setting. First time may be very slow.")
self.memory_encoder.forward = torch.compile(
self.memory_encoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)
self.memory_attention.forward = torch.compile(
self.memory_attention.forward,
mode="max-autotune",
fullgraph=True,
dynamic=True, # Num. of memories varies
)
self.sam_prompt_encoder.forward = torch.compile(
self.sam_prompt_encoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False, # Accuracy regression on True
)
self.sam_mask_decoder.forward = torch.compile(
self.sam_mask_decoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False, # Accuracy regression on True
)
def forward_image(self, img_batch: torch.Tensor):
"""
Identical to the corresponding method in the parent (SAM2VideoPredictor), but
cloning the backbone features and pos encoding to enable compilation.
"""
backbone_out = self.image_encoder(img_batch)
if self.use_high_res_features_in_sam:
# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
backbone_out["backbone_fpn"][0]
)
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
backbone_out["backbone_fpn"][1]
)
# Clone to help torch.compile
for i in range(len(backbone_out["backbone_fpn"])):
backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone()
backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][
i
].clone()
return backbone_out
def _forward_sam_heads(
self,
backbone_features,
point_inputs=None,
mask_inputs=None,
high_res_features=None,
multimask_output=False,
):
"""
Identical to the corresponding method in the parent (SAM2VideoPredictor), but
cloning the outputs of prompt_encoder and mask_decoder to enable compilation.
"""
B = backbone_features.size(0)
device = backbone_features.device
assert backbone_features.size(1) == self.sam_prompt_embed_dim
assert backbone_features.size(2) == self.sam_image_embedding_size
assert backbone_features.size(3) == self.sam_image_embedding_size
# a) Handle point prompts
if point_inputs is not None:
sam_point_coords = point_inputs["point_coords"]
sam_point_labels = point_inputs["point_labels"]
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
else:
# If no points are provide, pad with an empty point (with label -1)
sam_point_coords = torch.zeros(B, 1, 2, device=device)
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
# b) Handle mask prompts
if mask_inputs is not None:
# If mask_inputs is provided, downsize it into low-res mask input if needed
# and feed it as a dense mask prompt into the SAM mask encoder
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
sam_mask_prompt = F.interpolate(
mask_inputs.float(),
size=self.sam_prompt_encoder.mask_input_size,
align_corners=False,
mode="bilinear",
antialias=True, # use antialias for downsampling
)
else:
sam_mask_prompt = mask_inputs
else:
# Otherwise, simply feed None (and SAM's prompt encoder will add
# a learned `no_mask_embed` to indicate no mask input in this case).
sam_mask_prompt = None
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
points=(sam_point_coords, sam_point_labels),
boxes=None,
masks=sam_mask_prompt,
)
# Clone image_pe and the outputs of sam_prompt_encoder
# to enable compilation
sparse_embeddings = sparse_embeddings.clone()
dense_embeddings = dense_embeddings.clone()
image_pe = self.sam_prompt_encoder.get_dense_pe().clone()
(
low_res_multimasks,
ious,
sam_output_tokens,
object_score_logits,
) = self.sam_mask_decoder(
image_embeddings=backbone_features,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
repeat_image=False, # the image is already batched
high_res_features=high_res_features,
)
# Clone the output of sam_mask_decoder
# to enable compilation
low_res_multimasks = low_res_multimasks.clone()
ious = ious.clone()
sam_output_tokens = sam_output_tokens.clone()
object_score_logits = object_score_logits.clone()
if self.pred_obj_scores:
is_obj_appearing = object_score_logits > 0
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
# consistent with the actual mask prediction
low_res_multimasks = torch.where(
is_obj_appearing[:, None, None],
low_res_multimasks,
NO_OBJ_SCORE,
)
# convert masks from possibly bfloat16 (or float16) to float32
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
low_res_multimasks = low_res_multimasks.float()
high_res_multimasks = F.interpolate(
low_res_multimasks,
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
sam_output_token = sam_output_tokens[:, 0]
if multimask_output:
# take the best mask prediction (with the highest IoU estimation)
best_iou_inds = torch.argmax(ious, dim=-1)
batch_inds = torch.arange(B, device=device)
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
if sam_output_tokens.size(1) > 1:
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
else:
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
# Extract object pointer from the SAM output token (with occlusion handling)
obj_ptr = self.obj_ptr_proj(sam_output_token)
if self.pred_obj_scores:
# Allow *soft* no obj ptr, unlike for masks
if self.soft_no_obj_ptr:
lambda_is_obj_appearing = object_score_logits.sigmoid()
else:
lambda_is_obj_appearing = is_obj_appearing.float()
if self.fixed_no_obj_ptr:
obj_ptr = lambda_is_obj_appearing * obj_ptr
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
return (
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
)
def _encode_new_memory(
self,
current_vision_feats,
feat_sizes,
pred_masks_high_res,
object_score_logits,
is_mask_from_pts,
):
"""
Identical to the corresponding method in the parent (SAM2VideoPredictor), but
cloning the memories and their pos enc to enable compilation.
"""
B = current_vision_feats[-1].size(1) # batch size on this frame
C = self.hidden_dim
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
# top-level feature, (HW)BC => BCHW
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
if self.non_overlap_masks_for_mem_enc and not self.training:
# optionally, apply non-overlapping constraints to the masks (it's applied
# in the batch dimension and should only be used during eval, where all
# the objects come from the same video under batch size 1).
pred_masks_high_res = self._apply_non_overlapping_constraints(
pred_masks_high_res
)
# scale the raw mask logits with a temperature before applying sigmoid
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
if binarize and not self.training:
mask_for_mem = (pred_masks_high_res > 0).float()
else:
# apply sigmoid on the raw mask logits to turn them into range (0, 1)
mask_for_mem = torch.sigmoid(pred_masks_high_res)
# apply scale and bias terms to the sigmoid probabilities
if self.sigmoid_scale_for_mem_enc != 1.0:
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
if self.sigmoid_bias_for_mem_enc != 0.0:
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
maskmem_out = self.memory_encoder(
pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
)
# Clone the feats and pos_enc to enable compilation
maskmem_features = maskmem_out["vision_features"].clone()
maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]]
# add a no-object embedding to the spatial memory to indicate that the frame
# is predicted to be occluded (i.e. no object is appearing in the frame)
if self.no_obj_embed_spatial is not None:
is_obj_appearing = (object_score_logits > 0).float()
maskmem_features += (
1 - is_obj_appearing[..., None, None]
) * self.no_obj_embed_spatial[..., None, None].expand(
*maskmem_features.shape
)
return maskmem_features, maskmem_pos_enc
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/sam2_video_predictor_legacy.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from collections import OrderedDict
import torch
from tqdm import tqdm
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
class SAM2VideoPredictor(SAM2Base):
"""The predictor class to handle user interactions and manage inference states."""
def __init__(
self,
fill_hole_area=0,
# whether to apply non-overlapping constraints on the output object masks
non_overlap_masks=False,
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
clear_non_cond_mem_around_input=False,
# whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
clear_non_cond_mem_for_multi_obj=False,
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
add_all_frames_to_correct_as_cond=False,
**kwargs,
):
super().__init__(**kwargs)
self.fill_hole_area = fill_hole_area
self.non_overlap_masks = non_overlap_masks
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
@torch.inference_mode()
def init_state(
self,
video_path,
offload_video_to_cpu=False,
offload_state_to_cpu=False,
async_loading_frames=False,
):
"""Initialize an inference state."""
compute_device = self.device # device of the model
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
compute_device=compute_device,
)
inference_state = {}
inference_state["images"] = images
inference_state["num_frames"] = len(images)
# whether to offload the video frames to CPU memory
# turning on this option saves the GPU memory with only a very small overhead
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
# whether to offload the inference state to CPU memory
# turning on this option saves the GPU memory at the cost of a lower tracking fps
# (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
# and from 24 to 21 when tracking two objects)
inference_state["offload_state_to_cpu"] = offload_state_to_cpu
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height
inference_state["video_width"] = video_width
inference_state["device"] = compute_device
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
else:
inference_state["storage_device"] = compute_device
# inputs on each frame
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
# visual features on a small number of recently visited frames for quick interactions
inference_state["cached_features"] = {}
# values that don't change across frames (so we only need to hold one copy of them)
inference_state["constants"] = {}
# mapping between client-side object id and model-side object index
inference_state["obj_id_to_idx"] = OrderedDict()
inference_state["obj_idx_to_id"] = OrderedDict()
inference_state["obj_ids"] = []
# A storage to hold the model's tracking results and states on each frame
inference_state["output_dict"] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: }
"non_cond_frame_outputs": {}, # dict containing {frame_idx: }
}
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
inference_state["output_dict_per_obj"] = {}
# A temporary storage to hold new outputs when user interact with a frame
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
inference_state["temp_output_dict_per_obj"] = {}
# Frames that already holds consolidated outputs from click or mask inputs
# (we directly use their consolidated outputs during tracking)
inference_state["consolidated_frame_inds"] = {
"cond_frame_outputs": set(), # set containing frame indices
"non_cond_frame_outputs": set(), # set containing frame indices
}
# metadata for each tracking frame (e.g. which direction it's tracked)
inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"] = {}
# Warm up the visual backbone and cache the image feature on frame 0
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
return inference_state
@classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
"""
Load a pretrained model from the Hugging Face hub.
Arguments:
model_id (str): The Hugging Face repository ID.
**kwargs: Additional arguments to pass to the model constructor.
Returns:
(SAM2VideoPredictor): The loaded model.
"""
from sam2.build_sam import build_sam2_video_predictor_hf
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
return sam_model
def _obj_id_to_idx(self, inference_state, obj_id):
"""Map client-side object id to model-side object index."""
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
if obj_idx is not None:
return obj_idx
# This is a new object id not sent to the server before. We only allow adding
# new objects *before* the tracking starts.
allow_new_object = not inference_state["tracking_has_started"]
if allow_new_object:
# get the next object slot
obj_idx = len(inference_state["obj_id_to_idx"])
inference_state["obj_id_to_idx"][obj_id] = obj_idx
inference_state["obj_idx_to_id"][obj_idx] = obj_id
inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
# set up input and output structures for this object
inference_state["point_inputs_per_obj"][obj_idx] = {}
inference_state["mask_inputs_per_obj"][obj_idx] = {}
inference_state["output_dict_per_obj"][obj_idx] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: }
"non_cond_frame_outputs": {}, # dict containing {frame_idx: }
}
inference_state["temp_output_dict_per_obj"][obj_idx] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: }
"non_cond_frame_outputs": {}, # dict containing {frame_idx: }
}
return obj_idx
else:
raise RuntimeError(
f"Cannot add new object id {obj_id} after tracking starts. "
f"All existing object ids: {inference_state['obj_ids']}. "
f"Please call 'reset_state' to restart from scratch."
)
def _obj_idx_to_id(self, inference_state, obj_idx):
"""Map model-side object index to client-side object id."""
return inference_state["obj_idx_to_id"][obj_idx]
def _get_obj_num(self, inference_state):
"""Get the total number of unique object ids received so far in this session."""
return len(inference_state["obj_idx_to_id"])
@torch.inference_mode()
def add_new_points_or_box(
self,
inference_state,
frame_idx,
obj_id,
points=None,
labels=None,
clear_old_points=True,
normalize_coords=True,
box=None,
):
"""Add new points to a frame."""
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
if (points is not None) != (labels is not None):
raise ValueError("points and labels must be provided together")
if points is None and box is None:
raise ValueError("at least one of points or box must be provided as input")
if points is None:
points = torch.zeros(0, 2, dtype=torch.float32)
elif not isinstance(points, torch.Tensor):
points = torch.tensor(points, dtype=torch.float32)
if labels is None:
labels = torch.zeros(0, dtype=torch.int32)
elif not isinstance(labels, torch.Tensor):
labels = torch.tensor(labels, dtype=torch.int32)
if points.dim() == 2:
points = points.unsqueeze(0) # add batch dimension
if labels.dim() == 1:
labels = labels.unsqueeze(0) # add batch dimension
# If `box` is provided, we add it as the first two points with labels 2 and 3
# along with the user-provided points (consistent with how SAM 2 is trained).
if box is not None:
if not clear_old_points:
raise ValueError(
"cannot add box without clearing old points, since "
"box prompt must be provided before any point prompt "
"(please use clear_old_points=True instead)"
)
if inference_state["tracking_has_started"]:
warnings.warn(
"You are adding a box after tracking starts. SAM 2 may not always be "
"able to incorporate a box prompt for *refinement*. If you intend to "
"use box prompt as an *initial* input before tracking, please call "
"'reset_state' on the inference state to restart from scratch.",
category=UserWarning,
stacklevel=2,
)
if not isinstance(box, torch.Tensor):
box = torch.tensor(box, dtype=torch.float32, device=points.device)
box_coords = box.reshape(1, 2, 2)
box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
box_labels = box_labels.reshape(1, 2)
points = torch.cat([box_coords, points], dim=1)
labels = torch.cat([box_labels, labels], dim=1)
if normalize_coords:
video_H = inference_state["video_height"]
video_W = inference_state["video_width"]
points = points / torch.tensor([video_W, video_H]).to(points.device)
# scale the (normalized) coordinates by the model's internal image size
points = points * self.image_size
points = points.to(inference_state["device"])
labels = labels.to(inference_state["device"])
if not clear_old_points:
point_inputs = point_inputs_per_frame.get(frame_idx, None)
else:
point_inputs = None
point_inputs = concat_points(point_inputs, points, labels)
point_inputs_per_frame[frame_idx] = point_inputs
mask_inputs_per_frame.pop(frame_idx, None)
# If this frame hasn't been tracked before, we treat it as an initial conditioning
# frame, meaning that the inputs points are to generate segments on this frame without
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
# the input points will be used to correct the already tracked masks.
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
# whether to track in reverse time order
if is_init_cond_frame:
reverse = False
else:
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
# Add a frame to conditioning output if it's an initial conditioning frame or
# if the model sees all frames receiving clicks/mask as conditioning frames.
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Get any previously predicted mask logits on this object and feed it along with
# the new clicks into the SAM mask decoder.
prev_sam_mask_logits = None
# lookup temporary output dict first, which contains the most recent output
# (if not found, then lookup conditioning and non-conditioning frame output)
prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
if prev_out is None:
prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
if prev_out is None:
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
if prev_out is not None and prev_out["pred_masks"] is not None:
device = inference_state["device"]
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
current_out, _ = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=obj_output_dict, # run on the slice of a single object
frame_idx=frame_idx,
batch_size=1, # run on the slice of a single object
is_init_cond_frame=is_init_cond_frame,
point_inputs=point_inputs,
mask_inputs=None,
reverse=reverse,
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
# allows us to enforce non-overlapping constraints on all objects before encoding
# them into memory.
run_mem_encoder=False,
prev_sam_mask_logits=prev_sam_mask_logits,
)
# Add the output to the output dict (to be used as future memory)
obj_temp_output_dict[storage_key][frame_idx] = current_out
# Resize the output mask to the original video resolution
obj_ids = inference_state["obj_ids"]
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state,
frame_idx,
is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, consolidated_out["pred_masks_video_res"]
)
return frame_idx, obj_ids, video_res_masks
def add_new_points(self, *args, **kwargs):
"""Deprecated method. Please use `add_new_points_or_box` instead."""
return self.add_new_points_or_box(*args, **kwargs)
@torch.inference_mode()
def add_new_mask(
self,
inference_state,
frame_idx,
obj_id,
mask,
):
"""Add new mask to a frame."""
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
if not isinstance(mask, torch.Tensor):
mask = torch.tensor(mask, dtype=torch.bool)
assert mask.dim() == 2
mask_H, mask_W = mask.shape
mask_inputs_orig = mask[None, None] # add batch and channel dimension
mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
# resize the mask if it doesn't match the model's image size
if mask_H != self.image_size or mask_W != self.image_size:
mask_inputs = torch.nn.functional.interpolate(
mask_inputs_orig,
size=(self.image_size, self.image_size),
align_corners=False,
mode="bilinear",
antialias=True, # use antialias for downsampling
)
mask_inputs = (mask_inputs >= 0.5).float()
else:
mask_inputs = mask_inputs_orig
mask_inputs_per_frame[frame_idx] = mask_inputs
point_inputs_per_frame.pop(frame_idx, None)
# If this frame hasn't been tracked before, we treat it as an initial conditioning
# frame, meaning that the inputs points are to generate segments on this frame without
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
# the input points will be used to correct the already tracked masks.
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
# whether to track in reverse time order
if is_init_cond_frame:
reverse = False
else:
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
# Add a frame to conditioning output if it's an initial conditioning frame or
# if the model sees all frames receiving clicks/mask as conditioning frames.
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
current_out, _ = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=obj_output_dict, # run on the slice of a single object
frame_idx=frame_idx,
batch_size=1, # run on the slice of a single object
is_init_cond_frame=is_init_cond_frame,
point_inputs=None,
mask_inputs=mask_inputs,
reverse=reverse,
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
# allows us to enforce non-overlapping constraints on all objects before encoding
# them into memory.
run_mem_encoder=False,
)
# Add the output to the output dict (to be used as future memory)
obj_temp_output_dict[storage_key][frame_idx] = current_out
# Resize the output mask to the original video resolution
obj_ids = inference_state["obj_ids"]
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state,
frame_idx,
is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, consolidated_out["pred_masks_video_res"]
)
return frame_idx, obj_ids, video_res_masks
def _get_orig_video_res_output(self, inference_state, any_res_masks):
"""
Resize the object scores to the original video resolution (video_res_masks)
and apply non-overlapping constraints for final output.
"""
device = inference_state["device"]
video_H = inference_state["video_height"]
video_W = inference_state["video_width"]
any_res_masks = any_res_masks.to(device, non_blocking=True)
if any_res_masks.shape[-2:] == (video_H, video_W):
video_res_masks = any_res_masks
else:
video_res_masks = torch.nn.functional.interpolate(
any_res_masks,
size=(video_H, video_W),
mode="bilinear",
align_corners=False,
)
if self.non_overlap_masks:
video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
return any_res_masks, video_res_masks
def _consolidate_temp_output_across_obj(
self,
inference_state,
frame_idx,
is_cond,
run_mem_encoder,
consolidate_at_video_res=False,
):
"""
Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
a frame into a single output for all objects, including
1) fill any missing objects either from `output_dict_per_obj` (if they exist in
`output_dict_per_obj` for this frame) or leave them as placeholder values
(if they don't exist in `output_dict_per_obj` for this frame);
2) if specified, rerun memory encoder after apply non-overlapping constraints
on the object scores.
"""
batch_size = self._get_obj_num(inference_state)
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Optionally, we allow consolidating the temporary outputs at the original
# video resolution (to provide a better editing experience for mask prompts).
if consolidate_at_video_res:
assert not run_mem_encoder, "memory encoder cannot run at video resolution"
consolidated_H = inference_state["video_height"]
consolidated_W = inference_state["video_width"]
consolidated_mask_key = "pred_masks_video_res"
else:
consolidated_H = consolidated_W = self.image_size // 4
consolidated_mask_key = "pred_masks"
# Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
# will be added when rerunning the memory encoder after applying non-overlapping
# constraints to object scores. Its "pred_masks" are prefilled with a large
# negative value (NO_OBJ_SCORE) to represent missing objects.
consolidated_out = {
"maskmem_features": None,
"maskmem_pos_enc": None,
consolidated_mask_key: torch.full(
size=(batch_size, 1, consolidated_H, consolidated_W),
fill_value=NO_OBJ_SCORE,
dtype=torch.float32,
device=inference_state["storage_device"],
),
"obj_ptr": torch.full(
size=(batch_size, self.hidden_dim),
fill_value=NO_OBJ_SCORE,
dtype=torch.float32,
device=inference_state["device"],
),
"object_score_logits": torch.full(
size=(batch_size, 1),
# default to 10.0 for object_score_logits, i.e. assuming the object is
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
fill_value=10.0,
dtype=torch.float32,
device=inference_state["device"],
),
}
empty_mask_ptr = None
for obj_idx in range(batch_size):
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
out = obj_temp_output_dict[storage_key].get(frame_idx, None)
# If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
# we fall back and look up its previous output in "output_dict_per_obj".
# We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
# "output_dict_per_obj" to find a previous output for this object.
if out is None:
out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
if out is None:
out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
# If the object doesn't appear in "output_dict_per_obj" either, we skip it
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
# placeholder above) and set its object pointer to be a dummy pointer.
if out is None:
# Fill in dummy object pointers for those objects without any inputs or
# tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
# i.e. when we need to build the memory for tracking).
if run_mem_encoder:
if empty_mask_ptr is None:
empty_mask_ptr = self._get_empty_mask_ptr(
inference_state, frame_idx
)
# fill object pointer with a dummy pointer (based on an empty mask)
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
continue
# Add the temporary object output mask to consolidated output mask
obj_mask = out["pred_masks"]
consolidated_pred_masks = consolidated_out[consolidated_mask_key]
if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
else:
# Resize first if temporary object mask has a different resolution
resized_obj_mask = torch.nn.functional.interpolate(
obj_mask,
size=consolidated_pred_masks.shape[-2:],
mode="bilinear",
align_corners=False,
)
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
"object_score_logits"
]
# Optionally, apply non-overlapping constraints on the consolidated scores
# and rerun the memory encoder
if run_mem_encoder:
device = inference_state["device"]
high_res_masks = torch.nn.functional.interpolate(
consolidated_out["pred_masks"].to(device, non_blocking=True),
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
if self.non_overlap_masks_for_mem_enc:
high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
inference_state=inference_state,
frame_idx=frame_idx,
batch_size=batch_size,
high_res_masks=high_res_masks,
object_score_logits=consolidated_out["object_score_logits"],
is_mask_from_pts=True, # these frames are what the user interacted with
)
consolidated_out["maskmem_features"] = maskmem_features
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
return consolidated_out
def _get_empty_mask_ptr(self, inference_state, frame_idx):
"""Get a dummy object pointer based on an empty mask on the current frame."""
# A dummy (empty) mask with a single object
batch_size = 1
mask_inputs = torch.zeros(
(batch_size, 1, self.image_size, self.image_size),
dtype=torch.float32,
device=inference_state["device"],
)
# Retrieve correct image features
(
_,
_,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
) = self._get_image_feature(inference_state, frame_idx, batch_size)
# Feed the empty mask and image feature above to get a dummy object pointer
current_out = self.track_step(
frame_idx=frame_idx,
is_init_cond_frame=True,
current_vision_feats=current_vision_feats,
current_vision_pos_embeds=current_vision_pos_embeds,
feat_sizes=feat_sizes,
point_inputs=None,
mask_inputs=mask_inputs,
output_dict={},
num_frames=inference_state["num_frames"],
track_in_reverse=False,
run_mem_encoder=False,
prev_sam_mask_logits=None,
)
return current_out["obj_ptr"]
@torch.inference_mode()
def propagate_in_video_preflight(self, inference_state):
"""Prepare inference_state and consolidate temporary outputs before tracking."""
# Tracking has started and we don't allow adding new objects until session is reset.
inference_state["tracking_has_started"] = True
batch_size = self._get_obj_num(inference_state)
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
# add them into "output_dict".
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
output_dict = inference_state["output_dict"]
# "consolidated_frame_inds" contains indices of those frames where consolidated
# temporary outputs have been added (either in this call or any previous calls
# to `propagate_in_video_preflight`).
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
for is_cond in [False, True]:
# Separately consolidate conditioning and non-conditioning temp outputs
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Find all the frames that contain temporary outputs for any objects
# (these should be the frames that have just received clicks for mask inputs
# via `add_new_points_or_box` or `add_new_mask`)
temp_frame_inds = set()
for obj_temp_output_dict in temp_output_dict_per_obj.values():
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
consolidated_frame_inds[storage_key].update(temp_frame_inds)
# consolidate the temporary output across all objects on this frame
for frame_idx in temp_frame_inds:
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
)
# merge them into "output_dict" and also create per-object slices
output_dict[storage_key][frame_idx] = consolidated_out
self._add_output_per_object(
inference_state, frame_idx, consolidated_out, storage_key
)
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
)
if clear_non_cond_mem:
# clear non-conditioning memory of the surrounding frames
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
# clear temporary outputs in `temp_output_dict_per_obj`
for obj_temp_output_dict in temp_output_dict_per_obj.values():
obj_temp_output_dict[storage_key].clear()
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
# output on the same frame in "non_cond_frame_outputs"
for frame_idx in output_dict["cond_frame_outputs"]:
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
for obj_output_dict in inference_state["output_dict_per_obj"].values():
for frame_idx in obj_output_dict["cond_frame_outputs"]:
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
assert frame_idx in output_dict["cond_frame_outputs"]
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
# with either points or mask inputs (which should be true under a correct workflow).
all_consolidated_frame_inds = (
consolidated_frame_inds["cond_frame_outputs"]
| consolidated_frame_inds["non_cond_frame_outputs"]
)
input_frames_inds = set()
for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
input_frames_inds.update(point_inputs_per_frame.keys())
for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
input_frames_inds.update(mask_inputs_per_frame.keys())
assert all_consolidated_frame_inds == input_frames_inds
@torch.inference_mode()
def propagate_in_video(
self,
inference_state,
start_frame_idx=None,
max_frame_num_to_track=None,
reverse=False,
):
"""Propagate the input points across frames to track in the entire video."""
self.propagate_in_video_preflight(inference_state)
output_dict = inference_state["output_dict"]
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
obj_ids = inference_state["obj_ids"]
num_frames = inference_state["num_frames"]
batch_size = self._get_obj_num(inference_state)
if len(output_dict["cond_frame_outputs"]) == 0:
raise RuntimeError("No points are provided; please add points first")
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
)
# set start index, end index, and processing order
if start_frame_idx is None:
# default: start from the earliest frame with input points
start_frame_idx = min(output_dict["cond_frame_outputs"])
if max_frame_num_to_track is None:
# default: track all the frames in the video
max_frame_num_to_track = num_frames
if reverse:
end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
if start_frame_idx > 0:
processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
else:
processing_order = [] # skip reverse tracking if starting from frame 0
else:
end_frame_idx = min(
start_frame_idx + max_frame_num_to_track, num_frames - 1
)
processing_order = range(start_frame_idx, end_frame_idx + 1)
for frame_idx in tqdm(processing_order, desc="propagate in video"):
# We skip those frames already in consolidated outputs (these are frames
# that received input clicks or mask). Note that we cannot directly run
# batched forward on them via `_run_single_frame_inference` because the
# number of clicks on each object might be different.
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
storage_key = "cond_frame_outputs"
current_out = output_dict[storage_key][frame_idx]
pred_masks = current_out["pred_masks"]
if clear_non_cond_mem:
# clear non-conditioning memory of the surrounding frames
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
storage_key = "non_cond_frame_outputs"
current_out = output_dict[storage_key][frame_idx]
pred_masks = current_out["pred_masks"]
else:
storage_key = "non_cond_frame_outputs"
current_out, pred_masks = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=output_dict,
frame_idx=frame_idx,
batch_size=batch_size,
is_init_cond_frame=False,
point_inputs=None,
mask_inputs=None,
reverse=reverse,
run_mem_encoder=True,
)
output_dict[storage_key][frame_idx] = current_out
# Create slices of per-object outputs for subsequent interaction with each
# individual object after tracking.
self._add_output_per_object(
inference_state, frame_idx, current_out, storage_key
)
inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
# Resize the output mask to the original video resolution (we directly use
# the mask scores on GPU for output to avoid any CPU conversion in between)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, pred_masks
)
yield frame_idx, obj_ids, video_res_masks
def _add_output_per_object(
self, inference_state, frame_idx, current_out, storage_key
):
"""
Split a multi-object output into per-object output slices and add them into
`output_dict_per_obj`. The resulting slices share the same tensor storage.
"""
maskmem_features = current_out["maskmem_features"]
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
maskmem_pos_enc = current_out["maskmem_pos_enc"]
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
output_dict_per_obj = inference_state["output_dict_per_obj"]
for obj_idx, obj_output_dict in output_dict_per_obj.items():
obj_slice = slice(obj_idx, obj_idx + 1)
obj_out = {
"maskmem_features": None,
"maskmem_pos_enc": None,
"pred_masks": current_out["pred_masks"][obj_slice],
"obj_ptr": current_out["obj_ptr"][obj_slice],
"object_score_logits": current_out["object_score_logits"][obj_slice],
}
if maskmem_features is not None:
obj_out["maskmem_features"] = maskmem_features[obj_slice]
if maskmem_pos_enc is not None:
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
obj_output_dict[storage_key][frame_idx] = obj_out
@torch.inference_mode()
def clear_all_prompts_in_frame(
self, inference_state, frame_idx, obj_id, need_output=True
):
"""Remove all input points or mask in a specific frame for a given object."""
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
# Clear the conditioning information on the given frame
inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
# Check and see if there are still any inputs left on this frame
batch_size = self._get_obj_num(inference_state)
frame_has_input = False
for obj_idx2 in range(batch_size):
if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
frame_has_input = True
break
if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
frame_has_input = True
break
# If this frame has no remaining inputs for any objects, we further clear its
# conditioning frame status
if not frame_has_input:
output_dict = inference_state["output_dict"]
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
# Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
if out is not None:
# The frame is not a conditioning frame anymore since it's not receiving inputs,
# so we "downgrade" its output (if exists) to a non-conditioning frame output.
output_dict["non_cond_frame_outputs"][frame_idx] = out
inference_state["frames_already_tracked"].pop(frame_idx, None)
# Similarly, do it for the sliced output on each object.
for obj_idx2 in range(batch_size):
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
if obj_out is not None:
obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
# If all the conditioning frames have been removed, we also clear the tracking outputs
if len(output_dict["cond_frame_outputs"]) == 0:
self._reset_tracking_results(inference_state)
if not need_output:
return
# Finally, output updated masks per object (after removing the inputs above)
obj_ids = inference_state["obj_ids"]
is_cond = any(
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
for obj_temp_output_dict in temp_output_dict_per_obj.values()
)
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state,
frame_idx,
is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, consolidated_out["pred_masks_video_res"]
)
return frame_idx, obj_ids, video_res_masks
@torch.inference_mode()
def reset_state(self, inference_state):
"""Remove all input points or mask in all frames throughout the video."""
self._reset_tracking_results(inference_state)
# Remove all object ids
inference_state["obj_id_to_idx"].clear()
inference_state["obj_idx_to_id"].clear()
inference_state["obj_ids"].clear()
inference_state["point_inputs_per_obj"].clear()
inference_state["mask_inputs_per_obj"].clear()
inference_state["output_dict_per_obj"].clear()
inference_state["temp_output_dict_per_obj"].clear()
def _reset_tracking_results(self, inference_state):
"""Reset all tracking inputs and results across the videos."""
for v in inference_state["point_inputs_per_obj"].values():
v.clear()
for v in inference_state["mask_inputs_per_obj"].values():
v.clear()
for v in inference_state["output_dict_per_obj"].values():
v["cond_frame_outputs"].clear()
v["non_cond_frame_outputs"].clear()
for v in inference_state["temp_output_dict_per_obj"].values():
v["cond_frame_outputs"].clear()
v["non_cond_frame_outputs"].clear()
inference_state["output_dict"]["cond_frame_outputs"].clear()
inference_state["output_dict"]["non_cond_frame_outputs"].clear()
inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"].clear()
def _get_image_feature(self, inference_state, frame_idx, batch_size):
"""Compute the image features on a given frame."""
# Look up in the cache first
image, backbone_out = inference_state["cached_features"].get(
frame_idx, (None, None)
)
if backbone_out is None:
# Cache miss -- we will run inference on a single image
device = inference_state["device"]
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).
inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
# expand the features to have the same dimension as the number of objects
expanded_image = image.expand(batch_size, -1, -1, -1)
expanded_backbone_out = {
"backbone_fpn": backbone_out["backbone_fpn"].copy(),
"vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
}
for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
expanded_backbone_out["backbone_fpn"][i] = feat.expand(
batch_size, -1, -1, -1
)
for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
pos = pos.expand(batch_size, -1, -1, -1)
expanded_backbone_out["vision_pos_enc"][i] = pos
features = self._prepare_backbone_features(expanded_backbone_out)
features = (expanded_image,) + features
return features
def _run_single_frame_inference(
self,
inference_state,
output_dict,
frame_idx,
batch_size,
is_init_cond_frame,
point_inputs,
mask_inputs,
reverse,
run_mem_encoder,
prev_sam_mask_logits=None,
):
"""Run tracking on a single frame based on current inputs and previous memory."""
# Retrieve correct image features
(
_,
_,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
) = self._get_image_feature(inference_state, frame_idx, batch_size)
# point and mask should not appear as input simultaneously on the same frame
assert point_inputs is None or mask_inputs is None
current_out = self.track_step(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats,
current_vision_pos_embeds=current_vision_pos_embeds,
feat_sizes=feat_sizes,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
output_dict=output_dict,
num_frames=inference_state["num_frames"],
track_in_reverse=reverse,
run_mem_encoder=run_mem_encoder,
prev_sam_mask_logits=prev_sam_mask_logits,
)
# optionally offload the output to CPU memory to save GPU space
storage_device = inference_state["storage_device"]
maskmem_features = current_out["maskmem_features"]
if maskmem_features is not None:
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
pred_masks_gpu = current_out["pred_masks"]
# potentially fill holes in the predicted masks
if self.fill_hole_area > 0:
pred_masks_gpu = fill_holes_in_mask_scores(
pred_masks_gpu, self.fill_hole_area
)
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
obj_ptr = current_out["obj_ptr"]
object_score_logits = current_out["object_score_logits"]
# make a compact version of this frame's output to reduce the state size
compact_current_out = {
"maskmem_features": maskmem_features,
"maskmem_pos_enc": maskmem_pos_enc,
"pred_masks": pred_masks,
"obj_ptr": obj_ptr,
"object_score_logits": object_score_logits,
}
return compact_current_out, pred_masks_gpu
def _run_memory_encoder(
self,
inference_state,
frame_idx,
batch_size,
high_res_masks,
object_score_logits,
is_mask_from_pts,
):
"""
Run the memory encoder on `high_res_masks`. This is usually after applying
non-overlapping constraints to object scores. Since their scores changed, their
memory also need to be computed again with the memory encoder.
"""
# Retrieve correct image features
_, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
inference_state, frame_idx, batch_size
)
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
current_vision_feats=current_vision_feats,
feat_sizes=feat_sizes,
pred_masks_high_res=high_res_masks,
object_score_logits=object_score_logits,
is_mask_from_pts=is_mask_from_pts,
)
# optionally offload the output to CPU memory to save GPU space
storage_device = inference_state["storage_device"]
maskmem_features = maskmem_features.to(torch.bfloat16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
maskmem_pos_enc = self._get_maskmem_pos_enc(
inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
)
return maskmem_features, maskmem_pos_enc
def _get_maskmem_pos_enc(self, inference_state, current_out):
"""
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
a constant in the inference session to reduce session storage size.
"""
model_constants = inference_state["constants"]
# "out_maskmem_pos_enc" should be either a list of tensors or None
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
if out_maskmem_pos_enc is not None:
if "maskmem_pos_enc" not in model_constants:
assert isinstance(out_maskmem_pos_enc, list)
# only take the slice for one object, since it's same across objects
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
else:
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
# expand the cached maskmem_pos_enc to the actual batch size
batch_size = out_maskmem_pos_enc[0].size(0)
expanded_maskmem_pos_enc = [
x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
]
else:
expanded_maskmem_pos_enc = None
return expanded_maskmem_pos_enc
@torch.inference_mode()
def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
"""
Remove an object id from the tracking state. If strict is True, we check whether
the object id actually exists and raise an error if it doesn't exist.
"""
old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
updated_frames = []
# Check whether this object_id to remove actually exists and possibly raise an error.
if old_obj_idx_to_rm is None:
if not strict:
return inference_state["obj_ids"], updated_frames
raise RuntimeError(
f"Cannot remove object id {obj_id} as it doesn't exist. "
f"All existing object ids: {inference_state['obj_ids']}."
)
# If this is the only remaining object id, we simply reset the state.
if len(inference_state["obj_id_to_idx"]) == 1:
self.reset_state(inference_state)
return inference_state["obj_ids"], updated_frames
# There are still remaining objects after removing this object id. In this case,
# we need to delete the object storage from inference state tensors.
# Step 0: clear the input on those frames where this object id has point or mask input
# (note that this step is required as it might downgrade conditioning frames to
# non-conditioning ones)
obj_input_frames_inds = set()
obj_input_frames_inds.update(
inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
)
obj_input_frames_inds.update(
inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
)
for frame_idx in obj_input_frames_inds:
self.clear_all_prompts_in_frame(
inference_state, frame_idx, obj_id, need_output=False
)
# Step 1: Update the object id mapping (note that it must be done after Step 0,
# since Step 0 still requires the old object id mappings in inference_state)
old_obj_ids = inference_state["obj_ids"]
old_obj_inds = list(range(len(old_obj_ids)))
remain_old_obj_inds = old_obj_inds.copy()
remain_old_obj_inds.remove(old_obj_idx_to_rm)
new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
new_obj_inds = list(range(len(new_obj_ids)))
# build new mappings
old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
inference_state["obj_ids"] = new_obj_ids
# Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
# (note that "consolidated_frame_inds" doesn't need to be updated in this step as
# it's already handled in Step 0)
def _map_keys(container):
new_kvs = []
for k in old_obj_inds:
v = container.pop(k)
if k in old_idx_to_new_idx:
new_kvs.append((old_idx_to_new_idx[k], v))
container.update(new_kvs)
_map_keys(inference_state["point_inputs_per_obj"])
_map_keys(inference_state["mask_inputs_per_obj"])
_map_keys(inference_state["output_dict_per_obj"])
_map_keys(inference_state["temp_output_dict_per_obj"])
# Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
def _slice_state(output_dict, storage_key):
for frame_idx, out in output_dict[storage_key].items():
out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
out["maskmem_pos_enc"] = [
x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
]
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
out["object_score_logits"] = out["object_score_logits"][
remain_old_obj_inds
]
# also update the per-object slices
self._add_output_per_object(
inference_state, frame_idx, out, storage_key
)
_slice_state(inference_state["output_dict"], "cond_frame_outputs")
_slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
# Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
# could show an updated mask for objects previously occluded by the object being removed
if need_output:
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
for frame_idx in obj_input_frames_inds:
is_cond = any(
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
for obj_temp_output_dict in temp_output_dict_per_obj.values()
)
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state,
frame_idx,
is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
inference_state, consolidated_out["pred_masks_video_res"]
)
updated_frames.append((frame_idx, video_res_masks))
return inference_state["obj_ids"], updated_frames
def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
"""
Remove the non-conditioning memory around the input frame. When users provide
correction clicks, the surrounding frames' non-conditioning memories can still
contain outdated object appearance information and could confuse the model.
This method clears those non-conditioning memories surrounding the interacted
frame to avoid giving the model both old and new information about the object.
"""
r = self.memory_temporal_stride_for_eval
frame_idx_begin = frame_idx - r * self.num_maskmem
frame_idx_end = frame_idx + r * self.num_maskmem
output_dict = inference_state["output_dict"]
non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
for t in range(frame_idx_begin, frame_idx_end + 1):
non_cond_frame_outputs.pop(t, None)
for obj_output_dict in inference_state["output_dict_per_obj"].values():
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/utils/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/utils/amg.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from copy import deepcopy
from itertools import product
from typing import Any, Dict, Generator, ItemsView, List, Tuple
import numpy as np
import torch
# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py
class MaskData:
"""
A structure for storing masks and their related data in batched format.
Implements basic filtering and concatenation.
"""
def __init__(self, **kwargs) -> None:
for v in kwargs.values():
assert isinstance(
v, (list, np.ndarray, torch.Tensor)
), "MaskData only supports list, numpy arrays, and torch tensors."
self._stats = dict(**kwargs)
def __setitem__(self, key: str, item: Any) -> None:
assert isinstance(
item, (list, np.ndarray, torch.Tensor)
), "MaskData only supports list, numpy arrays, and torch tensors."
self._stats[key] = item
def __delitem__(self, key: str) -> None:
del self._stats[key]
def __getitem__(self, key: str) -> Any:
return self._stats[key]
def items(self) -> ItemsView[str, Any]:
return self._stats.items()
def filter(self, keep: torch.Tensor) -> None:
for k, v in self._stats.items():
if v is None:
self._stats[k] = None
elif isinstance(v, torch.Tensor):
self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
elif isinstance(v, np.ndarray):
self._stats[k] = v[keep.detach().cpu().numpy()]
elif isinstance(v, list) and keep.dtype == torch.bool:
self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
elif isinstance(v, list):
self._stats[k] = [v[i] for i in keep]
else:
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
def cat(self, new_stats: "MaskData") -> None:
for k, v in new_stats.items():
if k not in self._stats or self._stats[k] is None:
self._stats[k] = deepcopy(v)
elif isinstance(v, torch.Tensor):
self._stats[k] = torch.cat([self._stats[k], v], dim=0)
elif isinstance(v, np.ndarray):
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
elif isinstance(v, list):
self._stats[k] = self._stats[k] + deepcopy(v)
else:
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
def to_numpy(self) -> None:
for k, v in self._stats.items():
if isinstance(v, torch.Tensor):
self._stats[k] = v.float().detach().cpu().numpy()
def is_box_near_crop_edge(
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
) -> torch.Tensor:
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
return torch.any(near_crop_edge, dim=1)
def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
box_xywh = deepcopy(box_xyxy)
box_xywh[2] = box_xywh[2] - box_xywh[0]
box_xywh[3] = box_xywh[3] - box_xywh[1]
return box_xywh
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
assert len(args) > 0 and all(
len(a) == len(args[0]) for a in args
), "Batched iteration must have inputs of all the same size."
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
for b in range(n_batches):
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
"""
Encodes masks to an uncompressed RLE, in the format expected by
pycoco tools.
"""
# Put in fortran order and flatten h,w
b, h, w = tensor.shape
tensor = tensor.permute(0, 2, 1).flatten(1)
# Compute change indices
diff = tensor[:, 1:] ^ tensor[:, :-1]
change_indices = diff.nonzero()
# Encode run length
out = []
for i in range(b):
cur_idxs = change_indices[change_indices[:, 0] == i, 1]
cur_idxs = torch.cat(
[
torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
cur_idxs + 1,
torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
]
)
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if tensor[i, 0] == 0 else [0]
counts.extend(btw_idxs.detach().cpu().tolist())
out.append({"size": [h, w], "counts": counts})
return out
def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
"""Compute a binary mask from an uncompressed RLE."""
h, w = rle["size"]
mask = np.empty(h * w, dtype=bool)
idx = 0
parity = False
for count in rle["counts"]:
mask[idx : idx + count] = parity
idx += count
parity ^= True
mask = mask.reshape(w, h)
return mask.transpose() # Put in C order
def area_from_rle(rle: Dict[str, Any]) -> int:
return sum(rle["counts"][1::2])
def calculate_stability_score(
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
) -> torch.Tensor:
"""
Computes the stability score for a batch of masks. The stability
score is the IoU between the binary masks obtained by thresholding
the predicted mask logits at high and low values.
"""
# One mask is always contained inside the other.
# Save memory by preventing unnecessary cast to torch.int64
intersections = (
(masks > (mask_threshold + threshold_offset))
.sum(-1, dtype=torch.int16)
.sum(-1, dtype=torch.int32)
)
unions = (
(masks > (mask_threshold - threshold_offset))
.sum(-1, dtype=torch.int16)
.sum(-1, dtype=torch.int32)
)
return intersections / unions
def build_point_grid(n_per_side: int) -> np.ndarray:
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
offset = 1 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
return points
def build_all_layer_point_grids(
n_per_side: int, n_layers: int, scale_per_layer: int
) -> List[np.ndarray]:
"""Generates point grids for all crop layers."""
points_by_layer = []
for i in range(n_layers + 1):
n_points = int(n_per_side / (scale_per_layer**i))
points_by_layer.append(build_point_grid(n_points))
return points_by_layer
def generate_crop_boxes(
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
) -> Tuple[List[List[int]], List[int]]:
"""
Generates a list of crop boxes of different sizes. Each layer
has (2**i)**2 boxes for the ith layer.
"""
crop_boxes, layer_idxs = [], []
im_h, im_w = im_size
short_side = min(im_h, im_w)
# Original image
crop_boxes.append([0, 0, im_w, im_h])
layer_idxs.append(0)
def crop_len(orig_len, n_crops, overlap):
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
for i_layer in range(n_layers):
n_crops_per_side = 2 ** (i_layer + 1)
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
crop_w = crop_len(im_w, n_crops_per_side, overlap)
crop_h = crop_len(im_h, n_crops_per_side, overlap)
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
# Crops in XYWH format
for x0, y0 in product(crop_box_x0, crop_box_y0):
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
crop_boxes.append(box)
layer_idxs.append(i_layer + 1)
return crop_boxes, layer_idxs
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
# Check if boxes has a channel dimension
if len(boxes.shape) == 3:
offset = offset.unsqueeze(1)
return boxes + offset
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0]], device=points.device)
# Check if points has a channel dimension
if len(points.shape) == 3:
offset = offset.unsqueeze(1)
return points + offset
def uncrop_masks(
masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
) -> torch.Tensor:
x0, y0, x1, y1 = crop_box
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
return masks
# Coordinate transform masks
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
pad = (x0, pad_x - x0, y0, pad_y - y0)
return torch.nn.functional.pad(masks, pad, value=0)
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True
def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
from pycocotools import mask as mask_utils # type: ignore
h, w = uncompressed_rle["size"]
rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
return rle
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
"""
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
"""
# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
# Normalize shape to CxHxW
shape = masks.shape
h, w = shape[-2:]
if len(shape) > 2:
masks = masks.flatten(0, -3)
else:
masks = masks.unsqueeze(0)
# Get top and bottom edges
in_height, _ = torch.max(masks, dim=-1)
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
in_height_coords = in_height_coords + h * (~in_height)
top_edges, _ = torch.min(in_height_coords, dim=-1)
# Get left and right edges
in_width, _ = torch.max(masks, dim=-2)
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
right_edges, _ = torch.max(in_width_coords, dim=-1)
in_width_coords = in_width_coords + w * (~in_width)
left_edges, _ = torch.min(in_width_coords, dim=-1)
# If the mask is empty the right edge will be to the left of the left edge.
# Replace these boxes with [0, 0, 0, 0]
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
out = out * (~empty_filter).unsqueeze(-1)
# Return to original shape
if len(shape) > 2:
out = out.reshape(*shape[:-2], 4)
else:
out = out[0]
return out
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/utils/misc.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import warnings
from threading import Thread
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
def get_sdpa_settings():
if torch.cuda.is_available():
old_gpu = torch.cuda.get_device_properties(0).major < 7
# only use Flash Attention on Ampere (8.0) or newer GPUs
use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
if not use_flash_attn:
warnings.warn(
"Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.",
category=UserWarning,
stacklevel=2,
)
# keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only
# available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases)
pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
if pytorch_version < (2, 2):
warnings.warn(
f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. "
"Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).",
category=UserWarning,
stacklevel=2,
)
math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn
else:
old_gpu = True
use_flash_attn = False
math_kernel_on = True
return old_gpu, use_flash_attn, math_kernel_on
def get_connected_components(mask):
"""
Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W).
Inputs:
- mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is
background.
Outputs:
- labels: A tensor of shape (N, 1, H, W) containing the connected component labels
for foreground pixels and 0 for background pixels.
- counts: A tensor of shape (N, 1, H, W) containing the area of the connected
components for foreground pixels and 0 for background pixels.
"""
from sam2 import _C
return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())
def mask_to_box(masks: torch.Tensor):
"""
compute bounding box given an input mask
Inputs:
- masks: [B, 1, H, W] masks, dtype=torch.Tensor
Returns:
- box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
"""
B, _, h, w = masks.shape
device = masks.device
xs = torch.arange(w, device=device, dtype=torch.int32)
ys = torch.arange(h, device=device, dtype=torch.int32)
grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy")
grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w)
grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w)
min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1)
max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1)
min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1)
max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1)
bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1)
return bbox_coords
def _load_img_as_tensor(img_path, image_size):
img_pil = Image.open(img_path)
img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
img_np = img_np / 255.0
else:
raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
img = torch.from_numpy(img_np).permute(2, 0, 1)
video_width, video_height = img_pil.size # the original video size
return img, video_height, video_width
class AsyncVideoFrameLoader:
"""
A list of video frames to be load asynchronously without blocking session start.
"""
def __init__(
self,
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
):
self.img_paths = img_paths
self.image_size = image_size
self.offload_video_to_cpu = offload_video_to_cpu
self.img_mean = img_mean
self.img_std = img_std
# items in `self.images` will be loaded asynchronously
self.images = [None] * len(img_paths)
# catch and raise any exceptions in the async loading thread
self.exception = None
# video_height and video_width be filled when loading the first image
self.video_height = None
self.video_width = None
self.compute_device = compute_device
# load the first frame to fill video_height and video_width and also
# to cache it (since it's most likely where the user will click)
self.__getitem__(0)
# load the rest of frames asynchronously without blocking the session start
def _load_frames():
try:
for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
self.__getitem__(n)
except Exception as e:
self.exception = e
self.thread = Thread(target=_load_frames, daemon=True)
self.thread.start()
def __getitem__(self, index):
if self.exception is not None:
raise RuntimeError("Failure in frame loading thread") from self.exception
img = self.images[index]
if img is not None:
return img
img, video_height, video_width = _load_img_as_tensor(
self.img_paths[index], self.image_size
)
self.video_height = video_height
self.video_width = video_width
# normalize by mean and std
img -= self.img_mean
img /= self.img_std
if not self.offload_video_to_cpu:
img = img.to(self.compute_device, non_blocking=True)
self.images[index] = img
return img
def __len__(self):
return len(self.images)
def load_video_frames(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
compute_device=torch.device("cuda"),
):
"""
Load the video frames from video_path. The frames are resized to image_size as in
the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo.
"""
is_bytes = isinstance(video_path, bytes)
is_str = isinstance(video_path, str)
is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"]
if is_bytes or is_mp4_path:
return load_video_frames_from_video_file(
video_path=video_path,
image_size=image_size,
offload_video_to_cpu=offload_video_to_cpu,
img_mean=img_mean,
img_std=img_std,
compute_device=compute_device,
)
elif is_str and os.path.isdir(video_path):
return load_video_frames_from_jpg_images(
video_path=video_path,
image_size=image_size,
offload_video_to_cpu=offload_video_to_cpu,
img_mean=img_mean,
img_std=img_std,
async_loading_frames=async_loading_frames,
compute_device=compute_device,
)
else:
raise NotImplementedError(
"Only MP4 video and JPEG folder are supported at this moment"
)
def load_video_frames_from_jpg_images(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
compute_device=torch.device("cuda"),
):
"""
Load the video frames from a directory of JPEG files (".jpg" format).
The frames are resized to image_size x image_size and are loaded to GPU if
`offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
You can load a frame asynchronously by setting `async_loading_frames` to `True`.
"""
if isinstance(video_path, str) and os.path.isdir(video_path):
jpg_folder = video_path
else:
raise NotImplementedError(
"Only JPEG frames are supported at this moment. For video files, you may use "
"ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n"
"```\n"
"ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n"
"```\n"
"where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks "
"ffmpeg to start the JPEG file from 00000.jpg."
)
frame_names = [
p
for p in os.listdir(jpg_folder)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
num_frames = len(frame_names)
if num_frames == 0:
raise RuntimeError(f"no images found in {jpg_folder}")
img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
if async_loading_frames:
lazy_images = AsyncVideoFrameLoader(
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
)
return lazy_images, lazy_images.video_height, lazy_images.video_width
images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu:
images = images.to(compute_device)
img_mean = img_mean.to(compute_device)
img_std = img_std.to(compute_device)
# normalize by mean and std
images -= img_mean
images /= img_std
return images, video_height, video_width
def load_video_frames_from_video_file(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
compute_device=torch.device("cuda"),
):
"""Load the video frames from a video file."""
import decord
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
# Get the original video height and width
decord.bridge.set_bridge("torch")
video_height, video_width, _ = decord.VideoReader(video_path).next().shape
# Iterate over all frames in the video
images = []
for frame in decord.VideoReader(video_path, width=image_size, height=image_size):
images.append(frame.permute(2, 0, 1))
images = torch.stack(images, dim=0).float() / 255.0
if not offload_video_to_cpu:
images = images.to(compute_device)
img_mean = img_mean.to(compute_device)
img_std = img_std.to(compute_device)
# normalize by mean and std
images -= img_mean
images /= img_std
return images, video_height, video_width
def fill_holes_in_mask_scores(mask, max_area):
"""
A post processor to fill small holes in mask scores with area under `max_area`.
"""
# Holes are those connected components in background with area <= self.max_area
# (background regions are those with mask scores <= 0)
assert max_area > 0, "max_area must be positive"
input_mask = mask
try:
labels, areas = get_connected_components(mask <= 0)
is_hole = (labels > 0) & (areas <= max_area)
# We fill holes with a small positive mask score (0.1) to change them to foreground.
mask = torch.where(is_hole, 0.1, mask)
except Exception as e:
# Skip the post-processing step on removing small holes if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
"functionality may be limited (which doesn't affect the results in most cases; see "
"https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
)
mask = input_mask
return mask
def concat_points(old_point_inputs, new_points, new_labels):
"""Add new points and labels to previous point inputs (add at the end)."""
if old_point_inputs is None:
points, labels = new_points, new_labels
else:
points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
return {"point_coords": points, "point_labels": labels}
================================================
FILE: camera_pose_annotation/dynamic_mask/sam2/utils/transforms.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Normalize, Resize, ToTensor
class SAM2Transforms(nn.Module):
def __init__(
self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
):
"""
Transforms for SAM2.
"""
super().__init__()
self.resolution = resolution
self.mask_threshold = mask_threshold
self.max_hole_area = max_hole_area
self.max_sprinkle_area = max_sprinkle_area
self.mean = [0.485, 0.456, 0.406]
self.std = [0.229, 0.224, 0.225]
self.to_tensor = ToTensor()
self.transforms = torch.jit.script(
nn.Sequential(
Resize((self.resolution, self.resolution)),
Normalize(self.mean, self.std),
)
)
def __call__(self, x):
x = self.to_tensor(x)
return self.transforms(x)
def forward_batch(self, img_list):
img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
img_batch = torch.stack(img_batch, dim=0)
return img_batch
def transform_coords(
self, coords: torch.Tensor, normalize=False, orig_hw=None
) -> torch.Tensor:
"""
Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
Returns
Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
"""
if normalize:
assert orig_hw is not None
h, w = orig_hw
coords = coords.clone()
coords[..., 0] = coords[..., 0] / w
coords[..., 1] = coords[..., 1] / h
coords = coords * self.resolution # unnormalize coords
return coords
def transform_boxes(
self, boxes: torch.Tensor, normalize=False, orig_hw=None
) -> torch.Tensor:
"""
Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
"""
boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
return boxes
def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
"""
Perform PostProcessing on output masks.
"""
from sam2.utils.misc import get_connected_components
masks = masks.float()
input_masks = masks
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
try:
if self.max_hole_area > 0:
# Holes are those connected components in background with area <= self.fill_hole_area
# (background regions are those with mask scores <= self.mask_threshold)
labels, areas = get_connected_components(
mask_flat <= self.mask_threshold
)
is_hole = (labels > 0) & (areas <= self.max_hole_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with a small positive mask score (10.0) to change them to foreground.
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
if self.max_sprinkle_area > 0:
labels, areas = get_connected_components(
mask_flat > self.mask_threshold
)
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with negative mask score (-10.0) to change them to background.
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
except Exception as e:
# Skip the post-processing step if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
"functionality may be limited (which doesn't affect the results in most cases; see "
"https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
)
masks = input_masks
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
return masks
================================================
FILE: caption/LLM/__init__.py
================================================
================================================
FILE: caption/LLM/inference.py
================================================
import os
import time
import queue
from argparse import ArgumentParser
from multiprocessing import Manager
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import concurrent
import pandas as pd
import numpy as np
import sys
sys.path.append(os.path.abspath(os.path.join(__file__, "../..")))
from utils.api_call import api_call
def get_pose(pose_dir):
"""
Retrieve and process pose data from extrinsics.npy file
"""
# Base directory for pose data
pose_path = os.path.join(pose_dir, 'extrinsics.npy')
assert os.path.isfile(pose_path), f"Pose file not found: {pose_path}"
# Load and process the pose file
poses = np.load(pose_path)
# Data processing steps
poses = poses[::5, :, 3] # Take first row for every 5 rows
max_value = np.max(poses)
min_value = np.min(poses)
min_abs_value = np.min(np.abs(poses))
# Normalize and convert to integers (minimize integer digits)
poses = np.round(poses / (max_value - min_value) /
min_abs_value).astype(int)
# Keep only first 3 columns and transpose
poses = poses[:, :3].T
# Extract individual axes
poses1, poses2, poses3 = poses[0], poses[1], poses[2]
# Convert each axis to string
poses1_str = ' '.join(map(str, poses1))
poses2_str = ' '.join(map(str, poses2))
poses3_str = ' '.join(map(str, poses3))
# Combine into formatted string
poses_str = f'x:{poses1_str}\ny:{poses2_str}\nz:{poses3_str}'
return poses_str
def get_prompt(pose_dir, prompt_dir, vqa_caption, dist_level):
"""
Construct a prompt by combining content from prompt1.txt, prompt2.txt, VQA caption, and pose data
"""
# Read prompt components
p1_file = os.path.join(prompt_dir, 'prompt1.txt')
p2_file = os.path.join(prompt_dir, 'prompt2.txt')
with open(p1_file, 'r', encoding='utf-8') as f:
p1_content = f.read().strip()
with open(p2_file, 'r', encoding='utf-8') as f:
p2_content = f.read().strip()
# Get pose data
poses = get_pose(pose_dir)
# Assemble final prompt
prompt = (f"{p1_content}\nGiven Information:\n{vqa_caption}\n3.Camera Position Data:\n{poses}\n"
f"\n4.Motion intensity:\n{dist_level}\n{p2_content}")
return prompt
def process_single_row(args, row):
"""
Process a single row of data by calling API and saving the result
"""
# Check if VQA file exists
vqa_path = os.path.join(args.vqa_path, f"{row['id']}.txt")
assert os.path.isfile(vqa_path), f"VQA file not found: {vqa_path}"
# Read VQA caption
with open(vqa_path, "r") as f:
vqa_caption = f.read()
# Skip processing if file already exists
save_file = os.path.join(args.llm_path, f"{row['id']}.txt")
if os.path.exists(save_file) and os.path.getsize(save_file) > 0:
return
# Call API with retry mechanism
pose_dir = os.path.join(args.pose_load_dir, row["id"], "reconstructions")
prompt_text = get_prompt(pose_dir, args.prompt_dir,
vqa_caption, row["distLevel"])
llm_caption = api_call(prompt_text, args.model,
args.api_key, args.base_domain)
assert llm_caption is not None, f"API call failed for id {row['id']}"
# Save the result with model information
with open(save_file, 'w', encoding='utf-8') as f:
f.write(llm_caption + f"\n\n6. Qwen model: \n{args.model}")
return
def worker(args, task_queue, pbar):
"""
Worker function to process tasks from the queue
Args:
task_queue: Queue containing tasks to process
pbar: Progress bar object for tracking progress
"""
while True:
try:
index, row = task_queue.get(timeout=1)
except queue.Empty:
break
# Add delay to prevent overwhelming API
time.sleep(args.wait_time)
# Process the single row
process_single_row(args, row)
# Update progress
task_queue.task_done()
pbar.update(1)
def parse_args():
"""
Parse command line arguments
Returns:
Parsed arguments object
"""
parser = ArgumentParser(description='VQA Processing Program')
parser.add_argument('--csv_path', type=str, required=True,
help='Path to CSV file')
parser.add_argument('--pose_load_dir', type=str, required=True,
help='Directory to load pose data')
parser.add_argument('--output_dir', type=str, required=True,
help='Directory to save results')
parser.add_argument('--prompt_dir', type=str,
default=os.path.join(os.path.dirname(
__file__), "vqa_prompt.txt"),
help='Path to prompt file')
parser.add_argument('--model', type=str, default="qwen3-30b-a3b",
help='Model name')
parser.add_argument('--api_key', type=str,
default="sk-****",
help='API key')
parser.add_argument('--num_workers', type=int, default=1,
help='Number of worker threads')
parser.add_argument('--wait_time', type=float, default=0.5,
help='Time between requests in seconds')
parser.add_argument('--base_domain', type=str, default="https://cn2us02.opapi.win/",
help='API base domain')
return parser.parse_args()
def main():
"""
Main processing function that handles multiple rows using parallel workers
Args:
group_id (str): Identifier for the group
prompt_dir (str): Directory containing prompt files
model_file (str): Path to file containing model names
api_key_file (str): Path to file containing API keys
num_workers (int): Number of worker threads
wait_time (float): Time to wait between requests
base_domain (str): Base domain for API calls
record_time (bool): Whether to record processing time
Returns:
None
"""
args = parse_args()
# Validate temporary directory exists
# Create LLM directory if it doesn't exist
args.llm_path = os.path.join(args.output_dir, "LLM")
if not os.path.isdir(args.llm_path):
os.makedirs(args.llm_path, exist_ok=True)
# Validate VQA directory exists
args.vqa_path = os.path.join(args.output_dir, "VQA")
assert os.path.isdir(
args.vqa_path), f"VQA directory not found: {args.vqa_path}"
# Read CSV file containing scene information
df = pd.read_csv(args.csv_path)
# Initialize task queue with all rows
manager = Manager()
task_queue = manager.Queue()
for index, row in df.iterrows():
task_queue.put((index, row))
# Start processing with progress bar
with tqdm(total=len(df), desc="LLM Finished") as pbar:
with ThreadPoolExecutor(max_workers=args.num_workers) as executor:
# Start worker threads
futures = [executor.submit(worker, args, task_queue, pbar)
for _ in range(args.num_workers)]
# Wait for all workers to complete
for future in concurrent.futures.as_completed(futures):
future.result()
if __name__ == "__main__":
main()
================================================
FILE: caption/LLM/prompt1.txt
================================================
You are given a video sequence with camera trajectory data representing the camera's movement through a scene.
The data consists of:
Camera Motion Caption: A basic description of how the camera moves.
Scene Description: A detailed visual summary of the environment.
Camera position data: Three lines, representing the sequence of the camera's x-coordinate, y-coordinate, and z-coordinate. These values are derived from normalized 3D pose data using the following formula: poses = np.round(poses / (max_value - min_value) / min_abs_value).astype(int); Each value is then multiplied by 1,000,000 and rounded to the nearest integer.
Motion intensity: An integer that indicates the level of camera movement, where a value of 0 means the camera is static, 1 indicates slight movement, and 2 or higher represents normal or noticeable motion. In tasks such as Optimized Camera Motion Caption and Main Motion Trend Summary, this intensity value should be used to qualify the degree of motion described — for example, using "slight forward translate" when the intensity is 1.
Your Tasks:
1. Optimized Camera Motion Caption
Generate a refined motion caption **from the perspective of the camera itself**, using only the **camera position data** to determine movement direction and dynamics.
Use the following rules to interpret motion:
x increasing: camera moves right
x decreasing: camera moves left
y increasing: camera moves down
y decreasing: camera moves up
z increasing: camera moves forward
z decreasing: camera moves backward
Analyze the full trajectory over time to capture acceleration, deceleration, or steady motion. Integrate scene context but prioritize accuracy based on numerical data. Avoid vague phrases like "zoom out" unless it's clearly due to focal length change — here, use translation terms instead.
If motion intensity is 0, describe the fixed viewpoint and what the camera observes from that vantage point, incorporating compositional or environmental elements from the original caption. If intensity is 1, reflect subtle movement in the description (e.g., "slight right translate") without exaggerating the motion. For both cases, preserve visual context while aligning with the actual movement level. Avoid mentioning data analysis or detection explicitly — let the description itself reflect the motion state.
Target Length: 50–100 words
2. Scene Abstract Caption
Provide a single-sentence summary that captures:
- Key architectural elements
- Overall atmosphere/style
- Notable design features
Target Length: About 50 words
3. Main Motion Trend Summary
Summarize the general movement using only 1–3 short motion phrases , depending on how many are clearly present. Focus strictly on major, sustained movements — ignore minor fluctuations or brief directional changes. If only one or two movements dominate, list only those. Use directional translation terms (e.g., forward translate, left translate, upward drift)
4. Scene Keywords
Extract up to 4 keywords summarizing the key aspects of the scene. Include one term that broadly describes the scene type. Use nouns/noun phrases related to weather, place, time, lighting, scene type. Avoid adjectives/gerunds except for weather. Example: sunset, foggy, marketplace, city street, village
5. Immersive Shot Summary
Blend Optimized Camera Motion Caption and Scene Description evenly — do not focus more on the camera or the scene alone. Describe the visuals as if someone is watching a moving image unfold. Use descriptive, cinematic language that evokes imagery and emotion. Keep it concise but expressive — suitable for use in scripts, storyboards, or AI video/image generation. Target Length: 50–100 words
================================================
FILE: caption/LLM/prompt2.txt
================================================
Output Format:
1. Camera Motion Caption:
[From the perspective of the camera holder, with the camera as the subject. Combine camera pose information to describe]
2. Scene Abstract Caption:
[A concise one sentence summary of the scene]
3. Main Motion Trend Summary:
[keywords separated by commas, e.g., forward translate, downward tilt]
4. Scene Keywords:
[word1, word2, word3, ...] (max 5 words)
5. Immersive Shot Summary:
================================================
FILE: caption/README.md
================================================
# Semantic Information Annotation
This script automates the process of generating structured text descriptions (captions) for videos through a multi-step pipeline involving Visual Question Answering (VQA), Large Language Models (LLM), result combination, and tagging.
## Captioning Workflow
The video captioning process follows these sequential steps:
1. **VQA Captioning**: Uses a Visual Question Answering model to analyze visual content and generate initial captions based on predefined prompts.
2. **LLM Captioning**: Employs a Large Language Model to process pose data and generate additional descriptive captions.
3. **Result Combination**: Merges the outputs from the VQA and LLM steps into a unified structure.
4. **Tag Addition**: Enhances the combined results with relevant tags using a language model.
## Script Explanation
### Configuration Parameters
- `CSV`: Path to the result CSV file generated in the annotation step
- `SRC_DIR`: Path to the annotation output directory containing video frames and pose data
- `OUTPUT_DIR`: Path where all output files will be saved
- `num_workers`: Number of parallel workers to use for processing
- `wait_time`: Waiting time between API requests (in seconds)
### Step 1: VQA Captioning
Generates captions by analyzing visual content using a VQA model.
Parameters:
- `--csv_path`: Path to the input CSV file
- `--fig_load_dir`: Directory containing video frames/images
- `--output_dir`: Directory to save VQA results
- `--prompt_file`: Path to VQA prompt template file
- `--model`: VQA model to use (default: gemini-2.0-flash)
- `--api_key`: API key for accessing the VQA model service
- `--base_domain`: API endpoint domain for the VQA model
- `--num_workers`: Number of parallel workers
- `--wait_time`: Waiting time between API requests
### Step 2: LLM Captioning
Generates additional captions by processing pose data using a Large Language Model.
Parameters:
- `--csv_path`: Path to the input CSV file
- `--pose_load_dir`: Directory containing pose data
- `--output_dir`: Directory to save LLM results
- `--prompt_dir`: Directory containing LLM prompt templates
- `--model`: LLM model to use (default: qwen3-30b-a3b)
- `--api_key`: API key for accessing the LLM service
- `--base_domain`: API endpoint domain for the LLM
- `--num_workers`: Number of parallel workers
- `--wait_time`: Waiting time between API requests
### Step 3: Combine Results
Merges the outputs from VQA and LLM steps into a unified format.
Parameters:
- `--csv_path`: Path to the input CSV file
- `--load_dir`: Directory containing VQA and LLM results
- `--output_dir`: Directory to save combined results
- `--num_workers`: Number of parallel workers
### Step 4: Add Tags
Enhances the combined results with relevant tags using a language model.
Parameters:
- `--csv_path`: Path to the input CSV file
- `--json_load_dir`: Directory containing combined results
- `--prompt_file`: Path to tagging prompt template file
- `--model`: Model to use for tagging (default: qwen3-30b-a3b)
- `--api_key`: API key for accessing the tagging model service
- `--base_domain`: API endpoint domain for the tagging model
- `--num_workers`: Number of parallel workers
- `--wait_time`: Waiting time between API requests
## Usage
1. Replace all placeholder values (enclosed in square brackets) with your actual paths and API keys
2. Make the script executable: `chmod +x caption_pipeline.sh`
3. Run the script: `./caption_pipeline.sh`
The script will execute each step sequentially, displaying start/end times and duration for each step, and save all outputs to the specified `OUTPUT_DIR`.
## results example
several samples of video captions generated by the model after each step.
================================================
FILE: caption/VQA/__init__.py
================================================
================================================
FILE: caption/VQA/inference.py
================================================
import os
import concurrent.futures
from multiprocessing import Manager
import queue
import pandas as pd
from tqdm import tqdm
import argparse
import time
import base64
import cv2
from glob import glob
import sys
sys.path.append(os.path.abspath(os.path.join(__file__, "../..")))
from utils.api_call import api_call
def encode_image(image_path):
"""
Resizes an image to 640x360 and encodes it as a Base64 string with data URI prefix.
"""
# Read image using OpenCV
image = cv2.imread(image_path)
# Resize image to standard dimensions (640x360)
resized_image = cv2.resize(image, (640, 360))
# Encode image as JPEG and convert to Base64
_, buffer = cv2.imencode('.jpeg', resized_image)
base64_data = base64.b64encode(buffer).decode("utf-8")
# Return with data URI format for API compatibility
return f"data:image/jpeg;base64,{base64_data}"
def get_prompt(fig_dir, prompt_text):
"""
Load key frames from a video, constructs a multimodal request, and calls the API.
"""
# Get frames from directory
frames = sorted(glob(f"{fig_dir}/*.jpg"))[::5]
# Construct multimodal input content
messages_content = []
# Add encoded images to request content
for frame in frames:
try:
encoded_frame = encode_image(frame)
messages_content.append({
"type": "image_url",
"image_url": {"url": encoded_frame}
})
except Exception as e:
print(f"Image processing error: {str(e)}")
return None
# Add text prompt to request content
messages_content.append({"type": "text", "text": prompt_text})
return messages_content
def process_single_row(args, row):
"""
Process a single row: call the VQA API and save the result for one scene.
Handles retries and error logging.
"""
save_path = os.path.join(args.output_dir, "VQA")
if not os.path.isdir(save_path):
os.makedirs(save_path, exist_ok=True)
save_file = os.path.join(save_path, f"{row['id']}.txt")
if os.path.exists(save_file) and os.path.getsize(save_file) > 0:
# Skip if already exists
return
# Call API
fig_dir = os.path.join(args.fig_load_dir, row['id'], "img")
prompt_text = get_prompt(fig_dir, args.prompt_text)
vqa_caption = api_call(prompt_text, args.model,
args.api_key, args.base_domain)
assert vqa_caption is not None, f"API call failed for id {row['id']}"
# Save result
with open(save_file, 'w', encoding='utf-8') as f:
f.write(vqa_caption)
def worker(args, task_queue, pbar):
while True:
try:
index, row = task_queue.get(timeout=1)
except queue.Empty:
break
time.sleep(args.wait_time)
process_single_row(args, row)
task_queue.task_done()
pbar.update(1)
def parse_args():
"""
Parse command line arguments for VQA batch processing.
"""
parser = argparse.ArgumentParser(description='VQA batch processing script')
parser.add_argument('--csv_path', type=str, required=True,
help='CSV file path')
parser.add_argument('--fig_load_dir', type=str, required=True,
help='Directory to load figures')
parser.add_argument('--output_dir', type=str, required=True,
help='Directory to save results')
parser.add_argument('--prompt_file', type=str,
default="vqa_prompt.txt",
help='Prompt file path')
parser.add_argument('--model', type=str, default="gemini-2.0-flash",
help='Model name')
parser.add_argument('--api_key', type=str,
default="sk-****",
help='API key')
parser.add_argument('--num_workers', type=int, default=4,
help='Number of worker threads')
parser.add_argument('--wait_time', type=float, default=0.8,
help='Request interval for each thread (seconds)')
parser.add_argument('--base_domain', type=str, default="https://cn2us02.opapi.win/",
help='API base domain')
return parser.parse_args()
def main():
"""
Batch process all scenes in a group: call VQA API for each row in the CSV.
Uses a thread pool for concurrency and supports timing.
"""
args = parse_args()
df = pd.read_csv(args.csv_path)
# Read prompt text
with open(args.prompt_file, "r", encoding="utf-8") as f:
args.prompt_text = f.read().strip()
manager = Manager()
task_queue = manager.Queue()
for index, row in df.iterrows():
task_queue.put((index, row))
with tqdm(total=len(df), desc=f"VQA Finished") as pbar:
with concurrent.futures.ThreadPoolExecutor(max_workers=args.num_workers) as executor:
futures = []
for worker_id in range(args.num_workers):
futures.append(executor.submit(worker, args, task_queue, pbar))
for future in concurrent.futures.as_completed(futures):
future.result()
if __name__ == "__main__":
main()
================================================
FILE: caption/VQA/prompt.txt
================================================
You are given a sequence of video frames in chronological order. Analyze them carefully and generate two distinct captions based on the following instructions:
1. Camera Motion Caption:
From the perspective of the camera operator, describe the entire motion trajectory of the camera throughout the clip using precise cinematography terminology (e.g., static, pan, tilt, dolly, handheld, crane, aerial, zoom, etc.).
Do NOT assume the camera starts in a "static" position just because it appears stationary in the first frame.Only describe the camera as stationary if there is no visual change across multiple consecutive frames.
Instead, focus on changes between frames to infer movement. Describe motion state transitions, not frame-by-frame repetition (e.g., do not say “the camera moves forward again” if it’s continuous). For example:
- Starting with a dolly forward along a straight path,
- Then transitioning into a slow right-hand pan,
- Or shifting from handheld walking movement to a stationary pivot tilt.
Include brief environmental context where relevant to clarify direction or intent (e.g., "The camera dollies forward through a narrow alleyway, then smoothly turns left at the intersection").
Keep the final caption concise, between 50–100 words, focused only on motion and its evolution over time.
2. Scene Description:
Provide a rich, holistic description of the visual content. Include:
- Main subjects and dynamic objects: who or what is present, and what they are doing (e.g., a cyclist rides past from left to right, a group of people gather near a bench),
- Background/environment: setting (urban street, forest trail, indoor space), notable landmarks or structures,
- Lighting and atmosphere: time of day, weather conditions, mood (e.g., golden-hour lighting, overcast sky casting soft shadows, neon-lit nighttime scene),
- Overall tone or emotion conveyed by the scene.
Avoid focusing on individual frames—describe the general impression and ongoing activity across the entire clip. Aim for around 100 words, balancing detail and conciseness.
Output Format:
Do not include any explanations or extra text before or after your response.
Begin directly with:
1. Camera Motion Caption: ...
followed by
2. Scene Description: ...
================================================
FILE: caption/__init__.py
================================================
================================================
FILE: caption/tagging/__init__.py
================================================
================================================
FILE: caption/tagging/inference.py
================================================
import os
import time
import json
import queue
import argparse
import pandas as pd
from tqdm import tqdm
from multiprocessing import Manager
import concurrent.futures
import sys
sys.path.append(os.path.abspath(os.path.join(__file__, "../..")))
from utils.api_call import api_call
def parse_category_tags(tag_caption):
"""
Parse API response to structured category data using camelCase naming convention
"""
lines = [line.strip()
for line in tag_caption.strip().split('\n') if line.strip()]
# Initialize category data with default values
category_data = {
"sceneType": {
"first": "Unknown",
"second": "Unknown"
},
"lighting": "Unknown",
"timeOfDay": "Unknown",
"weather": "Unknown",
"crowdDensity": "Unknown"
}
# Parse each line to extract category information
for line in lines:
line_lower = line.lower()
if line_lower.startswith("primary scene type:"):
category_data["sceneType"]["first"] = line.split(":", 1)[1].strip()
elif line_lower.startswith("secondary scene type:"):
category_data["sceneType"]["second"] = line.split(":", 1)[
1].strip()
elif line_lower.startswith("lighting:"):
category_data["lighting"] = line.split(":", 1)[1].strip()
elif line_lower.startswith("time of day:"):
category_data["timeOfDay"] = line.split(":", 1)[1].strip()
elif line_lower.startswith("weather:"):
category_data["weather"] = line.split(":", 1)[1].strip()
elif line_lower.startswith("crowd density:"):
category_data["crowdDensity"] = line.split(":", 1)[1].strip()
return category_data
def process_single_row(args, json_file):
"""
Process a single JSON file to add category tags via API call
"""
# Check if CategoryTag field already exists
with open(json_file, 'r') as f:
data = json.load(f)
# Skip if CategoryTag already exists
if "CategoryTag" in data:
return
description = data['SceneDesc']
prompt_text = args.prompt_text + description
# Call API to get category tags with retry mechanism
tag_caption = api_call(prompt_text, args.model,
args.api_key, args.base_domain)
assert tag_caption is not None, f"API call failed for file {json_file}"
# Parse and add category tags to the JSON file
category_tag = parse_category_tags(tag_caption)
# Merge new data with existing data
data["CategoryTag"] = category_tag
# Overwrite file with updated data
with open(json_file, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def worker(args, task_queue, pbar):
while True:
try:
index, json_file = task_queue.get(timeout=1)
except queue.Empty:
break
time.sleep(args.wait_time)
process_single_row(args, json_file)
task_queue.task_done()
pbar.update(1)
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description='Category Tag Processing Program')
parser.add_argument('--csv_path', type=str, required=True,
help='Path to the CSV file')
parser.add_argument('--json_load_dir', type=str, required=True,
help='Directory containing JSON files')
parser.add_argument('--prompt_file', type=str,
default="prompt.txt",
help='Path to prompt file')
parser.add_argument('--model', type=str, default="qwen3-30b-a3b",
help='Model name')
parser.add_argument('--api_key', type=str,
default="sk-****",
help='API key')
parser.add_argument('--num_workers', type=int, default=4,
help='Number of worker threads')
parser.add_argument('--wait_time', type=float, default=0.8,
help='Time interval between requests per thread (seconds)')
parser.add_argument('--base_domain', type=str, default="https://cn2us02.opapi.win/",
help='API base domain')
return parser.parse_args()
def main():
"""
Process a group of JSON files using multiple threads to add category tags
"""
args = parse_args()
df = pd.read_csv(args.csv_path)
with open(args.prompt_file, 'r', encoding='utf-8') as f:
args.prompt_text = f.read().strip()
# Initialize task queue and add all files to process
manager = Manager()
task_queue = manager.Queue()
for index, row in df.iterrows():
clip_id = row['id']
json_file = os.path.join(args.json_load_dir, f"{clip_id}.json")
task_queue.put((index, json_file))
# Start processing with progress bar
with tqdm(total=task_queue.qsize(), desc="Tags Completed") as pbar:
with concurrent.futures.ThreadPoolExecutor(max_workers=args.num_workers) as executor:
# Start worker threads
futures = [executor.submit(worker, args, task_queue, pbar)
for _ in range(args.num_workers)]
# Wait for all workers to complete
for future in concurrent.futures.as_completed(futures):
future.result()
if __name__ == "__main__":
main()
================================================
FILE: caption/tagging/prompt.txt
================================================
You are an AI assistant specialized in analyzing scene descriptions and extracting structured metadata. Your task is to read the provided scene description and infer six attributes with hierarchical classification where applicable.
Output Rules:
1. Scene Type (Choose one from):
- [Urban, Natural Landscape, Interior, Rural, Waterfront, Unknown]
- Add a custom secondary tag (unrestricted) to further define the scene (e.g., Urban → "Street Scene", Interior → "Library")
2. Lighting: [Bright / Dim/Dark / Unknown]
3. Time of Day: [Dawn/Morning / Daytime / Dusk/Evening / Night / Unknown]
4. Weather: [Sunny / Rainy / Foggy / Cloudy / Snowy / Unknown]
5. Crowd Density: [Deserted / Sparse / Moderate / Crowded / Unknown]
Deduction Guidelines:
- Prioritize explicit descriptors over implied cues (e.g., "snow scattered" → Snowy; "wet surfaces + cloudy" → Cloudy)
- If no evidence exists for any attribute, output 'Unknown' for that field. Maintain strict objectivity - never assume information beyond the text.
Output Format (strictly follow line breaks):
Primary Scene Type: [X]
Secondary Scene Type: [CustomTag]
Lighting: [X]
Time of Day: [X]
Weather: [X]
Crowd Density: [X]
The following is a scene description:
================================================
FILE: caption/utils/__init__.py
================================================
================================================
FILE: caption/utils/api_call.py
================================================
import requests
def api_call(prompt_text, model, api_key, base_domain):
"""
Make an API call to a language model with a constructed prompt,
handling different API formats for different model providers.
"""
# Determine if using Qwen model API (Aliyun)
is_qwen = "dashscope.aliyuncs.com" in base_domain
# Configure API endpoint and payload based on model type
if is_qwen:
api_url = base_domain + "v1/chat/completions"
# Payload format specific to Qwen model
payload = {
"model": model,
"messages": [
# {"role": "system", "content": "You are a helpful assistant."}, # Optional system message
{"role": "user", "content": prompt_text}
],
"enable_thinking": False,
"temperature": 0.1 # Low temperature for more deterministic output
}
else:
# Payload format for other models
api_url = base_domain + "v1beta/openai/"
payload = {
"model": model,
"messages": [
{"role": "user", "content": prompt_text}
],
"temperature": 0.1,
"user": "User"
}
# Configure request headers
if is_qwen:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
"Accept": "application/json"
}
else:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
"User-Agent": f"({base_domain})",
"Accept": "application/json"
}
try:
# Execute API request with timeout
response = requests.post(
api_url,
headers=headers,
json=payload,
timeout=120 # 2-minute timeout
)
response.raise_for_status() # Raise exception for HTTP errors
response_data = response.json()
# Optional: Uncomment to log token usage
# if 'usage' in response_data:
# usage = response_data.get('usage', {})
# prompt_tokens = usage.get('prompt_tokens', 0)
# completion_tokens = usage.get('completion_tokens', 0)
# total_tokens = usage.get('total_tokens', 0)
#
# print(f"Input tokens: {prompt_tokens}")
# print(f"Output tokens: {completion_tokens}")
# print(f"Total tokens: {total_tokens}")
# else:
# print("API response does not contain token usage information")
# Extract and return response content based on API format
if is_qwen:
return response_data.get("choices", [{}])[0].get("message", {}).get("content", "")
else:
return response_data.get("choices", [{}])[0].get("message", {}).get("content", "")
except Exception as e:
print(f"API request error: {str(e)}")
return None
================================================
FILE: caption/utils/combine.py
================================================
import os
import json
import re
import queue
import argparse
import pandas as pd
from multiprocessing import Manager
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
def parse_text_to_json(text):
"""
Parses text in a specific format into a JSON structure.
"""
# Define mapping between text labels and JSON keys
labels = {
"Camera Motion Caption": "OptCamMotion",
"Scene Abstract Caption": "SceneSummary",
"Main Motion Trend Summary": "MotionTrends",
"Scene Keywords": "SceneTags",
"Immersive Shot Summary": "ShotImmersion",
"Qwen model": "LLM"
}
# Initialize result dictionary with empty values
result = {key: "" for key in labels.values()}
current_label = None
current_content = []
# Process text line by line
lines = text.split('\n')
i = 0
while i < len(lines):
line = lines[i].strip()
# Check if current line contains any label
for label, json_key in labels.items():
if label in line:
# Find position of first letter after the label
start_pos = line.find(label) + len(label)
# Skip non-alphabet characters
while start_pos < len(line) and not line[start_pos].isalpha():
start_pos += 1
content = line[start_pos:].strip()
# Process content if it exists after the label
if content:
if json_key in ["MotionTrends", "SceneTags"]:
# Split by commas (both Chinese and English), preserve spaces in phrases
items = re.split(r'[,,]\s*', content)
result[json_key] = [item.strip()
for item in items if item.strip()]
else:
result[json_key] = content
current_label = None
current_content = []
else:
# No content after label, continue reading subsequent lines
current_label = json_key
current_content = []
break
else:
# If collecting content for a label
if current_label:
# Check if line is not empty
if line:
# Find first letter in line
start_pos = 0
while start_pos < len(line) and not line[start_pos].isalpha():
start_pos += 1
if start_pos < len(line):
current_content.append(line[start_pos:])
else:
# Empty line indicates end of current label content
content = ' '.join(current_content).strip()
if current_label in ["MotionTrends", "SceneTags"]:
# Split by commas (both Chinese and English), preserve spaces in phrases
items = re.split(r'[,,]\s*', content)
result[current_label] = [item.strip()
for item in items if item.strip()]
else:
result[current_label] = content
current_label = None
current_content = []
i += 1
# Handle Qwen model label which might extend to the end
if current_label == "LLM" and current_content:
content = ' '.join(current_content).strip()
result[current_label] = content
return result
def vqa_parse_text_to_json(text):
"""
Parses text containing Camera Motion Caption and Scene Description into JSON format.
"""
result = {
"CamMotion": "",
"SceneDesc": ""
}
# Process Camera Motion Caption - from first letter after label to newline
camera_pattern = r'Camera Motion Caption:\s*(\w[\s\S]*?)(?=\n|$)'
camera_match = re.search(camera_pattern, text)
if camera_match:
result["CamMotion"] = camera_match.group(1).strip()
# Process Scene Description - from first letter after label to end of text
scene_pattern = r'Scene Description:\s*(\w[\s\S]*)$'
scene_match = re.search(scene_pattern, text, re.DOTALL)
if scene_match:
result["SceneDesc"] = scene_match.group(1).strip()
return result
def process_single_row(args, clip_id):
"""
Processes VQA and LLM captions for a single clip and merges them into one JSON file.
"""
# Define file paths
vqa_path = os.path.join(args.load_dir, "VQA", f"{clip_id}.txt")
assert os.path.exists(vqa_path), f"VQA path does not exist: {vqa_path}"
llm_path = os.path.join(args.load_dir, "LLM", f"{clip_id}.txt")
assert os.path.exists(llm_path), f"LLM path does not exist: {llm_path}"
output_path = os.path.join(args.output_dir, f"{clip_id}.json")
# Skip if output file already exists
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
return
# Read VQA file content
with open(vqa_path, 'r', encoding='utf-8') as f:
vqa_text = f.read()
# Read LLM file content
with open(llm_path, 'r', encoding='utf-8') as f:
llm_text = f.read()
# Parse text content to JSON
vqa_json = vqa_parse_text_to_json(vqa_text)
llm_json = parse_text_to_json(llm_text)
# Merge JSON objects
combined_json = {**vqa_json, **llm_json}
# Save merged JSON to output file
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(combined_json, f, ensure_ascii=False, indent=2)
def worker(args, task_queue, pbar):
while True:
try:
idx, clip_id = task_queue.get(timeout=1)
except queue.Empty:
break
process_single_row(args, clip_id)
task_queue.task_done()
pbar.update(1)
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description='Merge VQA and LLM caption data')
parser.add_argument('--csv_path', type=str,
required=True, help='Path to the CSV file')
parser.add_argument('--load_dir', type=str, required=True,
help='Directory containing caption files')
parser.add_argument('--output_dir', type=str, required=True,
help='Directory to save merged JSON files')
parser.add_argument('--num_workers', type=int,
default=32, help='Number of worker threads')
return parser.parse_args()
def main():
"""
Processes all scenes in the specified batch.
"""
args = parse_args()
df = pd.read_csv(args.csv_path)
os.makedirs(args.output_dir, exist_ok=True)
# Use multiprocessing manager for thread-safe queue
manager = Manager()
task_queue = manager.Queue()
# Add tasks to queue
for index, row in df.iterrows():
task_queue.put((index, row['id']))
# Start multi-threaded processing with progress bar
with tqdm(total=len(df), desc="Processing progress") as pbar:
with ThreadPoolExecutor(max_workers=args.num_workers) as executor:
futures = []
for _ in range(args.num_workers):
futures.append(executor.submit(worker, args, task_queue, pbar))
# Wait for all futures to complete
for future in as_completed(futures):
future.result()
if __name__ == "__main__":
main()
================================================
FILE: docker-entrypoint.sh
================================================
#!/usr/bin/env bash
# Simple entrypoint: activate venv if present and run provided command
set -euo pipefail
if [ -f "/workspace/venv/bin/activate" ]; then
echo "Activating venv"
# shellcheck disable=SC1091
source /workspace/venv/bin/activate
fi
if [ "$#" -gt 0 ]; then
exec "$@"
else
exec bash
fi
================================================
FILE: requirements/requirements.txt
================================================
torch==2.7.0 --index-url https://download.pytorch.org/whl/cu126
torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126
torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu126
opencv-python==4.11.0.86
tqdm==4.67.1
imageio==2.37.0
einops==0.8.1
scipy==1.15.2
matplotlib==3.10.0
ninja==1.11.1.3
numpy==1.26.4
pandas==2.2.3
huggingface_hub
================================================
FILE: requirements/requirements_annotation.txt
================================================
wandb==0.19.8
timm==1.0.15
kornia==0.8.0
xformers==0.0.30
torch_scatter==2.1.2
gradio_imageslider==0.0.20
gradio==4.29.0
# sam2
hydra-core==1.3.2
iopath==0.1.10
OpenEXR
================================================
FILE: requirements/requirements_scoring.txt
================================================
ftfy==6.3.1
diffusers==0.29.0
accelerate==1.4.0
av==14.2.0
scenedetect==0.6.5.2
decord==0.6.0
imageio-ffmpeg==0.6.0
ffmpeg-python==0.2.0
clip @ git+https://github.com/openai/CLIP.git
cpbd==1.0.7
# paddlepaddle-gpu==3.0.0 --index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/
paddleocr==3.0.0
nvidia-nccl-cu12==2.26.2
numpy==1.26.4
================================================
FILE: scoring/README.md
================================================
# Scoring
## Aesthetic Score
To evaluate the aesthetic quality of videos, we use the scoring model from [CLIP+MLP Aesthetic Score Predictor](https://github.com/christophschuhmann/improved-aesthetic-predictor). This model is trained on 176K SAC (Simulacra Aesthetic Captions) pairs, 15K LAION-Logos (Logos) pairs, and 250K AVA (The Aesthetic Visual Analysis) image-text pairs.
The aesthetic score is between 1 and 10, where 5.5 can be considered as the threshold for fair aesthetics, and 6.5 for high aesthetics. Good text-to-image models can achieve a score of 7.0 or higher.
First, download the scoring model to `./checkpoints/aesthetic.pth`. Skip this step if you already follow the installation instructions in [README](../README.md).
```bash
wget https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac+logos+ava1-l14-linearMSE.pth -O checkpoints/aesthetic.pth
```
Then, run the following command to compute aesthetic scores.
```bash
torchrun --nproc_per_node ${GPU_NUM} scoring/aesthetic/inference.py \
${ROOT_META}/clips_info.csv \
--bs 16 \
--num_workers ${NUM_WORKERS} \
--fig_load_dir ${ROOT_FIG}
```
## Luminance Score
Luminance was calculated for the first, middle, and last frames using the standard formula $L = 0.2126 R + 0.7152 G + 0.0722 B$, where $R$, $G$, and $B$ are the respective channel values. Clips with average luminance outside the range [20, 140], either too dark or too bright, were excluded, ensuring that only videos with proper exposure were retained.
Run the following command to compute luminance scores.
```bash
torchrun --nproc_per_node ${GPU_NUM} scoring/luminance/inference.py \
${ROOT_META}/clips_info.csv \
--bs 16 \
--num_workers ${NUM_WORKERS} \
--fig_load_dir ${ROOT_FIG}
```
## Motion Score
Conventional motion analysis using optical flow is computationally expensive and less effective for videos with complex motion patterns. Inspired by Open-Sora 2.0, we adopted a lightweight VMAF-based motion analysis method integrated into FFMPEG. This method yields a motion score between 0 and 20.
Clips with scores outside the valid range of [2, 14], either too static (scores $<$ 2) or excessively chaotic (scores $>$ 14), were filtered out.
Run the following command to compute motion scores.
```bash
python scoring/motion/inference.py ${ROOT_META}/clips_info.csv \
--temp_save_dir ${ROOT_TEMP} \
--num_workers $((GPU_NUM * 4)) \
--gpu_num ${GPU_NUM}
```
## OCR
For text detection, we used the latest release of PaddleOCR, which offers high accuracy and robust multilingual support. We processed the first, middle, and last frames of each clip to detect text regions, computing the ratio of text area to frame size. Clips where the text area exceeded 30% were removed, as these were considered informational rather than visual.
Run the following command to compute OCR scores.
```bash
python scoring/ocr/inference.py ${ROOT_META}/clips_info.csv \
--fig_load_dir ${ROOT_FIG} \
--num_workers $((GPU_NUM * 4)) \
--gpu_num ${GPU_NUM}
```
================================================
FILE: scoring/__init__.py
================================================
================================================
FILE: scoring/aesthetic/__init__.py
================================================
================================================
FILE: scoring/aesthetic/inference.py
================================================
"""
Aesthetic scoring script for video frames using CLIP and MLP models.
Adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py
Calculates aesthetic scores for video clips using distributed processing.
"""
# adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py
import argparse
import gc
import os
from glob import glob
from datetime import timedelta
from PIL import Image
import clip
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
def merge_scores(gathered_list: list, csv: pd.DataFrame, column):
"""Merge aesthetic scores from all distributed processes."""
# Reorder results from all processes
indices_list = list(map(lambda x: x[0], gathered_list))
scores_list = list(map(lambda x: x[1], gathered_list))
flat_indices = []
for x in zip(*indices_list):
flat_indices.extend(x)
flat_scores = []
for x in zip(*scores_list):
flat_scores.extend(x)
flat_indices = np.array(flat_indices)
flat_scores = np.array(flat_scores)
# Filter duplicates from distributed processing
unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
csv.loc[unique_indices, column] = flat_scores[unique_indices_idx]
# Drop indices in csv not in unique_indices
csv = csv.loc[unique_indices]
return csv
class VideoTextDataset(torch.utils.data.Dataset):
"""Dataset for loading video frames for aesthetic scoring."""
def __init__(self, csv_path, fig_load_dir, transform=None):
self.csv_path = csv_path
self.csv = pd.read_csv(csv_path)
self.transform = transform
self.fig_load_dir = fig_load_dir
def __getitem__(self, index):
"""Load and transform video frames for a single sample."""
sample = self.csv.iloc[index]
# Load first 3 frames from video clip
images_dir = os.path.join(self.fig_load_dir, sample["id"])
images = sorted(glob(f"{images_dir}/img/*.jpg"))[:3]
# Apply CLIP preprocessing transforms
images = [self.transform(Image.open(img).convert("RGB")) for img in images]
# Stack images into tensor
images = torch.stack(images)
return dict(index=index, images=images)
def __len__(self):
return len(self.csv)
class MLP(nn.Module):
"""Multi-layer perceptron for aesthetic score prediction."""
def __init__(self, input_size):
super().__init__()
self.input_size = input_size
self.layers = nn.Sequential(
nn.Linear(self.input_size, 1024),
nn.Dropout(0.2),
nn.Linear(1024, 128),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.Dropout(0.1),
nn.Linear(64, 16),
nn.Linear(16, 1),
)
def forward(self, x):
return self.layers(x)
class AestheticScorer(nn.Module):
"""Combined CLIP + MLP model for aesthetic scoring."""
def __init__(self, input_size, device):
super().__init__()
self.mlp = MLP(input_size)
self.clip, self.preprocess = clip.load("ViT-L/14", device=device)
self.eval()
self.to(device)
def forward(self, x):
"""Extract CLIP features and predict aesthetic scores."""
image_features = self.clip.encode_image(x)
image_features = F.normalize(image_features, p=2, dim=-1).float()
return self.mlp(image_features)
def parse_args():
"""Parse command line arguments for aesthetic scoring."""
parser = argparse.ArgumentParser()
parser.add_argument("--csv_path", type=str, help="Path to the input CSV file")
parser.add_argument(
"--load_num", type=int, default=4, help="Number of frames to load"
)
parser.add_argument("--bs", type=int, default=1024, help="Batch size")
parser.add_argument("--num_workers", type=int, default=16, help="Number of workers")
parser.add_argument(
"--fig_load_dir",
type=str,
required=True,
help="Directory to load the extracted frames",
)
parser.add_argument(
"--prefetch_factor", type=int, default=3, help="Prefetch factor"
)
parser.add_argument("--skip_if_existing", action="store_true")
args = parser.parse_args()
return args
def main():
args = parse_args()
csv_path = args.csv_path
if not os.path.exists(csv_path):
print(f"CSV file '{csv_path}' not found. Exit.")
exit()
wo_ext, ext = os.path.splitext(csv_path)
out_path = f"{wo_ext}_aes{ext}"
if args.skip_if_existing and os.path.exists(out_path):
print(f"Output CSV file '{out_path}' already exists. Exit.")
exit()
# Initialize distributed processing
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
# Build aesthetic scoring model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model = AestheticScorer(768, device)
model.mlp.load_state_dict(
torch.load("checkpoints/aesthetic.pth", map_location=device)
)
preprocess = model.preprocess
# Build dataset and dataloader
dataset = VideoTextDataset(
args.csv_path, transform=preprocess, fig_load_dir=args.fig_load_dir
)
dataloader = DataLoader(
dataset,
batch_size=args.bs,
num_workers=args.num_workers,
sampler=DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=False,
drop_last=False,
),
)
# Compute aesthetic scores for all batches
indices_list = []
scores_list = []
model.eval()
for batch in tqdm(
dataloader, disable=(dist.get_rank() != 0), position=dist.get_rank()
):
indices = batch["index"]
images = batch["images"].to(device, non_blocking=True)
B = images.shape[0]
images = rearrange(images, "B N C H W -> (B N) C H W")
# Compute aesthetic scores using CLIP + MLP
with torch.no_grad():
scores = model(images)
# Average scores across frames for each video
scores = rearrange(scores, "(B N) 1 -> B N", B=B)
scores = scores.mean(dim=1)
scores_np = scores.to(torch.float32).cpu().numpy()
indices_list.extend(indices.tolist())
scores_list.extend(scores_np.tolist())
# Wait for all ranks to finish data processing
dist.barrier()
# Gather results from all processes and save
torch.cuda.empty_cache()
gc.collect()
gathered_list = [None] * dist.get_world_size()
dist.all_gather_object(gathered_list, (indices_list, scores_list))
if dist.get_rank() == 0:
csv_new = merge_scores(gathered_list, dataset.csv, column="aesthetic score")
csv_new.to_csv(out_path, index=False)
print(f"New csv with aesthetic scores saved to '{out_path}'.")
if __name__ == "__main__":
main()
================================================
FILE: scoring/luminance/__init__.py
================================================
================================================
FILE: scoring/luminance/inference.py
================================================
"""
Luminance analysis script for video frames using distributed processing.
Calculates mean, min, and max luminance scores for video clips using PyTorch distributed computing.
"""
import argparse
import os
import gc
from glob import glob
from datetime import timedelta
from PIL import Image
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.transforms.functional import pil_to_tensor
from tqdm import tqdm
def merge_scores(gathered_list: list, csv: pd.DataFrame):
"""Merge luminance scores from all distributed processes."""
# Reorder results from all processes
indices_list = list(map(lambda x: x[0], gathered_list))
mean_scores_list = list(map(lambda x: x[1], gathered_list))
min_scores_list = list(map(lambda x: x[2], gathered_list))
max_scores_list = list(map(lambda x: x[3], gathered_list))
flat_indices = []
for x in zip(*indices_list):
flat_indices.extend(x)
flat_mean_scores = []
for x in zip(*mean_scores_list):
flat_mean_scores.extend(x)
flat_min_scores = []
for x in zip(*min_scores_list):
flat_min_scores.extend(x)
flat_max_scores = []
for x in zip(*max_scores_list):
flat_max_scores.extend(x)
flat_indices = np.array(flat_indices)
flat_mean_scores = np.array(flat_mean_scores)
flat_min_scores = np.array(flat_min_scores)
flat_max_scores = np.array(flat_max_scores)
# Filter duplicates from distributed processing
unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
csv.loc[unique_indices, "luminance mean"] = flat_mean_scores[unique_indices_idx]
csv.loc[unique_indices, "luminance min"] = flat_min_scores[unique_indices_idx]
csv.loc[unique_indices, "luminance max"] = flat_max_scores[unique_indices_idx]
# Drop indices in csv not in unique_indices
csv = csv.loc[unique_indices]
return csv
class VideoDataset(torch.utils.data.Dataset):
"""Dataset to handle video luminance computation."""
def __init__(self, csv_path, fig_load_dir):
self.csv_path = csv_path
self.csv = pd.read_csv(csv_path)
self.fig_load_dir = fig_load_dir
def __getitem__(self, index):
"""Get video frames and compute luminance for a single sample."""
sample = self.csv.iloc[index]
# Load first 3 frames from video clip
images_dir = os.path.join(self.fig_load_dir, sample["id"])
images = sorted(glob(f"{images_dir}/img/*.jpg"))[:3]
# Transform images to tensors
images = torch.stack(
[pil_to_tensor(Image.open(img).convert("RGB")) for img in images]
)
return {"index": index, "images": images}
def __len__(self):
return len(self.csv)
def parse_args():
"""Parse command line arguments for luminance analysis."""
parser = argparse.ArgumentParser()
parser.add_argument("--csv_path", type=str, help="Path to the input CSV file")
parser.add_argument("--bs", type=int, default=4, help="Batch size")
parser.add_argument("--num_workers", type=int, default=16, help="Number of workers")
parser.add_argument(
"--fig_load_dir",
type=str,
required=True,
help="Directory to load the extracted frames",
)
parser.add_argument("--skip_if_existing", action="store_true")
return parser.parse_args()
def main():
args = parse_args()
csv_path = args.csv_path
if not os.path.exists(csv_path):
print(f"csvdata file '{csv_path}' not found. Exiting.")
return
output_path = csv_path.replace(".csv", "_lum.csv")
if args.skip_if_existing and os.path.exists(output_path):
print(f"Output '{output_path}' already exists. Exiting.")
return
# Initialize distributed processing
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
(
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
if torch.cuda.is_available()
else None
)
# Setup dataset and distributed dataloader
dataset = VideoDataset(csv_path, fig_load_dir=args.fig_load_dir)
dataloader = DataLoader(
dataset,
batch_size=args.bs,
num_workers=args.num_workers,
sampler=DistributedSampler(
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()
),
)
# Process batches and calculate luminance scores
indices_list = []
mean_scores_list = []
max_scores_list = []
min_scores_list = []
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
for batch in tqdm(
dataloader, disable=(dist.get_rank() != 0), position=dist.get_rank()
):
indices = batch["index"]
images = batch["images"].to(device, non_blocking=True) # [B, N, C, H, W]
# Calculate luminance using standard RGB weights
R, G, B = images[:, :, 0], images[:, :, 1], images[:, :, 2]
luminance = 0.2126 * R + 0.7152 * G + 0.0722 * B
scores = luminance.mean(dim=[2, 3])
# Compute statistics across frames
mean_scores = scores.mean(dim=1).cpu().numpy()
max_scores = scores.max(dim=1)[0].cpu().numpy()
min_scores = scores.min(dim=1)[0].cpu().numpy()
indices_list.extend(indices.tolist())
mean_scores_list.extend(mean_scores.tolist())
max_scores_list.extend(max_scores.tolist())
min_scores_list.extend(min_scores.tolist())
# Wait for all ranks to finish data processing
dist.barrier()
# Gather results from all processes and save
torch.cuda.empty_cache()
gc.collect()
gathered_list = [None] * dist.get_world_size()
dist.all_gather_object(
gathered_list,
(indices_list, mean_scores_list, min_scores_list, max_scores_list),
)
if dist.get_rank() == 0:
csv_new = merge_scores(gathered_list, dataset.csv)
csv_new.to_csv(output_path, index=False)
print(f"New csv with luminance scores saved to '{output_path}'")
if __name__ == "__main__":
main()
================================================
FILE: scoring/motion/INSTALL.md
================================================
# Compiling FFmpeg with NVIDIA GPU Acceleration and VMAF on Ubuntu
This guide provides a comprehensive walkthrough for compiling FFmpeg from source on an Ubuntu system equipped with an NVIDIA GPU. The resulting build will support NVIDIA's hardware encoding/decoding (NVENC/DEC), NPP filters (NVIDIA Performance Primitives), and CUDA-based VMAF (Video Multi-Method Assessment Fusion) for video quality assessment.
## Environment and Versions
Before you begin, ensure your system environment is similar to the configuration below. Version matching is crucial for a successful compilation.
The GPU needs to support HEVC; refer to the [NVIDIA NVDEC Support Matrix](https://en.wikipedia.org/wiki/NVIDIA_Video_Coding_Engine#NVDEC).
- **GPU**: NVIDIA GeForce RTX 4090 or other compatible models
- **OS**: Ubuntu 22.04
- **NVIDIA Driver Version**: A version compatible with CUDA 12.6
- **CUDA Version (from `nvidia-smi`)**: `12.x`
- **CUDA Toolkit Version**: `12.6` (This is the version used for compilation)
- **Target FFmpeg Version**: `6.1`
**Key Tip**: The version of the `NVIDIA Codec Headers` (`ffnvcodec`) must be compatible with the `CUDA Toolkit` version installed on your system and the version of `FFmpeg` you intend to compile.
## Compilation Steps
Please follow these steps in order.
### Step 1: Install System Dependencies
Update system packages and install required development tools and libraries:
```bash
sudo apt-get update
sudo DEBIAN_FRONTEND=noninteractive apt-get install -y \
libopenjp2-7-dev \
ninja-build \
cmake \
git \
python3 \
python3-pip \
nasm \
xxd \
pkg-config \
curl \
unzip \
ca-certificates \
libnuma-dev \
libsm6 \
libxext6 \
libxrender1 \
libgl1 \
vim \
nvidia-cuda-toolkit
```
### Step 2: Clone Required Repositories
```bash
# Create a working directory (custom path allowed)
mkdir -p ~/ffmpeg-build && cd ~/ffmpeg-build
# Clone nv-codec-headers (NVIDIA codec headers)
git clone https://github.com/FFmpeg/nv-codec-headers.git
# Clone libvmaf (video quality assessment library)
git clone https://github.com/Netflix/vmaf.git
cd vmaf && git checkout master # Switch to master branch (modify version if needed)
cd ..
# Clone FFmpeg source code
git clone https://github.com/FFmpeg/FFmpeg.git
cd FFmpeg && git checkout master # Switch to master branch (modify version if needed)
cd ..
```
### Step 3: Install nv-codec-headers
```bash
cd nv-codec-headers
make
sudo make install
cd ..
```
### Step 4: Compile and Install libvmaf (with CUDA Support)
1. Install the meson build tool:
```bash
python3 -m pip install meson
```
2. Compile and install libvmaf:
```bash
cd vmaf
meson libvmaf/build libvmaf \
-Denable_cuda=true \
-Denable_avx512=true \
--buildtype release
ninja -vC libvmaf/build
sudo ninja -vC libvmaf/build install
cd ..
```
3. Update system library cache:
```bash
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/x86_64-linux-gnu/
sudo ldconfig
```
### Step 5: Compile and Install FFmpeg (with NVIDIA and libvmaf Support)
```bash
cd FFmpeg
# Configure compilation options (enable CUDA, NVENC, NVDEC, and libvmaf)
./configure \
--enable-libnpp \
--enable-nonfree \
--enable-nvdec \
--enable-nvenc \
--enable-cuvid \
--enable-cuda \
--enable-cuda-nvcc \
--enable-libvmaf \
--enable-ffnvcodec \
--disable-stripping \
--extra-cflags="-I/usr/local/cuda/include" \
--extra-ldflags="-L/usr/local/cuda/lib64 -L/usr/local/cuda/lib64/stubs/"
# Compile (adjust the number after -j based on CPU cores for faster compilation)
make -j$(nproc)
# Install
sudo make install
cd ..
```
### Step 6: Configure Python Environment
1. Upgrade pip and set up links:
```bash
sudo ln -sf /usr/bin/python3 /usr/bin/python
python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel
```
2. Install Python dependencies (assuming project code is cloned locally; replace with actual path):
```bash
# Navigate to the project root directory
cd /path/to/your/project
# Install dependencies
python3 -m pip --no-cache-dir install -r requirements/requirements.txt
python3 -m pip --no-cache-dir install -r requirements/requirements_scoring.txt || true
python3 -m pip --no-cache-dir install -r requirements/requirements_annotation.txt || true
```
### Step 7: Verify Installation
1. Check FFmpeg version and configuration:
```bash
ffmpeg -version
ffmpeg -encoders | grep nvenc # Verify NVENC support
ffmpeg -decoders | grep nvdec # Verify NVDEC support
ffmpeg -filters | grep vmaf # Verify libvmaf support
```
2. If all the above commands output corresponding content correctly, the installation is successful.
## Troubleshooting
### Issue 1: VMAF compilation fails with `vcs_version.h: No such file or directory`
- **Cause**: This error typically occurs if you downloaded the VMAF source code as a ZIP archive instead of using `git clone`. The build script relies on the `.git` directory to generate version header files.
- **Solution**: Always use `git clone` to get the source code.
```bash
git clone https://github.com/Netflix/vmaf.git
```
### Issue 2: FFmpeg `configure` fails with error about Video Codec SDK version being too low
- **Error Message**: Something like `ERROR: nvenc requested, but NVIDIA Video Codec SDK 12.1 or later is required.` (The version number may vary).
- **Cause**: This means the version of `nv-codec-headers` you checked out is not compatible with your NVIDIA driver, CUDA Toolkit, or the version of FFmpeg you are building.
- **Solution**:
1. Carefully re-check your [NVIDIA Driver](https://www.nvidia.com/Download/index.aspx) and [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive) versions.
2. Go back to [Step 3: Install NVIDIA Codec Headers](#step-3-install-nvidia-codec-headers) and ensure you `git checkout` the branch that best matches your environment (e.g., `sdk/12.6`).
3. Consult the [Official NVIDIA FFmpeg Guide](https://docs.nvidia.com/video-technologies/video-codec-sdk/ffmpeg-with-nvidia-gpu/index.html) or the `nv-codec-headers` repository to confirm version compatibility.
## References
- [VMAF on GitHub](https://github.com/Netflix/vmaf)
- [FFmpeg Official Source](https://github.com/FFmpeg/FFmpeg/tree/release/6.1)
- [NVIDIA Codec Headers Source](https://github.com/FFmpeg/nv-codec-headers/tree/sdk/12.6)
- [Official NVIDIA Guide for Compiling FFmpeg](https://docs.nvidia.com/video-technologies/video-codec-sdk/ffmpeg-with-nvidia-gpu/index.html)
================================================
FILE: scoring/motion/__init__.py
================================================
================================================
FILE: scoring/motion/inference.py
================================================
"""
Motion analysis script for video quality assessment using FFmpeg and VMAF.
Calculates motion scores for video clips using hardware acceleration when available.
"""
import os
import argparse
import pandas as pd
import subprocess
from multiprocessing import Manager
import queue
import concurrent.futures
from tqdm import tqdm
FFMPEG_PATH = "/usr/local/bin/ffmpeg"
def get_ffmpeg_acceleration():
"""
Auto detect the best acceleration method.
Priority: NVIDIA GPU > CPU.
"""
try:
# Get the list of ffmpeg configuration
output = subprocess.check_output(
[FFMPEG_PATH, "-version"], stderr=subprocess.DEVNULL
).decode("utf-8")
if "--enable-cuda-nvcc" in output and "--enable-libvmaf" in output:
return "nvidia"
else:
return "cpu" # Use CPU
except Exception as e:
print(f"FFmpeg acceleration detection failed: {e}")
return "cpu"
ACCELERATION_TYPE = get_ffmpeg_acceleration()
print(f"FFmpeg acceleration type: {ACCELERATION_TYPE}")
def process_single_row(video_path, args, process_id):
"""Process a single video to generate motion analysis CSV using FFmpeg."""
path = os.path.join(
args.temp_save_dir, os.path.basename(video_path).split(".")[0] + ".csv"
)
# Build FFmpeg command with appropriate acceleration
command = [FFMPEG_PATH]
if ACCELERATION_TYPE == "nvidia":
command += [
"-hwaccel",
"cuda",
"-hwaccel_output_format",
"cuda",
"-hwaccel_device",
f"{process_id % args.gpu_num}",
]
command += ["-i", f"{video_path}"]
if ACCELERATION_TYPE == "nvidia":
command += [
"-hwaccel",
"cuda",
"-hwaccel_output_format",
"cuda",
"-hwaccel_device",
f"{process_id % args.gpu_num}",
]
command += ["-i", f"{video_path}"]
if ACCELERATION_TYPE == "nvidia":
command += [
"-filter_complex",
f"[0:v]scale_cuda=format=yuv420p[dis],[1:v]scale_cuda=format=yuv420p[ref],[dis][ref]libvmaf_cuda=log_fmt=csv:log_path={path}",
]
else:
command += ["-lavfi", f"libvmaf=log_fmt=csv:log_path={path}"]
command += ["-f", "null", "-"]
try:
result = subprocess.run(command, capture_output=True, text=True, check=True)
except subprocess.CalledProcessError as e:
print(f"Error: {e.stderr}")
def calculate_score(row, args):
"""Calculate motion score for a specific video clip segment."""
csv_path = os.path.join(args.temp_save_dir, f'{row["id_ori"]}.csv')
df = pd.read_csv(csv_path)
df = df[(df["Frame"] >= row["frame_start"]) & (df["Frame"] <= row["frame_end"])]
mean_value = df["integer_motion2"].mean()
return mean_value
def worker1(task_queue, progress_queue, args, process_id):
"""Worker function for processing videos in parallel."""
while True:
try:
video_path = task_queue.get(timeout=1)
except queue.Empty:
break
process_single_row(video_path, args, process_id)
progress_queue.put(video_path)
task_queue.task_done()
def worker2(task_queue, results_queue, args):
"""Worker function for calculating motion scores in parallel."""
while True:
try:
index, row = task_queue.get(timeout=1)
except queue.Empty:
break
value = calculate_score(row, args)
results_queue.put((index, value))
task_queue.task_done()
def parse_args():
"""Parse command line arguments for motion analysis."""
parser = argparse.ArgumentParser()
parser.add_argument("--csv_path", type=str, required=True, help="Path to the CSV file")
parser.add_argument(
"--temp_save_dir",
type=str,
required=True,
help="Directory to save the temporary files",
)
parser.add_argument(
"--num_workers", type=int, default=None, help="#workers for concurrent.futures"
)
parser.add_argument(
"--disable_parallel", action="store_true", help="disable parallel processing"
)
parser.add_argument("--gpu_num", type=int, default=1, help="gpu number")
parser.add_argument("--skip_if_existing", action="store_true")
args = parser.parse_args()
return args
def main():
args = parse_args()
wo_ext, ext = os.path.splitext(args.csv_path)
out_path = f"{wo_ext}_motion{ext}"
if args.skip_if_existing and os.path.exists(out_path):
print(f"Output CSV file '{out_path}' already exists. Exit.")
exit()
df = pd.read_csv(args.csv_path)
video_paths = df["video_path"].unique()
if args.disable_parallel:
# Sequential processing
results = []
for video_path in tqdm(video_paths, desc="Processing videos"):
result = process_single_row(video_path, args, 0)
results.append(result)
for index, row in tqdm(
df.iterrows(), total=len(df), desc="Calculating scores"
):
result = calculate_score(row, args)
df.at[index, "motion"] = result
else:
# Parallel processing
if args.num_workers is not None:
num_workers = args.num_workers
else:
num_workers = os.cpu_count() or 1
# First phase: process videos to generate CSV files
manager = Manager()
task_queue = manager.Queue()
progress_queue = manager.Queue()
for video_path in video_paths:
task_queue.put(video_path)
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
futures = []
for id in range(num_workers):
futures.append(
executor.submit(worker1, task_queue, progress_queue, args, id)
)
processed = 0
total_video_tasks = len(video_paths)
with tqdm(total=total_video_tasks, desc="Processing videos") as pbar:
while processed < total_video_tasks:
try:
progress_queue.get(timeout=1)
processed += 1
pbar.update(1)
except queue.Empty:
if all(f.done() for f in futures) and progress_queue.empty():
break
for future in futures:
future.result()
# Second phase: calculate motion scores
result_queue = manager.Queue()
task_queue = manager.Queue()
for index, row in df.iterrows():
task_queue.put((index, row))
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
futures = []
for _ in range(num_workers):
futures.append(executor.submit(worker2, task_queue, result_queue, args))
results = []
processed = 0
total_score_tasks = len(df)
with tqdm(total=total_score_tasks, desc="Calculating scores") as pbar:
while processed < total_score_tasks:
try:
results.append(result_queue.get(timeout=1))
processed += 1
pbar.update(1)
except queue.Empty:
if all(f.done() for f in futures) and result_queue.empty():
break
for future in futures:
future.result()
# Collect and sort results
while not result_queue.empty():
results.append(result_queue.get())
results.sort(key=lambda x: x[0])
results = list(map(lambda x: x[1], results))
df["motion score"] = results
df.to_csv(out_path, index=False)
print(f"New df with motion scores saved to '{out_path}'.")
if __name__ == "__main__":
main()
================================================
FILE: scoring/ocr/__init__.py
================================================
================================================
FILE: scoring/ocr/inference.py
================================================
"""
OCR analysis script for video frames using PaddleOCR.
Calculates text area ratios for video clips using distributed processing.
"""
import os
from glob import glob
import argparse
import pandas as pd
from multiprocessing import Manager
import queue
import concurrent.futures
from tqdm import tqdm
import cv2
from paddleocr import PaddleOCR
def process_single_row(row, args, model):
"""Process a single row to calculate OCR text area ratio."""
img_dir = os.path.join(args.fig_load_dir, row["id"])
img_list = sorted(glob(f"{img_dir}/img/*.jpg"))[:3]
# Load images
images = [cv2.imread(img_path) for img_path in img_list]
images = [img for img in images if img is not None]
if not images:
return 0.0
result = model.predict(input=images)
area = images[0].shape[0] * images[0].shape[1] # Image area
area_list = []
for res in result:
total_text_area = 0 # Initialize total text area
for rec_box in res["rec_boxes"]:
x_min, y_min, x_max, y_max = (
float(rec_box[0]),
float(rec_box[1]),
float(rec_box[2]),
float(rec_box[3]),
) # Extract top-left and bottom-right coordinates
text_area = (x_max - x_min) * (y_max - y_min) # Calculate text area
total_text_area += text_area
ratio = total_text_area / area
area_list.append(ratio)
return (
max(area_list) if area_list else 0.0
) # Return max area ratio, 0.0 if no text detected
def worker(task_queue, result_queue, args, id):
"""Worker function for multiprocessing OCR inference."""
gpu_id = id % args.gpu_num
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) # Bind to specific GPU
device = "gpu:0" # if torch.cuda.is_available() else "cpu"
# Initialize PaddleOCR model with disabled orientation and unwarping features
model = PaddleOCR(
device=device,
use_doc_orientation_classify=False, # Disable document orientation classification
use_doc_unwarping=False, # Disable text image correction
use_textline_orientation=False, # Disable text line orientation classification
)
while True:
try:
index, row = task_queue.get_nowait()
except queue.Empty:
break
area_list = process_single_row(row, args, model)
result_queue.put((index, area_list))
def parse_args():
"""Parse command line arguments for OCR inference."""
parser = argparse.ArgumentParser(description="SAM2 Image Predictor")
parser.add_argument("--csv_path", type=str, help="Path to the csv file")
parser.add_argument(
"--fig_load_dir", type=str, default="img", help="Directory containing images"
)
parser.add_argument(
"--num_workers", type=int, default=16, help="#workers for concurrent.futures"
)
parser.add_argument("--gpu_num", type=int, default=1, help="gpu number")
parser.add_argument("--skip_if_existing", action="store_true")
parser.add_argument(
"--disable_parallel", action="store_true", help="Disable parallel processing"
)
return parser.parse_args()
def main():
args = parse_args()
if not os.path.exists(args.csv_path):
print(f"csv file '{args.csv_path}' not found. Exit.")
return
wo_ext, ext = os.path.splitext(args.csv_path)
out_path = f"{wo_ext}_ocr{ext}"
if args.skip_if_existing and os.path.exists(out_path):
print(f"Output csv file '{out_path}' already exists. Exit.")
exit()
df = pd.read_csv(args.csv_path)
results = []
if args.disable_parallel:
# Sequential processing
model = PaddleOCR(
device="gpu:0", # if torch.cuda.is_available() else "cpu"
use_doc_orientation_classify=False, # Disable document orientation classification
use_doc_unwarping=False, # Disable text image correction
use_textline_orientation=False, # Disable text line orientation classification
)
ocr_scores = []
for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing rows"):
score = process_single_row(row, args, model)
ocr_scores.append(score)
results.append((index, score))
else:
# Set up multiprocessing queues
manager = Manager()
task_queue = manager.Queue()
result_queue = manager.Queue()
for index, row in df.iterrows():
task_queue.put((index, row))
# Process tasks with multiple workers
with concurrent.futures.ProcessPoolExecutor(
max_workers=args.num_workers
) as executor:
futures = []
for id in range(args.num_workers):
futures.append(
executor.submit(worker, task_queue, result_queue, args, id)
)
processed = 0
total_tasks = len(df)
with tqdm(total=total_tasks, desc="Processing rows") as pbar:
while processed < total_tasks:
try:
results.append(result_queue.get(timeout=1))
processed += 1
pbar.update(1)
except queue.Empty:
if all(f.done() for f in futures) and result_queue.empty():
break
for future in futures:
future.result()
# Collect and sort results
while not result_queue.empty():
index, area_list = result_queue.get()
results.append((index, area_list))
results.sort(key=lambda x: x[0])
df["ocr score"] = [x[1] for x in results]
df.to_csv(out_path, index=False)
print(f"New csv (shape={df.shape}) with ocr results saved to '{out_path}'.")
if __name__ == "__main__":
main()
================================================
FILE: scripts/annotation.sh
================================================
#!/bin/bash
CSV=[Replace with the path to the CSV file generated in the scoring step]
OUTPUT_DIR=[Replace with the path to your output directory]
mkdir -p ${OUTPUT_DIR}
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
GPU_NUM=8
ENHANCED=true # Set to true to enable enhanced instruction generation
measure_time() {
local step_number=$1
shift
local green="\e[32m"
local red="\e[31m"
local no_color="\e[0m"
local yellow="\e[33m"
start_time=$(date +%s)
echo -e "${green}Step ${step_number} started at: $(date)${no_color}"
"$@"
end_time=$(date +%s)
echo -e "${red}Step ${step_number} finished at: $(date)${no_color}"
echo -e "${yellow}Duration: $((end_time - start_time)) seconds${no_color}"
echo "---------------------------------------"
}
# 1. Extract frames
measure_time 1 python utils/extract_frames.py \
--csv_path ${CSV} \
--output_dir ${OUTPUT_DIR} \
--num_workers $((GPU_NUM * 2)) \
--target_size "1280*720" \
--backend "opencv" \
--interval 0.2
# 2.1 Depth Estimation with Depth-Anything
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 2.1 torchrun --standalone --nproc_per_node ${GPU_NUM} camera_pose_annotation/depth_estimation/Depth-Anything/inference_batch.py \
--csv_path ${CSV} \
--encoder vitl \
--checkpoints_path checkpoints \
--output_dir ${OUTPUT_DIR} \
--bs 16 \
--num_workers ${GPU_NUM}
# 2.2 Depth Estimation with UniDepth
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 2.2 torchrun --standalone --nproc_per_node ${GPU_NUM} camera_pose_annotation/depth_estimation/UniDepth/inference_batch.py \
--csv_path ${CSV} \
--output_dir ${OUTPUT_DIR} \
--checkpoints_path checkpoints \
--bs 32 \
--num_workers ${GPU_NUM}
# 3. Camera Tracking
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 3 python camera_pose_annotation/camera_tracking/inference_batch.py \
--csv_path ${CSV} \
--dir_path ${OUTPUT_DIR} \
--checkpoints_path checkpoints \
--gpu_id ${CUDA_VISIBLE_DEVICES} \
--num_workers $((GPU_NUM * 2))
# 4.1 CVD Optimization Preprocess
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 4.1 python camera_pose_annotation/cvd_opt/preprocess/inference_batch.py \
--csv_path ${CSV} \
--dir_path ${OUTPUT_DIR} \
--checkpoints_path checkpoints \
--gpu_id ${CUDA_VISIBLE_DEVICES} \
--num_workers $((GPU_NUM * 2))
# 4.2 CVD Optimization
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 4.2 python camera_pose_annotation/cvd_opt/inference_batch.py \
--csv_path ${CSV} \
--dir_path ${OUTPUT_DIR} \
--gpu_id ${CUDA_VISIBLE_DEVICES} \
--num_workers $((GPU_NUM * 2))
# --only_depth
# 5. Dynamic Mask Prediction
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 5 python camera_pose_annotation/dynamic_mask/inference_batch.py \
--csv_path ${CSV} \
--dir_path ${OUTPUT_DIR} \
--checkpoints_path checkpoints \
--gpu_num ${GPU_NUM} \
--num_workers $((GPU_NUM * 2))
# 6. Evaluation of the results
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 6 python utils/evaluation.py \
--csv_path ${CSV} \
--dir_path ${OUTPUT_DIR} \
--gpu_num ${GPU_NUM} \
--num_workers $((GPU_NUM * 2)) \
--output_path ${OUTPUT_DIR}/final_results.csv
# 7. Get motion instructions
if [ "$ENHANCED" = false ] ; then
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 7 python utils/get_instructions.py \
--csv_path ${CSV} \
--dir_path ${OUTPUT_DIR} \
--interval 2 \
--num_workers $((GPU_NUM * 2))
else
echo "Standard instruction generation is enabled."
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 7 python utils/get_instructions_enhanced.py \
--csv_path ${OUTPUT_DIR}/final_results.csv \
--dir_path ${OUTPUT_DIR} \
--num_workers $((GPU_NUM * 2))
fi
# 8. Normalize the intrinsics
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 8 python utils/normalize_intrinsics.py \
--csv_path ${CSV} \
--dir_path ${OUTPUT_DIR} \
--num_workers $((GPU_NUM * 2))
# [Optional] Convert the output poses.npy into a c2w/w2c matrix
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 9 python utils/quat_to_mat.py \
--csv_path ${CSV} \
--format c2w \
--dir_path ${OUTPUT_DIR} \
--num_workers $((GPU_NUM * 2))
================================================
FILE: scripts/caption.sh
================================================
#!/bin/bash
CSV=[Replace with the path to the result CSV file generated in the annotation step]
SRC_DIR=[Replace with the path to the annotation output directory]
OUTPUT_DIR=[Replace with the path to your output directory]
mkdir -p ${OUTPUT_DIR}
num_workers=8
wait_time=1
# VQA
vqa_prompt_file=caption/VQA/prompt.txt
vqa_model=gemini-2.0-flash
vqa_api_key=[Replace with your api key]
vqa_base_domain=https://generativelanguage.googleapis.com/
# LLM
llm_prompt_dir=caption/LLM
llm_model=qwen3-30b-a3b
llm_api_key=[Replace with your api key]
llm_base_domain=https://dashscope.aliyuncs.com/compatible-mode/
# Tagging
tag_prompt_file=caption/tagging/prompt.txt
tag_model=qwen3-30b-a3b
tag_api_key=[Replace with your api key]
tag_base_domain=https://dashscope.aliyuncs.com/compatible-mode/
measure_time() {
local step_number=$1
shift
local green="\e[32m"
local red="\e[31m"
local no_color="\e[0m"
local yellow="\e[33m"
start_time=$(date +%s)
echo -e "${green}Step $step_number started at: $(date)${no_color}"
"$@"
end_time=$(date +%s)
echo -e "${red}Step $step_number finished at: $(date)${no_color}"
echo -e "${yellow}Duration: $((end_time - start_time)) seconds${no_color}"
echo "---------------------------------------"
}
# 1. VQA caption
measure_time 1 python caption/VQA/inference.py \
--csv_path ${CSV} \
--fig_load_dir ${SRC_DIR} \
--output_dir ${OUTPUT_DIR} \
--prompt_file ${vqa_prompt_file} \
--model ${vqa_model} \
--api_key ${vqa_api_key} \
--base_domain ${vqa_base_domain} \
--num_workers ${num_workers} \
--wait_time ${wait_time}
# 2. LLM caption
measure_time 2 python caption/LLM/inference.py \
--csv_path $CSV \
--pose_load_dir $SRC_DIR \
--output_dir $OUTPUT_DIR \
--prompt_dir $llm_prompt_dir \
--model $llm_model \
--api_key $llm_api_key \
--num_workers $num_workers \
--base_domain $llm_base_domain \
--wait_time $wait_time
# 3. Combine results
measure_time 3 python caption/utils/combine.py \
--csv_path $CSV \
--load_dir $OUTPUT_DIR \
--output_dir $OUTPUT_DIR/results \
--num_workers $num_workers
# 4. Add tags
python caption/tagging/inference.py \
--csv_path $CSV \
--json_load_dir $OUTPUT_DIR/results \
--prompt_file $tag_prompt_file \
--model $tag_model \
--api_key $tag_api_key \
--num_workers $num_workers \
--base_domain $tag_base_domain \
--wait_time $wait_time
================================================
FILE: scripts/docker_prepulls.sh
================================================
#!/usr/bin/env bash
# This script pre-pulls and tags GPU-related Docker images from specified registries.
set -euo pipefail
# Minimal script: pre-pull three images (builder/runtime/buildkit) and tag them to
# canonical names so downstream scripts can rely on the expected tags.
# You can override these by setting the env vars before running this script.
BUILDER_IMAGE=${BUILDER_IMAGE:-swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04}
RUNTIME_IMAGE=${RUNTIME_IMAGE:-swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/nvidia/cuda:12.6.3-runtime-ubuntu22.04}
BUILDKIT_IMAGE=${BUILDKIT_IMAGE:-swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/moby/buildkit:buildx-stable-1}
retry_pull() {
local img="$1"
for i in 1 2 3; do
echo "pull attempt $i for ${img}..."
if docker pull "${img}"; then
echo "pulled ${img}"
return 0
fi
sleep $((i * 2))
done
echo "Failed to pull ${img} after retries" >&2
return 1
}
echo "Pre-pulling images..."
echo "- builder: ${BUILDER_IMAGE}"
echo "- runtime: ${RUNTIME_IMAGE}"
echo "- buildkit: ${BUILDKIT_IMAGE}"
retry_pull "${BUILDER_IMAGE}" || true
retry_pull "${RUNTIME_IMAGE}" || true
retry_pull "${BUILDKIT_IMAGE}" || true
CANONICAL_BUILDKIT_TAG="moby/buildkit:buildx-stable-1"
if docker image inspect "${BUILDKIT_IMAGE}" >/dev/null 2>&1; then
echo "Tagging ${BUILDKIT_IMAGE} -> ${CANONICAL_BUILDKIT_TAG} (local only)"
docker tag "${BUILDKIT_IMAGE}" "${CANONICAL_BUILDKIT_TAG}" || true
fi
# Also tag the mirrored CUDA images to the original docker.io names expected by
# Dockerfiles and other scripts. This lets downstream tooling refer to
# docker.io/nvidia/cuda:12.6.3-... even when images were pulled from a mirror.
ORIG_BUILDER_TAG="docker.io/nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04"
ORIG_RUNTIME_TAG="docker.io/nvidia/cuda:12.6.3-runtime-ubuntu22.04"
if docker image inspect "${BUILDER_IMAGE}" >/dev/null 2>&1; then
echo "Tagging ${BUILDER_IMAGE} -> ${ORIG_BUILDER_TAG}"
docker tag "${BUILDER_IMAGE}" "${ORIG_BUILDER_TAG}" || true
fi
if docker image inspect "${RUNTIME_IMAGE}" >/dev/null 2>&1; then
echo "Tagging ${RUNTIME_IMAGE} -> ${ORIG_RUNTIME_TAG}"
docker tag "${RUNTIME_IMAGE}" "${ORIG_RUNTIME_TAG}" || true
fi
echo "Done pulling/tagging images."
echo "You can now run downstream build steps that expect these images to exist locally."
================================================
FILE: scripts/download_checkpoints.sh
================================================
mkdir -p ./checkpoints/
cd ./checkpoints/
# aesthetic
wget https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac+logos+ava1-l14-linearMSE.pth -O aesthetic.pth
# megasam
wget https://github.com/mega-sam/mega-sam/blob/main/checkpoints/megasam_final.pth -O megasam_final.pth
# raft
gdown -c https://drive.google.com/uc?id=1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM -O raft-things.pth
# depth anything
huggingface-cli download --resume-download depth-anything/Depth-Anything-V2-Large --local-dir Depth-Anything
# unidepth
huggingface-cli download --resume-download lpiccinelli/unidepth-v2-vitl14 --local-dir UniDepth
# sam
huggingface-cli download --resume-download facebook/sam2.1-hiera-large --local-dir SAM2
================================================
FILE: scripts/scoring.sh
================================================
#!/bin/bash
VIDEO_DIR=[Replace with the path to your video files]
OUTPUT_DIR=[Replace with the path to your output directory]
mkdir -p ${OUTPUT_DIR}
# Choose whether to cut the clips precisely based on the timestamps or to cut them fast based on keyframes.
# The precise cutting will be slower but more accurate, while the fast cutting will be faster but may not be as accurate.
FAST_CUT=False
GPU_NUM=8
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
NUM_WORKERS=$((GPU_NUM * 2))
ROOT_CLIPS=${OUTPUT_DIR}/clip
ROOT_META=${OUTPUT_DIR}/meta
ROOT_FIG=${OUTPUT_DIR}/fig
ROOT_TEMP=${OUTPUT_DIR}/temp
for dir in ${ROOT_CLIPS} ${ROOT_META} ${ROOT_FIG} ${ROOT_TEMP}; do
if [ ! -d ${dir} ]; then
mkdir -p ${dir}
fi
done
measure_time() {
local step_number=$1
shift
local green="\e[32m"
local red="\e[31m"
local no_color="\e[0m"
local yellow="\e[33m"
start_time=$(date +%s)
echo -e "${green}Step ${step_number} started at: $(date)${no_color}"
"$@"
end_time=$(date +%s)
echo -e "${red}Step ${step_number} finished at: $(date)${no_color}"
echo -e "${yellow}Duration: $((end_time - start_time)) seconds${no_color}"
echo "---------------------------------------"
}
# 1.1 Create a meta file from a video folder. This should output ${ROOT_META}/meta.csv
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 1.1 python utils/convert.py \
--video_dir ${VIDEO_DIR} \
--output ${ROOT_META}/meta.csv
# 1.2 Get video information and remove broken videos. This should output ${ROOT_META}/meta_info_fmin${fmin_1}.csv
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 1.2 python utils/get_info.py \
--csv_path ${ROOT_META}/meta.csv \
--csv_save_path ${ROOT_META}/meta_info.csv \
--backend "opencv" \
--num_workers 16
# 2.1 Detect scenes. This should output ${ROOT_META}/meta_info_fmin${fmin_1}_timestamp.csv
# Also, you can set the params like "--start-remove-sec 0.5 --end-remove-sec 0.5"
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 2.1 python utils/scene_detect.py \
--csv_path ${ROOT_META}/meta_info.csv \
--backend "opencv" \
--num_workers 64 \
--frame_skip 2\
--start_remove_sec 0.3 \
--end_remove_sec 0.3 \
--min_seconds 3 \
--max_seconds 15
# 2.2 Get clips. This should output ${ROOT_META}/clips_info.csv
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 2.2 python utils/get_clip.py \
--csv_path ${ROOT_META}/meta_info_timestamp.csv \
--csv_save_dir ${ROOT_META} \
--num_workers $((GPU_NUM * 4))
# 2.3 Extract frames for scoring.
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 2.3 python utils/extract_frames.py \
--csv_path ${ROOT_META}/clips_info.csv \
--output_dir ${ROOT_FIG} \
--num_workers 64 \
--target_size "640*360" \
--backend "opencv"
# 3.1 Predict aesthetic scores. This should output ${ROOT_META}/clips_info_aes.csv
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 3.1 torchrun --nproc_per_node ${GPU_NUM} scoring/aesthetic/inference.py \
--csv_path ${ROOT_META}/clips_info.csv \
--bs 16 \
--num_workers ${NUM_WORKERS} \
--fig_load_dir ${ROOT_FIG}
# 3.2 Predict luminance scores. This should output ${ROOT_META}/clips_info_lum.csv
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 3.2 torchrun --nproc_per_node ${GPU_NUM} scoring/luminance/inference.py \
--csv_path ${ROOT_META}/clips_info.csv \
--bs 16 \
--num_workers ${NUM_WORKERS} \
--fig_load_dir ${ROOT_FIG}
# 3.3 get motion score. This should output ${ROOT_META}/clips_info_motion.csv
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 3.3 python scoring/motion/inference.py \
--csv_path ${ROOT_META}/clips_info.csv \
--temp_save_dir ${ROOT_TEMP} \
--num_workers $((GPU_NUM * 4)) \
--gpu_num ${GPU_NUM}
# 3.4 get text by OCR using PaddleOCR, this should output ${ROOT_META}/clips_info_ocr.csv
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 3.4 python scoring/ocr/inference.py \
--csv_path ${ROOT_META}/clips_info.csv \
--fig_load_dir ${ROOT_FIG} \
--num_workers $((GPU_NUM * 4)) \
--gpu_num ${GPU_NUM}
# 4 merge all the scores. This should output ${ROOT_META}/clips_with_score.csv
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 4 python utils/merge_tables.py \
--csv_dir ${ROOT_META} \
--output ${ROOT_META}/clips_scores.csv
# 5 Filter the clips.
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 5 python utils/filter.py \
--csv_path ${ROOT_META}/clips_scores.csv \
--csv_save_path ${ROOT_META}/filtered_clips.csv \
--aes_min 4 \
--lum_min 20 \
--lum_max 140 \
--motion_min 2 \
--motion_max 14 \
--ocr_max 0.3
# 6 Cut the clips.
if [ "$FAST_CUT" = False ]; then
echo "Using precise cutting based on timestamps."
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 6 python utils/cut.py \
--csv_path ${ROOT_META}/filtered_clips.csv \
--csv_save_path ${OUTPUT_DIR}/results.csv \
--video_save_dir ${ROOT_CLIPS} \
--num_workers $((GPU_NUM * 4)) \
--gpu_num $GPU_NUM \
# --keep_audio
else
echo "Using fast cutting based on keyframes."
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} measure_time 6 python utils/cut_fast.py \
--csv_path ${ROOT_META}/filtered_clips.csv \
--csv_save_path ${OUTPUT_DIR}/results.csv \
--video_save_dir ${ROOT_CLIPS} \
--num_workers $((GPU_NUM * 4)) \
# --keep_audio
fi
================================================
FILE: utils/README.md
================================================
# Utils
- [`convert.py`](convert.py): convert all the paths of videos in a directory to a specific format, like csv.
- [`cut.py`](cut.py): cut videos into clips.
- [`download_SpatialVID.py`](download_SpatialVID.py): download the SpatialVID dataset.
- [`download_YouTube.py`](download_YouTube.py): download videos from YouTube.
- [`evaluate.py`](evaluate.py): evaluate the quality of video reconstructions.
- [`expand_npz.py`](expand_npz.py): get dynamic masks compressed in a npz file.
- [`extract_frames.py`](extract_frames.py): extract frames from videos.
- [`filter.py`](filter.py): filter video clips based on score.
- [`get_clip.py`](get_clip.py): get the clips separated from the video.
- [`get_info.py`](get_info.py): get video information, such as duration and resolution.
- [`get_instructions.py`](get_instructions.py): get motion instructions from camera poses.
- [`get_instructions_enhanced.py`](get_instructions_enhanced.py): an enhanced version to get more detailed and accurate motion instructions from camera poses.
- [`merge_tables.py`](merge_tables.py): merge multiple csv tables into one.
- [`normalize_intrinsics.py`](normalize_intrinsics.py): normalize camera intrinsics.
- [`pack_clip_assets.py`](pack_clip_assets.py): pack all the output files into an npz file for visualization.
- [`quat_to_mat.py`](quat_to_mat.py): convert camera parameters to camera-to-world or world-to-camera matrices.
- [`read_video.py`](read_video.py): read videos using opencv or av.
- [`scene_detect.py`](scene_detect.py): separate videos into clips.
================================================
FILE: utils/__init__.py
================================================
================================================
FILE: utils/convert.py
================================================
"""
Video file conversion utility for the SpatialVID project.
This module provides functionality to scan directories for video files,
process them, and generate CSV metadata files containing video information.
"""
import argparse
import os
import time
import pandas as pd
# Supported video file extensions
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv", ".m2ts", ".webm")
def scan_recursively(root):
"""
Recursively scan a directory tree and yield all entries.
"""
num = 0
for entry in os.scandir(root):
if entry.is_file():
yield entry
elif entry.is_dir():
num += 1
if num % 100 == 0:
print(f"Scanned {num} directories.")
yield from scan_recursively(entry.path)
def get_filelist(file_path, exts=None):
"""
Get a list of files from a directory tree, optionally filtered by extensions.
"""
filelist = []
time_start = time.time()
# Use recursive scanning to find all files
obj = scan_recursively(file_path)
for entry in obj:
if entry.is_file():
ext = os.path.splitext(entry.name)[-1].lower()
if exts is None or ext in exts:
filelist.append(entry.path)
time_end = time.time()
print(f"Scanned {len(filelist)} files in {time_end - time_start:.2f} seconds.")
return filelist
def split_by_capital(name):
"""
Split a camelCase or PascalCase string by capital letters.
"""
new_name = ""
for i in range(len(name)):
if name[i].isupper() and i != 0:
new_name += " "
new_name += name[i]
return new_name
def process_general_videos(root, output):
"""
Process video files in a directory and generate a CSV metadata file.
"""
# Expand user path (e.g., ~ to home directory)
root = os.path.expanduser(root)
if not os.path.exists(root):
return
# Get list of video files with supported extensions
path_list = get_filelist(root, VID_EXTENSIONS)
# Note: In some cases (like realestate dataset), you might want to use:
# path_list = get_filelist(root) # without extension filtering
path_list = list(set(path_list)) # Remove duplicate entries
# Extract filename without extension as ID
fname_list = [os.path.splitext(os.path.basename(x))[0] for x in path_list]
# Get relative paths from root directory
relpath_list = [os.path.relpath(x, root) for x in path_list]
# Create DataFrame with video metadata
df = pd.DataFrame(dict(video_path=path_list, id=fname_list, relpath=relpath_list))
# Ensure output directory exists
os.makedirs(os.path.dirname(output), exist_ok=True)
df.to_csv(output, index=False)
print(f"Saved {len(df)} samples to {output}.")
if __name__ == "__main__":
# Set up command line argument parser
parser = argparse.ArgumentParser(
description="Convert video directory structure to CSV metadata file"
)
parser.add_argument("--video_dir", type=str, help="Root directory containing video files")
parser.add_argument("--split", type=str, default="train", help="Dataset split name")
parser.add_argument("--info", type=str, default=None, help="Additional info file")
parser.add_argument(
"--output", type=str, default=None, required=True, help="Output CSV file path"
)
args = parser.parse_args()
# Process videos and generate metadata CSV
process_general_videos(args.video_dir, args.output)
================================================
FILE: utils/cut.py
================================================
"""
Precise frame-level video cutting tool
Strategy: Two-phase seek + forced keyframe alignment output
"""
import argparse
import os
import concurrent.futures
from functools import partial
import pandas as pd
import subprocess
from scenedetect import FrameTimecode
from tqdm import tqdm
FFMPEG_PATH = "/usr/local/bin/ffmpeg"
def get_ffmpeg_acceleration():
try:
output = subprocess.check_output(
[FFMPEG_PATH, "-encoders"], stderr=subprocess.DEVNULL
).decode("utf-8")
if "hevc_nvenc" in output:
return "nvidia"
return "cpu"
except Exception as e:
print(f"FFmpeg acceleration detection failed: {e}")
return "cpu"
ACCELERATION_TYPE = get_ffmpeg_acceleration()
print(f"FFmpeg acceleration type: {ACCELERATION_TYPE}")
# ════════════════════════════════════════════════════════════
# Core Utility Functions
# ════════════════════════════════════════════════════════════
def seconds_to_timecode(seconds: float) -> str:
"""
Convert seconds to FFmpeg precise timecode string.
Keep enough decimal places to ensure frame accuracy.
Example: 1.033333 -> "0:00:01.033333"
"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = seconds % 60
# Keep 6 decimal places (microsecond-level precision)
return f"{hours}:{minutes:02d}:{secs:09.6f}"
def build_precise_cut_cmd(
video_path: str,
start_sec: float,
end_sec: float,
save_path: str,
args,
process_id: int,
shorter_size: int | None,
) -> list[str]:
"""
Build frame-precise FFmpeg cut command.
Strategy: Two-phase seek
┌──────────────────────────────────────────────────────────┐
│ -ss (pre, coarse seek) │
│ -> Jump to nearest keyframe before start_sec │
│ -> Avoid decoding from file start (speed optimize) │
│ │
│ -i input │
│ │
│ -ss (post, fine seek) │
│ -> Decode from coarse point to exact start_sec │
│ -> value = start_sec - coarse_seek (always positive) │
│ │
│ -t duration │
│ -> Exact duration │
│ │
│ Force re-encode (cannot use -c copy, otherwise │
│ start frame won't be precise) │
└──────────────────────────────────────────────────────────┘
"""
duration = end_sec - start_sec
if duration <= 0:
raise ValueError(f"Invalid duration {duration:.4f}s (start={start_sec}, end={end_sec})")
# ==== Phase 1: Coarse seek (pre seek) ====
# Safety margin: ensure coarse point is before start_sec keyframe
# Too little -> may land after start_sec (seek ineffective)
# Too much -> decode more frames (slightly slower)
# Experience: max(GOP_size, 5s) covers most videos
GOP_SAFETY_MARGIN = 5.0
coarse_seek = max(0.0, start_sec - GOP_SAFETY_MARGIN)
# Offset for post precise seek = target time - coarse time
fine_seek = start_sec - coarse_seek
cmd = [FFMPEG_PATH, "-nostdin", "-y"]
# ==== GPU hardware acceleration (decode phase) ====
if ACCELERATION_TYPE == "nvidia":
cmd += [
"-hwaccel", "cuda",
"-hwaccel_output_format", "cuda",
"-hwaccel_device", str(process_id % args.gpu_num),
]
# ==== Phase 1: Coarse seek (pre, fast jump to GOP boundary) ====
cmd += ["-ss", seconds_to_timecode(coarse_seek)]
# ==== Input file ====
cmd += ["-i", video_path]
# ==== Phase 2: Precise seek (post, decode from GOP boundary to exact frame) ====
# Only need post seek when fine_seek > 0
# When coarse_seek == 0, fine_seek == start_sec, still correct
if fine_seek > 0.001: # Ignore errors less than 1ms
cmd += ["-ss", seconds_to_timecode(fine_seek)]
# ==== Exact duration ====
cmd += ["-t", seconds_to_timecode(duration)]
# ==== Video filters (scale + fps) ====
filters = _build_video_filters(shorter_size, args, ACCELERATION_TYPE)
if filters:
cmd += ["-vf", ",".join(filters)]
# ==== Encoder (must re-encode to ensure frame precision) ====
cmd += _build_encoder_args(ACCELERATION_TYPE)
# ==== Frame rate ====
if args.target_fps is not None:
cmd += ["-r", str(args.target_fps)]
# ==== Audio ====
if args.keep_audio:
cmd += ["-map", "0:v", "-map", "0:a?", "-c:a", "aac", "-b:a", "128k"]
else:
cmd += ["-map", "0:v", "-an"]
# ==== Output: force keyframe at first frame for easy concatenation/playback ====
cmd += [
"-force_key_frames", "expr:gte(t,0)", # Force keyframe at second 0
save_path,
]
return cmd
def _build_video_filters(shorter_size, args, accel_type) -> list[str]:
"""Build video filter list"""
filters = []
if shorter_size is not None:
if accel_type == "nvidia":
# CUDA scale filter
scale = (
f"scale_cuda="
f"'if(gt(iw,ih),-2,{shorter_size})':"
f"'if(gt(iw,ih),{shorter_size},-2)'"
)
else:
# Software scale: lanczos best quality, bicubic next
scale = (
f"scale="
f"'if(gt(iw,ih),-2,{shorter_size})':"
f"'if(gt(iw,ih),{shorter_size},-2)'"
f":flags=lanczos"
)
filters.append(scale)
if args.target_fps is not None:
# fps filter more accurate than -r parameter (-r sometimes drops frames)
filters.append(f"fps={args.target_fps}")
return filters
def _build_encoder_args(accel_type) -> list[str]:
"""Build encoder arguments"""
if accel_type == "nvidia":
return [
"-c:v", "hevc_nvenc",
"-preset", "p4", # p4=quality/speed balance, p7=slowest best
"-rc", "vbr",
"-cq", "24", # Quality factor, smaller is better (like CRF)
"-b:v", "0", # No bitrate limit in VBR mode
]
else:
return [
"-c:v", "libx264",
"-preset", "fast", # fast is best speed/quality for precise cutting
"-crf", "18", # High quality (0=lossless, 23=default, 18=visually lossless)
"-pix_fmt", "yuv420p", # Most compatible pixel format
]
# ════════════════════════════════════════════════════════════
# Single Row Processing (maintains compatibility with original interface)
# ════════════════════════════════════════════════════════════
def process_single_row(row, args, process_id):
"""
Precise frame-level cutting of a single segment.
Returns:
(row_values_list, valid, error_message)
"""
video_path = row["video_path"]
save_dir = args.video_save_dir
#
# ==== Scale size calculation ====
shorter_size = args.shorter_size
if (shorter_size is not None) and ("height" in row) and ("width" in row):
min_size = min(row["height"], row["width"])
if min_size <= shorter_size:
shorter_size = None # Already small enough, skip scaling (no upsample)
# ==== Timestamp parsing ====
try:
seg_start = FrameTimecode(timecode=row["timestamp_start"], fps=row["fps"])
seg_end = FrameTimecode(timecode=row["timestamp_end"], fps=row["fps"])
except Exception as e:
error_msg = f"Invalid timestamp for id={row.get('id', '?')}: {e}"
print(error_msg)
return row.values.tolist(), False, error_msg
start_sec = seg_start.get_seconds()
end_sec = seg_end.get_seconds()
duration = end_sec - start_sec
if duration <= 0:
error_msg = (
f"Invalid duration {duration:.4f}s for id={row.get('id','?')} "
f"(start={start_sec:.4f}, end={end_sec:.4f})"
)
print(error_msg)
return row.values.tolist(), False, error_msg
clip_id = row["id"]
save_path = os.path.join(save_dir, f"{clip_id}.mp4")
# ==== Skip if already exists ====
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
row = row.copy()
row["video_path"] = save_path
return row.values.tolist(), True, ""
# ==== Source file check ====
if not os.path.exists(video_path):
error_msg = f"Source video not found: {video_path} (id={clip_id})"
print(error_msg)
return row.values.tolist(), False, error_msg
# ==== Build precise cut command ====
try:
cmd = build_precise_cut_cmd(
video_path = video_path,
start_sec = start_sec,
end_sec = end_sec,
save_path = save_path,
args = args,
process_id = process_id,
shorter_size = shorter_size,
)
except ValueError as e:
error_msg = f"Command build failed for id={clip_id}: {e}"
print(error_msg)
return row.values.tolist(), False, error_msg
# ==== Execute FFmpeg ====
try:
subprocess.run(cmd, check=True, stderr=subprocess.PIPE)
except subprocess.CalledProcessError as e:
stderr_text = e.stderr.decode("utf-8", errors="replace") if e.stderr else str(e)
error_msg = f"FFmpeg failed for id={clip_id}:\n{stderr_text}"
print(error_msg)
_cleanup(save_path)
return row.values.tolist(), False, error_msg
except Exception as e:
error_msg = f"Unexpected error for id={clip_id}: {e}"
print(error_msg)
_cleanup(save_path)
return row.values.tolist(), False, error_msg
# ==== Basic integrity check ====
if not os.path.exists(save_path) or os.path.getsize(save_path) == 0:
_cleanup(save_path)
error_msg = f"FFmpeg produced empty/missing output for id={clip_id}"
print(error_msg)
return row.values.tolist(), False, error_msg
row = row.copy()
row["video_path"] = save_path
return row.values.tolist(), True, ""
def _cleanup(path: str):
"""Safely delete file"""
try:
if os.path.exists(path):
os.remove(path)
except OSError:
pass
# ════════════════════════════════════════════════════════════
# Argument Parsing
# ════════════════════════════════════════════════════════════
def parse_args():
parser = argparse.ArgumentParser(
description="Precise frame-level video cutting tool",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# ==== Input/Output ====
parser.add_argument("--csv_path", type=str, required=True,
help="Input CSV file path")
parser.add_argument("--csv_save_path", type=str, required=True,
help="Output CSV file path (success records)")
parser.add_argument("--video_save_dir", type=str, required=True,
help="Directory to save cut segments")
# ==== Video parameters ====
parser.add_argument("--target_fps", type=int, default=None,
help="Target frame rate (None=keep source frame rate)")
parser.add_argument("--shorter_size", type=int, default=None,
help="Short edge target size (maintain aspect ratio, no upsample)")
parser.add_argument("--keep_audio", action="store_true",
help="Keep audio track (default: discard)")
# ==== Parallel control ====
parser.add_argument("--num_workers", type=int, default=None,
help="Number of parallel workers (None=auto=CPU cores)")
parser.add_argument("--disable_parallel", action="store_true",
help="Disable parallel processing (for debugging)")
parser.add_argument("--gpu_num", type=int, default=1,
help="Number of available GPUs")
# ==== Result handling ====
parser.add_argument("--drop_invalid_timestamps", action="store_true",
help="Filter invalid timestamps and save corrected CSV")
return parser.parse_args()
# ════════════════════════════════════════════════════════════
# Parallel Worker
# ════════════════════════════════════════════════════════════
def _worker_fn(task: tuple, args, process_id: int) -> tuple:
"""
Top-level worker function for ProcessPoolExecutor (must be serializable).
Args:
task: (index, row_dict) <- Use dict instead of Series to avoid serialization issues
Returns:
(index, row_values, valid, error_msg)
"""
index, row_dict = task
# Restore dict to pandas Series (process_single_row depends on Series interface)
row = pd.Series(row_dict)
return (index,) + tuple(process_single_row(row, args, process_id)[0:3])
# Note: process_single_row returns (row_values, valid, error_msg)
# Packed here as (index, row_values, valid, error_msg)
# ════════════════════════════════════════════════════════════
# Result Saving
# ════════════════════════════════════════════════════════════
def save_results(all_results: list, csv: pd.DataFrame, args):
"""
Save processing results to success/failure CSVs separately.
Success CSV: Remove timestamp helper columns, update video_path to cut path
Failure CSV: Keep all original columns, add error column
"""
columns = csv.columns.tolist()
success_rows, failed_rows, failed_errors = [], [], []
for index, row_values, valid, error_msg in all_results:
if valid:
success_rows.append(row_values)
else:
failed_rows.append(row_values)
failed_errors.append(error_msg)
# ==== Save success records ====
if success_rows:
success_df = pd.DataFrame(success_rows, columns=columns)
# Remove cutting process helper columns (not needed by downstream)
drop_cols = [
c for c in ["timestamp_start", "timestamp_end", "frame_start", "frame_end"]
if c in success_df.columns
]
if drop_cols:
success_df = success_df.drop(columns=drop_cols)
success_df.to_csv(args.csv_save_path, index=False)
print(f"\n[OK] Success: {len(success_df)} records -> {args.csv_save_path}")
else:
print("\n[X] No success records")
# ==== Save failure records ====
if failed_rows:
base, ext = os.path.splitext(args.csv_save_path)
failed_csv_path = f"{base}_failed{ext}"
failed_df = pd.DataFrame(failed_rows, columns=columns)
failed_df["error"] = failed_errors
failed_df.to_csv(failed_csv_path, index=False)
print(f"[X] Failed: {len(failed_df)} records -> {failed_csv_path}")
# ==== Save corrected timestamps (optional) ====
if args.drop_invalid_timestamps and failed_rows:
valid_indices = [r[0] for r in all_results if r[2]]
filtered_csv = csv.iloc[valid_indices]
assert args.csv_path.endswith("timestamp.csv"), \
"--drop_invalid_timestamps only supports *timestamp.csv files"
corrected_path = args.csv_path.replace("timestamp.csv", "correct_timestamp.csv")
filtered_csv.to_csv(corrected_path, index=False)
print(f"[OK] Corrected timestamps -> {corrected_path}")
# ════════════════════════════════════════════════════════════
# Main Function
# ════════════════════════════════════════════════════════════
def main():
args = parse_args()
# ==== Pre-check ====
if not os.path.exists(args.csv_path):
print(f"[ERROR] CSV file does not exist: {args.csv_path}")
return
os.makedirs(args.video_save_dir, exist_ok=True)
csv = pd.read_csv(args.csv_path)
total = len(csv)
print(f"Total {total} records to process")
all_results = []
# ==== Serial mode ====
if args.disable_parallel:
for index, row in tqdm(csv.iterrows(), total=total, desc="Cutting progress"):
row_values, valid, error_msg = process_single_row(row, args, process_id=0)
all_results.append((index, row_values, valid, error_msg))
# ==== Parallel mode ====
else:
num_workers = args.num_workers or (os.cpu_count() or 1)
num_workers = min(num_workers, total) # worker count not exceeding task count
# Convert row to dict to avoid pandas Series serialization issues
tasks = [
(index, row.to_dict())
for index, row in csv.iterrows()
]
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
# Use enumerate to round-robin process_id (GPU rotation)
futures = {
executor.submit(
_worker_fn,
task,
args,
task_idx % max(args.gpu_num, 1), # GPU rotation
): task_idx
for task_idx, task in enumerate(tasks)
}
with tqdm(total=total, desc="Cutting progress") as pbar:
for future in concurrent.futures.as_completed(futures):
try:
result = future.result() # (index, row_values, valid, error_msg)
all_results.append(result)
except Exception as e:
task_idx = futures[future]
index, _ = tasks[task_idx]
row_values = csv.iloc[index].values.tolist()
all_results.append((index, row_values, False, str(e)))
print(f"\n[ERROR] Worker exception (task_idx={task_idx}): {e}")
finally:
pbar.update(1)
# ==== Sort by original order ====
all_results.sort(key=lambda x: x[0])
# ==== Statistics summary ====
success_count = sum(1 for r in all_results if r[2])
failed_count = total - success_count
print(f"\n{'='*50}")
print(f"Processing complete: Total={total}, Success={success_count}, Failed={failed_count}")
print(f"{'='*50}")
# ==== Save results ====
save_results(all_results, csv, args)
if __name__ == "__main__":
main()
================================================
FILE: utils/cut_fast.py
================================================
"""
High-speed video cutting utility using FFmpeg stream copy.
Features:
- No re-encoding: uses `-c copy`
- Optional audio: use --keep_audio to retain audio tracks
- Group tasks by source video_path for better efficiency
- Parallel processing by video group
- Per-clip progress bar
- Save successful and failed CSVs
Notes:
- This method is very fast, but not always frame-accurate.
- Clip boundaries may align to nearby keyframes depending on source encoding.
"""
import argparse
import os
import queue
import subprocess
import concurrent.futures
from multiprocessing import Manager
import pandas as pd
from scenedetect import FrameTimecode
from tqdm import tqdm
FFMPEG_PATH = "/usr/local/bin/ffmpeg"
def process_single_row(row, save_dir, keep_audio=False):
"""
Cut one clip from source video using ffmpeg stream copy.
Args:
row: DataFrame row with clip metadata
save_dir: directory to save output clips
keep_audio: if True, copy audio streams; if False, drop audio
Returns:
(row_values_list, valid, error_message)
"""
video_path = row["video_path"]
sample_id = row["id"]
save_path = os.path.join(save_dir, f"{sample_id}.mp4")
# Already exists -> treat as success
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
row = row.copy()
row["video_path"] = save_path
return row.values.tolist(), True, ""
if not os.path.exists(video_path):
error_msg = f"Source video not found: {video_path} (id={sample_id})"
return row.values.tolist(), False, error_msg
# Parse timestamps
try:
fps = row["fps"]
seg_start = FrameTimecode(timecode=row["timestamp_start"], fps=fps)
seg_end = FrameTimecode(timecode=row["timestamp_end"], fps=fps)
start_sec = float(seg_start.get_seconds())
end_sec = float(seg_end.get_seconds())
duration = end_sec - start_sec
if duration <= 0:
error_msg = f"Non-positive duration for id={sample_id}: {duration}"
return row.values.tolist(), False, error_msg
except Exception as e:
error_msg = f"Invalid timestamp for id={sample_id}: {e}"
return row.values.tolist(), False, error_msg
try:
# Build stream mapping and audio arguments based on keep_audio flag.
# '0:a?' uses '?' so FFmpeg silently skips if no audio track exists.
if keep_audio:
map_args = ["-map", "0:v:0", "-map", "0:a?"]
audio_args = ["-c:a", "copy"]
else:
map_args = ["-map", "0:v:0"]
audio_args = ["-an"]
# Fast seek + stream copy; explicitly specify video codec to avoid ambiguity.
cmd = [
FFMPEG_PATH,
"-nostdin",
"-y",
"-ss",
str(start_sec),
"-t",
str(duration),
"-i",
video_path,
*map_args,
*audio_args,
"-c:v",
"copy",
"-avoid_negative_ts",
"make_zero",
save_path,
]
subprocess.run(
cmd,
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
)
# Verify output exists and non-empty
if not os.path.exists(save_path) or os.path.getsize(save_path) == 0:
if os.path.exists(save_path):
os.remove(save_path)
error_msg = f"FFmpeg produced empty/missing output for id={sample_id}"
return row.values.tolist(), False, error_msg
row = row.copy()
row["video_path"] = save_path
return row.values.tolist(), True, ""
except subprocess.CalledProcessError as e:
stderr_text = e.stderr.decode("utf-8", errors="ignore") if e.stderr else str(e)
error_msg = f"FFmpeg failed for id={sample_id}: {stderr_text}"
if os.path.exists(save_path):
os.remove(save_path)
return row.values.tolist(), False, error_msg
except Exception as e:
error_msg = f"Unexpected error for id={sample_id}: {e}"
if os.path.exists(save_path):
os.remove(save_path)
return row.values.tolist(), False, error_msg
def process_video_group(group_df, save_dir, keep_audio=False):
"""
Process all clips from the same source video.
Args:
group_df: DataFrame containing rows from one source video_path
save_dir: output clip directory
keep_audio: passed through to process_single_row
Returns:
list of tuples: (index, row_values, valid, error_msg)
"""
results = []
# Sort by start timestamp to make access pattern a bit more sequential
if "timestamp_start" in group_df.columns:
group_df = group_df.sort_values(by="timestamp_start")
for index, row in group_df.iterrows():
row_values, valid, error_msg = process_single_row(
row, save_dir, keep_audio=keep_audio
)
results.append((index, row_values, valid, error_msg))
return results
def worker(task_queue, results_queue, video_save_dir, keep_audio=False):
"""
Worker that processes one video group at a time.
"""
while True:
try:
video_path, group_df = task_queue.get(timeout=1)
except queue.Empty:
break
try:
group_results = process_video_group(
group_df, video_save_dir, keep_audio=keep_audio
)
for item in group_results:
results_queue.put(item)
finally:
task_queue.task_done()
def parse_args():
parser = argparse.ArgumentParser(
description="Fast video cutting utility using FFmpeg stream copy"
)
parser.add_argument("--csv_path", type=str, required=True, help="Input CSV path")
parser.add_argument(
"--csv_save_path", type=str, required=True, help="Output CSV path"
)
parser.add_argument(
"--video_save_dir", type=str, required=True, help="Directory to save clips"
)
parser.add_argument(
"--num_workers",
type=int,
default=None,
help="Number of parallel workers (defaults to CPU count)",
)
parser.add_argument(
"--disable_parallel",
action="store_true",
help="Disable parallel processing",
)
parser.add_argument(
"--drop_invalid_timestamps",
action="store_true",
help="Drop invalid timestamp rows and save corrected CSV",
)
parser.add_argument(
"--keep_audio",
action="store_true",
help="Retain audio tracks in output clips (dropped by default)",
)
return parser.parse_args()
def main():
args = parse_args()
if not os.path.exists(args.csv_path):
print(f"csv file '{args.csv_path}' not found. Exit.")
return
os.makedirs(args.video_save_dir, exist_ok=True)
csv = pd.read_csv(args.csv_path)
if len(csv) == 0:
print("Input CSV is empty. Exit.")
return
required_cols = ["id", "video_path", "timestamp_start", "timestamp_end", "fps"]
missing_cols = [c for c in required_cols if c not in csv.columns]
if missing_cols:
raise ValueError(f"Missing required columns: {missing_cols}")
results = []
# Group by source video
grouped_items = list(csv.groupby("video_path", sort=False))
total_tasks = len(csv)
if args.disable_parallel:
success_cnt = 0
fail_cnt = 0
with tqdm(total=total_tasks, desc="Processing clips", dynamic_ncols=True) as pbar:
for video_path, group_df in grouped_items:
group_results = process_video_group(
group_df, args.video_save_dir, keep_audio=args.keep_audio
)
for item in group_results:
results.append(item)
_, _, valid, _ = item
if valid:
success_cnt += 1
else:
fail_cnt += 1
pbar.update(1)
pbar.set_postfix(success=success_cnt, fail=fail_cnt)
else:
manager = Manager()
task_queue = manager.Queue()
results_queue = manager.Queue()
for video_path, group_df in grouped_items:
task_queue.put((video_path, group_df))
num_workers = args.num_workers if args.num_workers else os.cpu_count()
num_workers = max(1, num_workers)
with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
futures = []
for _ in range(num_workers):
futures.append(
executor.submit(
worker,
task_queue,
results_queue,
args.video_save_dir,
args.keep_audio, # Forward keep_audio flag to each worker
)
)
finished = 0
success_cnt = 0
fail_cnt = 0
with tqdm(total=total_tasks, desc="Processing clips", dynamic_ncols=True) as pbar:
while finished < total_tasks:
try:
item = results_queue.get(timeout=1)
except queue.Empty:
continue
results.append(item)
finished += 1
_, _, valid, _ = item
if valid:
success_cnt += 1
else:
fail_cnt += 1
pbar.update(1)
pbar.set_postfix(success=success_cnt, fail=fail_cnt)
for future in futures:
future.result()
# Sort back by original row index
results.sort(key=lambda x: x[0])
# Separate successful and failed
success_rows = []
failed_rows = []
failed_errors = []
for index, row_values, valid, error_msg in results:
if valid:
success_rows.append(row_values)
else:
failed_rows.append(row_values)
failed_errors.append(error_msg)
# Optional corrected timestamp CSV
if args.drop_invalid_timestamps:
valid_indices = [r[0] for r in results if r[2]]
filtered_csv = csv.iloc[valid_indices]
if args.csv_path.endswith("timestamp.csv"):
corrected_path = args.csv_path.replace("timestamp.csv", "correct_timestamp.csv")
else:
base, ext = os.path.splitext(args.csv_path)
corrected_path = f"{base}_corrected{ext}"
filtered_csv.to_csv(corrected_path, index=False)
print(f"Corrected timestamp file saved to '{corrected_path}'")
columns = csv.columns
# Save successful clips CSV
if success_rows:
success_df = pd.DataFrame(success_rows, columns=columns)
for col in ["timestamp_start", "timestamp_end", "frame_start", "frame_end"]:
if col in success_df.columns:
success_df = success_df.drop(columns=[col])
success_df.to_csv(args.csv_save_path, index=False)
print(f"Saved {len(success_df)} successful clip(s) to {args.csv_save_path}.")
else:
print("No successful clips were generated.")
# Save failed clips CSV
if failed_rows:
base, ext = os.path.splitext(args.csv_save_path)
failed_csv_path = f"{base}_failed{ext}"
failed_df = pd.DataFrame(failed_rows, columns=columns)
failed_df["error"] = failed_errors
failed_df.to_csv(failed_csv_path, index=False)
print(f"Saved {len(failed_df)} failed record(s) to {failed_csv_path}.")
if __name__ == "__main__":
main()
================================================
FILE: utils/download_SpatialVID.py
================================================
import argparse
from huggingface_hub import hf_hub_download, snapshot_download
def main():
# Setup command line arguments
parser = argparse.ArgumentParser(
description="Download SpatialVID dataset from Hugging Face Hub."
)
parser.add_argument(
"--repo_id",
type=str,
choices=["SpatialVID", "SpatialVID-HQ"],
required=True,
help="Dataset type to download (SpatialVID or SpatialVID-HQ)",
)
parser.add_argument(
"--type",
type=str,
choices=["videos", "annotations", "depths", "metadata", "all"],
required=True,
help="Type of data to download (videos, annotations, metadata, all)",
)
parser.add_argument(
"--group_id",
type=int,
help="Specific group ID to download (e.g., 'group_1'). If not provided, downloads all groups.",
default=None,
)
parser.add_argument(
"--output_dir",
type=str,
help="Local directory to save dataset",
default="./SpatialVID_data",
)
args = parser.parse_args()
repo_id = f"SpatialVID/{args.repo_id}"
# Download csv metadata
if args.type == "metadata":
hub_path = f"data/train/{args.repo_id.replace('-', '_')}_metadata.csv"
hf_hub_download(
repo_id=repo_id,
repo_type="dataset",
filename=hub_path,
local_dir=args.output_dir,
resume_download=True,
)
print(f"Downloaded file '{hub_path}' from {repo_id} to {args.output_dir}")
# Download specific group
elif args.group_id:
hub_path = f"{args.type}/group_{args.group_id:04d}.tar.gz"
hf_hub_download(
repo_id=repo_id,
repo_type="dataset",
filename=hub_path,
local_dir=args.output_dir,
resume_download=True,
)
print(f"Downloaded file '{hub_path}' from {repo_id} to {args.output_dir}")
# Download entire type directory
elif args.type == "all":
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=args.output_dir,
resume_download=True,
)
print(f"Downloaded entire dataset from {repo_id} to {args.output_dir}")
if __name__ == "__main__":
main()
================================================
FILE: utils/download_YouTube.py
================================================
"""
Utility script to download YouTube videos using yt-dlp with support for concurrency and sharding.
Adapted from https://huggingface.co/Ligeng-Zhu/panda70m-download
running script: python download_YouTube.py --csv="$csv_file" # this csv file must contains 'url' column
if you want to download a specific youtube video, consider using:
- yt-dlp -F --list-formats https://www.youtube.com/watch\?v\=omP01s7RUSA # --proxy 127.0.0.1:xxxx --cookies cookies.txt
run 'ls -l /path/to/folder/*.json | wc -l' for counting the videos already downloaded
Customization Guide:
For customizing download settings (such as video format, cookie configurations like automatic Chrome cookie retrieval or custom cookie file usage),
refer to the official documentation at https://github.com/yt-dlp/yt-dlp-wiki.
"""
import sys, os, os.path as osp
import yt_dlp
import asyncio
from concurrent.futures import ProcessPoolExecutor
import fire
import pandas as pd
import json
import time
def ytb_download(url, json_info, output_dir="ytb_videos/"):
"""
Download a specified YouTube video using yt-dlp and save related metadata.
"""
os.makedirs(output_dir, exist_ok=True)
uid = url.split("?v=")[-1]
yt_opts = {
"format": "bv[height=720][ext=mp4]"
# "format": "bv[height=720]", # Download the best quality available
# "format": "bv[height=720][ext=mp4][vcodec!^=av]",
# "proxy": "127.0.0.1:xxxx",
"outtmpl": osp.join(output_dir, f"{uid}.%(ext)s"), # Output template
# "cookiesfrombrowser": "chrome", # Use Chrome's cookies automatically
# "cookiefile": "cookies.txt", # Use a custom cookies file
# "postprocessors": [
# {
# "key": "FFmpegVideoConvertor",
# "preferedformat": "mp4", # Convert video to mp4 format (slow)
# }
# ],
# "verbose" : True,
"abort-on-error": True, # Abort downloading when an error occurs
"retries": 60, # Number of retries
"ffmpeg_location": "/usr/bin/ffmpeg", # Path to ffmpeg
"quiet": True, # Suppress output
"sleep-requested": 5, # Sleep for 1.25 seconds between requests
"min-sleep-interval": 60,
"max-sleep-interval": 90,
}
video_path_mp4 = osp.join(output_dir, f"{uid}.mp4")
video_path_webm = osp.join(output_dir, f"{uid}.webm")
meta_path = osp.join(output_dir, f"{uid}.json")
if (osp.exists(video_path_mp4) or osp.exists(video_path_webm)) and osp.exists(
meta_path
):
print(f"\033[91m{uid} already labeled.\033[0m")
return 0
try:
with yt_dlp.YoutubeDL(yt_opts) as ydl:
ydl.download([url])
with open(meta_path, "w") as fp:
json.dump(json_info, fp, indent=2)
return 0
# exception logs
except Exception as e:
print(f"\033[91mError downloading {url}: {e}\033[0m")
err_map = {
"Requested format is not available": "z0322_dld_format_noavailable.log",
"removed by": "z0322_dld_removed_by.log",
"Private video": "z0322_dld_private_video.log",
}
for key, log_file in err_map.items():
if key in str(e):
with open(osp.join(output_dir, f"{log_file}"), "a") as f:
f.write(f"{url}\n")
break
else:
with open(osp.join(output_dir, f"z0322_dld_othererr.log"), "a") as f:
f.write(f"{url}, {str(e)}\n")
return -1
async def main(csv_path, output_dir, max_workers=10, shards=0, total=-1, limit=False):
"""
Batch download YouTube videos specified in a CSV file, supporting sharding and concurrency.
"""
PPE = ProcessPoolExecutor(max_workers=max_workers)
loop = asyncio.get_event_loop()
df = pd.read_csv(csv_path)
csv_path = os.path.basename(csv_path)
output_dir = f'{output_dir}/{csv_path.split(".")[0]}'
data_list = list(df.iterrows())
if total > 0:
chunk = len(data_list) // total
begin_idx = shards * chunk
end_idx = (shards + 1) * chunk if shards < total - 1 else len(data_list)
data_list = data_list[begin_idx:end_idx]
print(f"download total {len(data_list)} videos")
tasks = []
for idx, (index, row) in enumerate(data_list):
video_url = row["url"]
# json_info = {"caption": row["caption"]}
json_info = {"caption": ''} # for file checking.
tasks.append(
loop.run_in_executor(PPE, ytb_download, video_url, json_info, output_dir)
)
if limit and idx >= 20:
break
res = await asyncio.gather(*tasks)
print(f"[{sum(res)} / {len(res)}]")
def entry(
csv="meta_data_sample_500.csv",
output_dir="path/to/output",
shards=0,
total=-1,
limit=False,
max_workers=2,
):
"""
Command line entry function, supports fire invocation.
"""
print(csv, output_dir, shards, total, max_workers)
start_time = time.time()
print(
f"\033[92mStarting execution at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}\033[0m"
)
asyncio.run(
main(
csv,
output_dir,
max_workers=max_workers,
shards=shards,
total=total,
limit=limit,
)
)
end_time = time.time()
print(
f"\033[92mFinished execution at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))}\033[0m"
)
print(f"\033[92mTotal execution time: {end_time - start_time:.2f} seconds\033[0m")
def add_download(csv_path):
"""
Download missing videos according to the new_vid_path field in the CSV file.
"""
data = pd.read_csv(csv_path)
unique_ids = data['YouTube id'].unique()
for uid in unique_ids:
video_url = f"https://www.youtube.com/watch?v={uid}"
ytb_download(video_url, json_info={}, output_dir="videos/")
print(f"Downloaded {video_url}")
if __name__ == "__main__":
# Call entry function via command line arguments
fire.Fire(entry)
# for supplement download: add_download(csv_path='xxx.csv')
================================================
FILE: utils/evaluation.py
================================================
"""
Camera trajectory evaluation utility with anomaly detection and motion analysis.
"""
import os
import argparse
import pandas as pd
import numpy as np
import torch
import concurrent.futures
import multiprocessing as mp
from multiprocessing import Manager
import queue
from tqdm import tqdm
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter
# Import mask utility functions
from expand_npz import expand
def load_file(cam_pos_file, mask_file, device):
"""Load camera parameters and dynamic masks from files"""
try:
# Load camera parameters and split into position and rotation
params = torch.from_numpy(np.load(cam_pos_file)).float().to(device)
cam_pos = params[:, :3] # Position coordinates
cam_rotate = params[:, 3:] # Rotation quaternions
time_steps = params.shape[0]
# Load and expand dynamic masks
masks = torch.from_numpy(expand(np.load(mask_file))).to(device)
except FileNotFoundError:
print(f"Error: File not found - {cam_pos_file}")
exit()
except Exception as e:
print(f"Error processing {cam_pos_file}: {e}")
exit()
return cam_pos, cam_rotate, time_steps, masks
def anomaly_detection(cam_pos, time_steps, threshold, device):
"""Detect trajectory anomalies using linear prediction with acceleration"""
if time_steps < 4:
return True # Not enough data
preds = torch.zeros((time_steps, 3), dtype=torch.float32, device=device)
error_count = 0
# Linear prediction with acceleration
for t in range(0, time_steps - 3):
# Calculate velocity and acceleration
v1 = cam_pos[t + 2] - cam_pos[t + 1]
v2 = cam_pos[t + 1] - cam_pos[t]
acceleration = v1 - v2
# Predict next position
preds[t + 3] = cam_pos[t + 2] + v1 + 0.5 * acceleration
# Check prediction error
error = torch.sqrt(torch.sum((preds[t + 3] - cam_pos[t + 3]) ** 2))
if error > 0.03:
error_count += 1
if error_count >= threshold:
return True
else:
error_count = 0
return False
def move_distance(cam_pos, time_steps, device):
"""Calculate total movement distance and classify into levels"""
total_distance = torch.tensor(0., dtype=torch.float32, device=device)
# Distance thresholds for classification
thresholds = [0.08, 0.28, 0.92, 2.41]
# Calculate cumulative distance
for i in range(0, time_steps - 1):
total_distance += torch.norm(cam_pos[i + 1] - cam_pos[i])
# Determine movement level
distance_val = total_distance.item()
level = sum(1 for threshold in thresholds if distance_val >= threshold)
return distance_val, level
def quaternion_multiply(q1, q2):
"""Multiply two quaternions"""
# Extract components (q in [x, y, z, w] format)
w1, x1, y1, z1 = q1[3], q1[0], q1[1], q1[2]
w2, x2, y2, z2 = q2[3], q2[0], q2[1], q2[2]
# Quaternion multiplication
matrix = torch.tensor([
[w1, -z1, y1, x1],
[z1, w1, -x1, y1],
[-y1, x1, w1, z1],
[-x1, -y1, -z1, w1]
], dtype=q1.dtype, device=q1.device)
vector = torch.tensor([x2, y2, z2, w2], dtype=q2.dtype, device=q2.device)
result = torch.matmul(matrix, vector)
return result
def rotation_angle(cam_rotate, time_steps, device):
"""Calculate total rotation angle between consecutive frames"""
total_radians = torch.tensor(0.0, device=device)
for i in range(0, time_steps - 1):
q1 = cam_rotate[i]
q2 = cam_rotate[i + 1]
# Calculate relative rotation
q1_inverse = torch.stack([-q1[0], -q1[1], -q1[2], q1[3]], dim=0)
q_relative = quaternion_multiply(q2, q1_inverse)
w = torch.clamp(q_relative[3], -1.0, 1.0)
# Convert to angle
rotation_angle_rad = 2 * torch.arccos(w)
total_radians += rotation_angle_rad
return total_radians.item()
def trajectory_turns(cam_pos, time_steps, device, threshold=0.45):
"""Detect significant turns in camera trajectory"""
if time_steps < 3:
return [], 0
angles = []
# Calculate angles between trajectory segments
for t in range(1, time_steps - 1):
v1 = cam_pos[t] - cam_pos[0]
v2 = cam_pos[time_steps - 1] - cam_pos[t]
# Avoid division by zero
v1_norm = torch.norm(v1)
v2_norm = torch.norm(v2)
if v1_norm < 1e-8 or v2_norm < 1e-8:
continue
# Calculate angle between vectors
cos_theta = torch.dot(v1, v2) / (v1_norm * v2_norm)
cos_theta = torch.clamp(cos_theta, -1.0, 1.0)
angle = torch.arccos(cos_theta)
angles.append(angle.item())
# Smooth and find peaks
angles = gaussian_filter(angles, sigma=5)
peaks, _ = find_peaks(angles, height=threshold, distance=5)
peaks_values = [angles[i] for i in peaks]
# Include maximum angle if significant
max_angle = max(angles)
if max_angle > threshold and max_angle not in peaks_values:
peaks_values.append(max_angle)
return len(peaks_values)
def dynamic_ratio(masks):
"""Calculate ratio of dynamic pixels in video frames"""
# Downsample for efficiency
masks = masks[::5, :, :]
dynamic_pixels = torch.sum(masks)
total_pixels = masks.shape[1] * masks.shape[2] * masks.shape[0]
return (dynamic_pixels / total_pixels).item()
def process_single_row(row, index, args, device):
"""Process a single video row to extract trajectory metrics"""
video_id = row['id']
rec_path = os.path.join(args.dir_path, video_id, "reconstructions")
cam_pos_file = os.path.join(rec_path, "poses.npy")
mask_file = os.path.join(rec_path, "dyn_masks.npz")
# Check file existence
if not os.path.exists(cam_pos_file) or not os.path.exists(mask_file):
print(f"File not found: {cam_pos_file} or {mask_file}")
return False, False, -1, -1, -1, -1, -1
# Load and process data
cam_pos, cam_rotate, time_steps, masks = load_file(cam_pos_file, mask_file, device)
# Calculate metrics
anomaly = anomaly_detection(cam_pos, time_steps, args.anomaly_threshold, device)
move_dist, dist_level = move_distance(cam_pos, time_steps, device)
rot_angle = rotation_angle(cam_rotate, time_steps, device)
traj_turns = trajectory_turns(cam_pos, time_steps, device)
dyn_ratio = dynamic_ratio(masks)
return True, anomaly, move_dist, dist_level, rot_angle, traj_turns, dyn_ratio
def worker(task_queue, result_queue, args, worker_id):
"""Worker function for parallel processing"""
# Assign GPU based on worker ID
device = torch.device(
f"cuda:{worker_id % args.gpu_num}"
if torch.cuda.is_available() else "cpu"
)
while True:
try:
index, row = task_queue.get(timeout=1)
except queue.Empty:
break
result = process_single_row(row, index, args, device)
result_queue.put((index, result))
task_queue.task_done()
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description="Camera Trajectory Evaluation")
parser.add_argument("--csv_path", type=str, help="Path to input CSV file")
parser.add_argument("--dir_path", type=str, default="./outputs", help="Base directory with reconstruction data")
parser.add_argument("--output_path", type=str, default="./outputs/evaluation_results.csv", help="Output CSV path")
parser.add_argument("--anomaly_threshold", type=int, default=2, help="Anomaly detection threshold")
parser.add_argument('--gpu_num', type=int, default=1, help='Number of GPUs to use')
parser.add_argument("--num_workers", type=int, default=4, help="Number of parallel workers")
parser.add_argument("--disable_parallel", action="store_true", help="Disable parallel processing")
return parser.parse_args()
if __name__ == "__main__":
# Setup multiprocessing
mp.set_start_method('spawn')
args = parse_args()
# Load input data
df = pd.read_csv(args.csv_path)
results = []
if args.disable_parallel:
# Sequential processing
for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing videos"):
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
result = process_single_row(row, index, args, device)
results.append((index, result))
else:
# Parallel processing
manager = Manager()
task_queue = manager.Queue()
# Add tasks to queue
for index, row in df.iterrows():
task_queue.put((index, row))
result_queue = manager.Queue()
# Run workers
with concurrent.futures.ProcessPoolExecutor(max_workers=args.num_workers) as executor:
futures = []
for worker_id in range(args.num_workers):
futures.append(executor.submit(worker, task_queue, result_queue, args, worker_id))
processed = 0
total_tasks = len(df)
with tqdm(total=total_tasks, desc="Processing videos") as pbar:
while processed < total_tasks:
try:
index, result = result_queue.get(timeout=1)
results.append((index, result))
processed += 1
pbar.update(1)
except queue.Empty:
if all(f.done() for f in futures) and result_queue.empty():
break
for future in futures:
future.result()
# Collect results
while not result_queue.empty():
index, result = result_queue.get()
results.append((index, result))
# Sort and save results
results.sort(key=lambda x: x[0])
df['success'] = [result[1][0] for result in results]
df['anomaly'] = [result[1][1] for result in results]
df['moveDist'] = [result[1][2] for result in results]
df['distLevel'] = [result[1][3] for result in results]
df['rotAngle'] = [result[1][4] for result in results]
df['trajTurns'] = [result[1][5] for result in results]
df['dynamicRatio'] = [result[1][6] for result in results]
df.to_csv(args.output_path, index=False)
print(f"Results saved to {args.output_path}")
================================================
FILE: utils/expand_npz.py
================================================
"""
Mask utility functions for processing sparse matrix data.
"""
import numpy as np
from scipy.sparse import csr_matrix
def expand(loaded_data):
"""
Reconstruct 3D mask from sparse matrix data.
Args:
loaded_data (dict): Dictionary containing sparse matrix data with keys:
- 'shape': Original matrix dimensions
- 'f_{i}_data': Sparse matrix data for frame i
- 'f_{i}_indices': Sparse matrix indices for frame i
- 'f_{i}_indptr': Sparse matrix index pointers for frame i
Returns:
np.ndarray: 3D array with shape (frames, height, width)
"""
reconstructed_sparse_matrices = []
num_frames = (len(loaded_data) - 1) // 3 # Calculate number of frames
matrix_shape = loaded_data['shape'] # Get original matrix dimensions
# Reconstruct sparse matrix for each frame
for i in range(num_frames):
data = loaded_data[f'f_{i}_data']
indices = loaded_data[f'f_{i}_indices']
indptr = loaded_data[f'f_{i}_indptr']
reconstructed_matrix = csr_matrix((data, indices, indptr), shape=matrix_shape)
reconstructed_sparse_matrices.append(reconstructed_matrix)
# Stack all frames into a 3D array (frames, height, width)
reconstructed_mask_3d = np.stack([m.toarray() for m in reconstructed_sparse_matrices], axis=0)
return reconstructed_mask_3d
================================================
FILE: utils/extract_frames.py
================================================
"""
Video frame extraction utility with parallel processing support.
"""
import os
import sys
import cv2
import av
import glob
import argparse
import pandas as pd
import queue
import concurrent.futures
from multiprocessing import Manager
from tqdm import tqdm
def extract_frames_opencv(
video_path, output_dir, interval, frame_start, num_frames, target_size=None
):
"""Extract frames from video at specified intervals"""
# Create output directory
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Open video file
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_start)
if not cap.isOpened():
print(f"Error: Could not open video file {video_path}")
sys.exit(1)
# Extract frames
for frame in range(num_frames):
ret, image = cap.read()
if not ret:
break
# Save frame at specified intervals
if frame % interval == 0:
frame_filename = os.path.join(output_dir, f"frame_{frame:06d}.jpg")
if target_size is not None:
h, w = image.shape[:2]
# Adaptively adjust target size based on video orientation
# For portrait videos (height > width), swap width and height of target size
if h > w: # Portrait video
target_w, target_h = target_size[1], target_size[0]
else: # Landscape video
target_w, target_h = target_size[0], target_size[1]
image = cv2.resize(image, (target_w, target_h))
cv2.imwrite(frame_filename, image)
cap.release()
def extract_frames_av(
video_path, output_dir, interval, frame_start, num_frames, target_size=None
):
"""
Extract frames from video at specified intervals using PyAV backend.
"""
# Create output directory
if not os.path.exists(output_dir):
os.makedirs(output_dir)
try:
# Open video file
container = av.open(video_path)
stream = container.streams.video[0]
stream.thread_type = 'AUTO'
except Exception as e:
print(f"Error: Could not open video file {video_path}. Reason: {e}")
return
# Get video properties
fps = float(stream.average_rate)
time_base = stream.time_base
target_sec = frame_start / fps
# Set a small tolerance (e.g., half a frame time) to prevent frame loss due to floating-point precision issues
epsilon = 0.5 / fps
# Seek to the target start time
if frame_start > 0:
target_pts = int(target_sec / time_base)
container.seek(target_pts, stream=stream, backward=True)
count = 0
for packet in container.demux(stream):
try:
for frame in packet.decode():
if frame.pts is None:
continue
current_sec = frame.pts * time_base
if current_sec < (target_sec - epsilon):
continue
if count >= num_frames:
break
if count % interval == 0:
image = frame.to_ndarray(format='bgr24')
frame_filename = os.path.join(output_dir, f"frame_{count:06d}.jpg")
if target_size is not None:
if isinstance(target_size, str):
w, h = map(int, target_size.split('*'))
target_size = (w, h)
h, w = image.shape[:2]
if h > w:
target_w, target_h = target_size[1], target_size[0]
else:
target_w, target_h = target_size[0], target_size[1]
image = cv2.resize(image, (target_w, target_h))
cv2.imwrite(frame_filename, image)
count += 1
if count >= num_frames:
break
except av.error.InvalidDataError:
continue # 跳过损坏的包
container.close()
def _calc_expected_frames(num_frames, interval):
"""Calculate the expected number of output frames based on total frames and interval."""
if interval <= 0:
return num_frames
# Frames at indices 0, interval, 2*interval, ... that are < num_frames
return (num_frames - 1) // interval + 1
def _verify_frames(img_dir, expected_frames):
"""Check if img_dir has enough valid (non-empty) frame files.
Returns True if the directory exists and contains at least `expected_frames`
non-zero-byte frame_*.jpg files.
"""
if not os.path.isdir(img_dir):
return False
frame_files = glob.glob(os.path.join(img_dir, "frame_*.jpg"))
if len(frame_files) < expected_frames:
return False
if any(os.path.getsize(f) == 0 for f in frame_files):
return False
return True
def process_single_row(row, row_index, args):
"""Process a single video row to extract frames.
Returns:
True if processing succeeded or was skipped (already done),
False if an error occurred.
"""
video_path = row["video_path"]
frame_start = row.get("frame_start", 0)
num_frames = row["num_frames"]
output_dir = os.path.join(args.output_dir, row["id"])
img_dir = os.path.join(output_dir, "img")
# Calculate frame extraction interval
if args.interval is None:
interval = row["num_frames"] // 3 # Extract 3 frames by default
elif args.interval == 0:
interval = 1 # Extract every frame
else:
interval = int(args.interval * row["fps"])
expected_frames = _calc_expected_frames(num_frames, interval)
# --- Skip logic: already has enough valid frames ---
if _verify_frames(img_dir, expected_frames):
return True
if not os.path.exists(output_dir):
os.makedirs(output_dir)
try:
if args.backend == "opencv":
extract_frames_opencv(
video_path, img_dir, interval, frame_start, num_frames, args.target_size
)
elif args.backend == "av":
extract_frames_av(
video_path, img_dir, interval, frame_start, num_frames, args.target_size
)
# Post-extraction verification
if not _verify_frames(img_dir, expected_frames):
actual_count = len(glob.glob(os.path.join(img_dir, "frame_*.jpg")))
print(
f"[Verify FAIL] {row['id']}: expected {expected_frames} frames, "
f"got {actual_count} (or contains empty files)."
)
return False
return True
except Exception as e:
print(f"Error: Could not extract frames from video {video_path}. Reason: {e}")
return False
def worker(task_queue, progress_queue, failed_indices, args):
"""Worker function for parallel frame extraction"""
while True:
try:
index, row = task_queue.get(timeout=1)
except queue.Empty:
break
success = process_single_row(row, index, args)
if not success:
failed_indices.append(index)
progress_queue.put(index)
task_queue.task_done()
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description="Extract frames from video files")
parser.add_argument(
"--csv_path", type=str, help="Path to CSV file with video csvdata"
)
parser.add_argument(
"--output_dir",
type=str,
default="extract_frames",
help="Output directory for extracted frames",
)
parser.add_argument(
"--interval",
type=float,
default=None,
help="Frame extraction interval in seconds (set to 0 to extract every frame)",
)
parser.add_argument(
"--target_size",
type=str,
default=None,
help="Resize frames to size (width*height). For portrait videos (h>w), dimensions will be automatically swapped to (height*width) to maintain correct orientation.",
)
parser.add_argument(
"--num_workers", type=int, default=None, help="Number of parallel workers"
)
parser.add_argument(
"--backend",
type=str,
default="opencv",
choices=["opencv", "av"],
help="Backend for video reading",
)
parser.add_argument(
"--disable_parallel", action="store_true", help="Disable parallel processing"
)
return parser.parse_args()
def main():
"""Main function to process frame extraction"""
args = parse_args()
# Parse target size if provided
if args.target_size is not None:
args.target_size = tuple(map(int, args.target_size.split("*")))
# Create output directory
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
# Load video csvdata
csv = pd.read_csv(args.csv_path)
failed_indices = []
if args.disable_parallel:
# Sequential processing
for index, row in tqdm(
csv.iterrows(), total=len(csv), desc="Processing videos"
):
success = process_single_row(row, index, args)
if not success:
failed_indices.append(index)
else:
# Parallel processing
num_workers = args.num_workers if args.num_workers else os.cpu_count() or 1
manager = Manager()
task_queue = manager.Queue()
progress_queue = manager.Queue()
shared_failed_indices = manager.list()
# Add tasks to queue
for index, row in csv.iterrows():
task_queue.put((index, row))
# Execute workers
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
futures = []
for _ in range(num_workers):
future = executor.submit(worker, task_queue, progress_queue, shared_failed_indices, args)
futures.append(future)
processed = 0
total_tasks = len(csv)
with tqdm(total=total_tasks, desc="Processing videos") as pbar:
while processed < total_tasks:
try:
progress_queue.get(timeout=1)
processed += 1
pbar.update(1)
except queue.Empty:
if all(f.done() for f in futures) and progress_queue.empty():
break
for future in futures:
future.result()
failed_indices = list(shared_failed_indices)
# Save failed rows to a separate CSV; keep only successful rows in the original CSV
if failed_indices:
failed_csv = csv.loc[failed_indices]
base, ext = os.path.splitext(args.csv_path)
failed_csv_path = f"{base}_failed{ext}"
failed_csv.to_csv(failed_csv_path, index=False)
csv = csv.drop(index=failed_indices)
csv.to_csv(args.csv_path, index=False)
print(f"\n{len(failed_indices)} video(s) failed. Saved to: {failed_csv_path}")
print(f"Original CSV updated. Remaining rows: {len(csv)}")
else:
print("\nAll videos processed successfully.")
if __name__ == "__main__":
main()
================================================
FILE: utils/filter.py
================================================
"""
Dataset filtering utility for video metadata with various quality metrics.
"""
import argparse
import os
import random
from glob import glob
import numpy as np
import pandas as pd
def main(args):
"""Apply filtering criteria to dataset"""
# Load data
data = pd.read_csv(args.csv_path)
# Apply filters based on various metrics
if args.frames_min is not None:
assert "num_frames" in data.columns
data = data[data["num_frames"] >= args.frames_min]
if args.frames_max is not None:
assert "num_frames" in data.columns
data = data[data["num_frames"] <= args.frames_max]
if args.fps_max is not None:
assert "fps" in data.columns
data = data[(data["fps"] <= args.fps_max) | np.isnan(data["fps"])]
if args.fps_min is not None:
assert "fps" in data.columns
data = data[(data["fps"] >= args.fps_min) | np.isnan(data["fps"])]
if args.resolution_max is not None:
if "resolution" not in data.columns:
height = data["height"]
width = data["width"]
data["resolution"] = height * width
data = data[data["resolution"] <= args.resolution_max]
if args.aes_min is not None:
assert "aesthetic score" in data.columns
data = data[data["aesthetic score"] >= args.aes_min]
if args.ocr_max is not None:
assert "ocr score" in data.columns
data = data[data["ocr score"] <= args.ocr_max]
if args.ocr_min is not None:
assert "ocr score" in data.columns
data = data[data["ocr score"] >= args.ocr_min]
if args.lum_min is not None:
assert "luminance mean" in data.columns
data = data[data["luminance mean"] >= args.lum_min]
if args.lum_max is not None:
assert "luminance mean" in data.columns
data = data[data["luminance mean"] <= args.lum_max]
if args.motion_min is not None:
assert "motion score" in data.columns
data = data[data["motion score"] >= args.motion_min]
if args.motion_max is not None:
assert "motion score" in data.columns
data = data[data["motion score"] <= args.motion_max]
# Save filtered data
data.to_csv(args.csv_save_path, index=False)
print(f"Saved {len(data)} samples to {args.csv_save_path}.")
def parse_args():
"""Parse command line arguments for dataset filtering"""
parser = argparse.ArgumentParser(
description="Filter video dataset by quality metrics"
)
parser.add_argument(
"--csv_path", type=str, required=True, help="Path to input CSV file"
)
parser.add_argument(
"--csv_save_path", type=str, default=None, help="Path to save output CSV file"
)
parser.add_argument("--seed", type=int, default=42, help="Random seed")
# Video property filters
parser.add_argument(
"--frames_min", type=int, default=None, help="Minimum number of frames"
)
parser.add_argument(
"--frames_max", type=int, default=None, help="Maximum number of frames"
)
parser.add_argument(
"--resolution_max", type=int, default=None, help="Maximum resolution"
)
parser.add_argument("--fps_max", type=float, default=None, help="Maximum FPS")
parser.add_argument("--fps_min", type=float, default=None, help="Minimum FPS")
# Quality metric filters
parser.add_argument(
"--aes_min", type=float, default=None, help="Minimum aesthetic score"
)
parser.add_argument(
"--flow_min", type=float, default=None, help="Minimum optical flow score"
)
parser.add_argument(
"--flow_max", type=float, default=None, help="Maximum optical flow score"
)
parser.add_argument("--ocr_max", type=float, default=None, help="Maximum OCR score")
parser.add_argument("--ocr_min", type=float, default=None, help="Minimum OCR score")
parser.add_argument(
"--lum_min", type=float, default=None, help="Minimum luminance score"
)
parser.add_argument(
"--lum_max", type=float, default=None, help="Maximum luminance score"
)
parser.add_argument(
"--blur_max", type=float, default=None, help="Maximum blur score"
)
parser.add_argument(
"--motion_min", type=float, default=None, help="Minimum motion score"
)
parser.add_argument(
"--motion_max", type=float, default=None, help="Maximum motion score"
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# Set random seeds for reproducibility
if args.seed is not None:
random.seed(args.seed)
np.random.seed(args.seed)
main(args)
================================================
FILE: utils/get_clip.py
================================================
"""
Video clip information extraction utility with timestamp parsing.
"""
import argparse
import os
import queue
import concurrent.futures
from functools import partial
import pandas as pd
from scenedetect import FrameTimecode
import re
from tqdm import tqdm
def process_single_row(row, args):
"""Process a single video row to extract clip information"""
video_path = row["video_path"]
new_rows = []
try:
if "timestamp" in row:
timestamp_str = row["timestamp"]
# Parse timestamps using regex
timestamp_pattern = (
r"\('(\d{2}:\d{2}:\d{2}\.\d+)', '(\d{2}:\d{2}:\d{2}\.\d+)'\)"
)
matches = re.findall(timestamp_pattern, timestamp_str)
scene_list = [
(FrameTimecode(s, fps=row["fps"]), FrameTimecode(t, fps=row["fps"]))
for s, t in matches
]
else:
scene_list = [None]
if args.drop_invalid_timestamps:
return new_rows, True
except Exception as e:
if args.drop_invalid_timestamps:
return new_rows, False
height = row["height"]
width = row["width"]
fps = row["fps"]
# Extract clip information for each scene
for idx, scene in enumerate(scene_list):
if scene is not None:
s, t = scene # FrameTimecode objects
fname = os.path.basename(video_path)
fname_wo_ext = os.path.splitext(fname)[0]
# Calculate clip metrics
num_frames = t.frame_num - s.frame_num
aspect_ratio = width / height if height != 0 else 0
resolution = f"{width}x{height}"
timestamp_start = s.get_timecode()
timestamp_end = t.get_timecode()
frame_start = s.frame_num
frame_end = t.frame_num
id_ori = row["id"] if "id" in row else ""
id = f"{fname_wo_ext}_{idx}"
new_rows.append(
[
video_path,
id,
num_frames,
height,
width,
aspect_ratio,
fps,
resolution,
timestamp_start,
timestamp_end,
frame_start,
frame_end,
id_ori,
]
)
return (new_rows, True)
def worker(task_queue, results_queue, args):
"""Worker function for parallel processing"""
while True:
try:
index, row = task_queue.get(timeout=1)
except queue.Empty:
break
result = process_single_row(row, args)
results_queue.put((index, result))
task_queue.task_done()
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Extract video clip information from csvdata"
)
parser.add_argument("--csv_path", type=str, help="Path to the input CSV file")
parser.add_argument(
"--csv_save_dir",
type=str,
required=True,
help="Directory to save output CSV file",
)
parser.add_argument(
"--num_workers", type=int, default=None, help="Number of parallel workers"
)
parser.add_argument(
"--disable_parallel", action="store_true", help="Disable parallel processing"
)
parser.add_argument(
"--drop_invalid_timestamps",
action="store_true",
help="Drop rows with invalid timestamps",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
csv_path = args.csv_path
if not os.path.exists(csv_path):
print(f"csv file '{csv_path}' not found. Exit.")
return
os.makedirs(args.csv_save_dir, exist_ok=True)
# Load csvdata
csv = pd.read_csv(args.csv_path)
# Setup multiprocessing
from multiprocessing import Manager
manager = Manager()
task_queue = manager.Queue()
results_queue = manager.Queue()
for index, row in csv.iterrows():
task_queue.put((index, row))
if args.disable_parallel:
# Sequential processing
results = []
for index, row in tqdm(
csv.iterrows(), total=len(csv), desc="Processing rows"
):
result = process_single_row(row, args)
results.append((index, result))
else:
# Parallel processing
num_workers = args.num_workers if args.num_workers else os.cpu_count() or 1
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
futures = []
for _ in range(num_workers):
future = executor.submit(worker, task_queue, results_queue, args)
futures.append(future)
# Per-row progress is more informative than per-worker completion.
results = []
processed = 0
total_tasks = len(csv)
with tqdm(total=total_tasks, desc="Processing rows") as pbar:
while processed < total_tasks:
try:
results.append(results_queue.get(timeout=1))
processed += 1
pbar.update(1)
except queue.Empty:
if all(f.done() for f in futures) and results_queue.empty():
break
for future in futures:
future.result()
while not results_queue.empty():
results.append(results_queue.get())
# Process results
results.sort(key=lambda x: x[0])
new_rows = []
valid_rows = []
for index, (rows, valid) in results:
if valid:
valid_rows.append(index)
new_rows.extend(rows)
# Save corrected timestamps if needed
if args.drop_invalid_timestamps:
csv = csv[valid_rows]
assert args.csv_path.endswith("timestamp.csv"), "Only support *timestamp.csv"
csv.to_csv(
args.csv_path.replace("timestamp.csv", "correct_timestamp.csv"),
index=False,
)
print(
f"Corrected timestamp file saved to '{args.csv_path.replace('timestamp.csv', 'correct_timestamp.csv')}'"
)
# Create and save clip information DataFrame
columns = [
"video_path",
"id",
"num_frames",
"height",
"width",
"aspect_ratio",
"fps",
"resolution",
"timestamp_start",
"timestamp_end",
"frame_start",
"frame_end",
"id_ori",
]
new_df = pd.DataFrame(new_rows, columns=columns)
new_csv_path = os.path.join(args.csv_save_dir, "clips_info.csv")
new_df.to_csv(new_csv_path, index=False)
print(f"Saved {len(new_df)} clip information to {new_csv_path}.")
if __name__ == "__main__":
main()
================================================
FILE: utils/get_info.py
================================================
"""
Video information extraction utility supporting multiple backends (OpenCV, TorchVision, AV).
"""
import argparse
import os
import random
import cv2
import av
import numpy as np
import pandas as pd
from tqdm import tqdm
import concurrent.futures
def get_video_length(cap, method="header"):
"""Get video frame count using different methods"""
assert method in ["header", "set"]
if method == "header":
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
else:
cap.set(cv2.CAP_PROP_POS_AVI_RATIO, 1)
length = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
return length
def get_video_info(args):
"""Extract video information using specified backend"""
idx, path, backend = args
try:
if backend == "torchvision":
from tools.utils.read_video import read_video
vframes, infos = read_video(path)
num_frames, height, width = (
vframes.shape[0],
vframes.shape[2],
vframes.shape[3],
)
fps = (
float(infos.get("video_fps", np.nan))
if isinstance(infos, dict)
else np.nan
)
elif backend == "opencv":
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise ValueError("Video open failed")
num_frames = get_video_length(cap, method="header")
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
fps = float(cap.get(cv2.CAP_PROP_FPS))
cap.release()
elif backend == "av":
container = av.open(path)
stream = container.streams.video[0]
num_frames = int(stream.frames)
height = int(stream.height)
width = int(stream.width)
if stream.average_rate is not None:
fps = float(stream.average_rate)
elif stream.guessed_rate is not None:
fps = float(stream.guessed_rate)
else:
fps = np.nan
else:
raise ValueError("Unknown backend")
# Calculate derived metrics
hw = height * width
aspect_ratio = height / width if width > 0 else np.nan
return (idx, True, num_frames, height, width, aspect_ratio, hw, fps)
except Exception:
return (idx, False, 0, 0, 0, np.nan, np.nan, np.nan)
def main(args):
"""Main function to extract video information"""
# Load data
data = pd.read_csv(args.csv_path)
if data.empty:
data.to_csv(args.csv_save_path, index=False)
print(f"Input CSV is empty. Saved 0 samples to {args.csv_save_path}.")
return
tasks = [(index, row["video_path"], args.backend) for index, row in data.iterrows()]
num_workers = args.num_workers if args.num_workers else os.cpu_count() or 1
# Process videos with a per-video progress bar (more intuitive than per-worker)
if args.disable_parallel or num_workers <= 1:
ret = [
get_video_info(task)
for task in tqdm(tasks, total=len(tasks), desc="Processing videos")
]
else:
with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
ret = list(
tqdm(
executor.map(get_video_info, tasks, chunksize=16),
total=len(tasks),
desc="Processing videos",
)
)
ret.sort(key=lambda x: x[0])
(
_idx_list,
success_list,
num_frames_list,
height_list,
width_list,
aspect_ratio_list,
hw_list,
fps_list,
) = zip(*ret)
# Add extracted information to DataFrame
data["success"] = success_list
data["num_frames"] = num_frames_list
data["height"] = height_list
data["width"] = width_list
data["aspect_ratio"] = aspect_ratio_list
data["resolution"] = hw_list
data["fps"] = fps_list
# Filter existing files if requested
if args.ext:
assert "video_path" in data.columns
data = data[data["video_path"].apply(os.path.exists)]
# Sort by frame count
if "num_frames" in data.columns:
data = data.sort_values(by="num_frames", ascending=True)
data.to_csv(args.csv_save_path, index=False)
print(f"Saved {len(data)} samples to {args.csv_save_path}.")
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Extract video information using multiple backends"
)
parser.add_argument(
"--csv_path", type=str, required=True, help="Path to input CSV file"
)
parser.add_argument(
"--csv_save_path", type=str, default=None, help="Path to save output CSV file"
)
parser.add_argument(
"--backend",
type=str,
default="opencv",
help="Video backend",
choices=["opencv", "torchvision", "av"],
)
parser.add_argument(
"--disable-parallel", action="store_true", help="Disable parallel processing"
)
parser.add_argument(
"--num_workers", type=int, default=None, help="Number of parallel workers"
)
parser.add_argument("--seed", type=int, default=42, help="Random seed")
# File existence checking
parser.add_argument("--ext", action="store_true", help="Check if video files exist")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# Set random seeds for reproducibility
if args.seed is not None:
random.seed(args.seed)
np.random.seed(args.seed)
main(args)
================================================
FILE: utils/get_instructions.py
================================================
"""
This module processes camera pose sequences and generates movement instructions.
"""
import argparse
import numpy as np
from scipy.spatial.transform import Rotation as R
import os
import pandas as pd
from multiprocessing import Manager
import concurrent.futures
import queue
from tqdm import tqdm
import json
def filter_poses(poses_array, alpha):
"""
Filter pose sequences using exponential moving average.
- Position: Exponential moving average (EMA)
- Orientation (quaternion): NLERP-based EMA with hemisphere flip handling
Args:
poses_array: Array of poses [position(3) + quaternion(4)]
alpha: Smoothing factor (0 < alpha < 1)
Returns:
Filtered pose array with same shape as input
"""
positions = poses_array[:, :3]
quaternions = poses_array[:, 3:]
filtered_positions = np.zeros_like(positions)
filtered_quaternions = np.zeros_like(quaternions)
# Initialize with first frame
filtered_positions[0] = positions[0]
filtered_quaternions[0] = quaternions[0]
for i in range(1, len(poses_array)):
filtered_positions[i] = (
alpha * positions[i] + (1 - alpha) * filtered_positions[i - 1]
)
# quaternion filtering with hemisphere check
last_q = filtered_quaternions[i - 1]
current_q = quaternions[i]
# 1. Check hemisphere to ensure interpolation takes the "shortest path"
if np.dot(last_q, current_q) < 0:
current_q = -current_q
# 2. Linear interpolation
interp_q = (1 - alpha) * last_q + alpha * current_q
# 3. Re-normalize to ensure unit quaternion
filtered_quaternions[i] = interp_q / np.linalg.norm(interp_q)
return np.hstack([filtered_positions, filtered_quaternions])
def poses_to_multi_instructions(poses_array, translation_thresh, rotation_thresh_deg):
"""
Convert camera pose sequence to concurrent movement instruction sequence.
"""
# Convert NumPy array to Scipy Rotation objects for easier computation
poses = []
for row in poses_array:
pos = row[:3]
rot = R.from_quat(row[3:])
poses.append((pos, rot))
command_sequence = []
rotation_thresh_rad = np.deg2rad(rotation_thresh_deg)
for i in range(len(poses) - 1):
# Calculate local relative movement
pos_t_w2c, rot_t_w2c = poses[i]
pos_t1_w2c, rot_t1_w2c = poses[i+1]
delta_rot = rot_t1_w2c * rot_t_w2c.inv()
pos_t_c2w = -rot_t_w2c.inv().apply(pos_t_w2c)
pos_t1_c2w = -rot_t1_w2c.inv().apply(pos_t1_w2c)
local_delta_pos = rot_t_w2c.apply(pos_t1_c2w - pos_t_c2w)
dx, dy, dz = local_delta_pos
euler_angles_rad = delta_rot.as_euler(
"yxz"
) # 'y' for yaw, 'x' for pitch, 'z' for roll
yaw_change, pitch_change, roll_change = euler_angles_rad
instructions = []
# Translation movements
if dz < -translation_thresh:
instructions.append("Dolly Out")
elif dz > translation_thresh:
instructions.append("Dolly In")
if dx > translation_thresh:
instructions.append("Truck Right")
elif dx < -translation_thresh:
instructions.append("Truck Left")
if dy > translation_thresh:
instructions.append("Pedestal Down")
elif dy < -translation_thresh:
instructions.append("Pedestal Up")
# Rotation movements
if yaw_change > rotation_thresh_rad:
instructions.append("Pan Left")
elif yaw_change < -rotation_thresh_rad:
instructions.append("Pan Right")
if pitch_change > rotation_thresh_rad:
instructions.append("Tilt Down")
elif pitch_change < -rotation_thresh_rad:
instructions.append("Tilt Up")
if roll_change > rotation_thresh_rad:
instructions.append("Roll CCW")
elif roll_change < -rotation_thresh_rad:
instructions.append("Roll CW")
if not instructions:
instructions.append("Stay")
command_sequence.append(instructions)
return command_sequence
def process_single_row(args, row):
"""Process a single video row to generate camera movement instructions."""
npy_path = os.path.join(args.dir_path, row["id"], "reconstructions", "poses.npy")
# Load and subsample poses, then apply filtering
raw_poses = np.load(npy_path)[:: args.interval]
filtered_poses = filter_poses(raw_poses, alpha=args.alpha)
# Generate movement instructions
instructions = poses_to_multi_instructions(
filtered_poses, args.translation_threshold, args.rotation_threshold
)
json_file = os.path.join(args.dir_path, row["id"], "instructions.json")
if os.path.exists(json_file) and os.path.getsize(json_file) > 0:
return
# Merge consecutive identical instructions
merged_instructions = {}
start = 0
prev_cmd = instructions[0]
for i in range(1, len(instructions)):
if instructions[i] == prev_cmd:
continue
else:
key = f"{start}->{i}"
merged_instructions[key] = prev_cmd
start = i
prev_cmd = instructions[i]
# Add final segment
key = f"{start}->{len(instructions)}"
merged_instructions[key] = prev_cmd
# Save instructions to JSON file
with open(json_file, "w") as f:
json.dump(merged_instructions, f, ensure_ascii=False, indent=2)
def worker(task_queue, args, pbar):
"""Worker function for parallel processing of video rows."""
while True:
try:
index, row = task_queue.get(timeout=1)
except queue.Empty:
break
process_single_row(args, row)
task_queue.task_done()
pbar.update(1)
def args_parser():
"""Parse command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--csv_path", type=str, default="outputs.csv", help="Path to the input CSV file"
)
parser.add_argument("--dir_path", type=str, default="./outputs")
parser.add_argument(
"--interval", type=int, default=2, help="Interval for computing instructions"
)
parser.add_argument(
"--alpha",
type=float,
default=0.1,
help="Smoothing factor for filtering (0 < alpha < 1)",
)
parser.add_argument(
"--translation_threshold",
type=float,
default=0.02,
help="Translation threshold for command generation",
)
parser.add_argument(
"--rotation_threshold",
type=float,
default=0.5,
help="Rotation threshold for command generation",
)
parser.add_argument(
"--num_workers", type=int, default=8, help="Number of parallel workers"
)
parser.add_argument(
"--disable_parallel", action="store_true", help="Disable parallel processing"
)
return parser.parse_args()
def main():
args = args_parser()
csv = pd.read_csv(args.csv_path)
if args.disable_parallel:
# Sequential processing
for index, row in tqdm(csv.iterrows(), total=len(csv)):
process_single_row(args, row)
else:
# Parallel processing using ThreadPoolExecutor
manager = Manager()
task_queue = manager.Queue()
for index, row in csv.iterrows():
task_queue.put((index, row))
with tqdm(total=len(csv), desc="Finished tasks") as pbar:
with concurrent.futures.ThreadPoolExecutor(
max_workers=args.num_workers
) as executor:
futures = []
for _ in range(args.num_workers):
futures.append(executor.submit(worker, task_queue, args, pbar))
for future in concurrent.futures.as_completed(futures):
future.result()
if __name__ == "__main__":
main()
================================================
FILE: utils/get_instructions_enhanced.py
================================================
import argparse
from math import sqrt
import numpy as np
from scipy.spatial.transform import Rotation as R
import os
import pandas as pd
from multiprocessing import Manager
import concurrent.futures
import queue
from tqdm import tqdm
import json
from collections import defaultdict, Counter
import itertools
def filter_poses(poses_array, alpha):
"""
Smooth pose sequence using Exponential Moving Average (EMA).
- Position: Standard EMA
- Quaternion: EMA with hemisphere check (shortest path interpolation)
"""
positions = poses_array[:, :3]
quaternions = poses_array[:, 3:]
filtered_pos = np.zeros_like(positions)
filtered_quat = np.zeros_like(quaternions)
# Initialize with first frame
filtered_pos[0], filtered_quat[0] = positions[0], quaternions[0]
for i in range(1, len(poses_array)):
# Position smoothing
filtered_pos[i] = alpha * positions[i] + \
(1 - alpha) * filtered_pos[i-1]
# Quaternion smoothing with hemisphere correction
last_q, curr_q = filtered_quat[i-1], quaternions[i]
if np.dot(last_q, curr_q) < 0: # Flip to shortest interpolation path
curr_q = -curr_q
interp_q = (1 - alpha) * last_q + alpha * curr_q
filtered_quat[i] = interp_q / \
np.linalg.norm(interp_q) # Keep unit quaternion
return np.hstack([filtered_pos, filtered_quat])
def poses_to_multi_instructions(poses_array, translation_thresh, rotation_thresh_deg, interval=1):
"""
Convert pose sequence to motion instructions (e.g., Dolly, Pan).
Calculates pose difference between frame i and i+interval (convolution-like).
"""
# Convert to (position, Rotation object) pairs
poses = [(row[:3], R.from_quat(row[3:])) for row in poses_array]
command_seq = []
# Adjust thresholds by interval (scaling for longer gaps)
rotation_thresh_deg *= sqrt(interval) / 1.8
rotation_thresh_rad = np.deg2rad(rotation_thresh_deg)
translation_thresh *= sqrt(interval)
stride = int(sqrt(interval) + 1)
i = 0
while True:
if i + interval >= len(poses): # Ensure valid frame pair
break
# Calculate relative motion (local coordinate system)
pos_t_w2c, rot_t_w2c = poses[i]
pos_t1_w2c, rot_t1_w2c = poses[i+interval]
delta_rot = rot_t1_w2c * rot_t_w2c.inv()
pos_t_c2w = -rot_t_w2c.inv().apply(pos_t_w2c)
pos_t1_c2w = -rot_t1_w2c.inv().apply(pos_t1_w2c)
local_delta_pos = rot_t_w2c.apply(pos_t1_c2w - pos_t_c2w)
dx, dy, dz = local_delta_pos
yaw, pitch, roll = delta_rot.as_euler(
"yxz") # Yaw:Pan, Pitch:Tilt, Roll:Rotate
instructions = []
# Translation commands
if dz < -translation_thresh:
instructions.append("Dolly Out")
elif dz > translation_thresh:
instructions.append("Dolly In")
if dx > translation_thresh:
instructions.append("Truck Right")
elif dx < -translation_thresh:
instructions.append("Truck Left")
if dy > translation_thresh:
instructions.append("Pedestal Down")
elif dy < -translation_thresh:
instructions.append("Pedestal Up")
# Rotation commands
if yaw > rotation_thresh_rad:
instructions.append("Pan Left")
elif yaw < -rotation_thresh_rad:
instructions.append("Pan Right")
if pitch > rotation_thresh_rad:
instructions.append("Tilt Down")
elif pitch < -rotation_thresh_rad:
instructions.append("Tilt Up")
if roll > rotation_thresh_rad:
instructions.append("Roll CCW")
elif roll < -rotation_thresh_rad:
instructions.append("Roll CW")
command_seq.append(instructions if instructions else ["Stay"])
i += stride
return command_seq
def calculate_relative_scale(total_distance, num_poses, f_translation, min_threshold=0.001):
"""
Calculate relative translation threshold (dynamic scaling by total motion).
"""
if num_poses <= 1:
return min_threshold
base_scale = total_distance / num_poses # Base scale per frame
return max(base_scale / f_translation, min_threshold)
def voter(args, row, interval, alpha):
"""
Process single video with specific (interval, alpha) parameter pair.
"""
# Locate pose file
npy_path = os.path.join(
args.dir_path, row["id"], "reconstructions", "poses.npy"
)
try:
raw_poses = np.load(npy_path)
filtered_poses = filter_poses(raw_poses, alpha)
# Calculate dynamic thresholds
translation_thresh = calculate_relative_scale(
row["moveDist"], len(
filtered_poses), args.f_translation, args.min_threshold_translation
)
rotation_thresh = args.rotation_threshold
return poses_to_multi_instructions(
filtered_poses, translation_thresh, rotation_thresh, interval
)
except Exception as e:
print(f"Error processing {row['id']}: {e}")
return None
def collect_all_results(args, row, param_combinations):
"""Collect instruction results for all (interval, alpha) pairs."""
results = []
for interval, alpha in param_combinations:
res = voter(args, row, interval, alpha)
if res is not None:
results.append(res)
return results
# ------------------------------ Voting Logic ------------------------------
def get_mutually_exclusive_groups():
"""Return groups of conflicting instructions (cannot coexist)."""
return [
["Dolly In", "Dolly Out"], ["Truck Left", "Truck Right"],
["Pedestal Up", "Pedestal Down"], ["Pan Left", "Pan Right"],
["Tilt Up", "Tilt Down"], ["Roll CW", "Roll CCW"]
]
def remove_conflicting_instructions(instructions, conflict_groups):
"""Remove conflicting instructions (keep higher-voted ones)."""
selected = []
selected_set = set()
for inst, count in instructions:
conflict = False
for group in conflict_groups:
if inst in group and any(s in group for s in selected_set):
conflict = True
break
if not conflict:
selected.append((inst, count))
selected_set.add(inst)
return selected
def smart_instruction_selection(non_conflicting_inst):
"""
Smart instruction selection based on vote distribution:
- Keep leading votes (3x threshold for断层)
- Max 4 instructions
- Prioritize non-"Stay" commands
"""
if not non_conflicting_inst:
return ["Stay"]
if len(non_conflicting_inst) == 1:
return [non_conflicting_inst[0][0]]
# Separate Stay and other instructions
stay = [i for i in non_conflicting_inst if i[0] == "Stay"]
others = [i for i in non_conflicting_inst if i[0] != "Stay"]
if not others:
return ["Stay"]
votes = [c for _, c in others]
max_vote = votes[0]
selected = []
# Check for vote gap (3x threshold)
if len(others) >= 2 and max_vote >= votes[1] * 3:
selected = [i[0] for i in others if i[1] == max_vote]
else:
# Select up to 4 leading instructions
gap_thresh = max_vote * 0.5
selected = [i[0] for i in others if i[1] >= gap_thresh][:4]
# Ensure minimum 2 instructions if no large gap
if len(selected) < 2 and len(others) >= 2 and max_vote < votes[1] * 3:
selected = [i[0] for i in others[:2]]
return selected if selected else ["Stay"]
def collect_interval_based_votes(all_results, param_combinations):
"""
Vote by time interval: collect all instructions covering (start_frame->end_frame).
Handles overlapping segments from different (interval, alpha) pairs.
"""
if not all_results:
return {}
# Get max frame covered by any parameter pair
max_frames = 0
for index, res in enumerate(all_results):
interval = param_combinations[index][0]
stride = int(sqrt(interval) + 1)
if res:
last_start = (len(res)-1) * stride
max_frames = max(max_frames, last_start + interval)
interval_votes = {}
for start in range(max_frames):
end = start + 1
vote_counter = Counter()
# Check all parameter results for coverage of (start->end)
for res_index, res in enumerate(all_results):
interval, _ = param_combinations[res_index]
stride = int(sqrt(interval) + 1)
for seg_index, seg in enumerate(res):
seg_start = seg_index * stride
seg_end = seg_start + interval
# Check if segment covers target interval
if seg_start <= start < seg_end and seg_start < end <= seg_end:
for inst in seg:
vote_counter[inst] += 1
interval_votes[f"{start}->{end}"] = vote_counter
return interval_votes
def vote_for_final_instructions(all_results, param_combinations=None):
"""Generate final instructions via voting (interval-based if possible)."""
if not all_results:
return []
conflict_groups = get_mutually_exclusive_groups()
final_seq = []
# Use interval-based voting if parameters are provided
if param_combinations and len(param_combinations) == len(all_results):
interval_votes = collect_interval_based_votes(
all_results, param_combinations)
for key in sorted(interval_votes.keys(), key=lambda x: int(x.split('->')[0])):
votes = interval_votes[key]
if votes:
sorted_inst = votes.most_common()
non_conflict = remove_conflicting_instructions(
sorted_inst, conflict_groups)
selected = smart_instruction_selection(non_conflict)
else:
selected = ["Stay"]
final_seq.append(selected)
else:
# Fallback: frame-wise voting
max_len = max(len(res) for res in all_results)
for frame_index in range(max_len):
votes = Counter()
for res in all_results:
if frame_index < len(res):
for inst in res[frame_index]:
votes[inst] += 1
if votes:
sorted_inst = votes.most_common()
non_conflict = remove_conflicting_instructions(
sorted_inst, conflict_groups)
selected = smart_instruction_selection(non_conflict)
else:
selected = ["Stay"]
final_seq.append(selected)
return final_seq
# ------------------------------ Main Workflow ------------------------------
def merge_consecutive_instructions(instructions):
"""Merge consecutive identical instruction lists (e.g., [A,A,A] → "0->3":[A])."""
if not instructions:
return {}
merged = {}
start, prev = 0, instructions[0]
for i in range(1, len(instructions)):
if instructions[i] != prev:
merged[f"{start}->{i}"] = prev
start, prev = i, instructions[i]
merged[f"{start}->{len(instructions)}"] = prev # Add final segment
return merged
def process_single_row(args, row, param_combinations):
# Skip if output exists
out_file = os.path.join(args.dir_path, row['id'], "instructions.json")
if os.path.exists(out_file) and os.path.getsize(out_file) > 0:
return
# Collect results & vote
all_results = collect_all_results(args, row, param_combinations)
if not all_results:
print(f"No valid results for {row['id']}")
return
final_inst = vote_for_final_instructions(all_results, param_combinations)
merged_inst = merge_consecutive_instructions(final_inst)
# Save to JSON
with open(out_file, "w") as f:
json.dump(merged_inst, f, ensure_ascii=False, indent=2)
def generate_param_combinations(args):
"""Generate all (interval, alpha) parameter pairs for grid search."""
intervals = getattr(args, "intervals", [1, 3, 5])
alphas = getattr(args, "alphas", [0.03, 0.05, 0.1])
return list(itertools.product(intervals, alphas))
def worker(task_queue, args, param_combinations, pbar):
"""Parallel worker: process tasks from queue."""
while True:
try:
index, row = task_queue.get(timeout=1)
process_single_row(args, row, param_combinations)
except queue.Empty:
break
task_queue.task_done()
pbar.update(1)
def args_parser():
parser = argparse.ArgumentParser(
description="Enhanced Camera Pose Instruction Generator")
parser.add_argument("--csv_path", type=str, required=True,
help="Input CSV path (The final_results.csv generated by evaluation.py)")
parser.add_argument("--dir_path", type=str, required=True,
help="Annotation directory path")
parser.add_argument("--intervals", type=int, nargs="+",
default=[1, 3, 5], help="Frame intervals for grid search")
parser.add_argument("--alphas", type=float, nargs="+",
default=[0.03, 0.05, 0.1], help="Smoothing factors for grid search")
parser.add_argument("--f_translation", type=float,
default=1.1, help="Translation scale factor (>1)")
parser.add_argument("--min_threshold_translation", type=float,
default=0.01, help="Min translation threshold")
parser.add_argument("--rotation_threshold", type=float,
default=1.5, help="Fixed rotation threshold (degrees)")
parser.add_argument("--num_workers", type=int,
default=8, help="Parallel workers count")
parser.add_argument("--disable_parallel", action="store_true",
help="Disable parallel processing")
return parser.parse_args()
def main():
args = args_parser()
csv = pd.read_csv(args.csv_path)
param_combinations = generate_param_combinations(args)
if args.disable_parallel:
# Serial processing
for index, row in tqdm(csv.iterrows(), total=len(csv), desc="Processing"):
process_single_row(args, row, param_combinations)
else:
# Parallel processing
manager = Manager()
task_queue = manager.Queue()
for index, row in csv.iterrows():
task_queue.put((index, row))
with tqdm(total=len(csv), desc="Processing") as pbar:
with concurrent.futures.ThreadPoolExecutor(max_workers=args.num_workers) as executor:
for _ in range(args.num_workers):
executor.submit(worker, task_queue, args,
param_combinations, pbar)
task_queue.join()
if __name__ == "__main__":
main()
================================================
FILE: utils/merge_tables.py
================================================
"""
CSV table merging utility for combining multiple clip information files.
"""
import os
import glob
import argparse
import pandas as pd
def read_csv_file(file_path):
"""Read a single CSV file"""
return pd.read_csv(file_path)
def merge_tables_from_files(file_list, output_file, merge_on=None):
"""
Merge multiple CSV files using common columns as merge keys.
Args:
file_list: List of CSV file paths to merge
output_file: Output path for merged CSV file
merge_on: List of column names for merging (defaults to first 13 columns)
"""
if not file_list:
raise ValueError("File list is empty!")
# Read all CSV files
dfs = [read_csv_file(f) for f in file_list]
# Auto-select merge keys: first 13 columns
if merge_on is None:
merge_on = dfs[0].columns[:13].tolist()
# Merge dataframes
df_merged = dfs[0]
for df in dfs[1:]:
# Check if merge keys are consistent
if merge_on != df.columns[:13].tolist():
raise ValueError(
f"Common columns in one file are inconsistent with previous files!"
)
# Merge based on specified keys
df_merged = pd.merge(df_merged, df, on=merge_on)
# Save merged result
df_merged.to_csv(output_file, index=False)
print(f"Merge completed. Saved to {output_file}")
return df_merged
def main():
parser = argparse.ArgumentParser(
description="Merge multiple CSV files from a folder"
)
parser.add_argument("--csv_dir", type=str, help="Path to folder containing CSV files")
parser.add_argument(
"--output", type=str, required=True, help="Output path for merged CSV file"
)
args = parser.parse_args()
# Match CSV files with 'clips_info_' prefix
pattern = os.path.join(args.csv_dir, "clips_info_*.csv")
file_list = glob.glob(pattern)
file_list.sort() # Sort to ensure consistent merge order
if not file_list:
raise ValueError(f"No matching CSV files found in folder {args.csv_dir}!")
print(f"Found {len(file_list)} CSV files:")
for f in file_list:
print(f" {f}")
# Perform merge
merge_tables_from_files(file_list, args.output)
if __name__ == "__main__":
main()
================================================
FILE: utils/normalize_intrinsics.py
================================================
"""
Camera intrinsics normalization utility.
This module provides functionality for:
- Normalizing camera intrinsics to standard format
- Converting focal length to normalized coordinates
- Parallel processing of multiple camera files
- Support for both threaded and sequential processing
"""
import numpy as np
import os
import pandas as pd
import argparse
import concurrent.futures
import multiprocessing as mp
from multiprocessing import Manager
import queue
from tqdm import tqdm
def possess_single_row(row, args):
"""
Process a single row to normalize camera intrinsics.
"""
id = row["id"]
dir_path = os.path.join(args.dir_path, id, "reconstructions")
cam_intrinsics_file = os.path.join(dir_path, "intrinsics.npy")
# Load and normalize intrinsics
intrinsics = np.load(cam_intrinsics_file)
intrinsics[:, 0] /= intrinsics[:, 2] * 2 # Normalize focal length x
intrinsics[:, 1] /= intrinsics[:, 3] * 2 # Normalize focal length y
intrinsics[:, 2] = 0.5 # Set principal point x to center
intrinsics[:, 3] = 0.5 # Set principal point y to center
# Save normalized intrinsics
np.save(cam_intrinsics_file, intrinsics)
def worker(task_queue, args, pbar):
"""
Worker function for parallel processing of intrinsics normalization.
"""
while True:
try:
index, row = task_queue.get(timeout=1)
except queue.Empty:
break
possess_single_row(row, args)
task_queue.task_done()
pbar.update(1)
def parse_args():
"""Parse command line arguments for intrinsics normalization."""
parser = argparse.ArgumentParser(description="Normalize camera intrinsics to standard format")
parser.add_argument("--csv_path", type=str, help="Path to the csv file")
parser.add_argument("--dir_path", type=str, default="./outputs")
parser.add_argument(
"--num_workers",
type=int,
default=8,
help="Number of workers for parallel processing",
)
parser.add_argument(
"--disable_parallel", action="store_true", help="Disable parallel processing"
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
df = pd.read_csv(args.csv_path)
if args.disable_parallel:
# Sequential processing
for index, row in tqdm(df.iterrows(), total=len(df)):
possess_single_row(row, index, args)
else:
# Parallel processing using thread pool
manager = Manager()
task_queue = manager.Queue()
# Add all tasks to queue
for index, row in df.iterrows():
task_queue.put((index, row))
with tqdm(total=len(df), desc="Finished tasks") as pbar:
with concurrent.futures.ThreadPoolExecutor(
max_workers=args.num_workers
) as executor:
futures = []
for _ in range(args.num_workers):
futures.append(executor.submit(worker, task_queue, args, pbar))
for future in concurrent.futures.as_completed(futures):
future.result()
================================================
FILE: utils/pack_clip_assets.py
================================================
"""
pack_clip_assets.py
------------------
This script unifies depth, RGB frames, intrinsics, extrinsics, etc. of a specified video clip into a single npz file for downstream 3D reconstruction or analysis.
Usage example:
python pack_clip_assets.py --base_dir /path/to/HQ --clip_id group_xxxx/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx --height 328 --width 584
"""
import argparse
import numpy as np
import torch
from lietorch import SE3
import cv2
from read_depth import read_depth
def load_video(clip_path, indexes_path, height=720, width=1280):
"""
Read video frames at specified indexes and resize to (height, width).
Args:
clip_path (str): Path to video file
indexes_path (str): Path to frame indexes txt
height (int): Output frame height
width (int): Output frame width
Returns:
np.ndarray: (N, height, width, 3) RGB frames
"""
indexes = []
with open(indexes_path, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) == 2:
indexes.append(int(parts[1]))
print(f"Frame indexes: {indexes}")
cap = cv2.VideoCapture(clip_path)
frames = []
for idx in indexes:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if not ret:
raise ValueError(f"Frame at index {idx} could not be read.")
frame = cv2.resize(frame, (width, height))
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
cap.release()
return np.array(frames)
def load_intrinsics(intrinsics_path, tgt_width=1024, tgt_height=576):
"""
Read normalized intrinsics (n,4), convert to 3x3 matrix and scale to target resolution.
Args:
intrinsics_path (str): Path to intrinsics npy
tgt_width (int): Target width
tgt_height (int): Target height
Returns:
np.ndarray: (N, 3, 3) intrinsics matrices
"""
intrinsics = np.load(intrinsics_path)
intrinsics_3x3 = []
for intrin in intrinsics:
fx, fy, cx, cy = intrin
K = np.array([[fx, 0, cx],
[0, fy, cy],
[0, 0, 1]], dtype=np.float32)
intrinsics_3x3.append(K)
intrinsics_3x3 = np.array(intrinsics_3x3)
intrinsics_3x3[:, 0, 0] *= tgt_width
intrinsics_3x3[:, 1, 1] *= tgt_height
intrinsics_3x3[:, 0, 2] *= tgt_width
intrinsics_3x3[:, 1, 2] *= tgt_height
return intrinsics_3x3
def main():
"""
Main pipeline: load depth, RGB frames, intrinsics, extrinsics, and save as npz.
"""
parser = argparse.ArgumentParser(description="Pack clip assets into a single npz file.")
parser.add_argument('--base_dir', type=str, required=True, help='Root directory of HQ data')
parser.add_argument('--group_id', type=int, required=False, help='Group ID, e.g. group_xxxx')
parser.add_argument('--clip_id', type=str, required=True, help='Clip ID, e.g. xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx')
parser.add_argument('--height', type=int, default=328, help='Output image height')
parser.add_argument('--width', type=int, default=584, help='Output image width')
parser.add_argument('--output', type=str, default='sgd_cvd_hr.npz', help='Output npz filename')
args = parser.parse_args()
# Path construction
annotation_dir = f'{args.base_dir}/annotations/group_{args.group_id:04d}/{args.clip_id}'
depth_path = f'{args.base_dir}/depths/group_{args.group_id:04d}/{args.clip_id}.zip'
clip_path = f'{args.base_dir}/videos/group_{args.group_id:04d}/{args.clip_id}.mp4'
intrinsics_path = f'{annotation_dir}/intrinsics.npy'
extrinsics_path = f'{annotation_dir}/poses.npy'
indexes_path = f'{annotation_dir}/indexes.txt'
# Load intrinsics and extrinsics
intrinsics = load_intrinsics(intrinsics_path, tgt_width=args.width, tgt_height=args.height)
extrinsics = np.load(extrinsics_path)
# Load and resize depth
depth = np.clip(read_depth(depth_path), 1e-3, 1e2) # (N, H, W)
resized_depth = np.zeros((depth.shape[0], args.height, args.width), dtype=depth.dtype)
for i in range(depth.shape[0]):
resized_depth[i] = cv2.resize(depth[i], (args.width, args.height), interpolation=cv2.INTER_LINEAR)
# Load RGB frames
frames = load_video(clip_path, indexes_path, args.height, args.width)
# Compute camera poses
poses_th = torch.as_tensor(extrinsics, device="cpu").float()
cam_c2w = SE3(poses_th).inv().matrix()
K = intrinsics[0]
K_o = torch.from_numpy(K).float()
# Save as npz
np.savez(
args.output,
images=frames,
depths=resized_depth,
intrinsic=K_o.detach().cpu().numpy(),
cam_c2w=cam_c2w.detach().cpu().numpy(),
)
print(f"Saved to {args.output}")
if __name__ == "__main__":
main()
================================================
FILE: utils/quat_to_mat.py
================================================
"""
Camera pose conversion utility to camera-to-world (c2w) or world-to-camera (w2c) format.
Converts quaternion representations to rotation matrices and handles pose transformations.
This module provides utilities for:
- Converting between quaternion and matrix representations of camera poses
- Transforming between world-to-camera (w2c) and camera-to-world (c2w) coordinate systems
- Parallel processing of pose conversion for large datasets
"""
import einops
import torch
import torch.nn.functional as F
import numpy as np
import os
import pandas as pd
import argparse
import concurrent.futures
import multiprocessing as mp
from multiprocessing import Manager
import queue
from tqdm import tqdm
class Pose:
"""
A class of operations on camera poses (numpy arrays with shape [...,3,4]).
Each [3,4] camera pose takes the form of [R|t].
"""
def __call__(self, R=None, t=None):
"""
Construct a camera pose from the given rotation matrix R and/or translation vector t.
Args:
R: Rotation matrix [...,3,3] or None
t: Translation vector [...,3] or None
Returns:
pose: Camera pose matrix [...,3,4]
"""
assert R is not None or t is not None
if R is None:
if not isinstance(t, np.ndarray):
t = np.array(t)
R = np.eye(3, device=t.device).repeat(*t.shape[:-1], 1, 1)
elif t is None:
if not isinstance(R, np.ndarray):
R = np.array(R)
t = np.zeros(R.shape[:-1], device=R.device)
else:
if not isinstance(R, np.ndarray):
R = np.array(R)
if not isinstance(t, np.ndarray):
t = np.tensor(t)
assert R.shape[:-1] == t.shape and R.shape[-2:] == (3, 3)
R = R.astype(np.float32)
t = t.astype(np.float32)
pose = np.concatenate([R, t[..., None]], axis=-1) # [...,3,4]
assert pose.shape[-2:] == (3, 4)
return pose
def invert(self, pose, use_inverse=False): # c2w <==> w2c
"""
Invert a camera pose transformation matrix.
Converts between camera-to-world (c2w) and world-to-camera (w2c) representations.
For a pose [R|t], the inverse is [R^T | -R^T*t].
Args:
pose: Camera pose matrix [...,3,4] with shape [R|t]
use_inverse: Whether to use matrix inverse instead of transpose for rotation
Returns:
pose_inv: Inverted camera pose matrix [...,3,4]
"""
R, t = pose[..., :3], pose[..., 3:]
R_inv = (
R.inverse() if use_inverse else R.transpose(0, 2, 1)
) # For orthogonal matrices, transpose equals inverse
t_inv = (-R_inv @ t)[..., 0] # Apply inverse rotation to negative translation
pose_inv = self(R=R_inv, t=t_inv)
return pose_inv
def compose(self, pose_list):
"""
Compose a sequence of poses together.
pose_new(x) = poseN o ... o pose2 o pose1(x)
Args:
pose_list: List of camera poses to compose
Returns:
pose_new: Composed camera pose
"""
pose_new = pose_list[0]
for pose in pose_list[1:]:
pose_new = self.compose_pair(pose_new, pose)
return pose_new
def compose_pair(self, pose_a, pose_b):
"""
Compose two poses together.
pose_new(x) = pose_b o pose_a(x)
Args:
pose_a: First camera pose
pose_b: Second camera pose
Returns:
pose_new: Composed camera pose
"""
R_a, t_a = pose_a[..., :3], pose_a[..., 3:]
R_b, t_b = pose_b[..., :3], pose_b[..., 3:]
R_new = R_b @ R_a
t_new = (R_b @ t_a + t_b)[..., 0]
pose_new = self(R=R_new, t=t_new)
return pose_new
def scale_center(self, pose, scale):
"""
Scale the camera center from the origin.
0 = R@c+t --> c = -R^T@t (camera center in world coordinates)
0 = R@(sc)+t' --> t' = -R@(sc) = -R@(-R^T@st) = st
Args:
pose: Camera pose to scale
scale: Scale factor
Returns:
pose_new: Scaled camera pose
"""
R, t = pose[..., :3], pose[..., 3:]
pose_new = np.concatenate([R, t * scale], axis=-1)
return pose_new
def quaternion_to_matrix(quaternions, eps: float = 1e-8):
"""
Convert 4-dimensional quaternions to 3x3 rotation matrices.
This is adapted from Pytorch3D:
https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py
Args:
quaternions: Quaternion tensor [..., 4] (order: i, j, k, r)
eps: Small value for numerical stability
Returns:
Rotation matrices [..., 3, 3]
"""
# Order changed to match scipy format!
i, j, k, r = torch.unbind(quaternions, dim=-1)
two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return einops.rearrange(o, "... (i j) -> ... i j", i=3, j=3)
def pose_from_quaternion(pose):
"""
Convert pose from quaternion representation to transformation matrix.
Args:
pose: Pose tensor [..., 7] where first 3 elements are translation (t)
and last 4 elements are quaternion rotation (r)
Returns:
w2c_matrix: World-to-camera transformation matrices [..., 3, 4]
"""
# Input is w2c, pose(n,7) or (n,v,7), output is (N,3,4) w2c matrix
# Tensor format from https://github.com/pointrix-project/Geomotion/blob/6ab0c364f1b44ab4ea190085dbf068f62b42727c/geomotion/model/cameras.py#L6
if type(pose) == np.ndarray:
pose = torch.tensor(pose)
if len(pose.shape) == 1:
pose = pose[None]
quat_t = pose[..., :3] # Translation
quat_r = pose[..., 3:] # Quaternion rotation
w2c_matrix = torch.zeros((*list(pose.shape)[:-1], 3, 4), device=pose.device)
w2c_matrix[..., :3, 3] = quat_t
w2c_matrix[..., :3, :3] = quaternion_to_matrix(quat_r)
return w2c_matrix
def possess_single_row(row, index, args):
"""
Process a single row to convert camera poses to c2w/w2c format.
Args:
row: Data row containing video ID
index: Row index
args: Command line arguments
"""
id = row["id"]
dir_path = os.path.join(args.dir_path, id, "reconstructions")
cam_pos_file = os.path.join(dir_path, "poses.npy")
if not os.path.exists(cam_pos_file):
return
output_file = os.path.join(dir_path, "extrinsics.npy")
if os.path.exists(output_file):
return
# Load quaternion poses
pose = np.load(cam_pos_file)
# Convert w2c quaternion format (N,v,7) to w2c matrix format (N,v,3,4)
poses = pose_from_quaternion(pose)
poses = poses.cpu().numpy()
# Convert w2c matrices to c2w matrices (N,v,3,4)
if args.format == "c2w":
poses = Pose().invert(poses)
np.save(output_file, poses)
def worker(task_queue, args, pbar):
"""Worker function for parallel pose conversion processing."""
while True:
try:
index, row = task_queue.get(timeout=1)
except queue.Empty:
break
possess_single_row(row, index, args)
task_queue.task_done()
pbar.update(1)
def parse_args():
"""Parse command line arguments for camera pose conversion."""
parser = argparse.ArgumentParser(description="Convert quaternion to camera pose")
parser.add_argument("--csv_path", type=str, help="Path to the csv file")
parser.add_argument("--dir_path", type=str, default="./outputs")
parser.add_argument("--format", type=str, default="c2w", choices=["c2w", "w2c"])
parser.add_argument(
"--num_workers",
type=int,
default=8,
help="Number of workers for parallel processing",
)
parser.add_argument(
"--disable_parallel", action="store_true", help="Disable parallel processing"
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
df = pd.read_csv(args.csv_path)
if args.disable_parallel:
# Sequential processing
for index, row in tqdm(df.iterrows(), total=len(df)):
possess_single_row(row, index, args)
else:
# Parallel processing with multiple workers
manager = Manager()
task_queue = manager.Queue()
for index, row in df.iterrows():
task_queue.put((index, row))
with tqdm(total=len(df), desc="Finished tasks") as pbar:
with concurrent.futures.ThreadPoolExecutor(
max_workers=args.num_workers
) as executor:
futures = []
for _ in range(args.num_workers):
futures.append(executor.submit(worker, task_queue, args, pbar))
for future in concurrent.futures.as_completed(futures):
future.result()
================================================
FILE: utils/read_depth.py
================================================
import zipfile
import numpy as np
import OpenEXR
def read_depth(zip_file_path):
"""
Read depth from zipped exr files.
"""
valid_width, valid_height = 0, 0
depth_data_list = []
with zipfile.ZipFile(zip_file_path, "r") as z:
for file_name in sorted(z.namelist()):
with z.open(file_name) as f:
try:
exr = OpenEXR.InputFile(f)
except OSError:
# Sometimes EXR loader might fail, we return all nan maps.
assert valid_width > 0 and valid_height > 0
depth_data_list.append(
np.full((valid_height, valid_width), np.nan, dtype=np.float32))
continue
header = exr.header()
dw = header["dataWindow"]
valid_width = width = dw.max.x - dw.min.x + 1
valid_height = height = dw.max.y - dw.min.y + 1
channels = exr.channels(["Z"])
depth_data = np.frombuffer(
channels[0], dtype=np.float16).reshape((height, width))
depth_data_list.append(depth_data.astype(np.float32))
# Note that the depth with a negative value is an invalid depth.
# It can be set to the farthest point or other operations.
depth_array = np.array(depth_data_list)
depth_array_safe = np.where(depth_array == 0, 1e-12, depth_array)
return 1.0 / depth_array_safe
================================================
FILE: utils/read_video.py
================================================
"""
Video reading utilities with memory optimization and multiple backend support.
"""
import gc
import math
import os
import re
import warnings
from fractions import Fraction
from typing import Any, Dict, List, Optional, Tuple, Union
from tools.logger import test_lg
import av
import cv2
import numpy as np
import torch
from torchvision import get_video_backend
from torchvision.io.video import _check_av_available
MAX_NUM_FRAMES = 2500
def read_video_av(
filename: str,
start_pts: Union[float, Fraction] = 0,
end_pts: Optional[Union[float, Fraction]] = None,
pts_unit: str = "pts",
output_format: str = "THWC",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
Read video frames using PyAV backend with memory optimization.
Modified from torchvision.io.video.read_video with improvements:
- No audio extraction (returns empty aframes)
- PyAV backend only
- Added container.close() and gc.collect() to prevent memory leaks
- Optimized for memory efficiency
"""
# Validate format
output_format = output_format.upper()
if output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
# Check file existence
if not os.path.exists(filename):
raise RuntimeError(f"File not found: {filename}")
# Validate backend
assert get_video_backend() == "pyav", "pyav backend is required for read_video_av"
_check_av_available()
# Validate time range
if end_pts is None:
end_pts = float("inf")
if end_pts < start_pts:
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
# Extract video metadata
info = {}
container = av.open(filename, metadata_errors="ignore")
video_fps = container.streams.video[0].average_rate
if video_fps is not None:
info["video_fps"] = float(video_fps)
# Get frame dimensions
iter_video = container.decode(**{"video": 0})
frame = next(iter_video).to_rgb().to_ndarray()
height, width = frame.shape[:2]
total_frames = container.streams.video[0].frames
if total_frames == 0:
total_frames = MAX_NUM_FRAMES
warnings.warn(f"total_frames is 0, using {MAX_NUM_FRAMES} as a fallback")
container.close()
del container
# Pre-allocate frame buffer (np.zeros doesn't actually allocate memory)
video_frames = np.zeros((total_frames, height, width, 3), dtype=np.uint8)
# Read video frames
try:
container = av.open(filename, metadata_errors="ignore")
assert container.streams.video is not None
video_frames = _read_from_stream(
video_frames,
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
filename=filename,
)
except av.AVError as e:
print(f"[Warning] Error while reading video {filename}: {e}")
# Convert to tensor and adjust format
vframes = torch.from_numpy(video_frames).clone()
del video_frames
if output_format == "TCHW":
# Convert [T,H,W,C] to [T,C,H,W]
vframes = vframes.permute(0, 3, 1, 2)
aframes = torch.empty((1, 0), dtype=torch.float32)
return vframes, aframes, info
def _read_from_stream(
video_frames,
container: "av.container.Container",
start_offset: float,
end_offset: float,
pts_unit: str,
stream: "av.stream.Stream",
stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
filename: Optional[str] = None,
) -> List["av.frame.Frame"]:
"""Read frames from video stream with proper buffering and seeking"""
# Convert time units
if pts_unit == "sec":
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
if end_offset != float("inf"):
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
else:
warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
# Check if buffering is needed for DivX packed B-frames
should_buffer = True
max_buffer_size = 5
if stream.type == "video":
extradata = stream.codec_context.extradata
if extradata and b"DivX" in extradata:
pos = extradata.find(b"DivX")
d = extradata[pos:]
o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
if o is None:
o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
if o is not None:
should_buffer = o.group(3) == b"p"
# Calculate seek offset with safety margin
seek_offset = start_offset
seek_offset = max(seek_offset - 1, 0) # Safety margin for seeking
if should_buffer:
seek_offset = max(seek_offset - max_buffer_size, 0)
# Seek to start position
try:
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
except av.AVError as e:
print(f"[Warning] Error while seeking video {filename}: {e}")
return []
# Read frames from stream
buffer_count = 0
frames_pts = []
cnt = 0
try:
for _idx, frame in enumerate(container.decode(**stream_name)):
frames_pts.append(frame.pts)
video_frames[cnt] = frame.to_rgb().to_ndarray()
cnt += 1
if cnt >= len(video_frames):
break
if frame.pts >= end_offset:
if should_buffer and buffer_count < max_buffer_size:
buffer_count += 1
continue
break
except av.AVError as e:
print(f"[Warning] Error while reading video {filename}: {e}")
# Clean up resources to prevent memory leaks
container.close()
del container
gc.collect() # Force garbage collection for PyAV threads
# ensure that the results are sorted wrt the pts
# NOTE: here we assert frames_pts is sorted
start_ptr = 0
end_ptr = cnt
while start_ptr < end_ptr and frames_pts[start_ptr] < start_offset:
start_ptr += 1
while start_ptr < end_ptr and frames_pts[end_ptr - 1] > end_offset:
end_ptr -= 1
if start_offset > 0 and start_offset not in frames_pts[start_ptr:end_ptr]:
# if there is no frame that exactly matches the pts of start_offset
# add the last frame smaller than start_offset, to guarantee that
# we will have all the necessary data. This is most useful for audio
if start_ptr > 0:
start_ptr -= 1
result = video_frames[start_ptr:end_ptr].copy()
return result
def read_video_cv2(filename, start_pts=None, end_pts=None, pts_unit="pts"):
"""
Read video using OpenCV backend.
"""
if pts_unit != "frames":
warnings.warn("Using pts_unit other than 'frames' is not supported for cv2 backend")
cap = cv2.VideoCapture(filename)
# Get video metadata
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Calculate frame range
if start_pts is None:
start_pts = 0
if end_pts is None:
end_pts = frame_count
# Limit frame range to video bounds
start_pts = max(0, start_pts)
end_pts = min(frame_count, end_pts)
num_frames = end_pts - start_pts
if num_frames <= 0:
return torch.zeros(0, 3, 0, 0), None, {"video_fps": fps}
# Seek to start frame
cap.set(cv2.CAP_PROP_POS_FRAMES, start_pts)
# Read frames
frames = []
for i in range(num_frames):
ret, frame = cap.read()
if not ret:
break
# Convert BGR to RGB and change HWC to CHW format
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = torch.from_numpy(frame).permute(2, 0, 1).float()
frames.append(frame)
cap.release()
if frames:
video_tensor = torch.stack(frames)
else:
video_tensor = torch.zeros(0, 3, 0, 0)
metadata = {"video_fps": fps}
return video_tensor, None, metadata
def read_video(video_path, backend="av"):
"""
Read video using specified backend.
"""
if backend == "cv2":
vframes, vinfo = read_video_cv2(video_path)
elif backend == "av":
vframes, _, vinfo = read_video_av(filename=video_path, pts_unit="sec", output_format="TCHW")
else:
raise ValueError(f"Unsupported backend: {backend}")
return vframes, vinfo
================================================
FILE: utils/scene_detect.py
================================================
"""
Video scene detection and timestamp processing utility.
This module provides functionality for:
- Scene detection using PySceneDetect library
- Timestamp processing and filtering
- Scene duration management
- Parallel processing of video files
"""
import argparse
import os
import ast
import concurrent.futures
import queue
import numpy as np
import pandas as pd
from tqdm import tqdm
from scenedetect import (
AdaptiveDetector,
detect,
ContentDetector,
SceneManager,
open_video,
)
from multiprocessing import Manager
def timecode_to_seconds(timecode):
"""Convert timecode string to seconds."""
h, m, s = map(float, timecode.split(":"))
return h * 3600 + m * 60 + s
def seconds_to_timecode(seconds):
"""Convert seconds to timecode string format."""
h = int(seconds // 3600)
m = int((seconds % 3600) // 60)
s = seconds % 60
return f"{h:02d}:{m:02d}:{s:06.3f}"
def process_single_row(
row,
frame_skip=0,
start_remove_sec=0,
end_remove_sec=0,
min_seconds=2,
max_seconds=15,
backend="opencv",
):
"""
Process a single video file for scene detection.
"""
video_path = row["video_path"]
detector1 = ContentDetector(threshold=21, min_scene_len=15)
detector2 = AdaptiveDetector(
adaptive_threshold=3.0, min_scene_len=15, luma_only=True
)
detector = [detector1, detector2]
try:
if isinstance(detector, list):
scene_manager = SceneManager()
for i in detector:
scene_manager.add_detector(i)
if backend == "opencv":
video = open_video(video_path)
elif backend == "av":
video = open_video(video_path, backend="pyav")
# Get video frame rate
fps = video.frame_rate
scene_manager.detect_scenes(video=video, frame_skip=frame_skip)
scene_list = scene_manager.get_scene_list()
else:
video = open_video(video_path)
# Get video frame rate
fps = video.frame_rate
scene_list = detect(video_path, detector, start_in_scene=True)
if not scene_list:
# If no scenes are detected, treat the entire video as one scene
video_duration = video.duration
timestamp = [("00:00:00.000", seconds_to_timecode(video_duration.get_seconds()))]
else:
timestamp = [(s.get_timecode(), t.get_timecode()) for s, t in scene_list]
# Process timestamps: remove specified seconds from start/end, filter by duration
new_timestamp = []
total_remove_sec = start_remove_sec + end_remove_sec
for start_timecode, end_timecode in timestamp:
start_seconds = timecode_to_seconds(start_timecode)
end_seconds = timecode_to_seconds(end_timecode)
duration = end_seconds - start_seconds
# Only record scenes longer than total removal time
if duration >= total_remove_sec:
new_start_seconds = start_seconds + start_remove_sec
new_end_seconds = end_seconds - end_remove_sec
new_duration = new_end_seconds - new_start_seconds
if new_duration <= max_seconds:
# Duration within max_seconds, check if meets min_seconds
if min_seconds <= new_duration:
new_start_timecode = seconds_to_timecode(new_start_seconds)
new_end_timecode = seconds_to_timecode(new_end_seconds)
new_timestamp.append((new_start_timecode, new_end_timecode))
else:
# Duration exceeds max_seconds, split into segments
current_start = new_start_seconds
while current_start + max_seconds <= new_end_seconds:
new_start_timecode = seconds_to_timecode(current_start)
new_end_timecode = seconds_to_timecode(
current_start + max_seconds
)
new_timestamp.append((new_start_timecode, new_end_timecode))
current_start += max_seconds
# Handle remaining segment
last_duration = new_end_seconds - current_start
if last_duration >= min_seconds:
new_start_timecode = seconds_to_timecode(current_start)
new_end_timecode = seconds_to_timecode(new_end_seconds)
new_timestamp.append((new_start_timecode, new_end_timecode))
return True, str(new_timestamp), float(fps)
except Exception as e:
print(f"Video '{video_path}' with error {e}")
return False, "", None
def timecode_to_frames(timecode, fps):
"""Convert timecode to frame number using fps."""
h, m, s = map(float, timecode.split(":"))
total_seconds = h * 3600 + m * 60 + s
return int(total_seconds * fps)
def worker(task_queue, results_queue, args):
"""
Worker function for parallel scene detection processing.
"""
while True:
try:
index, row = task_queue.get(timeout=1)
except queue.Empty:
break
result = process_single_row(
row,
frame_skip=args.frame_skip,
start_remove_sec=args.start_remove_sec,
end_remove_sec=args.end_remove_sec,
min_seconds=args.min_seconds,
max_seconds=args.max_seconds,
backend=args.backend,
)
results_queue.put((index, result))
task_queue.task_done()
def parse_args():
"""Parse command line arguments for scene detection."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--csv_path",
type=str,
required=True,
help="Path to the input CSV file containing video paths.",
)
parser.add_argument(
"--num_workers", type=int, default=1, help="#workers for concurrent.futures"
)
parser.add_argument(
"--frame_skip", type=int, default=0, help="skip frame for detect_scenes"
)
parser.add_argument(
"--start_remove_sec",
type=float,
default=0,
help="Seconds to remove from the start of each timestamp",
)
parser.add_argument(
"--end_remove_sec",
type=float,
default=0,
help="Seconds to remove from the end of each timestamp",
)
parser.add_argument(
"--min_seconds",
type=float,
default=2,
help="Minimum duration of a scene in seconds",
)
parser.add_argument(
"--max_seconds",
type=float,
default=15,
help="Maximum duration of a scene in seconds",
)
parser.add_argument(
"--backend",
type=str,
default="opencv",
choices=["opencv", "av"],
help="Backend for video reading",
)
parser.add_argument(
"--disable_parallel", action="store_true", help="Disable parallel processing"
)
args = parser.parse_args()
return args
def main():
args = parse_args()
csv_path = args.csv_path
if not os.path.exists(csv_path):
print(f"csv file '{csv_path}' not found. Exit.")
return
csv = pd.read_csv(csv_path)
ret = []
if args.disable_parallel:
for index, row in tqdm(csv.iterrows(), total=len(csv)):
succ, timestamps, fps = process_single_row(
row,
frame_skip=args.frame_skip,
start_remove_sec=args.start_remove_sec,
end_remove_sec=args.end_remove_sec,
min_seconds=args.min_seconds,
max_seconds=args.max_seconds,
)
csv.at[index, "fps"] = fps
csv.at[index, "timestamp"] = timestamps
ret.append((index, (succ, timestamps, fps)))
else:
manager = Manager()
task_queue = manager.Queue()
results_queue = manager.Queue()
# Add all tasks to queue
for index, row in csv.iterrows():
task_queue.put((index, row))
# Set number of workers
if args.num_workers is not None:
num_workers = args.num_workers
else:
num_workers = os.cpu_count() or 1
# Process videos in parallel
with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
futures = []
for _ in range(num_workers):
future = executor.submit(worker, task_queue, results_queue, args)
futures.append(future)
processed = 0
total_tasks = len(csv)
with tqdm(total=total_tasks, desc="Processing videos") as pbar:
while processed < total_tasks:
try:
ret.append(results_queue.get(timeout=1))
processed += 1
pbar.update(1)
except queue.Empty:
if all(f.done() for f in futures) and results_queue.empty():
break
for future in futures:
future.result()
# Collect results
while not results_queue.empty():
ret.append(results_queue.get())
# Sort results by index
ret.sort(key=lambda x: x[0])
succ, timestamps, fps_list = list(zip(*[result for _, result in ret]))
csv["fps"] = fps_list
csv["timestamp"] = timestamps
csv = csv[np.array(succ)]
def calculate_frame_numbers(row):
"""Calculate frame numbers from timestamps and fps."""
timestamp = ast.literal_eval(row["timestamp"])
fps = row["fps"]
frame_numbers = [
(timecode_to_frames(start, fps), timecode_to_frames(end, fps))
for start, end in timestamp
]
return str(frame_numbers)
csv["frame_numbers"] = csv.apply(calculate_frame_numbers, axis=1)
# Save results to new CSV file
wo_ext, ext = os.path.splitext(csv_path)
out_path = f"{wo_ext}_timestamp{ext}"
csv.to_csv(out_path, index=False)
print(
f"New csv (shape={csv.shape}) with timestamp and frame numbers saved to '{out_path}'."
)
if __name__ == "__main__":
main()
================================================
FILE: viser/.clang-format
================================================
# C++ formatting rules; used for WebAssembly code.
BasedOnStyle: LLVM
AlignAfterOpenBracket: BlockIndent
BinPackArguments: false
BinPackParameters: false
IndentWidth: 4
================================================
FILE: viser/.gitignore
================================================
*.swp
*.swo
*.pyc
*.egg-info
*.ipynb_checkpoints
__pycache__
.coverage
htmlcov
.mypy_cache
.dmypy.json
.hypothesis
.envrc
.lvimrc
.DS_Store
.envrc
.vite
build
src/viser/client/build
src/viser/client/.nodeenv
record3d_dance
================================================
FILE: viser/.pre-commit-config.yaml
================================================
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
default_language_version:
python: python3
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.2
hooks:
# Run the linter.
- id: ruff
args: [--fix]
# Run the formatter.
- id: ruff-format
================================================
FILE: viser/.prettierignore
================================================
*.mjs
build/
================================================
FILE: viser/LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: viser/README.md
================================================
viser
### This repo is a customized version of https://github.com/nerfstudio-project/viser for project MonST3R (https://monst3r-project.github.io/)
`viser` is a library for interactive 3D visualization in Python.
Features include:
- API for visualizing 3D primitives
- GUI building blocks: buttons, checkboxes, text inputs, sliders, etc.
- Scene interaction tools (clicks, selection, transform gizmos)
- Programmatic camera control and rendering
- An entirely web-based client, for easy use over SSH!
For usage and API reference, see our documentation.
## Installation
You can install `viser` with `pip`:
```bash
pip install viser
```
To include example dependencies:
```bash
pip install viser[examples]
```
After an example script is running, you can connect by navigating to the printed
URL (default: `http://localhost:8080`).
See also: our [development docs](https://viser.studio/latest/development/).
## Examples
**Point cloud visualization**
https://github.com/nerfstudio-project/viser/assets/6992947/df35c6ee-78a3-43ad-a2c7-1dddf83f7458
Source: `./examples/07_record3d_visualizer.py`
**Gaussian splatting visualization**
https://github.com/nerfstudio-project/viser/assets/6992947/c51b4871-6cc8-4987-8751-2bf186bcb1ae
Source:
[WangFeng18/3d-gaussian-splatting](https://github.com/WangFeng18/3d-gaussian-splatting)
and
[heheyas/gaussian_splatting_3d](https://github.com/heheyas/gaussian_splatting_3d).
**SMPLX visualizer**
https://github.com/nerfstudio-project/viser/assets/6992947/78ba0e09-612d-4678-abf3-beaeeffddb01
Source: `./example/08_smpl_visualizer.py`
## Acknowledgements
`viser` is heavily inspired by packages like
[Pangolin](https://github.com/stevenlovegrove/Pangolin),
[rviz](https://wiki.ros.org/rviz/),
[meshcat](https://github.com/rdeits/meshcat), and
[Gradio](https://github.com/gradio-app/gradio).
It's made possible by several open-source projects.
The web client is implemented using [React](https://react.dev/), with:
- [Vite](https://vitejs.dev/) / [Rollup](https://rollupjs.org/) for bundling
- [three.js](https://threejs.org/) via [react-three-fiber](https://github.com/pmndrs/react-three-fiber) and [drei](https://github.com/pmndrs/drei)
- [Mantine](https://mantine.dev/) for UI components
- [zustand](https://github.com/pmndrs/zustand) for state management
- [vanilla-extract](https://vanilla-extract.style/) for stylesheets
The Python API communicates via [msgpack](https://msgpack.org/index.html) and [websockets](https://websockets.readthedocs.io/en/stable/index.html).
================================================
FILE: viser/docs/.gitignore
================================================
build/
================================================
FILE: viser/docs/Makefile
================================================
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
SPHINXPROJ = viser
SOURCEDIR = source
BUILDDIR = ./build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
================================================
FILE: viser/docs/source/_static/css/custom.css
================================================
img.sidebar-logo {
width: 5em;
margin: 1em 0 0 0;
}
================================================
FILE: viser/docs/source/_templates/sidebar/brand.html
================================================
{% block brand_content %} {%- if logo_url %}
{%- endif %} {%- if theme_light_logo and theme_dark_logo %}