Repository: ethz-vlg/mvtracker Branch: main Commit: ceea8ad2af77 Files: 183 Total size: 1.8 MB Directory structure: gitextract_p_4pogs3/ ├── .gitignore ├── README.md ├── configs/ │ ├── eval.yaml │ ├── experiment/ │ │ ├── mvtracker.yaml │ │ ├── mvtracker_overfit.yaml │ │ └── mvtracker_overfit_mini.yaml │ ├── model/ │ │ ├── copycat.yaml │ │ ├── cotracker1_offline.yaml │ │ ├── cotracker1_online.yaml │ │ ├── cotracker2_offline.yaml │ │ ├── cotracker2_online.yaml │ │ ├── cotracker3_offline.yaml │ │ ├── cotracker3_online.yaml │ │ ├── default.yaml │ │ ├── delta.yaml │ │ ├── locotrack.yaml │ │ ├── mvtracker.yaml │ │ ├── scenetracker.yaml │ │ ├── spatialtrackerv2.yaml │ │ ├── spatracker_monocular.yaml │ │ ├── spatracker_monocular_pretrained.yaml │ │ ├── spatracker_multiview.yaml │ │ └── tapip3d.yaml │ └── train.yaml ├── demo.py ├── hubconf.py ├── mvtracker/ │ ├── __init__.py │ ├── cli/ │ │ ├── __init__.py │ │ ├── eval.py │ │ ├── train.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── helpers.py │ │ ├── pylogger.py │ │ └── rich_utils.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── dexycb_multiview_dataset.py │ │ ├── generic_scene_dataset.py │ │ ├── kubric_multiview_dataset.py │ │ ├── panoptic_studio_multiview_dataset.py │ │ ├── tap_vid_datasets.py │ │ └── utils.py │ ├── evaluation/ │ │ ├── __init__.py │ │ ├── evaluator_3dpt.py │ │ └── metrics.py │ ├── models/ │ │ ├── __init__.py │ │ ├── core/ │ │ │ ├── __init__.py │ │ │ ├── copycat.py │ │ │ ├── cotracker2/ │ │ │ │ ├── __init__.py │ │ │ │ └── blocks.py │ │ │ ├── dpt/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base_model.py │ │ │ │ ├── blocks.py │ │ │ │ ├── midas_net.py │ │ │ │ ├── models.py │ │ │ │ ├── transforms.py │ │ │ │ └── vit.py │ │ │ ├── dynamic3dgs/ │ │ │ │ ├── LICENSE.md │ │ │ │ ├── colormap.py │ │ │ │ ├── export_depths_from_pretrained_checkpoint.py │ │ │ │ ├── external.py │ │ │ │ ├── helpers.py │ │ │ │ ├── merge_tapvid3d_per_camera_annotations.py │ │ │ │ ├── metadata_dexycb.py │ │ │ │ ├── metadata_kubric.py │ │ │ │ ├── reorganize_dexycb.py │ │ │ │ ├── test.py │ │ │ │ ├── track_2d.py │ │ │ │ ├── track_3d.py │ │ │ │ ├── train.py │ │ │ │ └── visualize.py │ │ │ ├── embeddings.py │ │ │ ├── loftr/ │ │ │ │ ├── __init__.py │ │ │ │ ├── linear_attention.py │ │ │ │ └── transformer.py │ │ │ ├── losses.py │ │ │ ├── model_utils.py │ │ │ ├── monocular_baselines.py │ │ │ ├── mvtracker/ │ │ │ │ ├── __init__.py │ │ │ │ └── mvtracker.py │ │ │ ├── ptv3/ │ │ │ │ ├── __init__.py │ │ │ │ ├── model.py │ │ │ │ └── serialization/ │ │ │ │ ├── __init__.py │ │ │ │ ├── default.py │ │ │ │ ├── hilbert.py │ │ │ │ └── z_order.py │ │ │ ├── shape-of-motion/ │ │ │ │ ├── .gitignore │ │ │ │ ├── .gitmodules │ │ │ │ ├── LICENSE │ │ │ │ ├── README.md │ │ │ │ ├── flow3d/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── configs.py │ │ │ │ │ ├── data/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── base_dataset.py │ │ │ │ │ │ ├── casual_dataset.py │ │ │ │ │ │ ├── colmap.py │ │ │ │ │ │ ├── iphone_dataset.py │ │ │ │ │ │ ├── panoptic_dataset.py │ │ │ │ │ │ └── utils.py │ │ │ │ │ ├── init_utils.py │ │ │ │ │ ├── loss_utils.py │ │ │ │ │ ├── metrics.py │ │ │ │ │ ├── params.py │ │ │ │ │ ├── renderer.py │ │ │ │ │ ├── scene_model.py │ │ │ │ │ ├── tensor_dataclass.py │ │ │ │ │ ├── trainer.py │ │ │ │ │ ├── trajectories.py │ │ │ │ │ ├── transforms.py │ │ │ │ │ ├── validator.py │ │ │ │ │ └── vis/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── playback_panel.py │ │ │ │ │ ├── render_panel.py │ │ │ │ │ ├── utils.py │ │ │ │ │ └── viewer.py │ │ │ │ └── launch_davis.py │ │ │ ├── spatracker/ │ │ │ │ ├── __init__.py │ │ │ │ ├── blocks.py │ │ │ │ ├── softsplat.py │ │ │ │ ├── spatracker_monocular.py │ │ │ │ └── spatracker_multiview.py │ │ │ ├── vggt/ │ │ │ │ ├── __init__.py │ │ │ │ ├── heads/ │ │ │ │ │ ├── camera_head.py │ │ │ │ │ ├── dpt_head.py │ │ │ │ │ ├── head_act.py │ │ │ │ │ ├── track_head.py │ │ │ │ │ ├── track_modules/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── base_track_predictor.py │ │ │ │ │ │ ├── blocks.py │ │ │ │ │ │ ├── modules.py │ │ │ │ │ │ └── utils.py │ │ │ │ │ └── utils.py │ │ │ │ ├── layers/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── attention.py │ │ │ │ │ ├── block.py │ │ │ │ │ ├── drop_path.py │ │ │ │ │ ├── layer_scale.py │ │ │ │ │ ├── mlp.py │ │ │ │ │ ├── patch_embed.py │ │ │ │ │ ├── rope.py │ │ │ │ │ ├── swiglu_ffn.py │ │ │ │ │ └── vision_transformer.py │ │ │ │ ├── models/ │ │ │ │ │ ├── aggregator.py │ │ │ │ │ └── vggt.py │ │ │ │ └── utils/ │ │ │ │ ├── geometry.py │ │ │ │ ├── load_fn.py │ │ │ │ ├── pose_enc.py │ │ │ │ ├── rotation.py │ │ │ │ └── visual_track.py │ │ │ └── vit/ │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ └── encoder.py │ │ └── evaluation_predictor_3dpt.py │ └── utils/ │ ├── __init__.py │ ├── basic.py │ ├── eval_utils.py │ ├── geom.py │ ├── improc.py │ ├── misc.py │ ├── visualizer_mp4.py │ └── visualizer_rerun.py ├── requirements.full.txt ├── requirements.txt └── scripts/ ├── 4ddress_preprocessing.py ├── __init__.py ├── compare_cdist-topk_against_pointops-knn.py ├── dex_ycb_to_neus_format.py ├── egoexo4d_preprocessing.py ├── estimate_depth_with_duster.py ├── hi4d_preprocessing.py ├── merge_comparison_mp4s.py ├── panoptic_studio_preprocessing.py ├── plot_aj_for_varying_depth_noise_levels.py ├── plot_aj_for_varying_n_of_views.py ├── profiling.md ├── selfcap_preprocessing.py ├── slurm/ │ ├── eval.sh │ ├── mvtracker-nodepthaugs.sh │ ├── mvtracker.sh │ ├── spatracker.sh │ ├── test_reproducibility.sh │ ├── triplane-128.sh │ └── triplane-256.sh └── summarize_eval_results.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .idea __pycache__/ *.DS_Store *.pth *.pt *.mp4 *.npy vis_results/ checkpoints/ logs/ slurm_logs/ submit* logs* /running /datasets /env.sh /eular_log /outputs /wandb ================================================ FILE: README.md ================================================

Multi-View 3D Point Tracking

arXiv Project Page Interactive Results [![](https://img.shields.io/badge/🤗%20Demo-Coming%20soon…-ffcc00)](#)
[**Frano Rajič**](https://m43.github.io/)1 · [**Haofei Xu**](https://haofeixu.github.io/)1 · [**Marko Mihajlovic**](https://markomih.github.io/)1 · [**Siyuan Li**](https://siyuanliii.github.io/)1 · [**Irem Demir**](https://github.com/iremddemir)1 [**Emircan Gündoğdu**](https://github.com/emircangun)1 · [**Lei Ke**](https://www.kelei.site/)2 · [**Sergey Prokudin**](https://vlg.inf.ethz.ch/team/Dr-Sergey-Prokudin.html)1,3 · [**Marc Pollefeys**](https://people.inf.ethz.ch/marc.pollefeys/)1,4 · [**Siyu Tang**](https://vlg.inf.ethz.ch/team/Prof-Dr-Siyu-Tang.html)1
1[ETH Zürich](https://vlg.inf.ethz.ch/)   2[Carnegie Mellon University](https://www.cmu.edu/)   3[Balgrist University Hospital](https://www.balgrist.ch/)   4[Microsoft](https://www.microsoft.com/)

selfcap dexycb 4d-dress-stretching 4d-dress-avatarmove

MVTracker is the first **data-driven multi-view 3D point tracker** for tracking arbitrary 3D points across multiple cameras. It fuses multi-view features into a unified 3D feature point cloud, within which it leverages kNN-based correlation to capture spatiotemporal relationships across views. A transformer then iteratively refines the point tracks, handling occlusions and adapting to varying camera setups without per-sequence optimization. ## Updates - August 28, 2025: Public release. ## Quick Start This repo was validated on **Python 3.10.12**, **PyTorch 2.3.0** (CUDA 12.1), **cuDNN 8903**, and **gcc 11.3.0**. If you want a fresh minimal environment that runs the Hub demo and `demo.py`: ```bash conda create -n 3dpt python=3.10.12 -y conda activate 3dpt conda install pytorch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 pytorch-cuda=12.1 -c pytorch -c nvidia -y pip install -r https://raw.githubusercontent.com/ethz-vlg/mvtracker/refs/heads/main/requirements.txt # Optional, speeds up the model pip install --upgrade --no-build-isolation flash-attn==2.5.8 # Speeds up attention pip install "git+https://github.com/ethz-vlg/pointcept.git@2082918#subdirectory=libs/pointops" # Speeds up kNN search; may require gcc 11.3.0: conda install -c conda-forge gcc_linux-64=11.3.0 gxx_linux-64=11.3.0 gcc=11.3.0 gxx=11.3.0 ``` With the minimal dependencies in place, you can try MVTracker directly via **PyTorch Hub**: ```python import torch import numpy as np from huggingface_hub import hf_hub_download device = "cuda" if torch.cuda.is_available() else "cpu" mvtracker = torch.hub.load("ethz-vlg/mvtracker", "mvtracker", pretrained=True, device=device) # Example input from demo sample (downloaded automatically) sample = np.load(hf_hub_download("ethz-vlg/mvtracker", "data_sample.npz")) rgbs = torch.from_numpy(sample["rgbs"]).float() depths = torch.from_numpy(sample["depths"]).float() intrs = torch.from_numpy(sample["intrs"]).float() extrs = torch.from_numpy(sample["extrs"]).float() query_points = torch.from_numpy(sample["query_points"]).float() with torch.no_grad(): results = mvtracker( rgbs=rgbs[None].to(device) / 255.0, depths=depths[None].to(device), intrs=intrs[None].to(device), extrs=extrs[None].to(device), query_points_3d=query_points[None].to(device), ) pred_tracks = results["traj_e"].cpu() # [T,N,3] pred_vis = results["vis_e"].cpu() # [T,N] print(pred_tracks.shape, pred_vis.shape) ``` Alternatively, you can run our interactive demo: ```bash python demo.py --rerun save --lightweight ``` By default this saves a lightweight `.rrd` recording (e.g., `mvtracker_demo.rrd`) that you can open in any Rerun viewer. The simplest option is to drag and drop the file into the [online viewer](https://app.rerun.io/version/0.21.0). For the best experience, you can also install Rerun locally (`pip install rerun-sdk==0.21.0; rerun`). Results can be explored interactively in the viewer with WASD/QE navigation, mouse rotation and zoom, and timeline playback controls.
[Interactive viewer on a cluster or with GUI support - click to expand] If you are working on a cluster, you can stream results directly to your laptop by forwarding a port (`ssh -R 9876:localhost:9876 user@cluster`) and then running the demo in streaming mode (`python demo.py --rerun stream`), which sends live data into your local Rerun instance. If you are running the demo locally with GUI support, you can automatically spawn a Rerun window (`python demo.py --rerun spawn`).
## Installation You can use a pretrained model directly via **PyTorch Hub** (see Quick Start above), or clone this repo if you want to run our demo, evaluation, or training. We recommend using **PyTorch with CUDA** for best performance. CPU-only runs are possible but very slow. ```bash git clone https://github.com/ethz-vlg/mvtracker.git cd mvtracker ``` To extend the conda environment from the Quick Start to support training and evaluation, install the full requirements by running `pip install -r requirements.full.txt`. Baselines based on SpatialTracker V1 also require cupy: ```bash pip install tensorflow==2.12.1 tensorflow-datasets tensorflow-graphics tensorboard pip install cupy-cuda12x==12.2.0 python -m cupyx.tools.install_library --cuda 12.x --library cutensor python -m cupyx.tools.install_library --cuda 12.x --library nccl python -m cupyx.tools.install_library --cuda 12.x --library cudnn ``` ## Datasets To benchmark multi-view 3D point tracking, we provide preprocessed versions of three datasets: - **MV-Kubric**: a synthetic training dataset adapted from single-view Kubric into a multi-view setting. - **Panoptic Studio**: evaluation benchmark with real-world activities such as basketball, juggling, and toy play (10 sequences). - **DexYCB**: evaluation benchmark with real-world hand–object interactions (10 sequences).
[Downloading our preprocessed datasets - click to expand] You can download and extract them as (~72 GB after extraction): ```bash # MV-Kubric (simulated + DUSt3R depths) wget https://huggingface.co/datasets/ethz-vlg/mv3dpt-datasets/resolve/main/kubric-multiview--test.tar.gz -P datasets/ wget https://huggingface.co/datasets/ethz-vlg/mv3dpt-datasets/resolve/main/kubric-multiview--test--dust3r-depth.tar.gz -P datasets/ tar -xvzf datasets/kubric-multiview--test.tar.gz -C datasets/ tar -xvzf datasets/kubric-multiview--test--dust3r-depth.tar.gz -C datasets/ rm datasets/kubric-multiview*.tar.gz # Panoptic Studio (optimization-based depth from Dynamic3DGS) wget https://huggingface.co/datasets/ethz-vlg/mv3dpt-datasets/resolve/main/panoptic-multiview.tar.gz -P datasets/ tar -xvzf datasets/panoptic-multiview.tar.gz -C datasets/ rm datasets/panoptic-multiview.tar.gz # DexYCB (Kinect + DUSt3R depths) wget https://huggingface.co/datasets/ethz-vlg/mv3dpt-datasets/resolve/main/dex-ycb-multiview.tar.gz -P datasets/ wget https://huggingface.co/datasets/ethz-vlg/mv3dpt-datasets/resolve/main/dex-ycb-multiview--dust3r-depth.tar.gz -P datasets/ tar -xvzf datasets/dex-ycb-multiview.tar.gz -C datasets/ tar -xvzf datasets/dex-ycb-multiview--dust3r-depth.tar.gz -C datasets/ rm datasets/dex-ycb-multiview*.tar.gz # $ du -sch datasets/* # 31G kubric-multiview # 13G panoptic-multiview # 29G dex-ycb-multiview # 72G total ```
[Regenerating datasets from scratch - click to expand] If you wish to regenerate datasets from scratch, we provide scripts with docstrings that explain usage and list the commands we used. For licensing and usage terms, please refer to the original datasets. - MV-Kubric data for training and testing can be generated with [ethz-vlg/kubric](https://github.com/ethz-vlg/kubric/blob/multiview-point-tracking/challenges/point_tracking_3d/worker.py). - DexYCB can be downloaded and labels regenerated using [`scripts/dex_ycb_to_neus_format.py`](./scripts/dex_ycb_to_neus_format.py); note that we have created labels for 10 sequences, but DexYCB is much larger and more labels could be produced if needed. - Panoptic Studio can be downloaded and labels regenerated using [`scripts/panoptic_studio_preprocessing.py`](./scripts/panoptic_studio_preprocessing.py). - DUSt3R depths can be produced for any dataset with [`scripts/estimate_depth_with_duster.py`](./scripts/estimate_depth_with_duster.py). - For unlabeled datasets used only in qualitative experiments, we provide the following preprocessing scripts: [4D-Dress](./scripts/4ddress_preprocessing.py), [Hi4D](./scripts/hi4d_preprocessing.py), [EgoExo4D](./scripts/egoexo4d_preprocessing.py), and [SelfCap](./scripts/selfcap_preprocessing.py).
For quick testing, we also release a small **demo sample** (~200 MB): ```bash python demo.py --random_query_points ``` Our generic loader [`GenericSceneDataset`](./mvtracker/datasets/generic_scene_dataset.py) supports adding new datasets. It can compute depths on the fly with [DUSt3R](https://github.com/naver/dust3r), [VGGT](https://vgg-t.github.io), [MonoFusion](https://imnotprepared.github.io/research/25_DSR/index.html), or [MoGe-2](https://github.com/microsoft/MoGe), and can also estimate camera poses with VGGT. ## Evaluation Evaluation is driven by Hydra configs. See [`mvtracker/cli/eval.py`](./mvtracker/cli/eval.py) and [`configs/eval.yaml`](./configs/eval.yaml) for details. To evaluate MVTracker with our best model, first download the checkpoint from [Hugging Face](https://huggingface.co/ethz-vlg/mvtracker): ```bash wget https://huggingface.co/ethz-vlg/mvtracker/resolve/main/mvtracker_200000_june2025.pth -P checkpoints/ ``` Then run: ```bash python -m mvtracker.cli.eval \ experiment_path=logs/mvtracker \ model=mvtracker \ datasets.eval.names=[kubric-multiview-v3-views0123] \ restore_ckpt_path=checkpoints/mvtracker_200000_june2025.pth # Expected result: # { # "eval_kubric-multiview-v3-views0123/model__ate_visible__dynamic-static-mean": 5.07, # "eval_kubric-multiview-v3-views0123/model__average_jaccard__dynamic-static-mean": 81.42, # "eval_kubric-multiview-v3-views0123/model__average_pts_within_thresh__dynamic-static-mean": 90.00 # } ``` To evaluate a baseline, e.g. CoTracker3-Online (auto-downloaded checkpoint), run: ```bash python -m mvtracker.cli.eval experiment_path=logs/cotracker3-online model=cotracker3_online # Expected result: # { # "eval_panoptic-multiview-views1_7_14_20/model__average_jaccard__any": 74.56 # } ``` For more baselines and dataset setups (e.g. varying camera counts, camera subsets, etc.), see [`scripts/slurm/eval.sh`](./scripts/slurm/eval.sh) for the commands used in our experiments.
[Details on evaluation parameters - click to expand] The evaluation datasets are specified with `datasets.eval.names`. Each name is parsed by the dataset `from_name()` factory (see e.g. [`DexYCBMultiViewDataset.from_name`](./mvtracker/datasets/dexycb_multiview_dataset.py)), which supports modifiers such as `-views`, `-duster`, `-novelviews`, `-removehand`, `-2dpt`, or `-cached`. This makes it easy to select subsets of cameras, enable different depth sources, or ensure deterministic track sampling. The main labeled benchmarks are: - **Kubric (synthetic)** — e.g. `kubric-multiview-v3-views0123` - **Panoptic Studio (real)** — e.g. `panoptic-multiview-views1_7_14_20` - **DexYCB (real)** — e.g. `dex-ycb-multiview-views0123` For reproducibility of our main results, we also provide *cached* variants of each benchmark, which freeze track selection exactly as used in our paper. Without `-cached`, random seeding ensures reproducibility, but cached versions guarantee identical tracks across environments. The following cached variants are included in the released datasets: - `kubric-multiview-v3-views0123-cached` - `kubric-multiview-v3-duster0123-cached` - `panoptic-multiview-views1_7_14_20-cached` - `panoptic-multiview-views27_16_14_8-cached` - `panoptic-multiview-views1_4_7_11-cached` - `dex-ycb-multiview-views0123-cached` - `dex-ycb-multiview-duster0123-cached`
## Training To run a small overfitting test that fits into 24 GB GPU RAM: ```bash python -m mvtracker.cli.train +experiment=mvtracker_overfit_mini ``` For a full-scale MVTracker on an 80 GB GPU: ```bash python -m mvtracker.cli.train +experiment=mvtracker_overfit ``` ## Practical Considerations
[Scene normalization - click to expand] Performance depends strongly on scene normalization. MVTracker was trained on Kubric with randomized but bounded scales and camera setups. At test time, scenes with very different scales, rotations, or translations must be aligned to this distribution. Our generic loader provides an automatic normalization that assumes the ground plane is parallel to the XY plane. This automatic normalization worked reasonably well for 4D-Dress, Hi4D, EgoExo4D, and SelfCap. For Panoptic and DexYCB, we applied manual similarity transforms, which are encoded in the respective dataloaders. Robust, general-purpose normalization remains an open challenge.
[Challenges and future directions - click to expand] The central challenge in multi-view 3D point tracking is 4D reconstruction: obtaining depth maps that are accurate, temporally consistent, and available in real time, especially under sparse-view setups. MVTracker performs well when sensor depth and camera calibration are provided, but in settings where both must be estimated, errors in reconstruction quickly make tracking unreliable. While learned motion priors help tolerate moderate noise, they cannot replace a robust reconstruction backbone. We believe progress will hinge on methods that jointly solve depth estimation and tracking for mutual refinement, or large-scale foundation models for 4D reconstruction and tracking that fully leverage data and compute. We hope the community will direct future efforts toward this goal.
## Acknowledgements Our code builds upon and was inspired by many prior works, including [SpaTracker](https://github.com/henry123-boy/SpaTracker), [CoTracker](https://github.com/facebookresearch/co-tracker), and [DUSt3R](https://github.com/naver/dust3r). We thank the authors for releasing their code and pretrained models. We are also grateful to maintainers of [Rerun](https://rerun.io) for their helpful visualization toolkit. ## Citation If you find our repository useful, please consider giving it a star ⭐ and citing our work: ```bibtex @inproceedings{rajic2025mvtracker, title = {Multi-View 3D Point Tracking}, author = {Raji{\v{c}}, Frano and Xu, Haofei and Mihajlovic, Marko and Li, Siyuan and Demir, Irem and G{\"u}ndo{\u{g}}du, Emircan and Ke, Lei and Prokudin, Sergey and Pollefeys, Marc and Tang, Siyu}, booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, year = {2025} } ``` ================================================ FILE: configs/eval.yaml ================================================ defaults: - train - _self_ modes: eval_only: true trainer: precision: 32-true # Optional overrides specific to evaluation runs datasets: eval: names: [ "panoptic-multiview-views1_7_14_20" ] max_seq_len: 1000 evaluation: consume_model_stats: false # whether to report model stats (which can slow down the forward pass) evaluator: rerun_viz_indices: null forward_pass_log_indices: null mp4_track_viz_indices: null # rerun_viz_indices: [ 0,1,2 ] # forward_pass_log_indices: [ 0,1,2 ] # mp4_track_viz_indices: [ 0,1,2 ] # rerun_viz_indices: [ 0,3,27, 2,23 ] # forward_pass_log_indices: null # mp4_track_viz_indices: [ 0,3,27, 2,23 ] # rerun_viz_indices: [ 0, 7 ] # forward_pass_log_indices: null # mp4_track_viz_indices: [ 0, 7 ] # rerun_viz_indices: [ 0, 5 ] # forward_pass_log_indices: null # mp4_track_viz_indices: [ 0, 5 ] # rerun_viz_indices: [ 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29 ] # forward_pass_log_indices: [ 0,1,2,3,4 ] # mp4_track_viz_indices: [ 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29 ] ================================================ FILE: configs/experiment/mvtracker.yaml ================================================ # @package _global_ defaults: - override /model: mvtracker experiment_path: ./logs/mvtracker ================================================ FILE: configs/experiment/mvtracker_overfit.yaml ================================================ # @package _global_ defaults: - override /model: mvtracker experiment_path: ./logs/debug/mvtracker-overfit datasets: root: ./datasets train: name: kubric-multiview-v3-views0123-training batch_size: 1 sequence_len: 24 traj_per_sample: 512 num_workers: 4 eval: names: [kubric-multiview-v3-views0123-overfit-on-training] num_workers: 2 max_seq_len: 1000 trainer: num_steps: 1500 eval_freq: 500 viz_freq: 500 save_ckpt_freq: 500 augment_train_iters: false augmentations: probability: 1.0 rgb: false depth: false cropping: true variable_trajpersample: false scene_transform: false camera_params_noise: false variable_depth_type: false variable_num_views: false modes: tune_per_scene: true dont_validate_at_start: true do_initial_static_pretrain: false pretrain_only: false eval_only: false debug: false ================================================ FILE: configs/experiment/mvtracker_overfit_mini.yaml ================================================ # @package _global_ defaults: - mvtracker_overfit experiment_path: ./logs/debug/mvtracker-overfit-mini datasets: train: traj_per_sample: 8 model: fmaps_dim: 32 ================================================ FILE: configs/model/copycat.yaml ================================================ # @package _global_ defaults: - default model: _target_: mvtracker.models.core.copycat.CopyCat ================================================ FILE: configs/model/cotracker1_offline.yaml ================================================ # @package _global_ defaults: - default model: _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter model: _target_: mvtracker.models.core.monocular_baselines.CoTrackerOfflineWrapper model_name: cotracker2v1 grid_size: 10 ================================================ FILE: configs/model/cotracker1_online.yaml ================================================ # @package _global_ defaults: - default model: _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter model: _target_: mvtracker.models.core.monocular_baselines.CoTrackerOnlineWrapper model_name: cotracker2v1_online grid_size: 10 ================================================ FILE: configs/model/cotracker2_offline.yaml ================================================ # @package _global_ defaults: - default model: _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter model: _target_: mvtracker.models.core.monocular_baselines.CoTrackerOfflineWrapper model_name: cotracker2 grid_size: 10 ================================================ FILE: configs/model/cotracker2_online.yaml ================================================ # @package _global_ defaults: - default model: _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter model: _target_: mvtracker.models.core.monocular_baselines.CoTrackerOnlineWrapper model_name: cotracker2_online grid_size: 10 ================================================ FILE: configs/model/cotracker3_offline.yaml ================================================ # @package _global_ defaults: - default model: _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter model: _target_: mvtracker.models.core.monocular_baselines.CoTrackerOfflineWrapper model_name: cotracker3_offline grid_size: 10 ================================================ FILE: configs/model/cotracker3_online.yaml ================================================ # @package _global_ defaults: - default model: _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter model: _target_: mvtracker.models.core.monocular_baselines.CoTrackerOnlineWrapper model_name: cotracker3_online grid_size: 10 ================================================ FILE: configs/model/default.yaml ================================================ # @package _global_ model: _target_: ??? trainer: train_iters: 4 evaluation: eval_iters: 4 interp_shape: null predictor_settings: kubric: visibility_threshold: 0.9 grid_size: 0 n_grids_per_view: 1 local_grid_size: 0 local_extent: 50 sift_size: 0 num_uniformly_sampled_pts: 0 dex_ycb: visibility_threshold: 0.9 grid_size: 0 n_grids_per_view: 1 local_grid_size: 0 local_extent: 50 sift_size: 0 num_uniformly_sampled_pts: 0 panoptic: visibility_threshold: 0.9 grid_size: 0 n_grids_per_view: 1 local_grid_size: 0 local_extent: 50 sift_size: 0 num_uniformly_sampled_pts: 0 tapvid2d-davis: visibility_threshold: 0.9 grid_size: 0 n_grids_per_view: 1 local_grid_size: 0 local_extent: 50 sift_size: 0 num_uniformly_sampled_pts: 0 generic: visibility_threshold: 0.9 grid_size: 0 n_grids_per_view: 1 local_grid_size: 0 local_extent: 50 sift_size: 0 num_uniformly_sampled_pts: 0 ================================================ FILE: configs/model/delta.yaml ================================================ # @package _global_ defaults: - default model: _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter model: _target_: mvtracker.models.core.monocular_baselines.DELTAWrapper ckpt: checkpoints/densetrack3d.pth upsample_factor: 4 grid_size: 20 return_2d_track: false ================================================ FILE: configs/model/locotrack.yaml ================================================ # @package _global_ defaults: - default model: _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter model: _target_: mvtracker.models.core.monocular_baselines.LocoTrackWrapper model_size: base evaluation: interp_shape: [ 256, 256 ] ================================================ FILE: configs/model/mvtracker.yaml ================================================ # @package _global_ defaults: - default model: _target_: mvtracker.models.core.mvtracker.mvtracker.MVTracker sliding_window_len: 12 stride: 4 normalize_scene_in_fwd_pass: false fmaps_dim: 128 add_space_attn: true num_heads: 6 hidden_size: 256 space_depth: 6 time_depth: 6 num_virtual_tracks: 64 use_flash_attention: true corr_n_groups: 1 corr_n_levels: 4 corr_neighbors: 16 corr_add_neighbor_offset: true corr_add_neighbor_xyz: false corr_filter_invalid_depth: false # slower, but would make sure points with invalid depth are not considered in corr evaluation: interp_shape: [ 384, 512 ] predictor_settings: kubric: visibility_threshold: 0.5 grid_size: 4 local_grid_size: 18 dex_ycb: visibility_threshold: 0.01 grid_size: 4 local_grid_size: 18 panoptic: visibility_threshold: 0.01 grid_size: 6 local_grid_size: 18 tapvid2d-davis: visibility_threshold: 0.01 grid_size: 6 n_grids_per_view: 6 local_grid_size: 0 local_extent: 50 sift_size: 0 num_uniformly_sampled_pts: 0 generic: visibility_threshold: 0.01 grid_size: 4 local_grid_size: 18 trainer: precision: bf16-mixed ================================================ FILE: configs/model/scenetracker.yaml ================================================ # @package _global_ defaults: - default model: _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter model: _target_: mvtracker.models.core.monocular_baselines.SceneTrackerWrapper ckpt: checkpoints/scenetracker-odyssey-200k.pth return_2d_track: false evaluation: interp_shape: [ 384, 512 ] ================================================ FILE: configs/model/spatialtrackerv2.yaml ================================================ # @package _global_ defaults: - default model: _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter model: _target_: mvtracker.models.core.monocular_baselines.SpaTrackerV2Wrapper model_type: online # or offline, whichever is better on a specific dataset vo_points: 756 evaluation: predictor_settings: kubric: visibility_threshold: 0.01 dex_ycb: visibility_threshold: 0.01 panoptic: visibility_threshold: 0.01 ================================================ FILE: configs/model/spatracker_monocular.yaml ================================================ # @package _global_ defaults: - default model: _target_: mvtracker.models.core.spatracker.spatracker_monocular.SpaTrackerMultiViewAdapter sliding_window_len: 12 stride: 4 add_space_attn: true num_heads: 8 hidden_size: 384 space_depth: 6 time_depth: 6 triplane_zres: 128 evaluation: interp_shape: [ 512, 512 ] # This checkpoint was trained on 512x512 Kubric sequences predictor_settings: kubric: visibility_threshold: 0.5 grid_size: 4 local_grid_size: 18 dex_ycb: visibility_threshold: 0.5 grid_size: 0 local_grid_size: 18 panoptic: visibility_threshold: 0.5 grid_size: 4 local_grid_size: 18 #restore_ckpt_path: checkpoints/spatracker_monocular_trained-on-kubric-depth_069800.pth #restore_ckpt_path: checkpoints/spatracker_monocular_trained-on-duster-depth_090800.pth ================================================ FILE: configs/model/spatracker_monocular_pretrained.yaml ================================================ # @package _global_ defaults: - default model: _target_: mvtracker.models.core.spatracker.spatracker_monocular.SpaTrackerMultiViewAdapter sliding_window_len: 12 stride: 4 add_space_attn: true num_heads: 8 hidden_size: 384 space_depth: 6 time_depth: 6 triplane_zres: 128 evaluation: interp_shape: [ 384, 512 ] predictor_settings: kubric: visibility_threshold: 0.9 grid_size: 4 local_grid_size: 18 dex_ycb: visibility_threshold: 0.9 grid_size: 4 local_grid_size: 18 panoptic: visibility_threshold: 0.9 grid_size: 4 local_grid_size: 18 #restore_ckpt_path: checkpoints/spatracker_monocular_original-authors-ckpt.pth ================================================ FILE: configs/model/spatracker_multiview.yaml ================================================ # @package _global_ defaults: - default model: _target_: mvtracker.models.core.spatracker.spatracker_multiview.MultiViewSpaTracker sliding_window_len: 12 stride: 4 add_space_attn: true use_3d_pos_embed: true remove_zeromlpflow: true concat_triplane_features: true num_heads: 8 hidden_size: 384 space_depth: 6 time_depth: 6 fmaps_dim: 128 triplane_xres: 128 triplane_yres: 128 triplane_zres: 128 evaluation: interp_shape: [ 512, 512 ] # This checkpoint was trained on 512x512 Kubric sequences predictor_settings: kubric: visibility_threshold: 0.5 grid_size: 4 local_grid_size: 18 dex_ycb: visibility_threshold: 0.01 grid_size: 4 local_grid_size: 18 panoptic: visibility_threshold: 0.01 grid_size: 4 local_grid_size: 18 #restore_ckpt_path: checkpoints/spatracker_multiview_trained-on-kubric-depth_100000.pth #model: # triplane_xres: 128 # triplane_yres: 128 # triplane_zres: 128 #restore_ckpt_path: checkpoints/spatracker_multiview_trained-on-duster-depth_100000.pth #model: # triplane_xres: 256 # triplane_yres: 256 # triplane_zres: 128 ================================================ FILE: configs/model/tapip3d.yaml ================================================ # @package _global_ defaults: - default model: _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter model: _target_: mvtracker.models.core.monocular_baselines.TAPIP3DWrapper ckpt: checkpoints/tapip3d_final.pth num_iters: 6 grid_size: 8 resolution_factor: 1 # --> [ 384, 512 ] # resolution_factor: 2 # --> [ 543, 724 ] evaluation: interp_shape: [ 384, 512 ] # --> resolution_factor = 1 # interp_shape: [ 543, 724 ] # --> resolution_factor = 2 predictor_settings: kubric: visibility_threshold: 0.01 dex_ycb: visibility_threshold: 0.01 panoptic: visibility_threshold: 0.01 ================================================ FILE: configs/train.yaml ================================================ defaults: - _self_ - model: mvtracker experiment_path: ??? # where to store checkpoints, visualizations, etc. restore_ckpt_path: null # resume from checkpoint # === Datasets === datasets: root: ./datasets train: name: kubric-multiview-v3-training batch_size: 1 sequence_len: 24 # frames per sequence traj_per_sample: 2048 # number of 3D points/trajectories per sample max_videos: null # takes all training videos by default kubric_max_depth: 24 num_workers: 8 eval: names: - panoptic-multiview-views1_7_14_20 - kubric-multiview-v3-overfit-on-training - kubric-multiview-v3-views0123 - kubric-multiview-v3-duster0123 - dex-ycb-multiview - dex-ycb-multiview-duster0123 num_workers: 4 max_seq_len: 1000 # === Trainer Settings === trainer: num_steps: 200000 eval_freq: 10000 viz_freq: 10000 save_ckpt_freq: 500 lr: 0.0005 gamma: 0.8 wdecay: 0.00001 anneal_strategy: linear grad_clip: 1.0 precision: 16-mixed # training precision (e.g., 16-mixed, bf16-mixed or 32-true) visibility_loss_weight: 0.1 augment_train_iters: false augment_train_iters_warmup: 2000 # === Evaluation Settings === evaluation: consume_model_stats: false # whether to report model stats (which can slow down the forward pass) evaluator: _target_: mvtracker.evaluation.evaluator_3dpt.Evaluator rerun_viz_indices: null forward_pass_log_indices: null mp4_track_viz_indices: [0] # === Execution Modes === modes: debug: false # enable for quick iteration tune_per_scene: false # overfit to single scene (debugging) validate_at_start: false # run eval before train starts do_initial_static_pretrain: false # run static-only phase first pretrain_only: false # stop after static pretraining eval_only: false # skip training, just run evaluation debugging_hotfix_datapoint_path: null # path to a dumped datapoint (no need to set debug flag) # === Reproducibility === reproducibility: # Note that reproducibility will not work if # floating point precision is set to 16-mixed, # but with 32 it will. Note also that the number # of data loading workers (num_workers) might # affect reproducibility as well. The number of # GPUs surely affects reproducibility. seed: 36 deterministic: false # speeds up training at expense of determinism # === Augmentations === augmentations: probability: 0.8 rgb: true depth: true variable_depth_type: true variable_num_views: true cropping: true cropping_size: [384, 512] variable_vggt_crop_size: false keep_principal_point_centered: false variable_trajpersample: true scene_transform: true camera_params_noise: true normalize_scene_following_vggt: false # === Logging === logging: log_wandb: false wandb_project: mvtracker-ablation tags: ["kubric", "3dpt", "multiview"] # === Extras === extras: print_config: true # pretty print config tree at the start ignore_warnings: false # disable python warnings if they annoy you enable_faulthandler_traceback: false # enable traceback dump on timeout for debugging of main process hanging faulthandler_traceback_timeout: 600 # timeout in seconds before dumping traceback (e.g. 600 = 10 min) # === Hydra Settings === hydra: run: dir: ${experiment_path} ================================================ FILE: demo.py ================================================ import argparse import os import warnings import numpy as np import rerun as rr # pip install rerun-sdk==0.21.0 import torch from huggingface_hub import hf_hub_download from mvtracker.utils.visualizer_rerun import log_pointclouds_to_rerun, log_tracks_to_rerun def main(): p = argparse.ArgumentParser() p.add_argument( "--rerun", choices=["save", "spawn", "stream"], default="save", help=( "Whether to save recording to disk, spawn a new Rerun instance, or stream to an existing one. " "If 'spawn', make sure a rerun window can be spawned in your environment. " "If 'stream', make sure a rerun instance is running at port 9876. " "If 'save', the recording will be saved to a `.rrd` file that can be drag-and-dropped into " "a running rerun viewer, including the online viewer at https://app.rerun.io/version/0.21.0. " "For the online viewer, you want to create low memory-usage recordings with --lightweight." ), ) p.add_argument( "--lightweight", action="store_true", help=( "Use lightweight rerun logging (less memory usage). This is recommended if you want to " "view the recording in the online Rerun viewer at https://app.rerun.io/version/0.21.0." ), ) p.add_argument( "--random_query_points", action="store_true", help="Use random query points instead of demo ones.", ) p.add_argument( "--rrd", default="mvtracker_demo.rrd", help=( "Path to save a .rrd file if `--rerun save` is used. " "Note that rerun prefers recordings to have a .rrd suffix." ), ) args = p.parse_args() np.random.seed(72) torch.manual_seed(72) device = "cuda" if torch.cuda.is_available() else "cpu" # Load MVTracker predictor mvtracker = torch.hub.load("ethz-vlg/mvtracker", "mvtracker", pretrained=True, device=device) # Download demo sample from Hugging Face Hub sample_path = hf_hub_download( repo_id="ethz-vlg/mvtracker", filename="data_sample.npz", token=os.getenv("HF_TOKEN"), repo_type="model", ) sample = np.load(sample_path) rgbs = torch.from_numpy(sample["rgbs"]).float() depths = torch.from_numpy(sample["depths"]).float() intrs = torch.from_numpy(sample["intrs"]).float() extrs = torch.from_numpy(sample["extrs"]).float() query_points = torch.from_numpy(sample["query_points"]).float() # Optionally, sample random queries in a cylinder of radius 12, height [-1, +10] and replace the demo queries if args.random_query_points: from mvtracker.models.core.model_utils import init_pointcloud_from_rgbd num_queries = 512 t0 = 0 xy_radius = 12.0 z_min, z_max = -1.0, 10.0 xyz, _ = init_pointcloud_from_rgbd( fmaps=rgbs[None], # [1,V,T,1,H,W], uint8 0–255 depths=depths[None], # [1,V,T,1,H,W] intrs=intrs[None], # [1,V,T,3,3] extrs=extrs[None], # [1,V,T,3,4] stride=1, level=0, ) pts = xyz[t0] # [V*H*W, 3] at t=0 assert pts.numel() > 0, "No valid depth points to sample queries from." r2 = pts[:, 0] ** 2 + pts[:, 1] ** 2 mask = (r2 <= xy_radius ** 2) & (pts[:, 2] >= z_min) & (pts[:, 2] <= z_max) pool = pts[mask] assert pool.shape[0] > 0, "Cylinder mask removed all points; increase radius or z-range." idx = torch.randperm(pool.shape[0])[:num_queries] pts = pool[idx] ts = torch.full((pts.shape[0], 1), float(t0), device=pts.device) query_points = torch.cat([ts, pts], dim=1).float() # (N,4): (t,x,y,z) print(f"Sampled {pts.shape[0]} queries from depth at t={t0} within r<={xy_radius}, z∈[{z_min},{z_max}].") # Run prediction torch.set_float32_matmul_precision("high") amp_dtype = torch.bfloat16 if (device == "cuda" and torch.cuda.get_device_capability()[0] >= 8) else torch.float16 with torch.no_grad(), torch.cuda.amp.autocast(enabled=device == "cuda", dtype=amp_dtype): results = mvtracker( rgbs=rgbs[None].to(device) / 255.0, depths=depths[None].to(device), intrs=intrs[None].to(device), extrs=extrs[None].to(device), query_points_3d=query_points[None].to(device), ) pred_tracks = results["traj_e"].cpu() # [T,N,3] pred_vis = results["vis_e"].cpu() # [T,N] # Visualize results rr.init("3dpt", recording_id="v0.16") if args.rerun == "stream": rr.connect_tcp() elif args.rerun == "spawn": rr.spawn() log_pointclouds_to_rerun( dataset_name="demo", datapoint_idx=0, rgbs=rgbs[None], depths=depths[None], intrs=intrs[None], extrs=extrs[None], depths_conf=None, conf_thrs=[5.0], log_only_confident_pc=False, radii=-2.45, fps=12, bbox_crop=None, sphere_radius_crop=12.0, sphere_center_crop=np.array([0, 0, 0]), log_rgb_image=False, log_depthmap_as_image_v1=False, log_depthmap_as_image_v2=False, log_camera_frustrum=True, log_rgb_pointcloud=True, ) log_tracks_to_rerun( dataset_name="demo", datapoint_idx=0, predictor_name="MVTracker", gt_trajectories_3d_worldspace=None, gt_visibilities_any_view=None, query_points_3d=query_points[None], pred_trajectories=pred_tracks, pred_visibilities=pred_vis, per_track_results=None, radii_scale=1.0, fps=12, sphere_radius_crop=12.0, sphere_center_crop=np.array([0, 0, 0]), log_per_interval_results=False, max_tracks_to_log=100 if args.lightweight else None, track_batch_size=50, method_id=None, color_per_method_id=None, memory_lightweight_logging=args.lightweight, ) if args.rerun == "save": rr.save(args.rrd) print(f"Saved Rerun recording to: {os.path.abspath(args.rrd)}") if __name__ == "__main__": warnings.filterwarnings("ignore", message=".*DtypeTensor constructors are no longer.*", module="pointops.query") warnings.filterwarnings("ignore", message=".*Plan failed with a cudnnException.*", module="torch.nn.modules.conv") main() ================================================ FILE: hubconf.py ================================================ # Copyright (c) ETH VLG. # Licensed under the terms in the LICENSE file at the root of this repo. from pathlib import Path import os import torch _WEIGHTS = { "mvtracker_main": "hf://ethz-vlg/mvtracker::mvtracker_200000_june2025.pth", "mvtracker_cleandepth": "hf://ethz-vlg/mvtracker::mvtracker_200000_june2025_cleandepth.pth", } def _load_ckpt(spec: str): if spec.startswith("http"): return torch.hub.load_state_dict_from_url(spec, map_location="cpu") if spec.startswith("hf://"): from huggingface_hub import hf_hub_download repo_id, filename = spec[len("hf://"):].split("::", 1) path = hf_hub_download(repo_id=repo_id, filename=filename, token=os.getenv("HF_TOKEN")) return torch.load(path, map_location="cpu") path = Path(spec).expanduser().resolve() return torch.load(str(path), map_location="cpu") def _extract_model_state(sd): """ Accept: - plain state dict - {'state_dict': ...} - {'model': ..., 'optimizer': ..., 'scheduler': ..., 'total_steps': ...} Returns a clean model state_dict. """ if isinstance(sd, dict): if "state_dict" in sd and isinstance(sd["state_dict"], dict): sd = sd["state_dict"] elif "model" in sd and isinstance(sd["model"], dict): sd = sd["model"] # Strip optional "model." prefix sd = {k.replace("model.", "", 1): v for k, v in sd.items()} return sd def _build_model(**overrides): from mvtracker.models.core.mvtracker.mvtracker import MVTracker cfg = dict( sliding_window_len=12, stride=4, normalize_scene_in_fwd_pass=False, fmaps_dim=128, add_space_attn=True, num_heads=6, hidden_size=256, space_depth=6, time_depth=6, num_virtual_tracks=64, use_flash_attention=True, corr_n_groups=1, corr_n_levels=4, corr_neighbors=16, corr_add_neighbor_offset=True, corr_add_neighbor_xyz=False, corr_filter_invalid_depth=False, ) cfg.update(overrides) return MVTracker(**cfg) def _load_into(model, checkpoint_key: str): raw = _load_ckpt(_WEIGHTS[checkpoint_key]) sd = _extract_model_state(raw) missing, unexpected = model.load_state_dict(sd, strict=False) if unexpected: raise RuntimeError(f"Unexpected keys in state_dict: {unexpected}") return model def mvtracker_model(*, pretrained: bool = False, device: str = "cuda", checkpoint: str = "mvtracker_main", **model_kwargs): """ Return a bare MVTracker nn.Module. - pretrained=False: random init with model_kwargs. - pretrained=True : load from _WEIGHTS[checkpoint], then .eval(). """ model = _build_model(**model_kwargs).to(device) if pretrained: model = _load_into(model, checkpoint) model.eval() return model def mvtracker_predictor(*, pretrained: bool = True, device: str = "cuda", checkpoint: str = "mvtracker_main", model_kwargs: dict | None = None, predictor_kwargs: dict | None = None): """ Return EvaluationPredictor wrapped around MVTracker. Pass model configuration via `model_kwargs={...}` (matches MVTracker.__init__). Pass predictor configuration via `predictor_kwargs={...}`: - interp_shape, visibility_threshold, grid_size, n_grids_per_view, local_grid_size, local_extent, sift_size, num_uniformly_sampled_pts, n_iters """ from mvtracker.models.evaluation_predictor_3dpt import EvaluationPredictor model_kwargs = {} if model_kwargs is None else dict(model_kwargs) predictor_kwargs = {} if predictor_kwargs is None else dict(predictor_kwargs) predictor_defaults = dict( interp_shape=(384, 512), visibility_threshold=0.5, grid_size=4, n_grids_per_view=1, local_grid_size=18, local_extent=50, sift_size=0, num_uniformly_sampled_pts=0, n_iters=6, ) pk = {**predictor_defaults, **predictor_kwargs} model = mvtracker_model(pretrained=pretrained, device=device, checkpoint=checkpoint, **model_kwargs) return EvaluationPredictor(multiview_model=model, **pk) def mvtracker(pretrained: bool = True, device: str = "cuda"): """Default public endpoint: predictor with main checkpoint.""" return mvtracker_predictor(pretrained=pretrained, device=device, checkpoint="mvtracker_main") def mvtracker_cleandepth(pretrained: bool = True, device: str = "cuda"): """Predictor with 'clean depth only' checkpoint.""" return mvtracker_predictor(pretrained=pretrained, device=device, checkpoint="mvtracker_cleandepth") ================================================ FILE: mvtracker/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: mvtracker/cli/__init__.py ================================================ ================================================ FILE: mvtracker/cli/eval.py ================================================ import hydra from omegaconf import DictConfig from mvtracker.cli.train import main as train_main @hydra.main(version_base="1.3", config_path="../../configs", config_name="eval") def main(cfg: DictConfig): train_main(cfg) if __name__ == "__main__": main() ================================================ FILE: mvtracker/cli/train.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch torch.set_float32_matmul_precision('high') from lightning.fabric.wrappers import _unwrap_objects from mvtracker.datasets.generic_scene_dataset import GenericSceneDataset from torch.utils.tensorboard import SummaryWriter import gpustat import json import threading import warnings from pathlib import Path import hydra import numpy as np import pandas as pd import torch.optim as optim import wandb from lightning.fabric import Fabric from lightning.fabric.utilities import AttributeDict from omegaconf import DictConfig, OmegaConf from torch import nn from torch.utils.data import DataLoader from tqdm import tqdm import signal, sys from mvtracker.datasets import KubricMultiViewDataset from mvtracker.datasets import TapVidDataset from mvtracker.datasets import kubric_multiview_dataset from mvtracker.datasets.dexycb_multiview_dataset import DexYCBMultiViewDataset from mvtracker.datasets.panoptic_studio_multiview_dataset import PanopticStudioMultiViewDataset from mvtracker.datasets.utils import collate_fn, dataclass_to_cuda_ from mvtracker.models.core.losses import balanced_ce_loss, sequence_loss_3d from mvtracker.models.core.model_utils import world_space_to_pixel_xy_and_camera_z, pixel_xy_and_camera_z_to_world_space from mvtracker.models.evaluation_predictor_3dpt import EvaluationPredictor as EvaluationPredictor3D from mvtracker.utils.visualizer_mp4 import MultiViewVisualizer, Visualizer from mvtracker.cli.utils import extras from mvtracker.cli.utils.helpers import maybe_close_wandb import logging import os import torch import time from collections import deque from torchdata.stateful_dataloader import StatefulDataLoader def fetch_optimizer(trainer_cfg, model): """Create the optimizer and learning rate scheduler""" optimizer = optim.AdamW(model.parameters(), lr=trainer_cfg.lr, weight_decay=trainer_cfg.wdecay) if trainer_cfg.anneal_strategy in ["linear", "cos"]: scheduler = optim.lr_scheduler.OneCycleLR( optimizer, trainer_cfg.lr, trainer_cfg.num_steps + 100, pct_start=0.05, cycle_momentum=False, anneal_strategy=trainer_cfg.anneal_strategy, ) elif trainer_cfg.anneal_strategy == "restarts": scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=5000, T_mult=1, eta_min=trainer_cfg.lr / 1000, ) return optimizer, scheduler def forward_batch_multi_view(batch, model, cfg, step, train_iters, gamma, save_debug_logs=False, debug_logs_path=''): # Per view data rgbs = batch.video depths = batch.videodepth image_features = batch.feats intrs = batch.intrs extrs = batch.extrs gt_trajectories_2d_pixelspace_w_z_cameraspace = batch.trajectory gt_visibilities_per_view = batch.visibility query_points_3d = batch.query_points_3d # Non-per-view data gt_trajectories_3d_worldspace = batch.trajectory_3d valid_tracks_per_frame = batch.valid track_upscaling_factor = batch.track_upscaling_factor batch_size, num_views, num_frames, _, height, width = rgbs.shape num_points = gt_trajectories_2d_pixelspace_w_z_cameraspace.shape[3] # Assert shapes of per-view data assert rgbs.shape == (batch_size, num_views, num_frames, 3, height, width) assert depths.shape == (batch_size, num_views, num_frames, 1, height, width) assert intrs.shape == (batch_size, num_views, num_frames, 3, 3) assert extrs.shape == (batch_size, num_views, num_frames, 3, 4) assert gt_trajectories_2d_pixelspace_w_z_cameraspace.shape == (batch_size, num_views, num_frames, num_points, 3) assert gt_visibilities_per_view.shape == (batch_size, num_views, num_frames, num_points) # Assert shapes of non-per-view data assert query_points_3d.shape == (batch_size, num_points, 4) assert gt_trajectories_3d_worldspace.shape == (batch_size, num_frames, num_points, 3) assert valid_tracks_per_frame.shape == (batch_size, num_frames, num_points) gt_visibilities_any_view = gt_visibilities_per_view.any(dim=1) assert gt_visibilities_any_view.any(dim=1).all(), "All points should be visible at in least one frame." for batch_idx in range(batch_size): for point_idx in range(num_points): t = query_points_3d[batch_idx, point_idx, 0].long().item() valid_tracks_per_frame[batch_idx, :t, point_idx] = False # Run the model results = model( rgbs=rgbs, depths=depths, image_features=image_features, query_points=query_points_3d, iters=train_iters, is_train=True, intrs=intrs, extrs=extrs, save_debug_logs=save_debug_logs, debug_logs_path=debug_logs_path, ) pred_trajectories = results["traj_e"] pred_visibilities = results["vis_e"] vis_predictions = results["train_data"]["vis_predictions"] coord_predictions = results["train_data"]["coord_predictions"] p_idx_end_list = results["train_data"]["p_idx_end_list"] sort_inds = results["train_data"]["sort_inds"] # Prepare the ground truth for the loss functions, # which expect the data to be in the sliding-window vis_gts = [] traj_gts = [] valids_gts = [] query_points_t_min = query_points_3d[:, :, 0].long().min() for i, wind_p_idx_end in enumerate(p_idx_end_list): gt_visibilities_any_view_sorted = gt_visibilities_any_view[:, :, sort_inds] gt_trajectories_3d_worldspace_sorted = gt_trajectories_3d_worldspace[:, :, sort_inds] valid_tracks_per_frame_sorted = valid_tracks_per_frame[:, :, sort_inds] ind = query_points_t_min + i * (cfg.model.sliding_window_len // 2) vis_gts.append(gt_visibilities_any_view_sorted[:, ind: ind + cfg.model.sliding_window_len, :wind_p_idx_end]) traj_gts.append( gt_trajectories_3d_worldspace_sorted[:, ind: ind + cfg.model.sliding_window_len, :wind_p_idx_end]) valids_gts.append(valid_tracks_per_frame_sorted[:, ind: ind + cfg.model.sliding_window_len, :wind_p_idx_end]) # Compute the losses logging.info(f"[DEBUG] " f"{step=} " f"{track_upscaling_factor=} " f"{coord_predictions[0][0][0, 0, 0]=} " f"{coord_predictions[-1][0][0, 0, 0]=} " f"{vis_predictions[0][0, 0, 0]=} " f"{vis_predictions[-1][0, 0, 0]=}") xyz_loss = sequence_loss_3d(coord_predictions, traj_gts, vis_gts, valids_gts, gamma) * track_upscaling_factor vis_loss = balanced_ce_loss(vis_predictions, vis_gts, valids_gts) # Compute 3DPT metrics # eval_3dpt_results_dict = evaluate_3dpt( # gt_tracks=gt_trajectories_3d_worldspace[0].cpu().numpy(), # gt_visibilities=gt_visibilities_any_view[0].cpu().numpy(), # pred_tracks=pred_trajectories[0].detach().cpu().numpy(), # pred_visibilities=(pred_visibilities[0] > 0.5).detach().cpu().numpy(), # evaluation_setting="kubric-multiview", # track_upscaling_factor=track_upscaling_factor, # prefix="train_3dpt", # verbose=False, # query_points=query_points_3d[0].cpu().numpy(), # ) # Invert the intrinsics and extrinsics matrices intrs_inv = torch.inverse(intrs.float()) extrs_square = torch.eye(4).to(extrs.device)[None].repeat(batch_size, num_views, num_frames, 1, 1) extrs_square[:, :, :, :3, :] = extrs extrs_inv = torch.inverse(extrs_square.float()) # Project the predictions to pixel space pred_trajectories = pred_trajectories[0].detach() pred_trajectories_pixel_xy_camera_z_per_view = torch.stack([ torch.cat(world_space_to_pixel_xy_and_camera_z( world_xyz=pred_trajectories, intrs=intrs[0, view_idx], extrs=extrs[0, view_idx], ), dim=-1) for view_idx in range(num_views) ], dim=0) for view_idx in range(num_views): pred_trajectories_reproduced = pixel_xy_and_camera_z_to_world_space( pixel_xy=pred_trajectories_pixel_xy_camera_z_per_view[view_idx, :, :, :2], camera_z=pred_trajectories_pixel_xy_camera_z_per_view[view_idx, :, :, 2:], intrs_inv=intrs_inv[0, view_idx], extrs_inv=extrs_inv[0, view_idx], ) if not torch.allclose(pred_trajectories_reproduced, pred_trajectories, atol=1): warnings.warn(f"Reprojection of the predicted trajectories failed: " f"view_idx={view_idx}, " f"max_diff={torch.max(torch.abs(pred_trajectories_reproduced - pred_trajectories))}") logging.info( f"{step=}, " f"seq={batch.seq_name}, " f"{xyz_loss.item()=}, " f"{vis_loss.item()=}, " ) output = { "flow": { "loss": xyz_loss * 1.0, "predictions": pred_trajectories_pixel_xy_camera_z_per_view, "predictions_worldspace": pred_trajectories, }, "visibility": { "loss": vis_loss * cfg.trainer.visibility_loss_weight, "predictions": pred_visibilities[0].detach(), }, # "metrics": { # k: v # for k, v in eval_3dpt_results_dict.items() # if "per_track" not in k # }, } return output def run_test_eval(cfg, evaluator, model, dataloaders, writer, step): if len(dataloaders) == 0: return logging.info(f"Eval – GPU usage A: {gpustat.new_query()}") log_dir = cfg.experiment_path model.eval() for ds_name, dataloader in dataloaders: if ds_name.startswith("kubric"): predictor_settings = cfg.evaluation.predictor_settings["kubric"] elif ds_name.startswith("dex-ycb"): predictor_settings = cfg.evaluation.predictor_settings["dex_ycb"] elif ds_name.startswith("panoptic"): predictor_settings = cfg.evaluation.predictor_settings["panoptic"] elif ds_name.startswith("tapvid2d-davis"): predictor_settings = cfg.evaluation.predictor_settings["tapvid2d-davis"] else: predictor_settings = cfg.evaluation.predictor_settings["generic"] logging.info(f"Using generic predictor settings for dataset with name {ds_name}") predictor = EvaluationPredictor3D( multiview_model=model, interp_shape=cfg.evaluation.interp_shape, single_point="single" in ds_name, n_iters=cfg.evaluation.eval_iters, **predictor_settings ) log_dir_ds = os.path.join(log_dir, f"eval_{ds_name}") os.makedirs(log_dir_ds, exist_ok=True) if cfg.evaluation.consume_model_stats and hasattr(model, "init_stats"): model.init_stats() metrics = evaluator.evaluate_sequence( model=predictor, test_dataloader=dataloader, dataset_name=ds_name, writer=writer, step=step, log_dir=log_dir_ds, ) if cfg.evaluation.consume_model_stats and hasattr(model, "consume_stats"): model.consume_stats() metrics_to_log = { k: np.nanmean([v[k] for v in metrics.values() if k in v]).round(2) for k in metrics[0].keys() } for k, v in metrics_to_log.items(): writer.add_scalar(k, v, step) with pd.option_context( 'display.max_rows', None, 'display.max_columns', None, 'display.max_colwidth', None, 'display.width', None, ): logging.info(f"Per-sequence Metrics for {ds_name}: {pd.DataFrame(metrics)}") logging.info(f"Average metrics for {ds_name}: {json.dumps(metrics_to_log, indent=4)}") # Save metrics to csv if log_dir_ds is not None: df = pd.DataFrame(metrics) df = df.T assert df.map(lambda x: (len(x) == 1) if isinstance(x, np.ndarray) else True).all().all() df = df.map(lambda x: x[0] if isinstance(x, np.ndarray) or isinstance(x, list) else x) df.to_csv(f"{log_dir_ds}/step-{step}_metrics.csv") df = pd.DataFrame(metrics_to_log, index=["score"]) df = df.T df.to_csv(f"{log_dir_ds}/step-{step}_metrics_avg.csv") logging.info(f"Saved metrics to {log_dir_ds}/step-{step}_metrics_avg.csv") # logging.info(f"Eval – GPU usage (after {ds_name}): {gpustat.new_query()}") # logging.info(f"Eval – GPU usage B: {gpustat.new_query()}") del predictor del metrics # logging.info(f"Eval – GPU usage C: {gpustat.new_query()}") torch.cuda.empty_cache() # logging.info(f"Eval – GPU usage D: {gpustat.new_query()}") model.train() def augment_train_iters(train_iters: int, current_step: int, warmup_steps: int = 1000) -> int: """ Adaptive iteration scheduler with warmup: - During warmup_steps: always return 1 - After warmup: - 10% chance: return 1 - 15% chance: return random int in [2, train_iters - 1] - 75% chance: return train_iters """ if current_step < warmup_steps or train_iters <= 1: return 1 rng = torch.Generator().manual_seed(current_step) p = torch.rand(1, generator=rng).item() if p < 0.10: return 1 elif p < 0.25 and train_iters > 2: mid_candidates = list(range(2, train_iters)) idx = torch.randint(len(mid_candidates), (1,), generator=rng).item() return mid_candidates[idx] else: return train_iters @hydra.main(version_base="1.3", config_path="../../configs", config_name="train.yaml") @maybe_close_wandb def main(cfg: DictConfig): """Main entry point for training. :param cfg: DictConfig configuration composed by Hydra. :return: Optional[float] with optimized metric value. """ extras(cfg) Path(cfg.experiment_path).mkdir(exist_ok=True, parents=True) num_nodes = int(os.environ.get("SLURM_JOB_NUM_NODES", 1)) devices = int(os.environ.get("SLURM_GPUS_PER_NODE", torch.cuda.device_count())) logging.info(f"SLURM job num nodes: {num_nodes}") logging.info(f"SLURM tasks per node (devices): {devices}") from lightning.fabric.strategies import DDPStrategy fabric = Fabric( num_nodes=num_nodes, devices=devices, precision=cfg.trainer.precision, strategy=DDPStrategy(find_unused_parameters=True), ) fabric.launch() fabric.seed_everything(cfg.reproducibility.seed, workers=True) if cfg.reproducibility.deterministic: torch.use_deterministic_algorithms(True) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.autograd.set_detect_anomaly(True) if cfg.logging.get("log_wandb", False) and fabric.global_rank == 0: exp_name = cfg.experiment_path.replace("./logs/", "").replace("/", "_").replace("\\", "_") wandb.init( project=cfg.logging.wandb_project, name=exp_name, tags=cfg.logging.get("tags", []), config=OmegaConf.to_container(cfg, resolve=True), sync_tensorboard=True, ) original_numpy = torch.Tensor.numpy def patched_numpy(self, *args, **kwargs): if self.dtype == torch.bfloat16: return original_numpy(self.float(), *args, **kwargs) return original_numpy(self, *args, **kwargs) torch.Tensor.numpy = patched_numpy eval_dataloaders = [] for dataset_name in cfg.datasets.eval.names: if dataset_name.startswith("tapvid2d-davis-"): eval_dataset = TapVidDataset.from_name(dataset_name, cfg.datasets.root) elif dataset_name.startswith("kubric-multiview-v3-25views"): kubric_kwargs = { "data_root": os.path.join(cfg.datasets.root, "kubric_multiview_003", "kubric_25_view"), "seq_len": 24, "traj_per_sample": 200, "seed": 72, "sample_vis_1st_frame": True, "tune_per_scene": False, "max_videos": 30, "use_duster_depths": False, "duster_views": None, "clean_duster_depths": False, "views_to_return": list(range(20)), "novel_views": list(range(20, 25)), "num_views": -1, "depth_noise_std": 0, } eval_dataset = KubricMultiViewDataset(**kubric_kwargs) elif dataset_name.startswith("kubric-multiview-v3"): eval_dataset = KubricMultiViewDataset.from_name(dataset_name, cfg.datasets.root, cfg) elif dataset_name.startswith("panoptic-multiview"): eval_dataset = PanopticStudioMultiViewDataset.from_name(dataset_name, cfg.datasets.root) elif dataset_name.startswith("dex-ycb-multiview"): eval_dataset = DexYCBMultiViewDataset.from_name(dataset_name, cfg.datasets.root) elif dataset_name == "egoexo4d": eval_dataset = GenericSceneDataset( dataset_dir="datasets/egoexo4d-processed/maxframes-300_downsample-1_downscale-512/", drop_first_n_frames=44, ) elif dataset_name == "4d-dress": eval_dataset = GenericSceneDataset( dataset_dir="datasets/4d-dress-processed-resized-512-selection", use_duster_depths=False, ) elif dataset_name == "hi4d": eval_dataset = GenericSceneDataset( dataset_dir="datasets/hi4d-processed-resized-512", use_duster_depths=False, use_vggt_depths_with_aligned_cameras=True, ) elif dataset_name == "selfcap-v1": eval_dataset = GenericSceneDataset( dataset_dir="datasets/selfcap-processed/numcams-8-seq-False_startframe-90_maxframes-256_downsample-10_downscale-512/", drop_first_n_frames=72, ) elif dataset_name == "selfcap-v2": eval_dataset = GenericSceneDataset( dataset_dir="datasets/selfcap-processed/numcams-8-seq-True_startframe-90_maxframes-256_downsample-10_downscale-512/", drop_first_n_frames=72, ) elif dataset_name == "selfcap-v3": eval_dataset = GenericSceneDataset( dataset_dir="datasets/selfcap-processed/numcams-8-seq-False_startframe-90_maxframes-256_downsample-20_downscale-512/", drop_first_n_frames=36, ) elif dataset_name == "selfcap-v4": eval_dataset = GenericSceneDataset( dataset_dir="datasets/selfcap-processed/numcams-8-seq-False_startframe-90_maxframes-256_downsample-30_downscale-512/", drop_first_n_frames=24, ) elif dataset_name == "selfcap-v5": eval_dataset = GenericSceneDataset( dataset_dir="datasets/selfcap-processed/numcams-8-seq-False_startframe-90_maxframes-256_downsample-5_downscale-512/", drop_first_n_frames=144, ) elif dataset_name == "selfcap-v6": eval_dataset = GenericSceneDataset( dataset_dir="datasets/selfcap-processed/numcams-8-seq-False_startframe-90_maxframes-2560_downsample-10_downscale-512/", drop_first_n_frames=44, ) elif dataset_name == "selfcap-v7": eval_dataset = GenericSceneDataset( dataset_dir="datasets/selfcap-processed/numcams-4-seq-False_startframe-90_maxframes-256_downsample-10_downscale-512/", drop_first_n_frames=72, ) else: raise ValueError(f"Dataset {dataset_name} not supported for evaluation.") eval_dataloader = torch.utils.data.DataLoader( eval_dataset, batch_size=1, shuffle=False, num_workers=cfg.datasets.eval.num_workers, collate_fn=collate_fn, ) eval_dataloaders.append((dataset_name, eval_dataloader)) # # Let each rank handle a subset of the evaluation dataloaders # eval_dataloaders_for_rank = [] # for idx, (dset_name, dset_loader) in enumerate(eval_dataloaders): # if (idx % fabric.world_size) == fabric.global_rank: # eval_dataloaders_for_rank.append((dset_name, fabric.setup_dataloaders(dset_loader))) # eval_dataloaders = eval_dataloaders_for_rank train_viz_save_dir = os.path.join(cfg.experiment_path, f"train_{cfg.datasets.train.name}") os.makedirs(train_viz_save_dir, exist_ok=True) visualizer = MultiViewVisualizer( save_dir=train_viz_save_dir, pad_value=16, fps=12, show_first_frame=0, tracks_leave_trace=0, ) evaluator = hydra.utils.instantiate(cfg.evaluation.evaluator) if cfg.modes.do_initial_static_pretrain and not cfg.modes.eval_only: pretraining_datasets = [ kubric_multiview_dataset.KubricMultiViewDataset( data_root=os.path.join(cfg.datasets.root, "kubric_multiview_003", "train"), traj_per_sample=cfg.datasets.train.traj_per_sample, ratio_dynamic=0.1, ratio_very_dynamic=0.0, num_views=4, enable_cropping_augs=cfg.augmentations.cropping, seq_len=seq_len, static_cropping=static_cropping, max_videos=max_videos, ) for seq_len, static_cropping, max_videos in [ (12, True, 500), (18, True, 500), (24, True, 1000), (24, False, 2000), ] ] pretraining_dataset = torch.utils.data.ConcatDataset(pretraining_datasets) pretraining_dataloader = StatefulDataLoader( pretraining_dataset, batch_size=cfg.datasets.train.batch_size, shuffle=False, num_workers=cfg.datasets.train.num_workers, pin_memory=True, pin_memory_device="cuda", collate_fn=collate_fn, drop_last=True, in_order=cfg.reproducibility.deterministic, ) pretraining_dataloader = fabric.setup_dataloaders(pretraining_dataloader) else: pretraining_dataloader = None if cfg.modes.eval_only: train_dataset = None elif cfg.datasets.train.name.startswith("kubric-multiview-v3"): train_dataset = KubricMultiViewDataset.from_name(cfg.datasets.train.name, cfg.datasets.root, cfg, fabric) else: raise ValueError(f"Dataset {cfg.datasets.train.name} not supported for training") if not cfg.modes.eval_only: train_loader = StatefulDataLoader( train_dataset, batch_size=cfg.datasets.train.batch_size, shuffle=True, num_workers=cfg.datasets.train.num_workers, pin_memory=True, collate_fn=collate_fn, drop_last=True, prefetch_factor=4 if cfg.datasets.train.num_workers > 0 else None, in_order=cfg.reproducibility.deterministic, ) # eval_dataloaders += [("kubric-multiview-v3-training", train_loader)] train_loader = fabric.setup_dataloaders(train_loader) logging.info(f"LEN TRAIN LOADER={len(train_loader)}") num_epochs = cfg.trainer.num_steps // len(train_loader) + 1 + (1 if cfg.modes.do_initial_static_pretrain else 0) if cfg.modes.do_initial_static_pretrain: cfg.trainer.num_steps += len(pretraining_dataloader) else: train_loader = None num_epochs = None epoch = -1 total_steps = 0 model: nn.Module = hydra.utils.instantiate(cfg.model) model.cuda() optimizer, scheduler = fetch_optimizer(cfg.trainer, model) model, optimizer = fabric.setup(model, optimizer) folder_ckpts = [ f for f in os.listdir(cfg.experiment_path) if f.endswith(".pth") and not os.path.isdir(f) and not "final" in f and not "unwrap_model" in f and not "unwrap_module" in f ] logging.info(f"Found {len(folder_ckpts)} checkpoints: {folder_ckpts}") if len(folder_ckpts) > 0: # We can load this checkpoint directly since we have saved it during training ckpt_name = sorted(folder_ckpts)[-1] experiment_path = os.path.join(cfg.experiment_path, ckpt_name) state = AttributeDict( model=model, optimizer=optimizer, scheduler=scheduler, total_steps=total_steps, ) logging.info(f"Total steps before loading checkpoint: {total_steps}") fabric.load(experiment_path, state) total_steps = state.total_steps # Integers are immutable, so they cannot be changed inplace if train_loader is not None: epoch = total_steps // len(train_loader) - 1 logging.info(f"Loaded checkpoint {experiment_path} (total_steps={total_steps})") logging.info(f"Total steps after loading checkpoint: {total_steps}") elif cfg.restore_ckpt_path is not None: restore_ckpt_path = cfg.restore_ckpt_path assert restore_ckpt_path.endswith(".pth") logging.info(f"Restoring pre-trained weights from {os.path.abspath(restore_ckpt_path)}") training_ckpt = "total_steps" in torch.load(restore_ckpt_path) if training_ckpt: # Loading a checkpoint saved by fabric during training logging.info("Trying to load as a training checkpoint...") state = AttributeDict(model=model) try: fabric.load(restore_ckpt_path, state, strict=True) except RuntimeError as e: logging.warning(f"Failed to load weights with from {restore_ckpt_path} with strict=True: {e}. " f"Trying again with strict=False.") fabric.load(restore_ckpt_path, state, strict=False) logging.info(f"Loaded checkpoint {restore_ckpt_path}") else: fabric.load_raw(restore_ckpt_path, model) tb_writer = SummaryWriter(log_dir=os.path.join(cfg.experiment_path, f"runs_{fabric.global_rank}")) if cfg.modes.eval_only or cfg.modes.validate_at_start: run_test_eval(cfg, evaluator, model, eval_dataloaders, tb_writer, total_steps - 1) fabric.barrier() if cfg.modes.eval_only: return total_durations = deque() dataloader_durations = deque() fwd_durations = deque() sync_durations = deque() bwd_durations = deque() timing_log_freq = 100 def handle_sigterm(signum, frame): logging.error(f"Signal {signum} received, saving checkpoint and exiting...") ckpt_iter = "0" * (6 - len(str(total_steps))) + str(total_steps) save_path = Path(f"{cfg.experiment_path}/model_{ckpt_iter}.pth") state = AttributeDict( model=model, optimizer=optimizer, scheduler=scheduler, total_steps=total_steps + 1, ) fabric.save(save_path, state) logging.info(f"Saved checkpoint to {save_path}. Waiting for all ranks to finish...") fabric.barrier() logging.info(f"Calling sys.exit(0) now.") sys.exit(0) signal.signal(signal.SIGUSR1, handle_sigterm) signal.signal(signal.SIGTERM, handle_sigterm) logging.info(f"Registered signal handlers for SIGUSR1 and SIGTERM.") model.train() should_keep_training = True if cfg.trainer.num_steps > 0 else False total_batches_loaded = 0 total_batches_failed = 0 if fabric.global_rank == 0: tqdm_total_steps = tqdm( total=cfg.trainer.num_steps, desc=f"Total Training Progress (rank={fabric.global_rank})", unit="batch", initial=total_steps, position=0, ) threads = [] had_run_pretraining_epoch = cfg.modes.do_initial_static_pretrain and total_steps > len(pretraining_dataloader) logging.info(f"{total_steps=}, {epoch=}/{num_epochs}, {had_run_pretraining_epoch=}") while should_keep_training: epoch += 1 i_batch = -1 if cfg.modes.do_initial_static_pretrain and not had_run_pretraining_epoch: had_run_pretraining_epoch = True data_iter = iter(pretraining_dataloader) n_batches = len(pretraining_dataloader) else: data_iter = iter(train_loader) n_batches = len(train_loader) if fabric.global_rank == 0: tqdm_epoch = tqdm(total=n_batches, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch", position=1) while i_batch < n_batches: start_time_1 = time.time() logging.info(f"Gonna load batch {i_batch + 1}/{n_batches} (rank={fabric.global_rank})") try: batch = next(data_iter) except StopIteration: data_iter = iter(train_loader) n_batches = len(train_loader) batch = next(data_iter) batch, gotit = batch total_batches_loaded += 1 if cfg.modes.debugging_hotfix_datapoint_path is not None: logging.info(f"Debugging hotfix: loading batch from {cfg.modes.debugging_hotfix_datapoint_path}") batch = torch.load(cfg.modes.debugging_hotfix_datapoint_path, map_location="cuda:0") logging.info(f"Debugging hotfix: loaded batch {batch.seq_name} " f"with {len(batch.video)} views and {batch.video.shape[2]} frames") if not all(gotit): total_batches_failed += 1 logging.info(f"batch is None: " f"failed {total_batches_failed} / {total_batches_loaded} " f"({total_batches_failed / total_batches_loaded * 100:.2f}%) batches") continue i_batch += 1 dataclass_to_cuda_(batch) assert model.training start_time_2 = time.time() dataloader_duration = start_time_2 - start_time_1 logging.info(f"Datapoint: {batch.seq_name} (Waited for {dataloader_duration:>5.2f}s)") train_iters = cfg.trainer.train_iters if cfg.trainer.augment_train_iters: train_iters = augment_train_iters(train_iters, total_steps, cfg.trainer.augment_train_iters_warmup) optimizer.zero_grad() try: output = forward_batch_multi_view( batch=batch, model=model, cfg=cfg, step=total_steps, train_iters=train_iters, gamma=cfg.trainer.gamma, save_debug_logs=( ((total_steps % cfg.trainer.viz_freq) == (cfg.trainer.viz_freq - 1)) or (total_steps in [0, 10, 100, cfg.trainer.num_steps - 1]) ), debug_logs_path=os.path.join( cfg.experiment_path, f'forward_pass__train_step-{total_steps}_global_rank-{fabric.global_rank}' ), ) except Exception as e: logging.critical(f"Forward pass crashed at step {total_steps}: {e}") # Save current checkpoint save_path = Path(f"{cfg.experiment_path}/test_{total_steps:06d}.pth") state = AttributeDict( model=model, optimizer=optimizer, scheduler=scheduler, total_steps=total_steps + 1, ) fabric._strategy.checkpoint_io.save_checkpoint( checkpoint=fabric._strategy._convert_stateful_objects_in_state(_unwrap_objects(state), filter={}), path=save_path, ) logging.info(f"Saved crash checkpoint to {save_path}") # Save the batch batch_path = Path(f"{cfg.experiment_path}/crash_batch_step_{total_steps:06d}.pt") try: torch.save(batch, batch_path) logging.info(f"Saved crashing batch to {batch_path}") except Exception as batch_exc: logging.error(f"Failed to save crashing batch as .pt: {batch_exc}") raise # re-raise to crash the job after saving artifacts loss = torch.tensor(0.0).cuda() for k, v in output.items(): if k == "metrics": for metric_name, metric_value in v.items(): tb_writer.add_scalar(metric_name, metric_value, total_steps) elif "loss" in v: loss += v["loss"] tb_writer.add_scalar(f"live_{k}_loss", v["loss"].item(), total_steps) else: raise ValueError(f"Unknown key {k} in output") start_time_3 = time.time() fwd_duration = start_time_3 - start_time_2 fabric.barrier() start_time_4 = time.time() sync_duration = start_time_4 - start_time_3 fabric.backward(loss) # Log a limited number of grad + optimizer state pairs, also log current learning rate if (total_steps <= 10) or (total_steps % cfg.trainer.viz_freq == 0): log_limit = 5 logged = 0 prefix = f"[DEBUG] [RANK={fabric.global_rank:03d}]" logging.info(f"{prefix} RNG seed: {torch.initial_seed()}") logging.info(f"{prefix} Step={total_steps} – Gradients and Optimizer State") for name, param in model.named_parameters(): if param.grad is not None and param in optimizer.state: state = optimizer.state[param] exp_avg_norm = state['exp_avg'].norm().item() if 'exp_avg' in state else float('nan') exp_avg_sq_norm = state['exp_avg_sq'].norm().item() if 'exp_avg_sq' in state else float('nan') grad_norm = param.grad.norm().item() logging.info( f"{prefix} Param: {name:<60s} | " f"grad_norm={grad_norm:>14.9f} | " f"exp_avg_norm={exp_avg_norm:>14.9f} | " f"exp_avg_sq_norm={exp_avg_sq_norm:>14.9f}" ) logged += 1 if logged >= log_limit: break for name, param in model.named_parameters(): if param.grad_fn: print(f"{prefix} {name} grad_fn: {param.grad_fn}") logging.info(f"{prefix} LR at step {total_steps}: {scheduler.get_last_lr()}") fabric.clip_gradients(model, optimizer, clip_val=cfg.trainer.grad_clip) optimizer.step() scheduler.step() start_time_5 = time.time() bwd_duration = start_time_5 - start_time_4 if fabric.global_rank == 0: if (total_steps % cfg.trainer.viz_freq == 0) or ( total_steps == cfg.trainer.num_steps - 1) or total_steps in [0, 10, 100]: logging.info(f"Creating training viz logs (rank: {fabric.global_rank}, step: {total_steps})") video = batch.video.clone().cpu() video_depth = batch.videodepth.clone().cpu() gt_viz, vector_colors = visualizer.visualize( video=video, video_depth=video_depth, tracks=batch.trajectory.clone().cpu(), visibility=batch.visibility.clone().cpu(), query_frame=batch.query_points_3d[..., 0].long().clone().cpu(), filename="train_gt_traj", writer=tb_writer, step=total_steps, save_video=False, ) pred_viz, _ = visualizer.visualize( video=video, video_depth=video_depth, tracks=output["flow"]["predictions"][None].cpu(), visibility=(output["visibility"]["predictions"][None] > 0.5).cpu(), query_frame=batch.query_points_3d[..., 0].long().clone().cpu(), filename="train_pred_traj", writer=tb_writer, step=total_steps, save_video=False, ) viz = torch.cat([gt_viz[..., :gt_viz.shape[-1] // 2], pred_viz], dim=-1) thread = threading.Thread( target=Visualizer.save_video, args=(viz, visualizer.save_dir, f"train", tb_writer, visualizer.fps, total_steps) ) thread.start() threads.append(thread) if len(output) > 1: tb_writer.add_scalar(f"live_total_loss", loss.item(), total_steps) tb_writer.add_scalar(f"learning_rate", optimizer.param_groups[0]["lr"], total_steps) if total_steps % cfg.trainer.save_ckpt_freq == 0: ckpt_iter = "0" * (6 - len(str(total_steps))) + str(total_steps) save_path = Path(f"{cfg.experiment_path}/model_{ckpt_iter}.pth") logging.info(f"Saving file {save_path}") state = AttributeDict( model=model, optimizer=optimizer, scheduler=scheduler, total_steps=total_steps + 1, ) fabric.save(save_path, state) if total_steps % cfg.trainer.eval_freq == 0 and total_steps > 1: run_test_eval(cfg, evaluator, model, eval_dataloaders, tb_writer, total_steps) fabric.barrier() total_steps += 1 if fabric.global_rank == 0: tqdm_epoch.update(1) tqdm_total_steps.update(1) tqdm_epoch.set_postfix( loss=loss.item(), lr=optimizer.param_groups[0]["lr"], train_iters=cfg.trainer.train_iters, gamma=cfg.trainer.gamma, seq_name=batch.seq_name, ) total_duration = time.time() - start_time_1 logging.info( f"[timing:{total_steps:06d}] " f"Total: {total_duration:>6.2f}s | " f"Data: {dataloader_duration:>6.2f}s | " f"Fwd: {fwd_duration:>6.2f}s | " f"Sync: {sync_duration:>6.2f}s | " f"Bwd: {bwd_duration:>6.2f}s | " ) if fabric.global_rank == 0: dataloader_durations.append(dataloader_duration) fwd_durations.append(fwd_duration) sync_durations.append(sync_duration) bwd_durations.append(bwd_duration) total_durations.append(total_duration) tb_writer.add_scalar(f"timing/step", total_duration, total_steps) tb_writer.add_scalar(f"timing/only_fwd", fwd_durations[-1], total_steps) tb_writer.add_scalar(f"timing/only_sync", sync_durations[-1], total_steps) tb_writer.add_scalar(f"timing/only_bwd", bwd_durations[-1], total_steps) tb_writer.add_scalar(f"timing/only_dataloader", dataloader_duration, total_steps) if len(total_durations) >= timing_log_freq: total_durations_np = np.array(total_durations) fwd_durations_np = np.array(fwd_durations) sync_durations_np = np.array(sync_durations) bwd_durations_np = np.array(bwd_durations) dataloader_durations_np = np.array(dataloader_durations) total_duration_mean = np.mean(total_durations_np) fwd_duration_mean = np.mean(fwd_durations_np) sync_duration_mean = np.mean(sync_durations_np) bwd_duration_mean = np.mean(bwd_durations_np) dataloader_duration_mean = np.mean(dataloader_durations_np) total_duration_median = np.median(total_durations_np) fwd_duration_median = np.median(fwd_durations_np) sync_duration_median = np.median(sync_durations_np) bwd_duration_median = np.median(bwd_durations_np) dataloader_duration_median = np.median(dataloader_durations_np) total_duration_std = np.std(total_durations_np) fwd_duration_std = np.std(fwd_durations_np) sync_duration_std = np.std(sync_durations_np) bwd_duration_std = np.std(bwd_durations_np) dataloader_duration_std = np.std(dataloader_durations_np) tb_writer.add_scalar("timing/step_mean", total_duration_mean, total_steps) tb_writer.add_scalar("timing/step_median", total_duration_median, total_steps) tb_writer.add_scalar("timing/only_fwd_mean", fwd_duration_mean, total_steps) tb_writer.add_scalar("timing/only_fwd_median", fwd_duration_median, total_steps) tb_writer.add_scalar("timing/only_sync_mean", sync_duration_mean, total_steps) tb_writer.add_scalar("timing/only_sync_median", sync_duration_median, total_steps) tb_writer.add_scalar("timing/only_bwd_mean", bwd_duration_mean, total_steps) tb_writer.add_scalar("timing/only_bwd_median", bwd_duration_median, total_steps) tb_writer.add_scalar("timing/only_dataloader_mean", dataloader_duration_mean, total_steps) tb_writer.add_scalar("timing/only_dataloader_median", dataloader_duration_median, total_steps) logging.info( f"[timing:total] " f"Mean: {total_duration_mean:>6.2f}s | " f"Median: {total_duration_median:>6.2f}s | " f"Std: {total_duration_std:6.2f}s" ) logging.info( f"[timing:fwd] " f"Mean: {fwd_duration_mean:>6.2f}s | " f"Median: {fwd_duration_median:>6.2f}s | " f"Std: {fwd_duration_std:6.2f}s" ) logging.info( f"[timing:sync] " f"Mean: {sync_duration_mean:>6.2f}s | " f"Median: {sync_duration_median:>6.2f}s | " f"Std: {sync_duration_std:6.2f}s" ) logging.info( f"[timing:bwd] " f"Mean: {bwd_duration_mean:>6.2f}s | " f"Median: {bwd_duration_median:>6.2f}s | " f"Std: {bwd_duration_std:6.2f}s" ) logging.info( f"[timing:datal] " f"Mean: {dataloader_duration_mean:>6.2f}s | " f"Median: {dataloader_duration_median:>6.2f}s | " f"Std: {dataloader_duration_std:6.2f}s" ) total_durations.clear() fwd_durations.clear() sync_durations.clear() bwd_durations.clear() dataloader_durations.clear() if total_steps > cfg.trainer.num_steps: should_keep_training = False break if fabric.global_rank == 0: tqdm_epoch.close() if fabric.global_rank == 0: tqdm_total_steps.close() logging.info("FINISHED TRAINING") save_path = f"{cfg.experiment_path}/model_final.pth" logging.info(f"Saving file {save_path}") state = AttributeDict( model=model, optimizer=optimizer, scheduler=scheduler, total_steps=total_steps, ) fabric.save(save_path, state) run_test_eval(cfg, evaluator, model, eval_dataloaders, tb_writer, total_steps) for thread in threads: thread.join() tb_writer.flush() tb_writer.close() fabric.barrier() if __name__ == "__main__": main() ================================================ FILE: mvtracker/cli/utils/__init__.py ================================================ from .pylogger import RankedLogger from .rich_utils import enforce_tags, print_config_tree from .helpers import extras, get_metric_value, task_wrapper ================================================ FILE: mvtracker/cli/utils/helpers.py ================================================ import faulthandler import warnings from functools import wraps from importlib.util import find_spec from typing import Any, Callable, Dict, Optional, Tuple import wandb from omegaconf import DictConfig from mvtracker.cli.utils import pylogger, rich_utils log = pylogger.RankedLogger(__name__, rank_zero_only=True) def extras(cfg: DictConfig) -> None: """Applies optional utilities before the task is started. Utilities: - Ignoring python warnings - Setting tags from command line - Rich config printing :param cfg: A DictConfig object containing the config tree. """ # return if no `extras` config if not cfg.get("extras"): log.warning("Extras config not found! ") return # disable python warnings if cfg.extras.get("ignore_warnings"): log.info("Disabling python warnings! ") warnings.filterwarnings("ignore") # prompt user to input tags from command line if none are provided in the config if cfg.extras.get("enforce_tags"): log.info("Enforcing tags! ") rich_utils.enforce_tags(cfg, save_to_file=True) # pretty print config tree using Rich library if cfg.extras.get("print_config"): log.info("Printing config tree with Rich! ") rich_utils.print_config_tree(cfg, print_order=None, resolve=True, save_to_file=True) if cfg.extras.get("enable_faulthandler_traceback"): log.info("Enabling faulthandler timeouts!") faulthandler.dump_traceback_later(timeout=cfg.extras.faulthandler_traceback_timeout, repeat=True) def task_wrapper(task_func: Callable) -> Callable: """Optional decorator that controls the failure behavior when executing the task function. This wrapper can be used to: - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) - save the exception to a `.log` file - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) - etc. (adjust depending on your needs) Example: ``` @utils.task_wrapper def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: ... return metric_dict, object_dict ``` :param task_func: The task function to be wrapped. :return: The wrapped task function. """ def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: # execute the task try: metric_dict, object_dict = task_func(cfg=cfg) # things to do if exception occurs except Exception as ex: # save exception to `.log` file log.exception("") # some hyperparameter combinations might be invalid or cause out-of-memory errors # so when using hparam search plugins like Optuna, you might want to disable # raising the below exception to avoid multirun failure raise ex # things to always do after either success or exception finally: # display output dir path in terminal log.info(f"Output dir: {cfg.paths.output_dir}") # always close wandb run (even if exception occurs so multirun won't fail) if find_spec("wandb"): # check if wandb is installed import wandb if wandb.run: log.info("Closing wandb!") wandb.finish() return metric_dict, object_dict return wrap def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]: """Safely retrieves value of the metric logged in LightningModule. :param metric_dict: A dict containing metric values. :param metric_name: If provided, the name of the metric to retrieve. :return: If a metric name was provided, the value of the metric. """ if not metric_name: log.info("Metric name is None! Skipping metric value retrieval...") return None if metric_name not in metric_dict: raise Exception( f"Metric value not found! \n" "Make sure metric name logged in LightningModule is correct!\n" "Make sure `optimized_metric` name in `hparams_search` config is correct!" ) metric_value = metric_dict[metric_name].item() log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") return metric_value def maybe_close_wandb(fn: Callable) -> Callable: @wraps(fn) def wrapper(cfg, *args, **kwargs): try: return fn(cfg, *args, **kwargs) finally: if wandb.run is not None: wandb.finish() return wrapper ================================================ FILE: mvtracker/cli/utils/pylogger.py ================================================ import logging from typing import Mapping, Optional from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only class RankedLogger(logging.LoggerAdapter): """A multi-GPU-friendly python command line logger.""" def __init__( self, name: str = __name__, rank_zero_only: bool = False, extra: Optional[Mapping[str, object]] = None, ) -> None: """Initializes a multi-GPU-friendly python command line logger that logs on all processes with their rank prefixed in the log message. :param name: The name of the logger. Default is ``__name__``. :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. """ logger = logging.getLogger(name) super().__init__(logger=logger, extra=extra) self.rank_zero_only = rank_zero_only def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None: """Delegate a log call to the underlying logger, after prefixing its message with the rank of the process it's being logged from. If `'rank'` is provided, then the log will only occur on that rank/process. :param level: The level to log at. Look at `logging.__init__.py` for more information. :param msg: The message to log. :param rank: The rank to log at. :param args: Additional args to pass to the underlying logging function. :param kwargs: Any additional keyword args to pass to the underlying logging function. """ if self.isEnabledFor(level): msg, kwargs = self.process(msg, kwargs) current_rank = getattr(rank_zero_only, "rank", None) if current_rank is None: raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") msg = rank_prefixed_message(msg, current_rank) if self.rank_zero_only: if current_rank == 0: self.logger.log(level, msg, *args, **kwargs) else: if rank is None: self.logger.log(level, msg, *args, **kwargs) elif current_rank == rank: self.logger.log(level, msg, *args, **kwargs) ================================================ FILE: mvtracker/cli/utils/rich_utils.py ================================================ from pathlib import Path from typing import Sequence, Optional import rich import rich.syntax import rich.tree from hydra.core.hydra_config import HydraConfig from lightning_utilities.core.rank_zero import rank_zero_only from omegaconf import DictConfig, OmegaConf, open_dict from rich.prompt import Prompt from mvtracker.cli.utils import pylogger log = pylogger.RankedLogger(__name__, rank_zero_only=True) @rank_zero_only def print_config_tree( cfg: DictConfig, print_order: Optional[Sequence[str]] = ( "experiment_paths", "model", "predictor_settings", ), resolve: bool = False, save_to_file: bool = False, ) -> None: """Prints the contents of a DictConfig as a tree structure using the Rich library. :param cfg: A DictConfig composed by Hydra. :param print_order: Determines in what order config components are printed. :param resolve: Whether to resolve reference fields of DictConfig. :param save_to_file: Whether to export config to the hydra output folder. """ style = "italic cyan" tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) queue = [] # add fields from `print_order` to queue if print_order is not None: for field in print_order: queue.append(field) if field in cfg else log.warning( f"Field '{field}' not found in config. Skipping '{field}' config printing..." ) # add all the other fields to queue (not specified in `print_order`) for field in cfg: if field not in queue: queue.append(field) # generate config tree from queue for field in queue: branch = tree.add(field, style=style, guide_style=style) config_group = cfg[field] if isinstance(config_group, DictConfig): branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) else: branch_content = str(config_group) branch.add(rich.syntax.Syntax(branch_content, "yaml")) # print config tree rich.print(tree) # save config tree to file if save_to_file: with open(Path(HydraConfig.get().runtime.output_dir, "config_tree.log"), "w") as file: rich.print(tree, file=file) @rank_zero_only def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: """Prompts user to input tags from command line if no tags are provided in config. :param cfg: A DictConfig composed by Hydra. :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. """ if not cfg.get("tags"): if "id" in HydraConfig().cfg.hydra.job: raise ValueError("Specify tags before launching a multirun!") log.warning("No tags provided in config. Prompting user to input tags...") tags = Prompt.ask("Enter a list of comma separated tags", default="dev") tags = [t.strip() for t in tags.split(",") if t != ""] with open_dict(cfg): cfg.tags = tags log.info(f"Tags: {cfg.tags}") if save_to_file: with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: rich.print(cfg.tags, file=file) ================================================ FILE: mvtracker/datasets/__init__.py ================================================ from .dexycb_multiview_dataset import DexYCBMultiViewDataset from .kubric_multiview_dataset import KubricMultiViewDataset from .panoptic_studio_multiview_dataset import PanopticStudioMultiViewDataset from .tap_vid_datasets import TapVidDataset ================================================ FILE: mvtracker/datasets/dexycb_multiview_dataset.py ================================================ import logging import os import pathlib import re import time import warnings import cv2 import matplotlib import numpy as np import pandas as pd import torch import torch.nn.functional as F from scipy.spatial.transform import Rotation as R from torch.utils.data import Dataset from mvtracker.datasets.utils import Datapoint, transform_scene class DexYCBMultiViewDataset(Dataset): @staticmethod def from_name(dataset_name: str, dataset_root: str): """ Examples of datasets supported by this factory method: - "dex-ycb-multiview", - "dex-ycb-multiview-single", - "dex-ycb-multiview-removehand", - "dex-ycb-multiview-duster0123", - "dex-ycb-multiview-duster0123cleaned", - "dex-ycb-multiview-duster0123cleaned-views0123", - "dex-ycb-multiview-duster0123cleaned-views0123-novelviews45", - "dex-ycb-multiview-duster0123cleaned-views0123-novelviews45-removehand", - "dex-ycb-multiview-duster0123cleaned-views0123-novelviews45-removehand-single", - "dex-ycb-multiview-duster0123cleaned-views0123-novelviews45-removehand-2dpt-single", - "dex-ycb-multiview-duster0123cleaned-views0123-novelviews45-removehand-2dpt-single-cached", """ # Parse the dataset name, chunk by chunk non_parsed = dataset_name.replace("dex-ycb-multiview", "", 1) if non_parsed.startswith("-duster"): match = re.match(r"-duster(\d+)(cleaned)?", non_parsed) assert match is not None duster_views = list(map(int, match.group(1))) use_duster = True use_duster_cleaned = match.group(2) is not None non_parsed = non_parsed.replace(match.group(0), "", 1) else: use_duster = False use_duster_cleaned = False duster_views = None if non_parsed.startswith("-views"): match = re.match(r"-views(\d+)", non_parsed) assert match is not None views = list(map(int, match.group(1))) if duster_views is not None: assert all(v in duster_views for v in views) non_parsed = non_parsed.replace(match.group(0), "", 1) else: views = duster_views if non_parsed.startswith("-novelviews"): match = re.match(r"-novelviews(\d+)", non_parsed) assert match is not None novel_views = list(map(int, match.group(1))) non_parsed = non_parsed.replace(match.group(0), "", 1) else: novel_views = None if non_parsed.startswith("-removehand"): remove_hand = True non_parsed = non_parsed.replace("-removehand", "", 1) else: remove_hand = False if non_parsed.startswith("-single"): single_point = True non_parsed = non_parsed.replace("-single", "", 1) else: single_point = False if non_parsed.startswith("-2dpt"): eval_2dpt = True non_parsed = non_parsed.replace("-2dpt", "", 1) else: eval_2dpt = False if non_parsed.startswith("-cached"): use_cached_tracks = True non_parsed = non_parsed.replace("-cached", "", 1) else: use_cached_tracks = False assert non_parsed == "", f"Unparsed part of the dataset name: {non_parsed}" if views is None and duster_views is None: views = [0, 1, 2, 3] # Make the legacy "dex-ycb-multiview" name take the first 4 views (not all 8) return DexYCBMultiViewDataset( data_root=os.path.join(dataset_root, "dex-ycb-multiview"), views_to_return=views, novel_views=novel_views, remove_hand=remove_hand, use_duster_depths=use_duster, duster_views=duster_views, clean_duster_depths=use_duster_cleaned, traj_per_sample=384, seed=72, max_videos=10, perform_sanity_checks=False, use_cached_tracks=use_cached_tracks, ) def __init__( self, data_root, remove_hand=False, views_to_return=None, novel_views=None, use_duster_depths=False, clean_duster_depths=False, duster_views=None, traj_per_sample=768, seed=None, max_videos=None, perform_sanity_checks=False, use_cached_tracks=False, ): super().__init__() self.data_root = data_root self.remove_hand = remove_hand self.views_to_return = views_to_return self.novel_views = novel_views self.use_duster_depths = use_duster_depths self.clean_duster_depths = clean_duster_depths self.duster_views = duster_views self.traj_per_sample = traj_per_sample self.seed = seed self.perform_sanity_checks = perform_sanity_checks self.use_cached_tracks = use_cached_tracks self.cache_name = self._cache_key() self.seq_names = self._get_sequence_names(max_videos) self.getitem_calls = 0 def _get_sequence_names(self, max_videos): """ Fetch all valid sequence names from the dataset root. Args: max_videos (int): Limit the number of sequences to load. Returns: List[str]: Sorted list of valid sequence names. """ seq_names = [ fname for fname in os.listdir(self.data_root) if os.path.isdir(os.path.join(self.data_root, fname)) and not fname.startswith(".") and not fname.startswith("_") ] seq_names = sorted(seq_names) valid_seqs = [] for seq_name in seq_names: scene_path = os.path.join(self.data_root, seq_name) view_folders = [ d for d in os.listdir(scene_path) if os.path.isdir(os.path.join(scene_path, d)) and d.startswith("view_") ] if not view_folders: warnings.warn(f"Skipping {scene_path} because it has no views.") continue valid_seqs.append(seq_name) if max_videos is not None: valid_seqs = valid_seqs[:max_videos] print(f"Using {len(valid_seqs)} videos from {self.data_root}") return valid_seqs def _cache_key(self): name = f"cachedtracks--seed{self.seed}" if self.views_to_return is not None: name += f"-views{'_'.join(map(str, self.views_to_return))}" if self.traj_per_sample is not None: name += f"-n{self.traj_per_sample}" if self.remove_hand: name += "-removehand" return name + "--v1" # bump this if you change the selection policy def __len__(self): return len(self.seq_names) def __getitem__(self, index): start_time = time.time() sample = self._getitem_helper(index) self.getitem_calls += 1 if self.getitem_calls < 10: print(f"Loading {index:>06d} took {time.time() - start_time:.3f} sec. Getitem calls: {self.getitem_calls}") return sample, True def _getitem_helper(self, index): """ Helper function to load a single sample. Args: index (int): Index of the sample to load. Returns: CoTrackerData, bool: Sample data and success flag. """ if self.seed is None: seed = torch.randint(0, 2 ** 32 - 1, (1,)).item() else: seed = self.seed rnd_torch = torch.Generator().manual_seed(seed) rnd_np = np.random.RandomState(seed=seed) datapoint_path = os.path.join(self.data_root, self.seq_names[index]) views = {} view_folders = sorted([f for f in os.listdir(datapoint_path) if f.startswith("view_")]) if self.views_to_return is not None: views_to_return = self.views_to_return else: views_to_return = sorted(list(range(len(view_folders)))) views_to_load = views_to_return.copy() if self.novel_views is not None: views_to_load = list(set(views_to_load + self.novel_views)) for v in views_to_load: view_path = os.path.join(datapoint_path, view_folders[v]) # Load RGB images rgb_folder = os.path.join(view_path, "rgb") rgb_files = sorted(os.listdir(rgb_folder)) rgb_images = [cv2.imread(os.path.join(rgb_folder, f))[:, :, ::-1] for f in rgb_files] # Load depth maps depth_folder = os.path.join(view_path, "depth") depth_files = sorted(os.listdir(depth_folder)) depth_images = [cv2.imread(os.path.join(depth_folder, f), cv2.IMREAD_ANYDEPTH) for f in depth_files] # Load camera parameters camera_params_file = os.path.join(view_path, "intrinsics_extrinsics.npz") params = np.load(camera_params_file) intrinsics = params["intrinsics"][:3, :3] # Extract K extrinsics = params["extrinsics"][:3, :] # Extract R|t (world to camera) views[v] = { "rgb": np.stack(rgb_images), "depth": np.stack(depth_images), "intrinsics": intrinsics, "extrinsics": extrinsics, } rgbs = np.stack([views[v]["rgb"] for v in views_to_return]) n_views, n_frames, h, w, _ = rgbs.shape depths = np.stack([views[v]["depth"] for v in views_to_return])[..., None].astype(np.float32) / 1000 intrs = np.stack([views[v]["intrinsics"] for v in views_to_return])[:, None, :, :].repeat(n_frames, axis=1) extrs = np.stack([views[v]["extrinsics"] for v in views_to_return])[:, None, :, :].repeat(n_frames, axis=1) # Load novel views if they exist novel_rgbs = None novel_intrs = None novel_extrs = None if self.novel_views is not None: novel_rgbs = np.stack([views[v]["rgb"] for v in self.novel_views]) novel_intrs = np.stack([views[v]["intrinsics"] for v in self.novel_views])[:, None, :, :].repeat(n_frames, axis=1) novel_extrs = np.stack([views[v]["extrinsics"] for v in self.novel_views])[:, None, :, :].repeat(n_frames, axis=1) # Load Duster's features and estimated depths if they exist duster_views = self.duster_views if self.duster_views is not None else views_to_return duster_views_str = ''.join(str(v) for v in duster_views) duster_root = pathlib.Path(datapoint_path) / f'duster-views-{duster_views_str}' if self.use_duster_depths: assert duster_root.exists() and (duster_root / f"3d_model__{n_frames - 1:05d}__scene.npz").exists(), \ f"Duster root {duster_root} does not exist." feats = None feat_dim = None feat_stride = None depth_confs = None if duster_root.exists() and (duster_root / f"3d_model__{n_frames - 1:05d}__scene.npz").exists(): duster_depths = [] duster_confs = [] duster_feats = [] for frame_idx in range(n_frames): scene = np.load(duster_root / f"3d_model__{frame_idx:05d}__scene.npz") duster_depth = torch.from_numpy(scene["depths"]) duster_conf = torch.from_numpy(scene["confs"]) duster_msk = torch.from_numpy(scene["cleaned_mask"]) if self.clean_duster_depths: duster_depth = duster_depth * duster_msk duster_depth = F.interpolate(duster_depth[:, None], (h, w), mode='nearest') duster_depths.append(duster_depth[:, 0, :, :, None]) duster_conf = F.interpolate(duster_conf[:, None], (h, w), mode='nearest') duster_confs.append(duster_conf[:, 0, :, :, None]) if "feats" in scene: duster_feats.append(torch.from_numpy(scene["feats"])) duster_depths = torch.stack(duster_depths, dim=1).numpy() duster_confs = torch.stack(duster_confs, dim=1).numpy() if duster_feats: feats = torch.stack(duster_feats, dim=1).numpy() # Extract the correct views assert duster_depths.shape[0] == len(duster_views) duster_depths = duster_depths[[duster_views.index(v) for v in views_to_return]] duster_confs = duster_confs[[duster_views.index(v) for v in views_to_return]] if feats is not None: assert feats.shape[0] == len(duster_views) feats = feats[[duster_views.index(v) for v in views_to_return]] # Reshape the features if feats is not None: assert feats.ndim == 4 assert feats.shape[0] == n_views assert feats.shape[1] == n_frames feat_stride = np.round(np.sqrt(h * w / feats.shape[2])).astype(int) feat_dim = feats.shape[3] feats = feats.reshape(n_views, n_frames, h // feat_stride, w // feat_stride, feat_dim) # Replace the depths with the Duster depths, if configured so if self.use_duster_depths: depths = duster_depths depth_confs = duster_confs tracks_3d_file = os.path.join(datapoint_path, "tracks_3d.npz") tracks_3d_data = np.load(tracks_3d_file, allow_pickle=True) traj3d_world = tracks_3d_data["tracks_3d"] traj2d = tracks_3d_data["tracks_2d"][views_to_return] traj2d_w_z = np.concatenate((traj2d, tracks_3d_data["tracks_2d_z"][views_to_return][:, :, :, None]), axis=-1) visibility = tracks_3d_data["tracks_2d_visibilities"][views_to_return] # Label the trajectories according to: 0: hand, 1: moving ycb object, 2: static ycb objects object_id_to_name = tracks_3d_data["object_id_to_name"].item() traj_object_id = tracks_3d_data["object_ids"] for object_name in object_id_to_name.values(): assert object_name == "mano-right-hand" or object_name.startswith("ycb") avg_movement_per_object_id = {} for object_id in np.unique(traj_object_id): object_mask = traj_object_id == object_id object_traj = traj3d_world[:, object_mask] avg_movement_per_object_id[object_id] = np.linalg.norm(object_traj[1:] - object_traj[:-1], axis=-1).mean() hand_id = {v: k for k, v in object_id_to_name.items()}["mano-right-hand"] dynamic_ycb_object_ids = [k for k, v in avg_movement_per_object_id.items() if v >= 1e-4 and k != hand_id] assert len(dynamic_ycb_object_ids) == 1 dynamic_ycb_object_id = dynamic_ycb_object_ids[0] static_ycb_object_ids = [k for k, v in avg_movement_per_object_id.items() if v < 1e-4 and k != hand_id] assert 1 + 1 + len(static_ycb_object_ids) == len(object_id_to_name) # remap object ids to 0: hand, 1: dynamic ycb object, 2: static ycb objects traj_object_id = ( 0 * (traj_object_id == hand_id) + 1 * (traj_object_id == dynamic_ycb_object_id) + 2 * np.isin(traj_object_id, static_ycb_object_ids) ) if self.remove_hand: traj3d_world = traj3d_world[:, traj_object_id > 0] traj2d = traj2d[:, :, traj_object_id > 0] traj2d_w_z = traj2d_w_z[:, :, traj_object_id > 0] visibility = visibility[:, :, traj_object_id > 0] traj_object_id = traj_object_id[traj_object_id > 0] n_tracks = traj3d_world.shape[1] assert rgbs.shape == (n_views, n_frames, h, w, 3) assert depths.shape == (n_views, n_frames, h, w, 1) assert depth_confs is None or depth_confs.shape == (n_views, n_frames, h, w, 1) assert feats is None or feats.shape == (n_views, n_frames, h // feat_stride, w // feat_stride, feat_dim) assert intrs.shape == (n_views, n_frames, 3, 3) assert extrs.shape == (n_views, n_frames, 3, 4) assert traj2d.shape == (n_views, n_frames, n_tracks, 2) assert visibility.shape == (n_views, n_frames, n_tracks) assert traj3d_world.shape == (n_frames, n_tracks, 3) assert traj_object_id.shape == (n_tracks,) if novel_rgbs is not None: assert novel_rgbs.shape == (len(self.novel_views), n_frames, h, w, 3) assert novel_intrs.shape == (len(self.novel_views), n_frames, 3, 3) assert novel_extrs.shape == (len(self.novel_views), n_frames, 3, 4) # Make sure our intrinsics and extrinsics work correctly point_3d_world = traj3d_world point_4d_world_homo = np.concatenate([point_3d_world, np.ones_like(point_3d_world[..., :1])], axis=-1) point_3d_camera = np.einsum('ABij,BCj->ABCi', extrs, point_4d_world_homo) if self.perform_sanity_checks: point_2d_pixel_homo = np.einsum('ABij,ABCj->ABCi', intrs, point_3d_camera) point_2d_pixel = point_2d_pixel_homo[..., :2] / point_2d_pixel_homo[..., 2:] point_2d_pixel_gt = traj2d point_2d_pixel_no_nan = np.nan_to_num(point_2d_pixel, nan=0) point_2d_pixel_gt_no_nan = np.nan_to_num(point_2d_pixel_gt, nan=0) assert np.allclose(point_2d_pixel_no_nan[0, :, 0, :], point_2d_pixel_no_nan[0, :, 0, :], atol=1), f"Proj. failed" assert np.allclose(point_2d_pixel_gt_no_nan, point_2d_pixel_gt_no_nan, atol=1), f"Point projection failed" assert np.allclose(point_3d_camera[..., 2:], traj2d_w_z[..., -1:], atol=1) # Convert everything to torch tensors rgbs = torch.from_numpy(rgbs).permute(0, 1, 4, 2, 3).float() depths = torch.from_numpy(depths).permute(0, 1, 4, 2, 3).float() depth_confs = torch.from_numpy(depth_confs).permute(0, 1, 4, 2, 3).float() if depth_confs is not None else None feats = torch.from_numpy(feats).permute(0, 1, 4, 2, 3).float() if feats is not None else None intrs = torch.from_numpy(intrs).float() extrs = torch.from_numpy(extrs).float() traj2d = torch.from_numpy(traj2d) traj2d_w_z = torch.from_numpy(traj2d_w_z) traj3d_world = torch.from_numpy(traj3d_world) traj_object_id = torch.from_numpy(traj_object_id) visibility = torch.from_numpy(visibility) if novel_rgbs is not None: novel_rgbs = torch.from_numpy(novel_rgbs).permute(0, 1, 4, 2, 3).float() novel_intrs = torch.from_numpy(novel_intrs).float() novel_extrs = torch.from_numpy(novel_extrs).float() # Track selection cache_root = os.path.join(self.data_root, self.seq_names[index], "cache") os.makedirs(cache_root, exist_ok=True) cache_file = os.path.join(cache_root, f"{self.cache_name}.npz") # Check if we can use cached tracks use_cache = bool(self.use_cached_tracks) and os.path.isfile(cache_file) if use_cache: cache = np.load(cache_file) inds_sampled = torch.from_numpy(cache["track_indices"]) traj2d_w_z = torch.from_numpy(cache["traj2d_w_z"]) traj3d_world = torch.from_numpy(cache["traj3d_world"]) traj_object_id = torch.from_numpy(cache["traj_object_id"]) visibility = torch.from_numpy(cache["visibility"]) valids = torch.from_numpy(cache["valids"]) query_points = torch.from_numpy(cache["query_points"]) # Otherwise, sample the tracks and create query points else: # Force query points on hand to appear later # This avoids querying when the GT hand reconstruction is severely lacking # Identify tracks that are invisible in the first frame across all views (as they are probably on the hand) invisible_at_first_frame = visibility[:, 0, :] == 0 invisible_at_first_frame = invisible_at_first_frame.unsqueeze(1).expand(-1, 5, -1) # Set visibility to 0 for the first 5 frames where the first frame was invisible visibility[:, 0:5, :] *= ~invisible_at_first_frame # Keep visible ones, set others to 0 # Sample the points to track visible_for_at_least_two_frames = visibility.any(0).sum(0) >= 2 hectic_visibility = ((visibility[:, :-1] & ~visibility[:, 1:]).sum(0) >= 3).any(0) valid_tracks = visible_for_at_least_two_frames & ~hectic_visibility valid_tracks = valid_tracks.nonzero(as_tuple=False)[:, 0] point_inds = torch.randperm(len(valid_tracks), generator=rnd_torch) traj_per_sample = self.traj_per_sample if self.traj_per_sample is not None else len(point_inds) assert len(point_inds) >= traj_per_sample point_inds = point_inds[:traj_per_sample] inds_sampled = valid_tracks[point_inds] n_tracks = len(inds_sampled) traj2d = traj2d[:, :, inds_sampled].float() traj2d_w_z = traj2d_w_z[:, :, inds_sampled].float() traj3d_world = traj3d_world[:, inds_sampled].float() traj_object_id = traj_object_id[inds_sampled] visibility = visibility[:, :, inds_sampled] valids = ~torch.isnan(traj2d).any(dim=-1).any(dim=0) # Create the query points gt_visibilities_any_view = visibility.any(dim=0) assert (gt_visibilities_any_view.sum(dim=0) >= 2).all(), "All points should be visible in least two frames." last_visible_index = (torch.arange(n_frames).unsqueeze(-1) * gt_visibilities_any_view).max(0).values assert gt_visibilities_any_view[last_visible_index[None, :], torch.arange(n_tracks)].all() gt_visibilities_any_view[last_visible_index[None, :], torch.arange(n_tracks)] = False assert (gt_visibilities_any_view.sum(dim=0) >= 1).all() n_non_first_point_appearance_queries = n_tracks // 4 n_first_point_appearance_queries = n_tracks - n_non_first_point_appearance_queries first_point_appearances = torch.argmax( gt_visibilities_any_view[..., -n_first_point_appearance_queries:].float(), dim=0) non_first_point_appearances = first_point_appearances.new_zeros((n_non_first_point_appearance_queries,)) for track_idx in range(n_tracks)[:n_non_first_point_appearance_queries]: # Randomly take a timestep where the point is visible non_zero_timesteps = torch.nonzero(gt_visibilities_any_view[:, track_idx] == 1) random_timestep = non_zero_timesteps[rnd_np.randint(len(non_zero_timesteps))].item() non_first_point_appearances[track_idx] = random_timestep query_points_t = torch.cat([non_first_point_appearances, first_point_appearances], dim=0) query_points_xyz_worldspace = traj3d_world[query_points_t, torch.arange(n_tracks)] query_points = torch.cat([query_points_t[:, None], query_points_xyz_worldspace], dim=1) assert gt_visibilities_any_view[query_points_t, torch.arange(n_tracks)].all() # Replace nans with zeros traj2d[torch.isnan(traj2d)] = 0 traj2d_w_z[torch.isnan(traj2d_w_z)] = 0 traj3d_world[torch.isnan(traj3d_world)] = 0 assert torch.isnan(visibility).sum() == 0 # Cache the selected tracks and query points if self.use_cached_tracks: logging.warn(f"Caching tracks for {self.seq_names[index]} at {os.path.abspath(cache_file)}") np.savez_compressed( cache_file, track_indices=inds_sampled.numpy(), traj2d_w_z=traj2d_w_z.numpy(), traj3d_world=traj3d_world.numpy(), traj_object_id=traj_object_id.numpy(), visibility=visibility.numpy(), valids=valids.numpy(), query_points=query_points.numpy(), ) # Normalize the scene to be similar to Kubric's scene scale = 6 rot_x = R.from_euler('x', 220, degrees=True).as_matrix() rot_y = R.from_euler('y', 3, degrees=True).as_matrix() rot_z = R.from_euler('z', -30, degrees=True).as_matrix() rot = torch.from_numpy(rot_z @ rot_y @ rot_x) translation = torch.tensor([0.0, 0.0, 0.5], dtype=torch.float32) ( depths_trans, extrs_trans, query_points_trans, traj3d_world_trans, traj2d_w_z_trans ) = transform_scene(scale, rot, translation, depths, extrs, query_points, traj3d_world, traj2d_w_z) novel_extrs_trans = transform_scene(scale, rot, translation, None, novel_extrs, None, None, None)[1] # rerun_viz_scene("nane/scene__no_transform/", rgbs, depths, intrs, extrs, traj3d_world, 0.1) # rerun_viz_scene("nane/scene_transformed/", rgbs, depths_trans, intrs, extrs_trans, traj3d_world_trans, 1) # # Use the auto scene normalization of generic scenes # from mvtracker.datasets.generic_scene_dataset import compute_auto_scene_normalization # scale, rot, translation = compute_auto_scene_normalization(depths, torch.ones_like(depths) * 100, extrs_trans, intrs) # scale = scale * T[0, 0].item() # print(f"{scale=}") # (depths_trans, extrs_trans, query_points_trans, traj3d_world_trans, traj2d_w_z_trans # ) = transform_scene(scale, rot, translation, depths_trans, extrs_trans, query_points_trans, traj3d_world_trans, traj2d_w_z_trans) # _, novel_extrs_trans, _, _, _ = transform_scene(scale, rot, translation, None, novel_extrs_trans, None, None, None) # 82.7 91.1 --> 80.8 89.1 segs = torch.ones((n_frames, 1, h, w)) # Dummy segmentation masks datapoint = Datapoint( video=rgbs, videodepth=depths_trans, videodepthconf=depth_confs.float() if depth_confs is not None else None, feats=feats, segmentation=segs, trajectory=traj2d_w_z_trans, trajectory_3d=traj3d_world_trans, trajectory_category=traj_object_id, visibility=visibility, valid=valids, seq_name=self.seq_names[index], intrs=intrs, extrs=extrs_trans, query_points=None, query_points_3d=query_points_trans, track_upscaling_factor=1 / scale, novel_video=novel_rgbs, novel_intrs=novel_intrs, novel_extrs=novel_extrs_trans, ) return datapoint def rerun_viz_scene(entity_prefix, rgbs, depths, intrs, extrs, tracks, radii_scale, viz_camera=False, viz_point_cloud=True, fps=12): import rerun as rr # Initialize Rerun rr.init(f"3dpt", recording_id="v0.16") rr.connect_tcp() V, T, _, H, W = rgbs.shape _, N, _ = tracks.shape assert rgbs.shape == (V, T, 3, H, W) assert depths.shape == (V, T, 1, H, W) assert intrs.shape == (V, T, 3, 3) assert extrs.shape == (V, T, 3, 4) assert tracks.shape == (T, N, 3) # Compute inverse intrinsics and extrinsics intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype) extrs_square = torch.eye(4).to(extrs.device).repeat(V, T, 1, 1) extrs_square[:, :, :3, :] = extrs extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype) assert intrs_inv.shape == (V, T, 3, 3) assert extrs_inv.shape == (V, T, 4, 4) for v in range(V): # Iterate over views for t in range(T): # Iterate over frames rr.set_time_seconds("frame", t / fps) # Log RGB image rgb_image = rgbs[v, t].permute(1, 2, 0).cpu().numpy() if viz_camera: rr.log(f"{entity_prefix}image/view-{v}/rgb", rr.Image(rgb_image)) # Log Depth map depth_map = depths[v, t, 0].cpu().numpy() if viz_camera: rr.log(f"{entity_prefix}image/view-{v}/depth", rr.DepthImage(depth_map, point_fill_ratio=0.2)) # Log Camera K = intrs[v, t].cpu().numpy() world_T_cam = np.eye(4) world_T_cam[:3, :3] = extrs_inv[v, t, :3, :3].cpu().numpy() world_T_cam[:3, 3] = extrs_inv[v, t, :3, 3].cpu().numpy() if viz_camera: rr.log(f"{entity_prefix}image/view-{v}", rr.Pinhole(image_from_camera=K, width=W, height=H)) rr.log(f"{entity_prefix}image/view-{v}", rr.Transform3D(translation=world_T_cam[:3, 3], mat3x3=world_T_cam[:3, :3])) # Generate and log point cloud colored by RGB values # Compute 3D points from depth map y, x = np.indices((H, W)) homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T depth_values = depth_map.ravel() cam_coords = (intrs_inv[v, t].cpu().numpy() @ homo_pixel_coords) * depth_values cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1])))) world_coords = (world_T_cam @ cam_coords)[:3].T # Filter out points with zero depth valid_mask = depth_values > 0 world_coords = world_coords[valid_mask] rgb_colors = rgb_image.reshape(-1, 3)[valid_mask].astype(np.uint8) # Log the point cloud if viz_point_cloud: rr.log(f"{entity_prefix}point_cloud/view-{v}", rr.Points3D(world_coords, colors=rgb_colors, radii=0.02 * radii_scale)) # Log 3D tracks x = tracks[0, :, 0] c = (x - x.min()) / (x.max() - x.min() + 1e-8) colors = (matplotlib.colormaps["gist_rainbow"](c)[:, :3] * 255).astype(np.uint8) for t in range(T): rr.set_time_seconds("frame", t / fps) rr.log( f"{entity_prefix}tracks/points", rr.Points3D(positions=tracks[t], colors=colors, radii=0.01 * radii_scale), ) if t > 0: strips = np.concatenate( [np.stack([tracks[:t, n], tracks[1:t + 1, n]], axis=-2) for n in range(N)], axis=0, ) strip_colors = np.concatenate( [np.repeat(colors[n][None], t, axis=0) for n in range(N)], axis=0, ) rr.log( f"{entity_prefix}tracks/lines", rr.LineStrips3D(strips=strips, colors=strip_colors, radii=0.005 * radii_scale), ) ================================================ FILE: mvtracker/datasets/generic_scene_dataset.py ================================================ import logging import os import pickle import sys from contextlib import ExitStack from typing import Tuple import numpy as np import torch import torch.nn.functional as F from PIL import Image from torch.nn.functional import interpolate from torch.utils.data import Dataset from torchvision import transforms as TF from tqdm import tqdm from mvtracker.datasets.utils import Datapoint, transform_scene, align_umeyama, apply_sim3_to_extrinsics class GenericSceneDataset(Dataset): def __init__( self, dataset_dir, use_duster_depths=True, use_vggt_depths_with_aligned_cameras=False, use_vggt_depths_with_raw_cameras=False, use_monofusion_depths=False, use_moge2_depths=False, skip_depth_computation_if_cached=True, drop_first_n_frames=0, scene_normalization_mode="auto", # "auto" | "manual" | "none" scene_normalization_auto_conf_thresh=4.8, scene_normalization_auto_target_radius=6.3, scene_normalization_auto_rescale_by_camera_radius=True, scene_normalization_manual_scale=None, # Optional float scene_normalization_manual_rotation=None, # Optional 3x3 torch.Tensor rotation matrix scene_normalization_manual_translation=None, # Optional 3D torch.Tensor post-scale translation vector # E.g., the manual transform that translates up by 1.4 units and scales 2.5 times (was good for EgoExo4D): # scale = 2.5 # translate_x = 0 # translate_y = 0 # translate_z = 1.4 * scale # T = torch.tensor([ # [scale, 0.0, 0.0, translate_x], # [0.0, scale, 0.0, translate_y], # [0.0, 0.0, scale, translate_z], # [0.0, 0.0, 0.0, 1.0], # ], dtype=torch.float32) stream_viz_to_rerun=False, ): self.dataset_dir = dataset_dir self.use_duster_depths = use_duster_depths self.use_vggt_depths_with_aligned_cameras = use_vggt_depths_with_aligned_cameras self.use_vggt_depths_with_raw_cameras = use_vggt_depths_with_raw_cameras self.use_monofusion_depths = use_monofusion_depths self.use_moge2_depths = use_moge2_depths # --- Assert exclusive depth-source configuration --- # Exactly 0 or 1 of these should be True. (0 => fall back to pkl/dust3r.) depth_flags = (int(self.use_duster_depths) + int(self.use_vggt_depths_with_aligned_cameras) + int(self.use_vggt_depths_with_raw_cameras) + int(self.use_monofusion_depths) + int(self.use_moge2_depths)) assert depth_flags <= 1, ( "Misconfigured dataset: choose at most one depth source among " "`use_monofusion_depths`, `use_moge2_depths`, `use_duster_depths`." ) self.skip_depth_computation_if_cached = skip_depth_computation_if_cached self.drop_first_n_frames = drop_first_n_frames self.scene_normalization_mode = scene_normalization_mode self.scene_normalization_auto_conf_thresh = scene_normalization_auto_conf_thresh self.scene_normalization_auto_target_radius = scene_normalization_auto_target_radius self.scene_normalization_auto_rescale_by_camera_radius = scene_normalization_auto_rescale_by_camera_radius self.scene_normalization_manual_scale = scene_normalization_manual_scale self.scene_normalization_manual_rotation = scene_normalization_manual_rotation self.scene_normalization_manual_translation = scene_normalization_manual_translation self.stream_viz_to_rerun = stream_viz_to_rerun self.seq_names = sorted([ f.replace(".pkl", "") for f in os.listdir(dataset_dir) if f.endswith(".pkl") ]) assert self.seq_names, f"No sequences found in {dataset_dir}" def __len__(self): return len(self.seq_names) def __getitem__(self, idx): seq_name = self.seq_names[idx] pkl_path = os.path.join(self.dataset_dir, f"{seq_name}.pkl") with open(pkl_path, "rb") as f: data = pickle.load(f) ego_cam = data.get("ego_cam_name", None) rgbs_dict = data["rgbs"] intrs_dict = data["intrs"] extrs_dict = data["extrs"] depths_dict = data.get("depths", None) if ego_cam: rgbs_dict.pop(ego_cam) intrs_dict.pop(ego_cam) extrs_dict.pop(ego_cam) if depths_dict is not None: depths_dict.pop(ego_cam) cam_names = sorted(rgbs_dict.keys()) n_views = len(cam_names) n_frames, _, H, W = rgbs_dict[cam_names[0]].shape rgbs = torch.stack([torch.from_numpy(rgbs_dict[cam]) for cam in cam_names]) # [V, T, 3, H, W] intrs = torch.stack([torch.from_numpy(intrs_dict[cam]) for cam in cam_names]) # [V, 3, 3] intrs = intrs[:, None].expand(-1, n_frames, -1, -1) # [V, T, 3, 3] extr_list = [] for cam in cam_names: e = extrs_dict[cam] if e.ndim == 2: e = np.broadcast_to(e[None, ...], (n_frames, 3, 4)) extr_list.append(torch.from_numpy(e.copy())) extrs = torch.stack(extr_list) # [V, T, 3, 4] # ------- Depth selection & caching ------- if self.use_duster_depths: depth_root = os.path.join(self.dataset_dir, f"duster_depths__{seq_name}") if not os.path.exists(os.path.join(depth_root, f"3d_model__{n_frames - 1:05d}__scene.npz")): if "../duster" not in sys.path: sys.path.insert(0, "../duster") from scripts.egoexo4d_preprocessing import main_estimate_duster_depth pkl_path = os.path.join(self.dataset_dir, f"{seq_name}.pkl") # Re-enable autograd locally (overrides any surrounding no_grad/inference_mode) with ExitStack() as stack: stack.enter_context(torch.inference_mode(False)) stack.enter_context(torch.enable_grad()) main_estimate_duster_depth(pkl_path, depth_root, self.skip_depth_computation_if_cached) duster_depths, duster_confs = [], [] for t in range(n_frames): scene_path = os.path.join(depth_root, f"3d_model__{t:05d}__scene.npz") scene = np.load(scene_path) d = torch.from_numpy(scene["depths"]) # [V, H', W'] d = interpolate(d[:, None], size=(H, W), mode="nearest") # [V, 1, H, W] duster_depths.append(d) c = torch.from_numpy(scene["confs"]) c = interpolate(c[:, None], size=(H, W), mode="nearest") duster_confs.append(c) depths = torch.stack(duster_depths, dim=1) # [V, T, 1, H, W] depth_confs = torch.stack(duster_confs, dim=1) elif self.use_vggt_depths_with_aligned_cameras: depths, depth_confs, intrs, extrs = _ensure_vggt_aligned_cache_and_load( rgbs=rgbs, seq_name=seq_name, dataset_root=self.dataset_dir, extrs_gt=extrs, # your current GT world->cam vggt_cache_subdir="vggt_cache", skip_if_cached=self.skip_depth_computation_if_cached, model_id="facebook/VGGT-1B", ) elif self.use_vggt_depths_with_raw_cameras: # Only use VGGT’s own (raw) cameras and depths depths, depth_confs, intrs, extrs = _ensure_vggt_raw_cache_and_load( rgbs=rgbs, seq_name=seq_name, dataset_root=self.dataset_dir, vggt_cache_subdir="vggt_cache", skip_if_cached=self.skip_depth_computation_if_cached, model_id="facebook/VGGT-1B", ) elif self.use_monofusion_depths: # MonoFusion (Dust3r + FG/BG-heuristic + MoGE-2) with caching final_depths, final_confs = _ensure_monofusion_cache_and_load( rgbs=rgbs, seq_name=seq_name, dataset_root=self.dataset_dir, monofusion_cache_subdir="monofusion_cache", skip_if_cached=self.skip_depth_computation_if_cached, ) depths = final_depths depth_confs = final_confs elif self.use_moge2_depths: # Raw MoGe-2 (metric) with caching depths, depth_confs = _ensure_moge2_cache_and_load( rgbs=rgbs, seq_name=seq_name, dataset_root=self.dataset_dir, moge2_cache_subdir="moge2_cache", skip_if_cached=self.skip_depth_computation_if_cached, ) elif depths_dict is not None: depths = torch.stack([torch.from_numpy(depths_dict[cam]) for cam in cam_names]).unsqueeze(2) depth_confs = depths.new_zeros(depths.shape) depth_confs[depths > 0] = 1000 else: raise ValueError("No depths available/configured") # Sometimes the first frames are noisy, e.g., due to timesync calibration if self.drop_first_n_frames: assert type(self.drop_first_n_frames) == int n_frames -= self.drop_first_n_frames rgbs = rgbs[:, self.drop_first_n_frames:] depths = depths[:, self.drop_first_n_frames:] depth_confs = depth_confs[:, self.drop_first_n_frames:] intrs = intrs[:, self.drop_first_n_frames:] extrs = extrs[:, self.drop_first_n_frames:] if self.scene_normalization_mode == "auto": scale, translation = compute_auto_scene_normalization( depths, depth_confs, extrs, intrs, conf_thresh=self.scene_normalization_auto_conf_thresh, target_radius=self.scene_normalization_auto_target_radius, rescale_by_camera_radius=self.scene_normalization_auto_rescale_by_camera_radius, ) rot = torch.eye(3, dtype=torch.float32, device=depths.device) elif self.scene_normalization_mode == "manual": assert self.scene_normalization_manual_scale is not None assert self.scene_normalization_manual_rotation is not None assert self.scene_normalization_manual_translation is not None scale = self.scene_normalization_manual_scale rot = self.scene_normalization_manual_rotation.to(depths.device) translation = self.scene_normalization_manual_translation.to(depths.device) elif self.scene_normalization_mode == "none": scale = 1.0 rot = torch.eye(3, dtype=torch.float32, device=depths.device) translation = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device=depths.device) else: raise ValueError(f"Unknown scene_normalization_mode: {self.scene_normalization_mode}") depths_trans, extrs_trans, _, _, _ = transform_scene(scale, rot, translation, depths, extrs, None, None, None) assert rgbs.shape == (n_views, n_frames, 3, H, W) assert depths.shape == (n_views, n_frames, 1, H, W) assert depth_confs.shape == (n_views, n_frames, 1, H, W) assert intrs.shape == (n_views, n_frames, 3, 3) assert extrs.shape == (n_views, n_frames, 3, 4) assert extrs_trans.shape == (n_views, n_frames, 3, 4) if self.stream_viz_to_rerun: import rerun as rr from mvtracker.utils.visualizer_rerun import log_pointclouds_to_rerun rr.init(f"3dpt", recording_id="v0.16") rr.connect_tcp() log_pointclouds_to_rerun(f"generic-1-before-norm", idx, rgbs[None], depths[None], intrs[None], extrs[None], depth_confs[None], [1.0]) log_pointclouds_to_rerun(f"generic-2-after-norm", idx, rgbs[None], depths[None], intrs[None], extrs_trans[None], depth_confs[None], [1.0]) datapoint = Datapoint( video=rgbs.float(), videodepth=depths_trans.float(), videodepthconf=depth_confs.float(), feats=None, segmentation=torch.ones((n_views, n_frames, 1, H, W), dtype=torch.float32), trajectory=None, trajectory_3d=None, visibility=None, valid=None, seq_name=seq_name, intrs=intrs.float(), extrs=extrs_trans.float(), query_points=None, query_points_3d=None, trajectory_category=None, track_upscaling_factor=1.0, novel_video=None, novel_intrs=None, novel_extrs=None, ) return datapoint, True def compute_auto_scene_normalization( depths, depth_confs, extrs, intrs, conf_thresh=4.8, target_radius=6.3, rescale_by_camera_radius=True, ): V, T, _, H, W = depths.shape device = depths.device extrs_square = torch.eye(4, device=device)[None, None].repeat(V, T, 1, 1) extrs_square[:, :, :3, :] = extrs extrs_inv = torch.inverse(extrs_square.float()) intrs_inv = torch.inverse(intrs.float()) y, x = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device), indexing="ij") homog = torch.stack([x, y, torch.ones_like(x)], dim=-1).reshape(-1, 3).float() homog = homog[None].expand(V, -1, -1) pts_all = [] for v in range(V): d = depths[v, 0, 0] c = depth_confs[v, 0, 0] mask = (c > conf_thresh) & (d > 0) if mask.sum() < 100: continue d_flat = d.flatten() conf_mask = mask.flatten() intr_inv = intrs_inv[v, 0] extr_inv = extrs_inv[v, 0] cam_pts = (intr_inv @ homog[v].T).T * d_flat[:, None] cam_pts = cam_pts[conf_mask] cam_pts_h = torch.cat([cam_pts, torch.ones_like(cam_pts[:, :1])], dim=-1) world_pts = (extr_inv @ cam_pts_h.T).T[:, :3] pts_all.append(world_pts) pts_all = torch.cat(pts_all, dim=0) if pts_all.shape[0] < 100: raise RuntimeError("Too few valid points for normalization.") # --- Center scene --- centroid = pts_all.mean(dim=0) pts_centered = pts_all - centroid # --- Lift scene so floor is at z=0 --- floor_z = pts_centered[:, 2].quantile(0.12) # robust floor estimate pts_lifted = pts_centered.clone() pts_lifted[:, 2] -= floor_z # --- Compute scale --- if rescale_by_camera_radius: cam_centers = extrs[:, 0, :, 3] # (V, 3) cam_centers_centered = cam_centers - centroid # shift cam_centers_centered[:, 2] -= floor_z # lift cam_dists = cam_centers_centered.norm(dim=1) median_dist = cam_dists.median() scale = target_radius / median_dist else: scene_radius = pts_lifted.norm(dim=1).quantile(0.95) scale = target_radius / scene_radius # --- Compute translation (after scaling) --- translate = -scale * centroid translate[2] -= scale * floor_z # lift to z=0 return scale, translate def _ensure_moge2_cache_and_load(rgbs, seq_name, dataset_root, moge2_cache_subdir, skip_if_cached=True): """ Raw MoGe-2 depth (metric) with per-sequence caching. Returns (depths, confs) shaped [V,T,1,H,W] on CPU. """ V, T, _, H, W = rgbs.shape cache_root = os.path.join(dataset_root, moge2_cache_subdir, seq_name) os.makedirs(cache_root, exist_ok=True) depths_path = os.path.join(cache_root, "moge2_depths.npy") confs_path = os.path.join(cache_root, "moge2_confs.npy") if skip_if_cached and os.path.isfile(depths_path) and os.path.isfile(confs_path): d = torch.from_numpy(np.load(depths_path)).float() # [V,T,H,W] c = torch.from_numpy(np.load(confs_path)).float() # [V,T,H,W] return d.unsqueeze(2), c.unsqueeze(2) d = _moge_depths(seq_name, rgbs, cache_root) # [V,T,H,W], CPU float # Simple constant confidence for MoGe-2 c = torch.full_like(d, 100.0) np.save(depths_path, d.numpy()) np.save(confs_path, c.numpy()) return d.unsqueeze(2), c.unsqueeze(2) def _ensure_monofusion_cache_and_load(rgbs, seq_name, dataset_root, monofusion_cache_subdir, skip_if_cached=True): """ MONOFUSION: - Background mask: patch-change detector over temporal window (static -> BG) - DUSt3R depth: load per frame/view; build static background depth by BG-temporal-average. - MoGe-2 monocular depth per frame/view; align to background by affine (a,b). - Merge BG (DUSt3R static) with FG (aligned MoGe). - Cache final depths & confs. """ V, T, _, H, W = rgbs.shape cache_root = os.path.join(dataset_root, monofusion_cache_subdir, seq_name) os.makedirs(cache_root, exist_ok=True) final_depths_path = os.path.join(cache_root, "final_depths.npy") final_confs_path = os.path.join(cache_root, "final_confs.npy") if skip_if_cached and os.path.isfile(final_depths_path) and os.path.isfile(final_confs_path): fd = torch.from_numpy(np.load(final_depths_path)) # [V,T,H,W] fc = torch.from_numpy(np.load(final_confs_path)) # [V,T,H,W] return fd.unsqueeze(2), fc.unsqueeze(2) # ---- DUSt3R depths per frame/view ---- depth_root = os.path.join(dataset_root, f"duster_depths__{seq_name}") if not os.path.exists(os.path.join(depth_root, f"3d_model__{T - 1:05d}__scene.npz")): if "../duster" not in sys.path: sys.path.insert(0, "../duster") from scripts.egoexo4d_preprocessing import main_estimate_duster_depth pkl_path = os.path.join(dataset_root, f"{seq_name}.pkl") # Re-enable autograd locally (overrides any surrounding no_grad/inference_mode) with ExitStack() as stack: stack.enter_context(torch.inference_mode(False)) stack.enter_context(torch.enable_grad()) main_estimate_duster_depth(pkl_path, depth_root, skip_if_cached) duster_depths = [] for t in range(T): scene_path = os.path.join(depth_root, f"3d_model__{t:05d}__scene.npz") scene = np.load(scene_path) d = torch.from_numpy(scene["depths"]) # [V, H', W'] d = interpolate(d[:, None], size=(H, W), mode="nearest")[:, 0] # [V, H, W] duster_depths.append(d) duster_depths = torch.stack(duster_depths, dim=1) # [V, T, H, W] # ---- Background mask (patch-change) ---- compute_device = "cuda" if torch.cuda.is_available() else "cpu" bg_mask = _static_bg_mask_from_window(rgbs.to(compute_device)).cpu() # [V,T,H,W] bool # ---- Static background depth per camera via temporal average on BG pixels ---- V, T, _, _ = duster_depths.shape D_bg = torch.zeros((V, H, W), dtype=torch.float32) for v in range(V): valid = bg_mask[v] # [T,H,W] num = (duster_depths[v] * valid).sum(dim=0) den = valid.sum(dim=0).clamp_min(1) D_bg[v] = num / den # ---- MoGe-2 monocular depths per frame/view ---- moge_depths = _moge_depths(seq_name, rgbs, cache_root) # [V,T,H,W] # ---- Align MoGe to background (solve a,b on BG pixels) ---- compute_device = "cuda" if torch.cuda.is_available() else "cpu" moge_depths = moge_depths.to(compute_device, dtype=torch.float32) # [V,T,H,W] D_bg_exp = D_bg[:, None].expand_as(moge_depths).to(compute_device) # [V,T,H,W] bg_mask = bg_mask.to(compute_device) # [V,T,H,W] # Valid BG pixels valid = bg_mask & torch.isfinite(moge_depths) & (moge_depths > 0) \ & torch.isfinite(D_bg_exp) & (D_bg_exp > 0) # Flatten over pixels X = moge_depths.view(V, T, -1) # [V,T,HW] Y = D_bg_exp.view(V, T, -1) # [V,T,HW] M = valid.view(V, T, -1).float() # [V,T,HW] # Count valid pixels n = M.sum(dim=-1) # [V,T] min_bg = 200 if (n < min_bg).any(): bad = torch.nonzero(n < min_bg, as_tuple=False) raise RuntimeError( f"Too few background pixels in frames: {[(int(v), int(t)) for v, t in bad.tolist()]}" ) # Sufficient statistics sx = (X * M).sum(dim=-1) sy = (Y * M).sum(dim=-1) sxx = (X * X * M).sum(dim=-1) sxy = (X * Y * M).sum(dim=-1) # Closed-form least squares for a, b eps = 1e-8 mx = sx / n my = sy / n varx = sxx / n - mx * mx cov = sxy / n - mx * my a = cov / (varx + eps) # [V,T] b = my - a * mx # Apply alignment aligned_moge = (a[..., None] * X + b[..., None]).view(V, T, H, W) # Optionally save scale/shift scale = a.float().cpu() shift = b.float().cpu() # ---- Merge FG/BG ---- final_depths = torch.where(bg_mask, D_bg_exp, aligned_moge) # [V,T,H,W] # ---- Confidence map: high for BG, moderate for FG ---- final_confs = torch.zeros_like(final_depths) final_confs[bg_mask] = 1000.0 final_confs[~bg_mask] = 10.0 # ---- Cache results ---- np.save(final_depths_path, final_depths.cpu().numpy()) np.save(final_confs_path, final_confs.cpu().numpy()) np.save(os.path.join(cache_root, "scale.npy"), scale.cpu().numpy()) np.save(os.path.join(cache_root, "shift.npy"), shift.cpu().numpy()) return final_depths.unsqueeze(2).cpu(), final_confs.unsqueeze(2).cpu() def _static_bg_mask_from_window( rgbs: torch.Tensor, win: int = -1, r: int = 7, # spatial patch radius -> (2r+1)x(2r+1) diff_thresh: float = 10.0 # uint8 scale threshold ): """ Fast BG detector using 3D max-pooling over frame-to-frame diffs. """ V, T, C, H, W = rgbs.shape device = rgbs.device if T == 1: return torch.ones((V, T, H, W), dtype=torch.bool, device=device) if win == -1: win = T # 1) Frame-to-frame abs diff (channel-mean): boundaries of length T-1 x = rgbs.float() diffs = (x[:, 1:] - x[:, :-1]).abs().mean(dim=2) # [V, T-1, H, W] diffs = diffs.unsqueeze(1) # [V, 1, T-1, H, W] (N,C,D,H,W for 3D pool) # 2) 3D max pool over time & space: # - temporal kernel spans (2*win-1) boundaries # - spatial kernel spans (2r+1)x(2r+1) patch kt = max(1, 2 * win - 1) kh = kw = 2 * r + 1 pt = (kt - 1) // 2 ph = pw = r pooled = F.max_pool3d(diffs, kernel_size=(kt, kh, kw), stride=1, padding=(pt, ph, pw)) pooled = pooled[:, 0] # [V, T-1, H, W] # 3) Map boundary maxima back to frame centers (symmetric nearest-window approx) change = torch.zeros((V, T, H, W), device=device, dtype=pooled.dtype) change[:, 0] = pooled[:, 0] change[:, 1:-1] = torch.maximum(pooled[:, :-1], pooled[:, 1:]) change[:, -1] = pooled[:, -1] # 4) Threshold -> background bg_mask = (change < diff_thresh) return bg_mask def _moge_depths(seq_name, rgbs, cache_root, resize_to=512, batch_size=18): """Runs (and caches) MoGe-2; returns [V,T,H,W] float32 at native resolution.""" # pip install git+https://github.com/microsoft/MoGe.git from moge.model.v2 import MoGeModel as MoGe2Model depths_path = os.path.join(cache_root, "moge_depths.npy") if os.path.isfile(depths_path): logging.info(f"Loading cached MoGe-2 depths for {seq_name} from {depths_path}") return torch.from_numpy(np.load(depths_path)).float() V, T, C, H, W = rgbs.shape device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = MoGe2Model.from_pretrained("Ruicheng/moge-2-vitl-normal").to(device).eval() if resize_to is None: h1, w1 = H, W else: if H >= W: h1, w1 = int(resize_to), max(1, round(resize_to * W / H)) else: w1, h1 = int(resize_to), max(1, round(resize_to * H / W)) imgs = rgbs.view(V * T, C, H, W).float() if (h1, w1) != (H, W): imgs = F.interpolate(imgs, size=(h1, w1), mode="bilinear", align_corners=False) imgs = (imgs / 255.0).to(device, non_blocking=True) # [N,3,h1,w1] out_small = torch.empty((V * T, h1, w1), dtype=torch.float32, device=device) with torch.inference_mode(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=(device.type == "cuda")): N = imgs.shape[0] for i in range(0, N, batch_size): chunk = imgs[i:i + batch_size] # [b,3,h1,w1] pred = model.infer(chunk) # expects batched input assert isinstance(pred, dict) and "depth" in pred, "MoGe-2 infer() must return dict with 'depth'." d = torch.as_tensor(pred["depth"], device=device) assert d.ndim == 3 and d.shape[0] == chunk.shape[0] and tuple(d.shape[1:]) == (h1, w1), \ f"Depth shape {tuple(d.shape)} != ({chunk.shape[0]},{h1},{w1})" out_small[i:i + chunk.shape[0]] = d if (h1, w1) != (H, W): out = F.interpolate(out_small[:, None], size=(H, W), mode="bilinear", align_corners=False)[:, 0] else: out = out_small out = out.clamp_min(0).view(V, T, H, W).cpu() np.save(depths_path, out.numpy()) return out def _ensure_vggt_raw_cache_and_load( rgbs: torch.Tensor, # uint8 [V,T,3,H,W] seq_name: str, dataset_root: str, vggt_cache_subdir: str = "vggt_cache", skip_if_cached: bool = True, model_id: str = "facebook/VGGT-1B", ): """ Run VGGT and cache RAW predictions (no alignment). Returns CPU float32 tensors: depths_raw [V,T,1,H,W] confs [V,T,1,H,W] (constant 100) intrs_raw [V,T,3,3] extrs_raw [V,T,3,4] (world->cam as predicted by VGGT) """ from mvtracker.models.core.vggt.models.vggt import VGGT from mvtracker.models.core.vggt.utils.pose_enc import pose_encoding_to_extri_intri assert rgbs.dtype == torch.uint8 and rgbs.ndim == 5 and rgbs.shape[2] == 3, "rgbs must be uint8 [V,T,3,H,W]" V, T, _, H, W = rgbs.shape cache_root = os.path.join(dataset_root, vggt_cache_subdir, seq_name) os.makedirs(cache_root, exist_ok=True) f_depths_raw = os.path.join(cache_root, "vggt_depths_raw.npy") # [V,T,H,W] f_confs = os.path.join(cache_root, "vggt_confs.npy") # [V,T,H,W] f_intr_raw = os.path.join(cache_root, "vggt_intrinsics_raw.npy") f_extr_raw = os.path.join(cache_root, "vggt_extrinsics_raw.npy") all_cached = all(os.path.isfile(p) for p in [f_depths_raw, f_confs, f_intr_raw, f_extr_raw]) if skip_if_cached and all_cached: depths_raw = torch.from_numpy(np.load(f_depths_raw)).float().unsqueeze(2) confs = torch.from_numpy(np.load(f_confs)).float().unsqueeze(2) intrs_raw = torch.from_numpy(np.load(f_intr_raw)).float() extrs_raw = torch.from_numpy(np.load(f_extr_raw)).float() return depths_raw, confs, intrs_raw, extrs_raw device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = VGGT.from_pretrained(model_id).to(device).eval() amp_dtype = torch.bfloat16 if ( device.type == "cuda" and torch.cuda.get_device_capability()[0] >= 8) else torch.float16 def _compute_pad_to_518(H0: int, W0: int, target: int = 518) -> Tuple[int, int, int, int, int, int]: """ Mirror VGGT's load_and_preprocess_images(mode='pad') padding math so we can undo it. Returns: new_h, new_w, pad_top, pad_bottom, pad_left, pad_right """ # Make largest dim target, keep aspect, round smaller dim to /14*14, then pad to (target, target) if W0 >= H0: new_w = target new_h = int(round((H0 * (new_w / W0)) / 14.0) * 14) h_pad = max(0, target - new_h) w_pad = 0 else: new_h = target new_w = int(round((W0 * (new_h / H0)) / 14.0) * 14) h_pad = 0 w_pad = max(0, target - new_w) pad_top = h_pad // 2 pad_bottom = h_pad - pad_top pad_left = w_pad // 2 pad_right = w_pad - pad_left return new_h, new_w, pad_top, pad_bottom, pad_left, pad_right depths_raw_arr = torch.empty((V, T, H, W), dtype=torch.float32) confs_arr = torch.full((V, T, H, W), 100.0, dtype=torch.float32) intr_raw_arr = torch.empty((V, T, 3, 3), dtype=torch.float32) extr_raw_arr = torch.empty((V, T, 3, 4), dtype=torch.float32) with torch.no_grad(), torch.cuda.amp.autocast(enabled=(device.type == "cuda"), dtype=amp_dtype): for t in tqdm(range(T), desc=f"VGGT RAW {seq_name}", unit="f"): image_items = [rgbs[v, t].cpu() for v in range(V)] # each: [3,H,W] uint8 images = _vggt_load_and_preprocess_images(image_items, mode="pad").to(device)[None] # [1,V,3,518,518] tokens, ps_idx = model.aggregator(images) pose_enc = model.camera_head(tokens)[-1] extr_pred, intr_pred = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:]) # [1,V,3,4],[1,V,3,3] depth_maps, _ = model.depth_head(tokens, images, ps_idx) # [1,V,518,518] # per-view: undo pad, resize back to (H0,W0), adjust intrinsics d_full_list, K_list = [], [] for v in range(V): H0, W0 = int(rgbs[v, t].shape[-2]), int(rgbs[v, t].shape[-1]) new_h, new_w, pt, pb, pl, pr = _compute_pad_to_518(H0, W0) # crop padding region out of the 518x518 depth d_small = depth_maps[0, v:v + 1, pt:518 - pb, pl:518 - pr] # [1,new_h,new_w] d_full_v = F.interpolate(d_small[:, None, :, :, 0], size=(H0, W0), mode="nearest")[:, 0] # [1,H0,W0] d_full_list.append(d_full_v.squeeze(0)) # adjust intrinsics: subtract removed pad, then scale to (H0,W0) K = intr_pred[0, v].detach().cpu().float().clone() K[0, 2] -= float(pl) K[1, 2] -= float(pt) S = torch.tensor([[W0 / float(new_w), 0.0, 0.0], [0.0, H0 / float(new_h), 0.0], [0.0, 0.0, 1.0]], dtype=torch.float32) K_list.append((S @ K).unsqueeze(0)) depths_raw_arr[:, t] = torch.stack(d_full_list, dim=0) intr_raw_arr[:, t] = torch.cat(K_list, dim=0) extr_raw_arr[:, t] = extr_pred[0].detach().cpu().float() # raw VGGT w2c # save raw cache np.save(f_depths_raw, depths_raw_arr.numpy()) np.save(f_confs, confs_arr.numpy()) np.save(f_intr_raw, intr_raw_arr.numpy()) np.save(f_extr_raw, extr_raw_arr.numpy()) return depths_raw_arr.unsqueeze(2), confs_arr.unsqueeze(2), intr_raw_arr, extr_raw_arr def _vggt_load_and_preprocess_images(image_items, mode="crop"): """ Same as VGGT loader, but accepts in-memory items as well. """ if len(image_items) == 0: raise ValueError("At least 1 image is required") # Validate mode if mode not in ["crop", "pad"]: raise ValueError("Mode must be either 'crop' or 'pad'") images = [] shapes = set() to_tensor = TF.ToTensor() target_size = 518 def _to_pil(item): # path if isinstance(item, str): img = Image.open(item) return img # numpy HWC if isinstance(item, np.ndarray): if item.ndim == 3 and item.shape[2] in (3, 4): if item.dtype != np.uint8: item = item.astype(np.uint8) return Image.fromarray(item) # torch CHW if torch.is_tensor(item): x = item if x.ndim == 3 and x.shape[0] in (3, 4): if x.dtype == torch.uint8: arr = x.permute(1, 2, 0).cpu().numpy() return Image.fromarray(arr) else: # assume float [0,1] arr = (x.clamp(0, 1) * 255.0).byte().permute(1, 2, 0).cpu().numpy() return Image.fromarray(arr) raise ValueError("Unsupported image item type/shape") for item in image_items: img = _to_pil(item) # If there's an alpha channel, blend onto white background: if img.mode == "RGBA": # Create white background background = Image.new("RGBA", img.size, (255, 255, 255, 255)) # Alpha composite onto the white background img = Image.alpha_composite(background, img) # Now convert to "RGB" (this step assigns white for transparent areas) img = img.convert("RGB") width, height = img.size if mode == "pad": # Make the largest dimension 518px while maintaining aspect ratio if width >= height: new_width = target_size new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14 else: new_height = target_size new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14 else: # mode == "crop" # Original behavior: set width to 518px new_width = target_size # Calculate height maintaining aspect ratio, divisible by 14 new_height = round(height * (new_width / width) / 14) * 14 # Resize with new dimensions (width, height) img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) img = to_tensor(img) # Convert to tensor (0, 1) # Center crop height if it's larger than 518 (only in crop mode) if mode == "crop" and new_height > target_size: start_y = (new_height - target_size) // 2 img = img[:, start_y: start_y + target_size, :] # For pad mode, pad to make a square of target_size x target_size if mode == "pad": h_padding = target_size - img.shape[1] w_padding = target_size - img.shape[2] if h_padding > 0 or w_padding > 0: pad_top = h_padding // 2 pad_bottom = h_padding - pad_top pad_left = w_padding // 2 pad_right = w_padding - pad_left # Pad with white (value=1.0) img = torch.nn.functional.pad( img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 ) shapes.add((img.shape[1], img.shape[2])) images.append(img) # Check if we have different shapes # In theory our model can also work well with different shapes if len(shapes) > 1: print(f"Warning: Found images with different shapes: {shapes}") # Find maximum dimensions max_height = max(shape[0] for shape in shapes) max_width = max(shape[1] for shape in shapes) # Pad images if necessary padded_images = [] for img in images: h_padding = max_height - img.shape[1] w_padding = max_width - img.shape[2] if h_padding > 0 or w_padding > 0: pad_top = h_padding // 2 pad_bottom = h_padding - pad_top pad_left = w_padding // 2 pad_right = w_padding - pad_left img = torch.nn.functional.pad( img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 ) padded_images.append(img) images = padded_images images = torch.stack(images) # concatenate images # Ensure correct shape when single image if len(image_items) == 1: # Verify shape is (1, C, H, W) if images.dim() == 3: images = images.unsqueeze(0) return images def _ensure_vggt_aligned_cache_and_load( rgbs: torch.Tensor, # uint8 [V,T,3,H,W] seq_name: str, dataset_root: str, extrs_gt: torch.Tensor, # [V,T,3,4] GT world->cam vggt_cache_subdir: str = "vggt_cache", skip_if_cached: bool = True, model_id: str = "facebook/VGGT-1B", ): """ Ensure RAW VGGT cache exists (running VGGT if needed), then align VGGT cameras to GT via Umeyama (pred→gt) per frame. Returns CPU float32: depths_aligned [V,T,1,H,W] (RAW depths scaled by s) confs [V,T,1,H,W] (same constant 100 as RAW) intr_aligned [V,T,3,3] (equal to RAW intrinsics; alignment is Sim3 in world) extr_aligned [V,T,3,4] (VGGT w2c aligned to GT) """ # 1) Get RAW results (runs VGGT if needed) depths_raw, confs_raw, intr_raw, extr_raw = _ensure_vggt_raw_cache_and_load( rgbs=rgbs, seq_name=seq_name, dataset_root=dataset_root, vggt_cache_subdir=vggt_cache_subdir, skip_if_cached=skip_if_cached, model_id=model_id, ) # 2) Aligned cache file paths cache_root = os.path.join(dataset_root, vggt_cache_subdir, seq_name) f_depths_aln = os.path.join(cache_root, "vggt_depths_aligned.npy") f_intr_aln = os.path.join(cache_root, "vggt_intrinsics_aligned.npy") f_extr_aln = os.path.join(cache_root, "vggt_extrinsics_aligned.npy") # 3) If aligned already cached, return it if skip_if_cached and all(os.path.isfile(p) for p in [f_depths_aln, f_intr_aln, f_extr_aln]): depths_aln = torch.from_numpy(np.load(f_depths_aln)).float().unsqueeze(2) intr_aln = torch.from_numpy(np.load(f_intr_aln)).float() extr_aln = torch.from_numpy(np.load(f_extr_aln)).float() return depths_aln, confs_raw, intr_aln, extr_aln # 4) Compute alignment depths_raw_ = depths_raw.squeeze(2) # [V,T,H,W] V, T, H, W = depths_raw_.shape assert extrs_gt.shape[:2] == (V, T), "GT extrinsics must be [V,T,3,4]" depths_aln = depths_raw_.clone() intr_aln = intr_raw.clone() # intrinsics unchanged by world Sim3 extr_aln = extr_raw.clone() def _camera_center_from_affine_extr(extr): extr_sq = np.eye(4, dtype=np.float32)[None].repeat(extr.shape[0], 0) extr_sq[:, :3, :4] = extr extr_sq_inv = np.linalg.inv(extr_sq) return extr_sq_inv[:, :3, 3] for t in range(T): gt_w2c = extrs_gt[:, t].cpu().numpy() pred_w2c = extr_raw[:, t].cpu().numpy() s, R_align, t_align = align_umeyama( _camera_center_from_affine_extr(gt_w2c), _camera_center_from_affine_extr(pred_w2c), ) pred_w2c_aligned = apply_sim3_to_extrinsics(pred_w2c, s, R_align, t_align) extr_aln[:, t] = torch.from_numpy(np.array(pred_w2c_aligned)).float() # 5) Save aligned cache np.save(f_depths_aln, depths_aln.numpy()) np.save(f_intr_aln, intr_aln.numpy()) np.save(f_extr_aln, extr_aln.numpy()) return depths_aln.unsqueeze(2), confs_raw, intr_aln, extr_aln ================================================ FILE: mvtracker/datasets/kubric_multiview_dataset.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import os import pathlib import re import time import cv2 import kornia import numpy as np import torch import torch.nn.functional as F from PIL import Image from scipy.spatial.transform import Rotation as R from torch.utils.data import get_worker_info from torchvision.transforms import ColorJitter, GaussianBlur from torchvision.transforms import functional as F_torchvision from mvtracker.datasets.utils import Datapoint, read_json, read_tiff, read_png, transform_scene, add_camera_noise, \ aug_depth class KubricMultiViewDataset(torch.utils.data.Dataset): @staticmethod def from_name( dataset_name: str, dataset_root: str, training_args=None, fabric=None, just_return_kwargs: bool = False, subset: str = "test", ): """ Examples of evaluation datasets supported by this factory method: - kubric-multiview-v3 - kubric-multiview-v3-duster0123 - kubric-multiview-v3-duster01234567 - kubric-multiview-v3-duster01234567cleaned - kubric-multiview-v3-duster01234567cleaned-views012 - kubric-multiview-v3-duster01234567cleaned-views012-novelviews7 - kubric-multiview-v3-duster01234567cleaned-views012-novelviews7-overfit-on-training - kubric-multiview-v3-duster01234567cleaned-views012-novelviews7-overfit-on-training-single - kubric-multiview-v3-duster01234567cleaned-views012-novelviews7-overfit-on-training-2dpt-single - kubric-multiview-v3-duster01234567cleaned-views012-novelviews7-overfit-on-training-2dpt-single-cached - kubric-multiview-v3-noise1.23cm Example of a training dataset: - kubric-multiview-v3-training """ # Parse the dataset name, chunk by chunk non_parsed = dataset_name.replace("kubric-multiview-v3", "", 1) if non_parsed.startswith("-noise"): match = re.match(r"-noise([\d.]+)cm", non_parsed) assert match is not None depth_noise_std = float(match.group(1)) depth_noise_std = depth_noise_std / 13 # real-world cm to kubric's metric unit non_parsed = non_parsed.replace(match.group(0), "", 1) else: depth_noise_std = 0.0 if non_parsed.startswith("-duster"): match = re.match(r"-duster(\d+)(cleaned)?", non_parsed) assert match is not None duster_views = list(map(int, match.group(1))) use_duster = True use_duster_cleaned = match.group(2) is not None non_parsed = non_parsed.replace(match.group(0), "", 1) else: use_duster = False use_duster_cleaned = False duster_views = None if non_parsed.startswith("-views"): match = re.match(r"-views(\d+)", non_parsed) assert match is not None views = list(map(int, match.group(1))) if duster_views is not None: assert all(v in duster_views for v in views) non_parsed = non_parsed.replace(match.group(0), "", 1) else: views = duster_views if non_parsed.startswith("-novelviews"): match = re.match(r"-novelviews(\d+)", non_parsed) assert match is not None novel_views = list(map(int, match.group(1))) non_parsed = non_parsed.replace(match.group(0), "", 1) else: novel_views = None if non_parsed.startswith("-training"): training = True non_parsed = non_parsed.replace("-training", "", 1) assert training_args is not None assert fabric is not None else: training = False if non_parsed.startswith("-overfit-on-training"): overfit_on_train = True non_parsed = non_parsed.replace("-overfit-on-training", "", 1) assert not training, "Either ...-training or ...-overfit-on-training[-single][-2dpt]" assert training_args is not None expected_training_dset_name = (dataset_name.replace("-overfit-on-training", "-training") .replace("-single", "").replace("2dpt", "")) assert training_args.datasets.train.name == expected_training_dset_name, \ f"{expected_training_dset_name} != {training_args.datasets.train.name}" else: overfit_on_train = False if non_parsed.startswith("-single"): assert not training, "The single-point evaluation options is not relevant for a training dataset" single_point = True non_parsed = non_parsed.replace("-single", "", 1) else: single_point = False if non_parsed.startswith("-2dpt"): eval_2dpt = True non_parsed = non_parsed.replace("-2dpt", "", 1) else: eval_2dpt = False if non_parsed.startswith("-cached"): use_cached_tracks = True non_parsed = non_parsed.replace("-cached", "", 1) else: use_cached_tracks = False assert non_parsed == "", f"Unparsed part of the dataset name: {non_parsed}" kubric_kwargs = { "data_root": os.path.join(dataset_root, "kubric-multiview", subset), "seq_len": 24, "traj_per_sample": 512, "seed": 72, "sample_vis_1st_frame": False, "tune_per_scene": False, "max_videos": 30, "use_duster_depths": use_duster, "duster_views": duster_views, "clean_duster_depths": use_duster_cleaned, "views_to_return": views, "novel_views": novel_views, "num_views": -1 if views is not None else 4, "depth_noise_std": depth_noise_std, "ratio_dynamic": 0.5, "ratio_very_dynamic": 0.25, "use_cached_tracks": use_cached_tracks, } if training: kubric_kwargs["virtual_dataset_size"] = fabric.world_size * (training_args.trainer.num_steps + 1000) if training or overfit_on_train: kubric_kwargs["data_root"] = ( os.path.join(training_args.datasets.root, "kubric-multiview", "train") if not training_args.modes.debug else os.path.join(training_args.datasets.root, "kubric-multiview", "validation") ) kubric_kwargs["seq_len"] = training_args.datasets.train.sequence_len kubric_kwargs["traj_per_sample"] = training_args.datasets.train.traj_per_sample kubric_kwargs["max_depth"] = training_args.datasets.train.kubric_max_depth kubric_kwargs["tune_per_scene"] = training_args.modes.tune_per_scene if training: kubric_kwargs["max_videos"] = training_args.datasets.train.max_videos else: kubric_kwargs["max_videos"] = 30 kubric_kwargs["augmentation_probability"] = training_args.augmentations.probability kubric_kwargs["enable_rgb_augs"] = training_args.augmentations.rgb kubric_kwargs["enable_depth_augs"] = training_args.augmentations.depth kubric_kwargs["enable_cropping_augs"] = training_args.augmentations.cropping kubric_kwargs["aug_crop_size"] = training_args.augmentations.cropping_size kubric_kwargs["enable_variable_trajpersample_augs"] = training_args.augmentations.variable_trajpersample kubric_kwargs["enable_scene_transform_augs"] = training_args.augmentations.scene_transform kubric_kwargs["enable_camera_params_noise_augs"] = training_args.augmentations.camera_params_noise kubric_kwargs["enable_variable_depth_type_augs"] = training_args.augmentations.variable_depth_type kubric_kwargs["enable_variable_num_views_augs"] = training_args.augmentations.variable_num_views kubric_kwargs["normalize_scene_following_vggt"] = training_args.augmentations.normalize_scene_following_vggt kubric_kwargs["enable_variable_vggt_crop_size_augs"] = training_args.augmentations.variable_vggt_crop_size kubric_kwargs["keep_principal_point_centered"] = training_args.augmentations.keep_principal_point_centered if training_args.modes.pretrain_only: kubric_kwargs["ratio_dynamic"] = 0.0 kubric_kwargs["ratio_very_dynamic"] = 0.0 if training_args.augmentations.variable_num_views: kubric_kwargs["num_views"] = None kubric_kwargs["views_to_return"] = None kubric_kwargs["duster_views"] = None kubric_kwargs["supported_duster_views_sets"] = [ [0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7], ] if just_return_kwargs: return kubric_kwargs return KubricMultiViewDataset(**kubric_kwargs) def __init__( self, data_root, views_to_return=None, novel_views=None, use_duster_depths=False, clean_duster_depths=False, duster_views=None, supported_duster_views_sets=None, seq_len=24, num_views=4, traj_per_sample=768, max_depth=1000, sample_vis_1st_frame=False, ratio_dynamic=0.5, ratio_very_dynamic=0.25, depth_noise_std=0.0, augmentation_probability=0.0, enable_rgb_augs=False, enable_depth_augs=False, enable_cropping_augs=False, aug_crop_size=(384, 512), enable_variable_trajpersample_augs=False, enable_scene_transform_augs=False, enable_camera_params_noise_augs=False, enable_variable_depth_type_augs=False, enable_variable_num_views_augs=False, normalize_scene_following_vggt=False, enable_variable_vggt_crop_size_augs=False, keep_principal_point_centered=False, static_cropping=False, seed=None, tune_per_scene=False, max_videos=None, virtual_dataset_size=None, max_tracks_to_preload=18000, perform_sanity_checks=False, use_cached_tracks=False, ): super(KubricMultiViewDataset, self).__init__() self.data_root = data_root self.views_to_return = views_to_return self.novel_views = novel_views self.use_duster_depths = use_duster_depths self.clean_duster_depths = clean_duster_depths self.duster_views = duster_views self.supported_duster_views_sets = supported_duster_views_sets if self.use_duster_depths: assert self.duster_views is not None, "When using Duster depths, duster_views must be set." if self.supported_duster_views_sets is None: self.supported_duster_views_sets = [self.duster_views] self.seq_len = seq_len self.num_views = num_views self.traj_per_sample = traj_per_sample self.sample_vis_1st_frame = sample_vis_1st_frame self.ratio_dynamic = ratio_dynamic self.ratio_very_dynamic = ratio_very_dynamic self.seed = seed self.add_index_to_seed = not tune_per_scene self.perform_sanity_checks = perform_sanity_checks self.use_cached_tracks = use_cached_tracks self.cache_name = self._cache_key() self.max_tracks_to_preload = max_tracks_to_preload if self.traj_per_sample is not None and self.max_tracks_to_preload is not None: assert self.traj_per_sample <= self.max_tracks_to_preload, "We need to preload more tracks than we sample." self.depth_noise_std = depth_noise_std # Augmentation settings self.augmentation_probability = augmentation_probability if any([enable_rgb_augs, enable_depth_augs, enable_variable_trajpersample_augs, enable_scene_transform_augs, enable_camera_params_noise_augs, enable_variable_num_views_augs, enable_variable_depth_type_augs]): assert self.augmentation_probability > 0, "Augmentations are enabled, but augmentation probability is 0%." if self.augmentation_probability > 0: assert not self.use_cached_tracks, "caching tracks not supported with augs" self.enable_rgb_augs = enable_rgb_augs self.enable_depth_augs = enable_depth_augs self.enable_cropping_augs = enable_cropping_augs self.enable_variable_trajpersample_augs = enable_variable_trajpersample_augs self.enable_scene_transform_augs = enable_scene_transform_augs self.enable_camera_params_noise_augs = enable_camera_params_noise_augs self.enable_variable_num_views_augs = enable_variable_num_views_augs self.enable_variable_depth_type_augs = enable_variable_depth_type_augs self.enable_variable_depth_type_augs__depth_type_probability = { "gt": 0.70, "duster": 0.20, "duster_cleaned": 0.10, } # TODO: self.enable_seqlen_augs = enable_seqlen_augs if self.enable_variable_depth_type_augs: assert not self.use_duster_depths, "Cannot force depth type when using variable depth type augs." assert not self.clean_duster_depths, "Cannot force depth type when using variable depth type augs." self.enable_variable_num_views_augs__n_views_probability = { # v2 1: 0.20, 2: 0.10, 3: 0.10, 4: 0.25, 5: 0.10, 6: 0.25, # # v1 # 1: 0.20, # 2: 0.10, # 3: 0.10, # 4: 0.25, # 5: 0.10, # 6: 0.05, # 7: 0.05, # 8: 0.15, } self.enable_variable_num_views_augs__trajpersample_adjustment_factor = { 1: 1.00, 2: 1.00, 3: 1.00, 4: 1.00, 5: 0.40, 6: 0.25, } if self.enable_variable_num_views_augs: assert self.num_views is None, "Cannot use enable_variable_num_views_augs with num_views != None." assert self.views_to_return is None, "Cannot use enable_variable_num_views_augs with views_to_return." # photometric augmentation # TODO: "Override" ColorJitter and GaussianBlur to take in a random state # in forward pass so we can assure reproducibility. This affects # only training as augmentation is disabled during evaluation. self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14) self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0)) self.blur_aug_prob = 0.25 self.color_aug_prob = 0.25 # occlusion augmentation self.eraser_aug_prob = 0.5 self.eraser_bounds = [2, 100] self.eraser_max = 10 # occlusion augmentation self.replace_aug_prob = 0.5 self.replace_bounds = [2, 100] self.replace_max = 10 # spatial augmentations self.crop_size = aug_crop_size self.normalize_scene_following_vggt = normalize_scene_following_vggt self.enable_variable_vggt_crop_size_augs = enable_variable_vggt_crop_size_augs self.keep_principal_point_centered = keep_principal_point_centered self.max_depth = max_depth self.pad_bounds = [0, 45] self.resize_lim = [0.8, 1.2] self.resize_delta = 0.15 self.max_crop_offset = 36 if static_cropping or tune_per_scene: self.pad_bounds = [0, 1] self.resize_lim = [1.0, 1.0] self.resize_delta = 0.0 self.max_crop_offset = 0 if self.keep_principal_point_centered: self.pad_bounds = [0, 45] self.resize_lim = [1.02, 1.25] self.resize_delta = None self.max_crop_offset = None if static_cropping or tune_per_scene: self.pad_bounds = [0, 1] self.resize_lim = [1.04, 1.04] self.seq_names = [ fname for fname in os.listdir(self.data_root) if os.path.isdir(os.path.join(self.data_root, fname)) and not fname.startswith(".") and not fname.startswith("_") ] self.seq_names = sorted(self.seq_names, key=lambda x: int(x)) seq_names_clean = [] for seq_name in self.seq_names: scene_path = os.path.join(self.data_root, seq_name) view_folders = [ d for d in os.listdir(scene_path) if os.path.isdir(os.path.join(scene_path, d)) and d.startswith('view_') ] if len(view_folders) == 0: logging.warning(f"Skipping {scene_path} because it has no views.") continue if self.num_views is not None and len(view_folders) < self.num_views: logging.warning(f"Skipping {scene_path} because it has {len(view_folders)} views (<{self.num_views}).") continue seq_names_clean.append(seq_name) self.seq_names = seq_names_clean if self.supported_duster_views_sets is not None: supported_duster_views_sets_cleaned = [] for s in self.supported_duster_views_sets: duster_views_str = ''.join(str(v) for v in s) if os.path.isdir(os.path.join(self.data_root, self.seq_names[0], f"duster-views-{duster_views_str}")): supported_duster_views_sets_cleaned.append(s) else: logging.warning(f"Skipping duster views set {s} because it does not exist.") self.supported_duster_views_sets = supported_duster_views_sets_cleaned if tune_per_scene: self.seq_names = self.seq_names[3:4] if max_videos is not None: self.seq_names = self.seq_names[:max_videos] logging.info("Using %d videos from %s" % (len(self.seq_names), self.data_root)) self.real_len = len(self.seq_names) if virtual_dataset_size is not None: self.virtual_len = virtual_dataset_size else: self.virtual_len = self.real_len logging.info(f"Real dataset size: {self.real_len}. Virtual dataset size: {self.virtual_len}.") self.getitem_calls = 0 def _cache_key(self): name = f"cachedtracks--seed{self.seed}-dynamic{self.ratio_dynamic}-verydynamic-{self.ratio_very_dynamic}" if self.views_to_return is not None: name += f"-views{'_'.join(map(str, self.views_to_return))}" if self.traj_per_sample is not None: name += f"-n{self.traj_per_sample}" if self.num_views is not None: name += f"-numviews{self.num_views}" if self.seq_len is not None: name += f"-t{self.seq_len}" if self.sample_vis_1st_frame: name += f"-sample_vis_1st_frame" return name + "--v1" # bump this if you change the selection policy def __len__(self): return self.virtual_len def __getitem__(self, index): index = index % self.real_len sample, gotit = self._getitem_helper(index) if not gotit: logging.warning("warning: sampling failed") # fake sample, so we can still collate num_views = self.num_views if self.num_views is not None else 4 h, w = 384, 512 traj_per_sample = self.traj_per_sample if self.traj_per_sample is not None else 768 sample = Datapoint( video=torch.zeros((num_views, self.seq_len, 3, h, w)), videodepth=torch.zeros((num_views, self.seq_len, 1, h, w)), segmentation=torch.zeros((num_views, self.seq_len, 1, h, w)), trajectory=torch.zeros((self.seq_len, traj_per_sample, 2)), visibility=torch.zeros((self.seq_len, traj_per_sample)), valid=torch.zeros((self.seq_len, traj_per_sample)), ) return sample, gotit def _getitem_helper(self, index): start_time_1 = time.time() gotit = True # Take a new seed from torch or use self.seed if set # The rest of the code will use generators initialized with this seed if self.seed is None: seed = torch.randint(0, 2 ** 32 - 1, (1,)).item() else: seed = self.seed if self.add_index_to_seed: seed += index rnd_torch = torch.Generator().manual_seed(seed) rnd_np = np.random.RandomState(seed=seed) # Load the data datapoint = KubricMultiViewDataset.getitem_raw_datapoint(os.path.join(self.data_root, self.seq_names[index])) traj3d_world = datapoint["tracks_3d"].numpy() tracks_segmentation_ids = datapoint["tracks_segmentation_ids"].numpy() tracked_objects = datapoint["tracked_objects"] camera_positions = datapoint["camera_positions"].numpy() lookat_positions = datapoint["lookat_positions"].numpy() views = datapoint["views"] # Take a random depth type, if enabled if self.enable_variable_depth_type_augs: assert self.use_duster_depths is False, "Cannot force depth type when using variable depth type augs." assert self.clean_duster_depths is False, "Cannot force depth type when using variable depth type augs." depth_type = rnd_np.choice( a=list(self.enable_variable_depth_type_augs__depth_type_probability.keys()), size=1, p=list(self.enable_variable_depth_type_augs__depth_type_probability.values()), )[0] use_duster_depths, clean_duster_depths = { "gt": (False, False), "duster": (True, False), "duster_cleaned": (True, True), }[depth_type] else: use_duster_depths = self.use_duster_depths clean_duster_depths = self.clean_duster_depths # Take a random number of views, if enabled all_views = sorted(list(range(len(views)))) if self.enable_variable_num_views_augs: assert self.num_views is None, "Cannot use enable_variable_num_views_augs with num_views != None." assert self.views_to_return is None, "Cannot use enable_variable_num_views_augs with views_to_return." num_views = rnd_np.choice( a=list(self.enable_variable_num_views_augs__n_views_probability.keys()), size=1, p=list(self.enable_variable_num_views_augs__n_views_probability.values()), )[0] if use_duster_depths: num_views = min(num_views, max([len(s) for s in self.supported_duster_views_sets])) # Take only those that have the closest number of views that is greater or equal to num_views closest_num_views_in_supported_duster_views_set = min([ len(vs) for vs in self.supported_duster_views_sets if len(vs) >= num_views ]) supported_duster_views_sets = [ vs for vs in self.supported_duster_views_sets if len(vs) == closest_num_views_in_supported_duster_views_set ] duster_views = supported_duster_views_sets[rnd_np.randint(len(supported_duster_views_sets))] views_to_return = rnd_np.choice(duster_views, num_views, replace=False).tolist() else: views_to_return = rnd_np.choice(all_views, num_views, replace=False).tolist() duster_views = views_to_return else: num_views = self.num_views if self.views_to_return is not None: assert num_views == -1, "Cannot use views_to_return with num_views != -1." views_to_return = self.views_to_return elif use_duster_depths: if self.duster_views is not None: duster_views = self.duster_views else: # Take only those that have the closest number of views that is greater or equal to num_views closest_num_views_in_supported_duster_views_set = min([ len(vs) for vs in self.supported_duster_views_sets if len(vs) >= num_views ]) supported_duster_views_sets = [ vs for vs in self.supported_duster_views_sets if len(vs) == closest_num_views_in_supported_duster_views_set ] duster_views = supported_duster_views_sets[rnd_np.randint(len(supported_duster_views_sets))] views_to_return = duster_views else: if num_views == -1: # Take all views views_to_return = all_views elif num_views is None: # Randomly sample a number of views n = rnd_np.randint(min(3, len(views)), len(views) + 1) views_to_return = rnd_np.choice(all_views, n, replace=False).tolist() else: # Take a fixed number of views assert num_views > 0, "Fixed number of views must be positive." assert num_views <= len(views), f"Not enough views available (idx={index})." views_to_return = rnd_np.choice(all_views, num_views, replace=False).tolist() if self.duster_views is not None: duster_views = self.duster_views else: duster_views = views_to_return # Extract only the data we need rgbs = np.stack([views[v]["rgba"][..., :3].numpy() for v in views_to_return]) depths = np.stack([views[v]["depth"].numpy() for v in views_to_return]) # segs = np.stack([views[v]["segmentation"].numpy() for v in views_to_return]) segs = np.ones(((rgbs.shape[0], rgbs.shape[1], rgbs.shape[2], rgbs.shape[3], 1)), dtype=np.float32) intrs = np.stack([views[v]["intrinsics"].numpy() for v in views_to_return]) intrs = intrs[:, None, :, :].repeat(rgbs.shape[1], axis=1) extrs = np.stack([views[v]["extrinsics"].numpy() for v in views_to_return]) traj2d = np.stack([views[v]["tracks_2d"].numpy() for v in views_to_return]) visibility = ~np.stack([views[v]["occlusion"].numpy() for v in views_to_return]) novel_rgbs = None novel_intrs = None novel_extrs = None if self.novel_views is not None: novel_rgbs = np.stack([views[v]["rgba"][..., :3].numpy() for v in self.novel_views]) novel_intrs = np.stack([views[v]["intrinsics"].numpy() for v in self.novel_views]) novel_intrs = novel_intrs[:, None, :, :].repeat(rgbs.shape[1], axis=1) novel_extrs = np.stack([views[v]["extrinsics"].numpy() for v in self.novel_views]) # Load Duster's features and estimated depths if they exist duster_views_str = ''.join(str(v) for v in duster_views) duster_root = pathlib.Path(self.data_root) / self.seq_names[index] / f'duster-views-{duster_views_str}' num_views, n_frames, h, w, _ = rgbs.shape feats = None feat_dim = None feat_stride = None duster_outputs_exist = duster_root.exists() and ( duster_root / f"3d_model__{n_frames - 1:05d}__scene.npz").exists() if use_duster_depths: assert duster_outputs_exist, "use_duster_depths --> duster_output_exist" if duster_outputs_exist: duster_depths = [] duster_feats = [] for frame_idx in range(n_frames): scene = np.load(duster_root / f"3d_model__{frame_idx:05d}__scene.npz") duster_depth = torch.from_numpy(scene["depths"]) duster_conf = torch.from_numpy(scene["confs"]) duster_msk = torch.from_numpy(scene["cleaned_mask"]) duster_feat = torch.from_numpy(scene["feats"]) if clean_duster_depths: ## Filter based on the confidence # conf_threshold = max(0.00001, min(0.1, torch.quantile(duster_conf.flatten(), 0.3).item())) # duster_depth = duster_depth * (duster_conf > conf_threshold) # Filter based on the mask duster_depth = duster_depth * duster_msk duster_depth = F.interpolate(duster_depth[:, None], (depths.shape[2], depths.shape[3]), mode='nearest') duster_depths.append(duster_depth[:, 0, :, :, None]) duster_feats.append(duster_feat) duster_depths = torch.stack(duster_depths, dim=1).numpy() feats = torch.stack(duster_feats, dim=1).numpy() # Extract the correct views assert duster_depths.shape[0] == feats.shape[0] == len(duster_views) duster_depths = duster_depths[[duster_views.index(v) for v in views_to_return]] feats = feats[[duster_views.index(v) for v in views_to_return]] # Reshape the features assert feats.ndim == 4 assert feats.shape[0] == num_views assert feats.shape[1] == n_frames feat_stride = np.round(np.sqrt(h * w / feats.shape[2])).astype(int) feat_dim = feats.shape[3] feats = feats.reshape(num_views, n_frames, h // feat_stride, w // feat_stride, feat_dim) # Replace the depths with the Duster depths, if configured so if use_duster_depths: depths = duster_depths start_time_2 = time.time() # Strategically select dynamic points to track visible_at_t_and_t_plus_1 = (visibility[:, :-1] & visibility[:, 1:]).any(0) movement = np.linalg.norm(traj3d_world[1:] - traj3d_world[:-1], axis=-1) movement[~visible_at_t_and_t_plus_1] = 0 movement = movement.sum(axis=0) assert np.isfinite(movement).all(), "Movement contains NaN or Inf values." static_threshold = 0.01 # < 1 cm dynamic_threshold = 0.1 # > 10 cm very_dynamic_threshold = 2.0 # > 2 m static_points = movement < static_threshold # 1 cm dynamic_points = movement > dynamic_threshold # 10 cm very_dynamic_points = movement > very_dynamic_threshold # 2 m if self.perform_sanity_checks: logging.info(f"Movement stats: " f"static: {static_points.sum()} ({static_points.mean() * 100:.2f}), " f"dynamic: {dynamic_points.sum()} ({dynamic_points.mean() * 100:.2f}), " f"very dynamic: {very_dynamic_points.sum()} ({very_dynamic_points.mean() * 100:.2f})" f"other: {(~static_points & ~dynamic_points & ~very_dynamic_points).sum()}") # Sample the points according to the desired ratios if possible max_tracks_to_preload = traj3d_world.shape[1] max_tracks_to_preload = min([ max_tracks_to_preload, int(dynamic_points.sum() / self.ratio_dynamic) if self.ratio_dynamic > 0 else max_tracks_to_preload, int(very_dynamic_points.sum() // self.ratio_very_dynamic) if self.ratio_very_dynamic > 0 else max_tracks_to_preload, int(static_points.sum() / (1 - self.ratio_dynamic - self.ratio_very_dynamic)), ]) if self.max_tracks_to_preload is not None: max_tracks_to_preload = min(max_tracks_to_preload, self.max_tracks_to_preload) n_dynamic = min(int(max_tracks_to_preload * self.ratio_dynamic), dynamic_points.sum()) n_very_dynamic = min(int(max_tracks_to_preload * self.ratio_very_dynamic), very_dynamic_points.sum()) n_static = max_tracks_to_preload - n_dynamic - n_very_dynamic dynamic_indices = rnd_np.choice(np.where(dynamic_points)[0], n_dynamic, replace=False) very_dynamic_indices = rnd_np.choice(np.where(very_dynamic_points)[0], n_very_dynamic, replace=False) static_indices = rnd_np.choice(np.where(static_points)[0], n_static, replace=False) selected_indices = np.concatenate([dynamic_indices, very_dynamic_indices, static_indices]) rnd_np.shuffle(selected_indices) traj3d_world = traj3d_world[:, selected_indices] traj2d = traj2d[:, :, selected_indices] visibility = visibility[:, :, selected_indices] tracks_segmentation_ids = tracks_segmentation_ids[selected_indices] if traj3d_world.shape[1] > max_tracks_to_preload: traj3d_world = traj3d_world[:, :max_tracks_to_preload] traj2d = traj2d[:, :, :max_tracks_to_preload] visibility = visibility[:, :, :max_tracks_to_preload] n_tracks = traj3d_world.shape[1] num_views, n_frames, h, w, _ = rgbs.shape assert n_frames >= self.seq_len assert rgbs.shape == (num_views, n_frames, h, w, 3) assert depths.shape == (num_views, n_frames, h, w, 1) assert segs.shape == (num_views, n_frames, h, w, 1) assert feats is None or feats.shape == (num_views, n_frames, h // feat_stride, w // feat_stride, feat_dim) assert intrs.shape == (num_views, n_frames, 3, 3) assert extrs.shape == (num_views, n_frames, 3, 4) assert traj2d.shape == (num_views, n_frames, n_tracks, 2) assert visibility.shape == (num_views, n_frames, n_tracks) assert traj3d_world.shape == (n_frames, n_tracks, 3) if novel_rgbs is not None: assert novel_rgbs.shape == (len(self.novel_views), n_frames, h, w, 3) assert novel_intrs.shape == (len(self.novel_views), n_frames, 3, 3) assert novel_extrs.shape == (len(self.novel_views), n_frames, 3, 4) if ((depths < 0.01) & (depths != 0)).mean() > 0.5: raise ValueError("Depth map might be invalid? Values that are too small will be ignored by SpaTracker, " "but found that more than half of non-zero depths are below 0.01 in the loaded depths.") # Make sure our intrinsics and extrinsics work correctly point_3d_world = traj3d_world point_4d_world_homo = np.concatenate([point_3d_world, np.ones_like(point_3d_world[..., :1])], axis=-1) point_3d_camera = np.einsum('ABij,BCj->ABCi', extrs, point_4d_world_homo) if self.perform_sanity_checks: point_2d_pixel_homo = np.einsum('ABij,ABCj->ABCi', intrs, point_3d_camera) point_2d_pixel = point_2d_pixel_homo[..., :2] / point_2d_pixel_homo[..., 2:] point_2d_pixel_gt = traj2d assert np.allclose(point_2d_pixel[0, :, 0, :], point_2d_pixel_gt[0, :, 0, :], atol=1e-3), f"Proj. failed" assert np.allclose(point_2d_pixel, point_2d_pixel_gt, atol=1e-3), f"Point projection failed" # Now save the z value in traj3d_camera as usual, just if needed traj3d_camera = point_3d_camera assert traj3d_camera.shape == (num_views, n_frames, n_tracks, 3) # Also sanity check that pix2cam is working correctly with the intrinsics if self.perform_sanity_checks: from mvtracker.models.core.spatracker.blocks import pix2cam xyz = np.concatenate([traj2d, traj3d_camera[..., 2:]], axis=-1) pix2cam_xyz = torch.from_numpy(xyz).double() pix2cam_intr = torch.from_numpy(intrs).double() traj_3d_repro = pix2cam(pix2cam_xyz, pix2cam_intr).numpy() assert np.allclose(traj3d_camera, traj_3d_repro, atol=0.1) # If the video is too long, randomly crop self.seq_len frames if self.seq_len < n_frames: start_ind = rnd_np.choice(n_frames - self.seq_len, 1)[0] rgbs = rgbs[:, start_ind: start_ind + self.seq_len] depths = depths[:, start_ind: start_ind + self.seq_len] segs = segs[:, start_ind: start_ind + self.seq_len] if feats is not None: feats = feats[:, start_ind: start_ind + self.seq_len] intrs = intrs[:, start_ind: start_ind + self.seq_len] extrs = extrs[:, start_ind: start_ind + self.seq_len] traj2d = traj2d[:, start_ind: start_ind + self.seq_len] visibility = visibility[:, start_ind: start_ind + self.seq_len] traj3d_camera = traj3d_camera[:, start_ind: start_ind + self.seq_len] traj3d_world = traj3d_world[start_ind: start_ind + self.seq_len] n_frames = self.seq_len # Add the z value to the traj2d traj2d_w_z = np.concatenate((traj2d[..., :], traj3d_camera[..., 2:]), axis=-1) start_time_3 = time.time() augment_this_datapoint = False if self.augmentation_probability > 0: augment_this_datapoint = rnd_np.rand() <= self.augmentation_probability if augment_this_datapoint and self.enable_rgb_augs: rgbs, visibility = self._add_photometric_augs(rgbs, traj2d_w_z, visibility, rnd_np) crop_size = self.crop_size if augment_this_datapoint and self.enable_variable_vggt_crop_size_augs: sizes = list(range(168, 518 + 14, 14)) # VIT-friendly sizes weights = np.array(sizes) ** 2 # Quadratic bias toward larger sizes probs = weights / weights.sum() shorter_side = rnd_np.choice(a=sizes, size=1, p=probs)[0] longer_side = max(crop_size) crop_size = (shorter_side, longer_side) if self.enable_cropping_augs and not self.keep_principal_point_centered: rgbs, depths, intrs, traj2d_w_z, visibility = self._add_cropping_augs( crop_size=crop_size, rgbs=rgbs, depths=depths, intrs=intrs, trajs=traj2d_w_z, visibles=visibility, ) h, w = rgbs.shape[-3:-1] if self.enable_cropping_augs and self.keep_principal_point_centered: rgbs, depths, intrs, traj2d_w_z, visibility = self._add_cropping_augs_with_pp_at_center( crop_size=crop_size, rgbs=rgbs, depths=depths, intrs=intrs, trajs=traj2d_w_z, visibles=visibility, ) h, w = rgbs.shape[-3:-1] depths[depths > self.max_depth] = 0.0 if augment_this_datapoint and self.enable_depth_augs: invalid_depth_mask = depths <= 0.0 depths = aug_depth( torch.from_numpy(depths).reshape(num_views * n_frames, 1, h, w), grid=(16, 16), scale=(0.99, 1.01), shift=(-0.001, 0.001), gn_kernel=(5, 5), gn_sigma=(2, 2), generator=rnd_torch, ).reshape(num_views, n_frames, h, w, 1).numpy() depths, visibility = self._rescale_and_erase_depth_patches(depths, traj2d_w_z, visibility, rnd_np) depths[invalid_depth_mask] = 0.0 # Restore invalid depths if self.depth_noise_std > 0.0: invalid_depth_mask = depths <= 0.0 noise = np.random.normal(loc=0.0, scale=self.depth_noise_std, size=depths.shape) depths = depths + noise.astype(depths.dtype) depths = np.clip(depths, 0.0, self.max_depth) depths[invalid_depth_mask] = 0.0 # Restore invalid depths rgbs = torch.from_numpy(rgbs).permute(0, 1, 4, 2, 3).float() depths = torch.from_numpy(depths).permute(0, 1, 4, 2, 3).float() segs = torch.from_numpy(segs).permute(0, 1, 4, 2, 3).float() feats = torch.from_numpy(feats).permute(0, 1, 4, 2, 3).float() if feats is not None else None intrs = torch.from_numpy(intrs).float() extrs = torch.from_numpy(extrs).float() visibility = torch.from_numpy(visibility) traj2d = torch.from_numpy(traj2d) traj2d_w_z = torch.from_numpy(traj2d_w_z) traj3d_camera = torch.from_numpy(traj3d_camera) traj3d_world = torch.from_numpy(traj3d_world) if novel_rgbs is not None: novel_rgbs = torch.from_numpy(novel_rgbs).permute(0, 1, 4, 2, 3).float() novel_intrs = torch.from_numpy(novel_intrs).float() novel_extrs = torch.from_numpy(novel_extrs).float() # Track selection cache_root = os.path.join(self.data_root, self.seq_names[index], "cache") os.makedirs(cache_root, exist_ok=True) cache_file = os.path.join(cache_root, f"{self.cache_name}.npz") # Check if we can use cached tracks use_cache = bool(self.use_cached_tracks) and os.path.isfile(cache_file) if use_cache: cache = np.load(cache_file) visible_inds_sampled = torch.from_numpy(cache["track_indices"]) traj2d_w_z = torch.from_numpy(cache["traj2d_w_z"]) traj3d_world = torch.from_numpy(cache["traj3d_world"]) visibility = torch.from_numpy(cache["visibility"]) valids = torch.from_numpy(cache["valids"]) query_points = torch.from_numpy(cache["query_points"]) # Otherwise, sample the tracks and create query points else: # Sample the points to track visibile_pts_first_frame_inds = (visibility.any(0)[0]).nonzero(as_tuple=False)[:, 0] if self.sample_vis_1st_frame: visibile_pts_inds = visibile_pts_first_frame_inds else: visibile_pts_mid_frame_inds = (visibility.any(0)[self.seq_len // 2]).nonzero(as_tuple=False)[:, 0] visibile_pts_inds = torch.cat((visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0) visibile_pts_inds = torch.unique(visibile_pts_inds) visible_for_at_least_two_frames = (visibility.any(0).sum(0) >= 2).nonzero(as_tuple=False)[:, 0] visibile_pts_inds = visibile_pts_inds[torch.isin(visibile_pts_inds, visible_for_at_least_two_frames)] point_inds = torch.randperm(len(visibile_pts_inds), generator=rnd_torch) traj_per_sample = self.traj_per_sample if self.traj_per_sample is not None else len(point_inds) if self.enable_variable_num_views_augs: adj_factor = self.enable_variable_num_views_augs__trajpersample_adjustment_factor.get(num_views, 1.0) traj_per_sample = int(traj_per_sample * adj_factor) if len(point_inds) == 0 or len(point_inds) < traj_per_sample // 4: gotit = False return None, gotit if augment_this_datapoint and self.enable_variable_trajpersample_augs: if index % 20 == 0: traj_per_sample = traj_per_sample // 8 elif index % 21 == 0: pass # keep the same number of trajectories else: low = max(1, traj_per_sample // 4) high = min(len(point_inds), traj_per_sample) + 1 traj_per_sample = torch.randint(low=low, high=high, size=(1,), generator=rnd_torch).item() else: traj_per_sample = min(len(point_inds), traj_per_sample) point_inds = point_inds[:traj_per_sample] logging.info( f"[i={index:04d};seq={self.seq_names[index]};seed={seed}]" f"Selected {len(point_inds)}/{len(visibile_pts_inds)} tracks. " f"{num_views=}. " f"{point_inds[0]=} max_depth={self.max_depth}." ) visible_inds_sampled = visibile_pts_inds[point_inds] n_tracks = len(visible_inds_sampled) traj2d = traj2d[:, :, visible_inds_sampled].float() traj2d_w_z = traj2d_w_z[:, :, visible_inds_sampled].float() traj3d_camera = traj3d_camera[:, :, visible_inds_sampled].float() traj3d_world = traj3d_world[:, visible_inds_sampled].float() visibility = visibility[:, :, visible_inds_sampled] valids = torch.ones((n_frames, n_tracks)) # Create the query points gt_visibilities_any_view = visibility.any(dim=0) assert (gt_visibilities_any_view.sum(dim=0) >= 2).all(), "All points should be visible in least two frames." last_visible_index = (torch.arange(n_frames).unsqueeze(-1) * gt_visibilities_any_view).max(0).values assert gt_visibilities_any_view[last_visible_index[None, :], torch.arange(n_tracks)].all() gt_visibilities_any_view[last_visible_index[None, :], torch.arange(n_tracks)] = False assert (gt_visibilities_any_view.sum(dim=0) >= 1).all() if self.sample_vis_1st_frame: n_non_first_point_appearance_queries = 0 n_first_point_appearance_queries = n_tracks else: n_non_first_point_appearance_queries = n_tracks // 4 n_first_point_appearance_queries = n_tracks - n_non_first_point_appearance_queries first_point_appearances = torch.argmax( gt_visibilities_any_view[..., -n_first_point_appearance_queries:].float(), dim=0) non_first_point_appearances = first_point_appearances.new_zeros((n_non_first_point_appearance_queries,)) for track_idx in range(n_tracks)[:n_non_first_point_appearance_queries]: # Randomly take a timestep where the point is visible non_zero_timesteps = torch.nonzero(gt_visibilities_any_view[:, track_idx] == 1) random_timestep = non_zero_timesteps[rnd_np.randint(len(non_zero_timesteps))].item() non_first_point_appearances[track_idx] = random_timestep query_points_t = torch.cat([non_first_point_appearances, first_point_appearances], dim=0) query_points_xyz_worldspace = traj3d_world[query_points_t, torch.arange(n_tracks)] query_points = torch.cat([query_points_t[:, None], query_points_xyz_worldspace], dim=1) assert gt_visibilities_any_view[query_points_t, torch.arange(n_tracks)].all() # Cache the selected tracks and query points if self.use_cached_tracks: logging.warn(f"Caching tracks for {self.seq_names[index]} at {os.path.abspath(cache_file)}") np.savez_compressed( cache_file, track_indices=visible_inds_sampled.numpy(), traj2d_w_z=traj2d_w_z.numpy(), traj3d_world=traj3d_world.numpy(), visibility=visibility.numpy(), valids=valids.numpy(), query_points=query_points.numpy(), ) # Apply a transform to the world space scale = 1.0 rot = torch.eye(3, dtype=torch.float32) translation = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) if self.enable_scene_transform_augs: rot_x_angle = rnd_np.uniform(-15, 15) rot_y_angle = rnd_np.uniform(-15, 15) rot_z_angle = 0.0 scale = rnd_np.uniform(0.8, 1.5) translate_x = rnd_np.uniform(-2, 2) translate_y = rnd_np.uniform(-2, 2) translate_z = rnd_np.uniform(-2, 2) rot_x = R.from_euler('x', rot_x_angle, degrees=True).as_matrix() rot_y = R.from_euler('y', rot_y_angle, degrees=True).as_matrix() rot_z = R.from_euler('z', rot_z_angle, degrees=True).as_matrix() rot = rot_z @ rot_y @ rot_x T_rot = torch.eye(4) T_rot[:3, :3] = torch.from_numpy(rot) T_scale_and_translate = torch.tensor([ [scale, 0.0, 0.0, translate_x], [0.0, scale, 0.0, translate_y], [0.0, 0.0, scale, translate_z], [0.0, 0.0, 0.0, 1.0], ], dtype=torch.float32) T = T_scale_and_translate @ T_rot if self.normalize_scene_following_vggt: assert not self.enable_scene_transform_augs, "Cannot normalize scene with scene transform augs enabled." extrs_square = torch.eye(4, device=extrs.device)[None, None].repeat(num_views, n_frames, 1, 1) extrs_square[:, :, :3, :] = extrs extrs_inv = torch.inverse(extrs_square) intrs_inv = torch.inverse(intrs) y, x = torch.meshgrid( torch.arange(h, device=extrs.device), torch.arange(w, device=extrs.device), indexing="ij", ) homog = torch.stack([x, y, torch.ones_like(x)], dim=-1).float().reshape(-1, 3) homog = homog[None].expand(num_views, -1, -1) cam_points = torch.einsum("Vij, VNj->VNi", intrs_inv[:, 0], homog) * depths[:, 0].reshape(num_views, -1, 1) cam_points_h = torch.cat([cam_points, torch.ones_like(cam_points[..., :1])], dim=-1) world_points_h = torch.einsum("Vij, VNj->VNi", extrs_inv[:, 0], cam_points_h) world_points_in_first = torch.einsum("ij, VNj->VNi", extrs[0, 0], world_points_h) mask = (depths[:, 0] > 0).reshape(num_views, -1) valid_points = world_points_in_first[mask] avg_dist = valid_points.norm(dim=1).mean() scale = 1.0 / avg_dist depths *= scale traj3d_world *= scale traj3d_camera *= scale traj2d_w_z[..., 2] *= scale extrs[:, :, :3, 3] *= scale T_first_cam_to_origin = torch.eye(4, device=extrs.device) T_first_cam_to_origin[:3, :4] = extrs[0, 0] T = T_first_cam_to_origin ( depths_trans, extrs_trans, query_points_trans, traj3d_world_trans, traj2d_w_z_trans ) = transform_scene(scale, rot, translation, depths, extrs, query_points, traj3d_world, traj2d_w_z) novel_extrs_trans = transform_scene(scale, rot, translation, None, novel_extrs, None, None, None)[1] if self.enable_camera_params_noise_augs: intrs, extrs_trans = add_camera_noise( intrs=intrs.numpy(), extrs=extrs_trans.numpy(), noise_std_intr=0.001, noise_std_extr=0.001, rnd=rnd_np, ) intrs = torch.from_numpy(intrs) extrs_trans = torch.from_numpy(extrs_trans) # Dump non-normalized tracks to disk if self.augmentation_probability == 0.0 and not self.enable_variable_trajpersample_augs and seed is not None: num_views_str = self.num_views if self.num_views is not None else "none" views_str = ''.join(str(v) for v in self.views_to_return) if self.views_to_return is not None else "none" duster_views_str = ''.join(str(v) for v in self.duster_views) if self.duster_views is not None else "none" sample_identifier_str = ( f"seed-{seed:06d}" f"_tracks-{self.traj_per_sample}" f"_use-duster-depths-{self.use_duster_depths}" f"_clean-duster-depths-{self.clean_duster_depths}" f"_num-views-{num_views_str}" f"_views-{views_str}" f"_duster-views-{duster_views_str}" f"_ratio-dynamic-{self.ratio_dynamic}" f"_ratio-very-dynamic-{self.ratio_very_dynamic}" f"_aug-prob-{self.augmentation_probability}" f"_max-tracks-to-preload-{self.max_tracks_to_preload}" ) datapoint_path = os.path.join(self.data_root, self.seq_names[index]) dumped_path = os.path.join(datapoint_path, f"{sample_identifier_str}.npz") # if not os.path.exists(dumped_path): # logging.info(f"Dumping {dumped_path}") # np.savez( # dumped_path, # trajectories=traj3d_world.numpy(), # trajectories_pixelspace=traj2d.numpy(), # per_view_visibilities=visibility.numpy(), # query_points_3d=query_points.numpy(), # extrinsics=extrs.numpy(), # intrinsics=intrs.numpy(), # transform_that_would_have_been_applied=T, # ) datapoint = Datapoint( video=rgbs, videodepth=depths_trans, feats=feats, segmentation=segs, trajectory=traj2d_w_z_trans, trajectory_3d=traj3d_world_trans, visibility=visibility, valid=valids, seq_name=self.seq_names[index], intrs=intrs, extrs=extrs_trans, query_points=None, query_points_3d=query_points_trans, track_upscaling_factor=1 / scale, novel_video=novel_rgbs, novel_intrs=novel_intrs, novel_extrs=novel_extrs_trans, ) # Log timings start_time_4 = time.time() self.getitem_calls += 1 top_duration = start_time_2 - start_time_1 middle_duration = start_time_3 - start_time_2 bottom_duration = start_time_4 - start_time_3 total_duration = start_time_4 - start_time_1 logging.info(f"Loading {index:>06d} took {total_duration:>7.3f}s " f"[top:{top_duration:>7.3f}s, middle:{middle_duration:>7.3f}s, bottom:{bottom_duration:>7.3f}s] " f"Getitem calls: {self.getitem_calls:>6d}. " f"n_views={num_views}, {n_tracks=:>4d}, augmented={int(augment_this_datapoint)} {rgbs.shape=}") min_valid_depth_ratio_threshold = 0.1 valid_depth_ratio = (depths > 0).float().mean() if valid_depth_ratio < min_valid_depth_ratio_threshold: logging.warning(f"Skipping datapoint {index} due to too little valid depth values: " f"{valid_depth_ratio * 100:.1f}% (< {min_valid_depth_ratio_threshold * 100:.1f}%)") return None, False return datapoint, gotit @staticmethod def getitem_raw_datapoint(scene_path, perform_2d_projection_sanity_check=True): # Load global scene data tracks_3d = torch.from_numpy( np.load(os.path.join(scene_path, 'tracks_3d.npz'))['tracks_3d'], ) tracks_segmentation_ids = torch.from_numpy( np.load(os.path.join(scene_path, 'tracks_segmentation_ids.npz'))['tracks_segmentation_ids'], ) tracked_objects = read_json(os.path.join(scene_path, 'tracked_objects.json')) if os.path.exists(os.path.join(scene_path, 'views.npz')): # V2 (lookat fixed to 0) camera_positions = torch.from_numpy(np.load(os.path.join(scene_path, 'views.npz'))['views']) lookat_positions = 0. * camera_positions elif os.path.exists(os.path.join(scene_path, 'cameras.npz')): # V3 (with randomized lookat) camera_positions = torch.from_numpy(np.load(os.path.join(scene_path, 'cameras.npz'))['camera_positions']) lookat_positions = torch.from_numpy(np.load(os.path.join(scene_path, 'cameras.npz'))['lookat_positions']) else: raise ValueError("No camera data found: neither views.npz nor cameras.npz exist.") n_frames = tracks_3d.shape[0] n_tracks = tracks_3d.shape[1] n_views = camera_positions.shape[0] assert tracks_3d.shape == (n_frames, n_tracks, 3) assert tracks_segmentation_ids.shape == (n_tracks,) assert camera_positions.shape == (n_views, 3) assert lookat_positions.shape == (n_views, 3) # Initialize views data views_data = [] view_folders = [ d for d in os.listdir(scene_path) if os.path.isdir(os.path.join(scene_path, d)) and d.startswith('view_') ] view_folders = sorted(view_folders, key=lambda x: int(x.split('_')[-1])) for view_folder in view_folders: view_path = os.path.join(scene_path, view_folder) # Load per-view data view_data = { 'rgba': [], 'depth': [], # 'segmentation': [], } frame_files = sorted(os.listdir(view_path)) for frame_file in frame_files: if frame_file.startswith('rgba_'): view_data['rgba'].append(read_png(os.path.join(view_path, frame_file))) elif frame_file.startswith('depth_'): view_data['depth'].append(read_tiff(os.path.join(view_path, frame_file))) # elif frame_file.startswith('segmentation_'): # view_data['segmentation'].append(read_png(os.path.join(view_path, frame_file))) assert len(view_data['rgba']) == n_frames, f"{len(view_data['rgba'])}!={n_frames}" assert len(view_data['depth']) == n_frames, f"{len(view_data['depth'])}!={n_frames}" # assert len(view_data['segmentation']) == n_frames, f"{len(view_data['segmentation'])}!={n_frames}" # Convert lists to torch tensors for key in view_data: if view_data[key][0].dtype == np.uint16: view_data[key] = [a.astype(np.int32) for a in view_data[key]] view_data[key] = torch.stack([torch.from_numpy(np.array(img)) for img in view_data[key]]) # Load additional per-view data view_data.update({ 'tracks_2d': torch.from_numpy(np.load(os.path.join(view_path, 'tracks_2d.npz'))['tracks_2d']), 'occlusion': torch.from_numpy(np.load(os.path.join(view_path, 'tracks_2d.npz'))['occlusion']), 'data_ranges': "NOT LOADED", # read_json(os.path.join(view_path, 'data_ranges.json')), 'metadata': read_json(os.path.join(view_path, 'metadata.json')), 'events': "NOT LOADED", # read_json(os.path.join(view_path, 'events.json')), 'object_id_to_segmentation_id': read_json(os.path.join(view_path, 'object_id_to_segmentation_id.json')), }) # Extracting the intrinsics view_data['intrinsics'] = torch.tensor(view_data['metadata']['camera']['K'], dtype=torch.float64) assert view_data['intrinsics'].shape == (3, 3) # Extracting the extrinsics positions = torch.tensor(view_data['metadata']['camera']['positions'], dtype=torch.float64) quaternions = torch.tensor(view_data['metadata']['camera']['quaternions'], dtype=torch.float64) rotation_matrices = kornia.geometry.quaternion_to_rotation_matrix(quaternions) assert positions.shape == (n_frames, 3) assert quaternions.shape == (n_frames, 4) assert rotation_matrices.shape == (n_frames, 3, 3) extrinsics_inv = torch.zeros((n_frames, 4, 4), dtype=torch.float64) extrinsics_inv[:, :3, :3] = rotation_matrices extrinsics_inv[:, :3, 3] = positions extrinsics_inv[:, 3, 3] = 1 view_data['extrinsics'] = extrinsics_inv.inverse() assert torch.allclose(view_data['extrinsics'][:, 3, :3], torch.zeros(n_frames, 3, dtype=torch.float64)) assert torch.allclose(view_data['extrinsics'][:, 3, 3], torch.ones(n_frames, dtype=torch.float64)) view_data['extrinsics'] = view_data['extrinsics'][:, :3, :] # Change the intrinsics to the format w, h = view_data["metadata"]["metadata"]["resolution"] view_data['intrinsics'] = np.diag([w, h, 1]) @ view_data['intrinsics'].numpy() @ np.diag([1, -1, -1]) view_data['extrinsics'] = np.diag([1, -1, -1]) @ view_data['extrinsics'].numpy() view_data['intrinsics'] = torch.from_numpy(view_data['intrinsics']) view_data['extrinsics'] = torch.from_numpy(view_data['extrinsics']) # Project one point to the image plane to check if the extrinsics are correct if perform_2d_projection_sanity_check: point_3d_world = tracks_3d[0, 0] point_4d_world_homo = torch.cat([point_3d_world, torch.ones(1)]) point_2d_pixel = view_data['intrinsics'] @ view_data['extrinsics'][0] @ point_4d_world_homo point_2d_pixel = point_2d_pixel[:2] / point_2d_pixel[2] point_2d_pixel_gt = view_data["tracks_2d"][0, 0] assert torch.allclose(point_2d_pixel, point_2d_pixel_gt, atol=1e-3), f"Point projection failed" # The original depth is the euclidean distance from the camera # Compute the depth in z format instead (so the z coordinate in the camera space) view_data['depth'] = KubricMultiViewDataset.depth_from_euclidean_to_z( depth=view_data['depth'], sensor_width=view_data['metadata']['camera']['sensor_width'], focal_length=view_data['metadata']['camera']['focal_length'], ) # Sometimes the Kubric depths contains very high values of 10e9 # We will clip those to 10e3 to avoid problems with inf and nan larger_than_1000 = view_data['depth'] > 1000 if larger_than_1000.any(): logging.info(f"Datapoint {scene_path} has depths larger than 1000: " f"{view_data['depth'][larger_than_1000]}. " f"Replacing those by 0 to denote invalid depth and avoid inf and nan values later.") view_data['depth'][larger_than_1000] = 0 view_data['view_path'] = view_path views_data.append(view_data) datapoint = { "tracks_3d": tracks_3d, "tracks_segmentation_ids": tracks_segmentation_ids, "tracked_objects": tracked_objects, "camera_positions": camera_positions, "lookat_positions": lookat_positions, "views": views_data } return datapoint @staticmethod def depth_from_euclidean_to_z(depth, sensor_width, focal_length): n_frames, h, w, _ = depth.shape sensor_height = sensor_width / w * h pixel_centers_x = (np.arange(-w / 2, w / 2, dtype=np.float32) + 0.5) / w * sensor_width pixel_centers_y = (np.arange(-h / 2, h / 2, dtype=np.float32) + 0.5) / h * sensor_height # Calculate squared distance from the center of the image pixel_centers_x, pixel_centers_y = np.meshgrid(pixel_centers_x, pixel_centers_y, indexing="xy") squared_distance_from_center = np.square(pixel_centers_x) + np.square(pixel_centers_y) # Calculate rescaling factor for each pixel z_to_eucl_rescaling = np.sqrt(1 + squared_distance_from_center / focal_length ** 2) # Apply the rescaling to each depth value z_to_eucl_rescaling = np.expand_dims(z_to_eucl_rescaling, axis=-1) # Add a dimension for broadcasting depth_z = depth / z_to_eucl_rescaling return depth_z def _add_photometric_augs( self, rgbs, trajs, visibles, rndstate, eraser=True, replace=True, ): V, T, H, W, _ = rgbs.shape _, _, N, _ = trajs.shape assert rgbs.dtype == np.uint8 assert rgbs.shape == (V, T, H, W, 3) assert trajs.shape == (V, T, N, 3) assert visibles.shape == (V, T, N) rgbs = rgbs.copy() visibles = visibles.copy() if eraser: # eraser the specific region in the image for v in range(V): rgbs_view = rgbs[v] rgbs_view = [rgb.astype(np.float32) for rgb in rgbs_view] ############ eraser transform (per image after the first) ############ for i in range(1, T): if rndstate.rand() < self.eraser_aug_prob: for _ in range( rndstate.randint(1, self.eraser_max + 1) ): # number of times to occlude xc = rndstate.randint(0, W) yc = rndstate.randint(0, H) dx = rndstate.randint(self.eraser_bounds[0], self.eraser_bounds[1]) dy = rndstate.randint(self.eraser_bounds[0], self.eraser_bounds[1]) x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32) x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32) y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32) y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32) mean_color = np.mean(rgbs_view[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0) rgbs_view[i][y0:y1, x0:x1, :] = mean_color occ_inds = np.logical_and( np.logical_and(trajs[v, i, :, 0] >= x0, trajs[v, i, :, 0] < x1), np.logical_and(trajs[v, i, :, 1] >= y0, trajs[v, i, :, 1] < y1), ) visibles[v, i, occ_inds] = 0 rgbs_view = [rgb.astype(np.uint8) for rgb in rgbs_view] rgbs[v] = np.stack(rgbs_view) if replace: for v in range(V): rgbs_view = rgbs[v] rgbs_view_alt = [ np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_view ] rgbs_view_alt = [ np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_view_alt ] ############ replace transform (per image after the first) ############ rgbs_view = [rgb.astype(np.float32) for rgb in rgbs_view] rgbs_view_alt = [rgb.astype(np.float32) for rgb in rgbs_view_alt] for i in range(1, T): if rndstate.rand() < self.replace_aug_prob: for _ in range( rndstate.randint(1, self.replace_max + 1) ): # number of times to occlude xc = rndstate.randint(0, W) yc = rndstate.randint(0, H) dx = rndstate.randint(self.replace_bounds[0], self.replace_bounds[1]) dy = rndstate.randint(self.replace_bounds[0], self.replace_bounds[1]) x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32) x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32) y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32) y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32) wid = x1 - x0 hei = y1 - y0 y00 = rndstate.randint(0, H - hei) x00 = rndstate.randint(0, W - wid) fr = rndstate.randint(0, T) rep = rgbs_view_alt[fr][y00: y00 + hei, x00: x00 + wid, :] rgbs_view[i][y0:y1, x0:x1, :] = rep occ_inds = np.logical_and( np.logical_and(trajs[v, i, :, 0] >= x0, trajs[v, i, :, 0] < x1), np.logical_and(trajs[v, i, :, 1] >= y0, trajs[v, i, :, 1] < y1), ) visibles[v, i, occ_inds] = 0 rgbs_view = [rgb.astype(np.uint8) for rgb in rgbs_view] rgbs[v] = np.stack(rgbs_view) ############ photometric augmentation ############ if rndstate.rand() < self.color_aug_prob: # random per-frame amount of aug # but shared across all views for i in range(T): fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.photo_aug.get_params( self.photo_aug.brightness, self.photo_aug.contrast, self.photo_aug.saturation, self.photo_aug.hue ) for v in range(V): rgb = rgbs[v, i] rgb = Image.fromarray(rgb) for fn_id in fn_idx: if fn_id == 0 and brightness_factor is not None: rgb = F_torchvision.adjust_brightness(rgb, brightness_factor) elif fn_id == 1 and contrast_factor is not None: rgb = F_torchvision.adjust_contrast(rgb, contrast_factor) elif fn_id == 2 and saturation_factor is not None: rgb = F_torchvision.adjust_saturation(rgb, saturation_factor) elif fn_id == 3 and hue_factor is not None: rgb = F_torchvision.adjust_hue(rgb, hue_factor) rgb = np.array(rgb, dtype=np.uint8) rgbs[v, i] = rgb if rndstate.rand() < self.blur_aug_prob: # random per-frame amount of blur # but shared across all views for i in range(T): sigma = self.blur_aug.get_params(self.blur_aug.sigma[0], self.blur_aug.sigma[1]) for v in range(V): rgb = rgbs[v, i] rgb = Image.fromarray(rgb) F_torchvision.gaussian_blur(rgb, self.blur_aug.kernel_size, [sigma, sigma]) rgb = np.array(rgb, dtype=np.uint8) rgbs[v, i] = rgb return rgbs, visibles def _add_cropping_augs(self, crop_size, rgbs, depths, intrs, trajs, visibles): V, T, H, W, _ = rgbs.shape _, _, N, _ = trajs.shape assert rgbs.dtype == np.uint8 assert depths.dtype == np.float32 assert rgbs.shape == (V, T, H, W, 3) assert depths.shape == (V, T, H, W, 1) assert intrs.shape == (V, T, 3, 3) assert trajs.shape == (V, T, N, 3) assert visibles.shape == (V, T, N) rgbs = rgbs.copy() depths = depths.copy() intrs = intrs.copy() trajs = trajs.copy() visibles = visibles.copy() ############ spatial transform ############ rgbs_new = np.zeros((V, T, crop_size[0], crop_size[1], 3), dtype=np.uint8) depths_new = np.zeros((V, T, crop_size[0], crop_size[1], 1), dtype=np.float32) for v in range(V): # padding pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) rgbs_view = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs[v]] depths_view = [np.pad(depth, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for depth in depths[v]] intrs[v, :, 0, 2] += pad_x0 intrs[v, :, 1, 2] += pad_y0 trajs[v, :, :, 0] += pad_x0 trajs[v, :, :, 1] += pad_y0 H_padded, W_padded = rgbs_view[0].shape[:2] # scaling + stretching scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1]) scale_x = scale scale_y = scale scale_delta_x = 0.0 scale_delta_y = 0.0 for t in range(T): if t == 1: scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta) scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta) elif t > 1: scale_delta_x = ( scale_delta_x * 0.8 + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2 ) scale_delta_y = ( scale_delta_y * 0.8 + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2 ) scale_x = scale_x + scale_delta_x scale_y = scale_y + scale_delta_y # bring h/w closer scale_xy = (scale_x + scale_y) * 0.5 scale_x = scale_x * 0.5 + scale_xy * 0.5 scale_y = scale_y * 0.5 + scale_xy * 0.5 # don't get too crazy scale_x = np.clip(scale_x, self.resize_lim[0], self.resize_lim[1]) scale_y = np.clip(scale_y, self.resize_lim[0], self.resize_lim[1]) H_new = int(H_padded * scale_y) W_new = int(W_padded * scale_x) # make it at least slightly bigger than the crop area, # so that the random cropping can add diversity H_new = np.clip(H_new, crop_size[0] + 10, None) W_new = np.clip(W_new, crop_size[1] + 10, None) # recompute scale in case we clipped scale_x = (W_new - 1) / float(W_padded - 1) scale_y = (H_new - 1) / float(H_padded - 1) rgbs_view[t] = cv2.resize(rgbs_view[t], (W_new, H_new), interpolation=cv2.INTER_LINEAR) depths_view[t] = cv2.resize(depths_view[t], (W_new, H_new), interpolation=cv2.INTER_NEAREST) intrs[v, t, 0, :] *= scale_x intrs[v, t, 1, :] *= scale_y trajs[v, t, :, 0] *= scale_x trajs[v, t, :, 1] *= scale_y ok_inds = visibles[v, 0, :] > 0 vis_trajs = trajs[v, :, ok_inds] # S,?,2 if vis_trajs.shape[0] > 0: mid_x = np.mean(vis_trajs[:, 0, 0]) mid_y = np.mean(vis_trajs[:, 0, 1]) else: mid_y = crop_size[0] // 2 mid_x = crop_size[1] // 2 x0 = int(mid_x - crop_size[1] // 2) y0 = int(mid_y - crop_size[0] // 2) offset_x = 0 offset_y = 0 for t in range(T): # on each frame, shift a bit more if t == 1: offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) elif t > 1: offset_x = int( offset_x * 0.8 + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2 ) offset_y = int( offset_y * 0.8 + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2 ) x0 = x0 + offset_x y0 = y0 + offset_y H_new, W_new = rgbs_view[t].shape[:2] if H_new == crop_size[0]: y0 = 0 else: y0 = min(max(0, y0), H_new - crop_size[0] - 1) if W_new == crop_size[1]: x0 = 0 else: x0 = min(max(0, x0), W_new - crop_size[1] - 1) rgbs_view[t] = rgbs_view[t][y0: y0 + crop_size[0], x0: x0 + crop_size[1]] depths_view[t] = depths_view[t][y0: y0 + crop_size[0], x0: x0 + crop_size[1]] intrs[v, t, 0, 2] -= x0 intrs[v, t, 1, 2] -= y0 trajs[v, t, :, 0] -= x0 trajs[v, t, :, 1] -= y0 H_new = crop_size[0] W_new = crop_size[1] # # h flip # if self.do_flip and np.random.rand() < self.h_flip_prob: # rgbs_view = [rgb[:, ::-1] for rgb in rgbs_view] # depths_view = [depth[:, ::-1] for depth in depths_view] # intrs[v, :, 0, 2] = W_new - intrs[v, :, 0, 2] # trajs[v, :, :, 0] = W_new - trajs[v, :, :, 0] # # # v flip # if np.random.rand() < self.v_flip_prob: # rgbs_view = [rgb[::-1] for rgb in rgbs_view] # depths_view = [depth[::-1] for depth in depths_view] # intrs[v, :, 1, 2] = H_new - intrs[v, :, 1, 2] # trajs[v, :, :, 1] = H_new - trajs[v, :, :, 1] rgbs_new[v] = np.stack(rgbs_view) depths_new[v] = np.stack(depths_view)[..., None] visibles = (visibles & (trajs[..., 0] >= 0) & (trajs[..., 1] >= 0) & (trajs[..., 0] < crop_size[1]) & (trajs[..., 1] < crop_size[0])) return rgbs_new, depths_new, intrs, trajs, visibles def _add_cropping_augs_with_pp_at_center(self, crop_size, rgbs, depths, intrs, trajs, visibles): V, T, H, W, _ = rgbs.shape _, _, N, _ = trajs.shape assert rgbs.dtype == np.uint8 assert depths.dtype == np.float32 assert rgbs.shape == (V, T, H, W, 3) assert depths.shape == (V, T, H, W, 1) assert intrs.shape == (V, T, 3, 3) assert trajs.shape == (V, T, N, 3) assert visibles.shape == (V, T, N) rgbs = rgbs.copy() depths = depths.copy() intrs = intrs.copy() trajs = trajs.copy() visibles = visibles.copy() rgbs_new = np.zeros((V, T, crop_size[0], crop_size[1], 3), dtype=np.uint8) depths_new = np.zeros((V, T, crop_size[0], crop_size[1], 1), dtype=np.float32) for v in range(V): pad_x0 = pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) pad_y0 = pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) rgbs_view = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs[v]] depths_view = [np.pad(depth, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for depth in depths[v]] intrs[v, :, 0, 2] += pad_x0 intrs[v, :, 1, 2] += pad_y0 trajs[v, :, :, 0] += pad_x0 trajs[v, :, :, 1] += pad_y0 H_padded, W_padded = rgbs_view[0].shape[:2] scale_x = np.random.uniform(self.resize_lim[0], self.resize_lim[1]) scale_y = scale_x + np.random.uniform(-0.01, 0.01) scale_y = max(self.resize_lim[0], min(self.resize_lim[1], scale_y)) H_new = max(int(H_padded * scale_y) + int(H_padded * scale_y) % 2, crop_size[0] + 10) W_new = max(int(W_padded * scale_x) + int(W_padded * scale_x) % 2, crop_size[1] + 10) scale_x = W_new / W_padded scale_y = H_new / H_padded for t in range(T): rgbs_view[t] = cv2.resize(rgbs_view[t], (W_new, H_new), interpolation=cv2.INTER_LINEAR) depths_view[t] = cv2.resize(depths_view[t], (W_new, H_new), interpolation=cv2.INTER_NEAREST) intrs[v, :, 0, :] *= scale_x intrs[v, :, 1, :] *= scale_y trajs[v, :, :, 0] *= scale_x trajs[v, :, :, 1] *= scale_y for t in range(T): cx = intrs[v, t, 0, 2] cy = intrs[v, t, 1, 2] x0 = round(cx - crop_size[1] / 2) y0 = round(cy - crop_size[0] / 2) H_new, W_new = rgbs_view[t].shape[:2] assert x0 >= 0 assert y0 >= 0 assert (H_new - crop_size[0]) >= 0 assert (W_new - crop_size[1]) >= 0 assert (H_new - crop_size[0]) >= y0 assert (W_new - crop_size[1]) >= x0 rgbs_view[t] = rgbs_view[t][y0:y0 + crop_size[0], x0:x0 + crop_size[1]] depths_view[t] = depths_view[t][y0:y0 + crop_size[0], x0:x0 + crop_size[1]] intrs[v, t, 0, 2] -= x0 intrs[v, t, 1, 2] -= y0 trajs[v, t, :, 0] -= x0 trajs[v, t, :, 1] -= y0 # Assert principal point is centered assert rgbs_view[t].shape[0] == crop_size[0] assert rgbs_view[t].shape[1] == crop_size[1] assert np.allclose(intrs[v, t, 0, 2], crop_size[1] / 2, atol=0.01) assert np.allclose(intrs[v, t, 1, 2], crop_size[0] / 2, atol=0.01) rgbs_new[v] = np.stack(rgbs_view) depths_new[v] = np.stack(depths_view)[..., None] visibles = (visibles & (trajs[..., 0] >= 0) & (trajs[..., 1] >= 0) & (trajs[..., 0] < crop_size[1]) & (trajs[..., 1] < crop_size[0])) return rgbs_new, depths_new, intrs, trajs, visibles def _rescale_and_erase_depth_patches(self, depths, trajs, visibles, rndstate): V, T, H, W, _ = depths.shape _, _, N, _ = trajs.shape assert depths.dtype == np.float32 assert depths.shape == (V, T, H, W, 1) assert trajs.shape == (V, T, N, 3) assert visibles.shape == (V, T, N) depths = depths.copy() visibles = visibles.copy() ############ eraser transform (per image after the first) ############ for v in range(V): for i in range(1, T): if rndstate.rand() < self.eraser_aug_prob: n = rndstate.randint(1, self.eraser_max + 1) # number of times to occlude for _ in range(n): xc = rndstate.randint(0, W) yc = rndstate.randint(0, H) dx = rndstate.randint(self.eraser_bounds[0], self.eraser_bounds[1]) dy = rndstate.randint(self.eraser_bounds[0], self.eraser_bounds[1]) x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32) x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32) y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32) y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32) eraser_depth = { 0: depths[v, i, y0:y1, x0:x1].mean(), 1: depths[v, i, y0:y1, x0:x1].min(), 2: depths[v, i, y0:y1, x0:x1].max(), 3: 0, }[rndstate.choice([0, 1, 2, 3], p=[0.2, 0.1, 0.35, 0.35])] depths[v, i, y0:y1, x0:x1] = eraser_depth occ_inds = np.logical_and( np.logical_and(trajs[v, i, :, 0] >= x0, trajs[v, i, :, 0] < x1), np.logical_and(trajs[v, i, :, 1] >= y0, trajs[v, i, :, 1] < y1), ) visibles[v, i, occ_inds] = 0 ############ replace transform (per image after the first) ############ for v in range(V): for i in range(1, T): if rndstate.rand() < self.replace_aug_prob: n = rndstate.randint(1, self.replace_max + 1) # number of times to occlude for _ in range(n): xc = rndstate.randint(0, W) yc = rndstate.randint(0, H) dx = rndstate.randint(self.replace_bounds[0], self.replace_bounds[1]) dy = rndstate.randint(self.replace_bounds[0], self.replace_bounds[1]) x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32) x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32) y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32) y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32) wid = x1 - x0 hei = y1 - y0 y00 = rndstate.randint(0, H - hei) x00 = rndstate.randint(0, W - wid) v_rnd = rndstate.randint(0, V) i_rnd = rndstate.randint(0, T) depths[v, i, y0:y1, x0:x1] = depths[v_rnd, i_rnd, y00: y00 + hei, x00: x00 + wid] occ_inds = np.logical_and( np.logical_and(trajs[v, i, :, 0] >= x0, trajs[v, i, :, 0] < x1), np.logical_and(trajs[v, i, :, 1] >= y0, trajs[v, i, :, 1] < y1), ) visibles[v, i, occ_inds] = 0 return depths, visibles def _crop(self, rgbs, trajs, crop_size): T, N, _ = trajs.shape S = len(rgbs) H, W = rgbs[0].shape[:2] assert S == T ############ spatial transform ############ H_new = H W_new = W # simple random crop y0 = 0 if crop_size[0] >= H_new else (H_new - crop_size[0]) // 2 # np.random.randint(0, x0 = 0 if crop_size[1] >= W_new else np.random.randint(0, W_new - crop_size[1]) rgbs = [rgb[y0: y0 + crop_size[0], x0: x0 + crop_size[1]] for rgb in rgbs] trajs[:, :, 0] -= x0 trajs[:, :, 1] -= y0 return np.stack(rgbs), trajs ================================================ FILE: mvtracker/datasets/panoptic_studio_multiview_dataset.py ================================================ import logging import os import pathlib import re import time import warnings import cv2 import numpy as np import pandas as pd import torch import torch.nn.functional as F from scipy.spatial.transform import Rotation as R from torch.utils.data import Dataset from mvtracker.datasets.utils import Datapoint, transform_scene class PanopticStudioMultiViewDataset(Dataset): @staticmethod def from_name(dataset_name: str, dataset_root: str): """ Examples of datasets supported by this factory method: - panoptic-multiview - panoptic-multiview-views27_16_14_8 - panoptic-multiview-duster27_16_14_8 - panoptic-multiview-duster27_16_14_8cleaned - panoptic-multiview-duster27_16_14_8cleaned-views27_16 - panoptic-multiview-duster27_16_14_8cleaned-views27_16-novelviews1_4 - panoptic-multiview-duster27_16_14_8cleaned-views27_16-novelviews1_4-single - panoptic-multiview-duster27_16_14_8cleaned-views27_16-novelviews1_4-single-2dpt - panoptic-multiview-duster27_16_14_8cleaned-views27_16-novelviews1_4-single-2dpt-cached """ # Parse the dataset name, chunk by chunk non_parsed = dataset_name.replace("panoptic-multiview", "", 1) if non_parsed.startswith("-duster"): match = re.match(r"-duster((?:\d+_?)+)(cleaned)?", non_parsed) assert match is not None duster_views = list(map(int, match.group(1).split("_"))) use_duster = True use_duster_cleaned = match.group(2) is not None non_parsed = non_parsed.replace(match.group(0), "", 1) else: use_duster = False use_duster_cleaned = False duster_views = None if non_parsed.startswith("-views"): match = re.match(r"-views((?:\d+_?)+)", non_parsed) assert match is not None views = list(map(int, match.group(1).split("_"))) if duster_views is not None: assert all(v in duster_views for v in views) non_parsed = non_parsed.replace(match.group(0), "", 1) else: views = duster_views if non_parsed.startswith("-novelviews"): match = re.match(r"-novelviews((?:\d+_?)+)", non_parsed) assert match is not None novel_views = list(map(int, match.group(1).split("_"))) non_parsed = non_parsed.replace(match.group(0), "", 1) else: novel_views = None if non_parsed.startswith("-single"): single_point = True non_parsed = non_parsed.replace("-single", "", 1) else: single_point = False if non_parsed.startswith("-2dpt"): eval_2dpt = True non_parsed = non_parsed.replace("-2dpt", "", 1) else: eval_2dpt = False if non_parsed.startswith("-cached"): use_cached_tracks = True non_parsed = non_parsed.replace("-cached", "", 1) else: use_cached_tracks = False assert non_parsed == "", f"Unparsed part of the dataset name: {non_parsed}" return PanopticStudioMultiViewDataset( data_root=os.path.join(dataset_root, "panoptic-multiview"), views_to_return=views, novel_views=novel_views, use_duster_depths=use_duster, clean_duster_depths=use_duster_cleaned, traj_per_sample=384, seed=72, max_videos=6, perform_sanity_checks=False, use_cached_tracks=use_cached_tracks, ) def __init__( self, data_root, views_to_return=None, novel_views=None, use_duster_depths=False, clean_duster_depths=False, traj_per_sample=512, seed=None, max_videos=None, perform_sanity_checks=False, use_cached_tracks=False, ): super().__init__() self.data_root = data_root self.views_to_return = views_to_return self.novel_views = novel_views self.use_duster_depths = use_duster_depths self.clean_duster_depths = clean_duster_depths self.traj_per_sample = traj_per_sample self.seed = seed self.perform_sanity_checks = perform_sanity_checks self.use_cached_tracks = use_cached_tracks self.cache_name = self._cache_key() self.seq_names = self._get_sequence_names(max_videos) self.getitem_calls = 0 def _get_sequence_names(self, max_videos): """ Fetch all valid sequence names from the dataset root. Args: max_videos (int): Limit the number of sequences to load. Returns: List[str]: Sorted list of valid sequence names. """ seq_names = [ fname for fname in os.listdir(self.data_root) if os.path.isdir(os.path.join(self.data_root, fname)) and not fname.startswith(".") and not fname.startswith("_") ] seq_names = sorted(seq_names) valid_seqs = [] for seq_name in seq_names: scene_path = os.path.join(self.data_root, seq_name) if not os.path.exists(os.path.join(scene_path, "tapvid3d_annotations.npz")): warnings.warn(f"Skipping {scene_path} because it has no tapvid3d_annotations.npz labels file.") continue valid_seqs.append(seq_name) if max_videos is not None: valid_seqs = valid_seqs[:max_videos] print(f"Using {len(valid_seqs)} videos from {self.data_root}") return valid_seqs def _cache_key(self): name = f"cachedtracks--seed{self.seed}" if self.views_to_return is not None: name += f"-views{'_'.join(map(str, self.views_to_return))}" if self.traj_per_sample is not None: name += f"-n{self.traj_per_sample}" return name + "--v1" # bump this if you change the selection policy def __len__(self): return len(self.seq_names) def __getitem__(self, index): start_time = time.time() sample = self._getitem_helper(index) self.getitem_calls += 1 if self.getitem_calls < 10: print(f"Loading {index:>06d} took {time.time() - start_time:.3f} sec. Getitem calls: {self.getitem_calls}") return sample, True def _getitem_helper(self, index): """ Helper function to load a single sample. Args: index (int): Index of the sample to load. Returns: CoTrackerData, bool: Sample data and success flag. """ if self.seed is None: seed = torch.randint(0, 2 ** 32 - 1, (1,)).item() else: seed = self.seed rnd_torch = torch.Generator().manual_seed(seed) rnd_np = np.random.RandomState(seed=seed) datapoint_path = os.path.join(self.data_root, self.seq_names[index]) ims_path = os.path.join(datapoint_path, "ims") depths_path = os.path.join(datapoint_path, "dynamic3dgs_depth") tapvid3d_merged_annotations = np.load(os.path.join(datapoint_path, "tapvid3d_annotations.npz")) traj3d_world = tapvid3d_merged_annotations["trajectories"] traj2d = tapvid3d_merged_annotations["trajectories_pixelspace"] visibility = tapvid3d_merged_annotations["per_view_visibilities"] query_points_3d = tapvid3d_merged_annotations["query_points_3d"] extrs = tapvid3d_merged_annotations["extrinsics"] intrs = tapvid3d_merged_annotations["intrinsics"] views = {} view_folders = sorted([f for f in os.listdir(ims_path)], key=lambda x: int(x)) if self.views_to_return is not None: views_to_return = self.views_to_return else: views_to_return = sorted(list(range(len(view_folders)))) views_to_load = views_to_return.copy() if self.novel_views is not None: views_to_load = list(set(views_to_load + self.novel_views)) for v in views_to_load: rgb_folder = os.path.join(ims_path, str(v)) rgb_files = sorted(os.listdir(rgb_folder)) rgb_images = [cv2.imread(os.path.join(rgb_folder, f))[:, :, ::-1] for f in rgb_files] depth = np.load(os.path.join(depths_path, f"depths_{v:02d}.npy")) views[v] = { "rgb": np.stack(rgb_images), "depth": depth, } rgbs = np.stack([views[v]["rgb"] for v in views_to_return]) n_views, n_frames, h, w, _ = rgbs.shape depths = np.stack([views[v]["depth"] for v in views_to_return])[..., None].astype(np.float32) intrs = np.stack([intrs[v] for v in views_to_return])[:, None, :, :].repeat(n_frames, axis=1) extrs = np.stack([extrs[v][:3, :] for v in views_to_return])[:, None, :, :].repeat(n_frames, axis=1) visibility = visibility[views_to_return] traj2d = traj2d[views_to_return] # Load novel views if they exist novel_rgbs = None novel_intrs = None novel_extrs = None if self.novel_views is not None: novel_rgbs = np.stack([views[v]["rgb"] for v in self.novel_views]) novel_intrs = np.stack([tapvid3d_merged_annotations["intrinsics"][v] for v in self.novel_views])[:, None, :, :].repeat(n_frames, axis=1) novel_extrs = np.stack([tapvid3d_merged_annotations["extrinsics"][v][:3, :] for v in self.novel_views])[:, None, :, :].repeat(n_frames, axis=1) # Load Duster's features and estimated depths if they exist views_selection_str = '-'.join(str(v) for v in self.views_to_return) duster_root = pathlib.Path(datapoint_path) / f'duster-views-{views_selection_str}' if self.use_duster_depths: assert duster_root.exists(), f"Duster root {duster_root} does not exist." last_frame_scene_file = duster_root / f"3d_model__{n_frames - 1:05d}__scene.npz" assert last_frame_scene_file.exists(), f"Duster scene file {last_frame_scene_file} does not exist." feats = None feat_dim = None feat_stride = None if duster_root.exists() and (duster_root / f"3d_model__{n_frames - 1:05d}__scene.npz").exists(): duster_depths = [] duster_feats = [] for frame_idx in range(n_frames): scene = np.load(duster_root / f"3d_model__{frame_idx:05d}__scene.npz") duster_depth = torch.from_numpy(scene["depths"]) duster_conf = torch.from_numpy(scene["confs"]) duster_msk = torch.from_numpy(scene["cleaned_mask"]) duster_feat = torch.from_numpy(scene["feats"]) if self.clean_duster_depths: duster_depth = duster_depth * duster_msk duster_depth = F.interpolate(duster_depth[:, None], (h, w), mode='nearest') duster_depths.append(duster_depth[:, 0, :, :, None]) duster_feats.append(duster_feat) feats = torch.stack(duster_feats, dim=1).numpy() assert feats.ndim == 4 assert feats.shape[0] == n_views assert feats.shape[1] == n_frames feat_stride = np.round(np.sqrt(h * w / feats.shape[2])).astype(int) feat_dim = feats.shape[3] feats = feats.reshape(n_views, n_frames, h // feat_stride, w // feat_stride, feat_dim) # Replace the depths with the Duster depths, if configured so if self.use_duster_depths: depths = torch.stack(duster_depths, dim=1).numpy() n_tracks = traj3d_world.shape[1] assert rgbs.shape == (n_views, n_frames, h, w, 3) assert depths.shape == (n_views, n_frames, h, w, 1) assert feats is None or feats.shape == (n_views, n_frames, h // feat_stride, w // feat_stride, feat_dim) assert intrs.shape == (n_views, n_frames, 3, 3) assert extrs.shape == (n_views, n_frames, 3, 4) assert traj2d.shape == (n_views, n_frames, n_tracks, 2) assert visibility.shape == (n_views, n_frames, n_tracks) assert traj3d_world.shape == (n_frames, n_tracks, 3) if novel_rgbs is not None: assert novel_rgbs.shape == (len(self.novel_views), n_frames, h, w, 3) assert novel_intrs.shape == (len(self.novel_views), n_frames, 3, 3) assert novel_extrs.shape == (len(self.novel_views), n_frames, 3, 4) # Make sure our intrinsics and extrinsics work correctly point_3d_world = traj3d_world point_4d_world_homo = np.concatenate([point_3d_world, np.ones_like(point_3d_world[..., :1])], axis=-1) point_3d_camera = np.einsum('ABij,BCj->ABCi', extrs, point_4d_world_homo) if self.perform_sanity_checks: point_2d_pixel_homo = np.einsum('ABij,ABCj->ABCi', intrs, point_3d_camera) point_2d_pixel = point_2d_pixel_homo[..., :2] / point_2d_pixel_homo[..., 2:] point_2d_pixel_gt = traj2d point_2d_pixel_no_nan = np.nan_to_num(point_2d_pixel, nan=0) point_2d_pixel_gt_no_nan = np.nan_to_num(point_2d_pixel_gt, nan=0) assert np.allclose(point_2d_pixel_no_nan[0, :, 0, :], point_2d_pixel_no_nan[0, :, 0, :], atol=.01) assert np.allclose(point_2d_pixel_gt_no_nan, point_2d_pixel_gt_no_nan, atol=.01), f"Point projection failed" traj2d_w_z = np.concatenate([traj2d, point_3d_camera[..., 2:]], axis=-1) rgbs = torch.from_numpy(rgbs).permute(0, 1, 4, 2, 3).float() depths = torch.from_numpy(depths).permute(0, 1, 4, 2, 3).float() feats = torch.from_numpy(feats).permute(0, 1, 4, 2, 3).float() if feats is not None else None intrs = torch.from_numpy(intrs).float() extrs = torch.from_numpy(extrs).float() traj2d = torch.from_numpy(traj2d) traj2d_w_z = torch.from_numpy(traj2d_w_z) traj3d_world = torch.from_numpy(traj3d_world) visibility = torch.from_numpy(visibility) if novel_rgbs is not None: novel_rgbs = torch.from_numpy(novel_rgbs).permute(0, 1, 4, 2, 3).float() novel_intrs = torch.from_numpy(novel_intrs).float() novel_extrs = torch.from_numpy(novel_extrs).float() # Track selection cache_root = os.path.join(self.data_root, self.seq_names[index], "cache") os.makedirs(cache_root, exist_ok=True) cache_file = os.path.join(cache_root, f"{self.cache_name}.npz") # Check if we can use cached tracks use_cache = bool(self.use_cached_tracks) and os.path.isfile(cache_file) if use_cache: cache = np.load(cache_file) inds_sampled = torch.from_numpy(cache["track_indices"]) traj2d_w_z = torch.from_numpy(cache["traj2d_w_z"]) traj3d_world = torch.from_numpy(cache["traj3d_world"]) visibility = torch.from_numpy(cache["visibility"]) valids = torch.from_numpy(cache["valids"]) query_points = torch.from_numpy(cache["query_points"]) # Otherwise, sample the tracks and create query points else: # Prefer TAPVid-3D's merged query points when selecting the query points # First, denote the points in time before the query points appeared as non-visible # Second, choose the query points as the first appearance of the points in the selected views (which might be # later than in the TAPVid-3D annotations because the query might not be visible in the selected views) tapvid3d_merged_query_point_timestep = query_points_3d[:, 0].round().astype(int) visibility *= (np.arange(n_frames)[None, :, None] >= tapvid3d_merged_query_point_timestep[None, None, :]) # Sample the points to track visible_for_at_least_two_frames = visibility.any(0).sum(0) >= 2 valid_tracks = visible_for_at_least_two_frames valid_tracks = valid_tracks.nonzero(as_tuple=False)[:, 0] point_inds = torch.randperm(len(valid_tracks), generator=rnd_torch) traj_per_sample = self.traj_per_sample if self.traj_per_sample is not None else len(point_inds) assert len(point_inds) >= traj_per_sample point_inds = point_inds[:traj_per_sample] inds_sampled = valid_tracks[point_inds] n_tracks = len(inds_sampled) traj2d = traj2d[:, :, inds_sampled].float() traj2d_w_z = traj2d_w_z[:, :, inds_sampled].float() traj3d_world = traj3d_world[:, inds_sampled].float() visibility = visibility[:, :, inds_sampled] valids = ~torch.isnan(traj2d).any(dim=-1).any(dim=0) # Create the query points gt_visibilities_any_view = visibility.any(dim=0) assert (gt_visibilities_any_view.sum(dim=0) >= 2).all(), "All points should be visible in least two frames." last_visible_index = (torch.arange(n_frames).unsqueeze(-1) * gt_visibilities_any_view).max(0).values assert gt_visibilities_any_view[last_visible_index[None, :], torch.arange(n_tracks)].all() gt_visibilities_any_view[last_visible_index[None, :], torch.arange(n_tracks)] = False assert (gt_visibilities_any_view.sum(dim=0) >= 1).all() query_points_t = torch.argmax(gt_visibilities_any_view.float(), dim=0) query_points_xyz_worldspace = traj3d_world[query_points_t, torch.arange(n_tracks)] query_points = torch.cat([query_points_t[:, None], query_points_xyz_worldspace], dim=1) assert gt_visibilities_any_view[query_points_t, torch.arange(n_tracks)].all() # Replace nans with zeros traj2d[torch.isnan(traj2d)] = 0 traj2d_w_z[torch.isnan(traj2d_w_z)] = 0 traj3d_world[torch.isnan(traj3d_world)] = 0 assert torch.isnan(visibility).sum() == 0 # Cache the selected tracks and query points if self.use_cached_tracks: logging.warn(f"Caching tracks for {self.seq_names[index]} at {os.path.abspath(cache_file)}") np.savez_compressed( cache_file, track_indices=inds_sampled.numpy(), traj2d_w_z=traj2d_w_z.numpy(), traj3d_world=traj3d_world.numpy(), visibility=visibility.numpy(), valids=valids.numpy(), query_points=query_points.numpy(), ) # Normalize the scene to be similar to Kubric's scene scale = 2.5 rot_x = R.from_euler('x', -90, degrees=True).as_matrix() rot_y = R.from_euler('y', 0, degrees=True).as_matrix() rot_z = R.from_euler('z', 0, degrees=True).as_matrix() rot = torch.from_numpy(rot_z @ rot_y @ rot_x) translate = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) ( depths_trans, extrs_trans, query_points_trans, traj3d_world_trans, traj2d_w_z_trans ) = transform_scene(scale, rot, translate, depths, extrs, query_points, traj3d_world, traj2d_w_z) novel_extrs_trans = transform_scene(scale, rot, translate, None, novel_extrs, None, None, None)[1] # # Use the auto scene normalization of generic scenes # from mvtracker.datasets.generic_scene_dataset import compute_auto_scene_normalization # scale, rot, translation = compute_auto_scene_normalization(depths, torch.ones_like(depths) * 100, extrs_trans, intrs) # scale = scale * T[0, 0].item() # print(f"{scale=}") # (depths_trans, extrs_trans, query_points_trans, traj3d_world_trans, traj2d_w_z_trans # ) = transform_scene(scale, rot, translation, depths_trans, extrs_trans, query_points_trans, traj3d_world_trans, traj2d_w_z_trans) # _, novel_extrs_trans, _, _, _ = transform_scene(scale, rot, translation, None, novel_extrs_trans, None, None, None) # 85.7 94.5 92.3 --> 86.0 94.8 92.2 # from mvtracker.datasets.dexycb_multiview_dataset import rerun_viz_scene # rerun_viz_scene("nane/pc__no_transform/", rgbs[:, ::20], depths[:, ::20], intrs[:, ::20], extrs[:, ::20], traj3d_world[:, ::20], 0.1) # rerun_viz_scene("nane/pc_transformed/", rgbs[:, ::20], depths[:, ::20], intrs[:, ::20], extrs_trans[:, ::20], traj3d_world_trans[:, ::20], 1) segs = torch.ones((n_frames, 1, h, w)) # Dummy segmentation masks datapoint = Datapoint( video=rgbs, videodepth=depths_trans, feats=feats, segmentation=segs, trajectory=traj2d_w_z_trans, trajectory_3d=traj3d_world_trans, trajectory_category=None, visibility=visibility, valid=valids, seq_name=self.seq_names[index], intrs=intrs, extrs=extrs_trans, query_points=None, query_points_3d=query_points_trans, track_upscaling_factor=1 / scale, novel_video=novel_rgbs, novel_intrs=novel_intrs, novel_extrs=novel_extrs_trans, ) return datapoint ================================================ FILE: mvtracker/datasets/tap_vid_datasets.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import glob import io import logging import os import pickle import re import sys from pathlib import Path from typing import * import matplotlib import mediapy as media import numpy as np import rerun as rr import torch from PIL import Image from scipy.spatial.transform import Rotation as R from mvtracker.datasets.utils import Datapoint, transform_scene DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]] def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: """Resize a video to output_size.""" # If you have a GPU, consider replacing this with a GPU-enabled resize op, # such as a jitted jax.image.resize. It will make things faster. return media.resize_video(video, output_size) def sample_queries_first( target_occluded: np.ndarray, target_points: np.ndarray, frames: np.ndarray, ) -> Mapping[str, np.ndarray]: """Package a set of frames and tracks for use in TAPNet evaluations. Given a set of frames and tracks with no query points, use the first visible point in each track as the query. Args: target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], where True indicates occluded. target_points: Position, of shape [n_tracks, n_frames, 2], where each point is [x,y] scaled between 0 and 1. frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between -1 and 1. Returns: A dict with the keys: video: Video tensor of shape [1, n_frames, height, width, 3] query_points: Query points of shape [1, n_queries, 3] where each point is [t, y, x] scaled to the range [-1, 1] target_points: Target points of shape [1, n_queries, n_frames, 2] where each point is [x, y] scaled to the range [-1, 1] """ valid = np.sum(~target_occluded, axis=1) > 0 target_points = target_points[valid, :] target_occluded = target_occluded[valid, :] query_points = [] for i in range(target_points.shape[0]): index = np.where(target_occluded[i] == 0)[0][0] x, y = target_points[i, index, 0], target_points[i, index, 1] query_points.append(np.array([index, x, y])) # [t, x, y] query_points = np.stack(query_points, axis=0) return { "video": frames[np.newaxis, ...], "query_points": query_points[np.newaxis, ...], "target_points": target_points[np.newaxis, ...], "occluded": target_occluded[np.newaxis, ...], } def sample_queries_strided( target_occluded: np.ndarray, target_points: np.ndarray, frames: np.ndarray, query_stride: int = 5, ) -> Mapping[str, np.ndarray]: """Package a set of frames and tracks for use in TAPNet evaluations. Given a set of frames and tracks with no query points, sample queries strided every query_stride frames, ignoring points that are not visible at the selected frames. Args: target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], where True indicates occluded. target_points: Position, of shape [n_tracks, n_frames, 2], where each point is [x,y] scaled between 0 and 1. frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between -1 and 1. query_stride: When sampling query points, search for un-occluded points every query_stride frames and convert each one into a query. Returns: A dict with the keys: video: Video tensor of shape [1, n_frames, height, width, 3]. The video has floats scaled to the range [-1, 1]. query_points: Query points of shape [1, n_queries, 3] where each point is [t, y, x] scaled to the range [-1, 1]. target_points: Target points of shape [1, n_queries, n_frames, 2] where each point is [x, y] scaled to the range [-1, 1]. trackgroup: Index of the original track that each query point was sampled from. This is useful for visualization. """ tracks = [] occs = [] queries = [] trackgroups = [] total = 0 trackgroup = np.arange(target_occluded.shape[0]) for i in range(0, target_occluded.shape[1], query_stride): mask = target_occluded[:, i] == 0 query = np.stack( [ i * np.ones(target_occluded.shape[0:1]), target_points[:, i, 1], target_points[:, i, 0], ], axis=-1, ) queries.append(query[mask]) tracks.append(target_points[mask]) occs.append(target_occluded[mask]) trackgroups.append(trackgroup[mask]) total += np.array(np.sum(target_occluded[:, i] == 0)) return { "video": frames[np.newaxis, ...], "query_points": np.concatenate(queries, axis=0)[np.newaxis, ...], "target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...], "occluded": np.concatenate(occs, axis=0)[np.newaxis, ...], "trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...], } class TapVidDataset(torch.utils.data.Dataset): @staticmethod def from_name(dataset_name: str, dataset_root: str): """ Examples of datasets supported by this factory method: - tapvid2d-davis-nodepth - tapvid2d-davis-moge - tapvid2d-davis-zoedepth - tapvid2d-davis-videodepthanything - tapvid2d-davis-megasam - tapvid2d-davis-mogewithextrinsics - tapvid2d-davis-mogewithextrinsics-256x256 - tapvid2d-davis-mogewithextrinsics-256x256-single """ if dataset_name.startswith("tapvid2d-davis-"): # Parse the dataset name, chunk by chunk non_parsed = dataset_name.replace("tapvid2d-davis-", "", 1) # Extract depth estimator (until first possible resolution or single flag) match = re.match(r"([^-]+)", non_parsed) assert match is not None depth_estimator_name = match.group(1) non_parsed = non_parsed.replace(depth_estimator_name, "", 1) # Extract resolution resize_to = None match = re.search(r"-([0-9]+x[0-9]+)", non_parsed) if match: width, height = map(int, match.group(1).split("x")) resize_to = (width, height) non_parsed = non_parsed.replace(match.group(0), "", 1) # Check for single point flag single_point = "-single" in non_parsed non_parsed = non_parsed.replace("-single", "", 1) if single_point else non_parsed # Ensure no unparsed parts left assert non_parsed == "", f"Unparsed part of the dataset name: {non_parsed}" data_root = os.path.join(dataset_root, "tapvid_davis/tapvid_davis.pkl") return TapVidDataset( dataset_type="davis", data_root=data_root, resize_to=resize_to, queried_first=True, depth_estimator_name=depth_estimator_name, depth_estimator_batch_size=2, depth_estimator_device="cuda", stream_rerun_depth_viz=False, save_rerun_depth_viz=False, ) def __init__( self, data_root, dataset_type="davis", resize_to=(256, 256), queried_first=True, depth_estimator_name="moge-with-extrinsics", depth_estimator_batch_size=2, depth_estimator_device="cuda", stream_rerun_depth_viz=False, save_rerun_depth_viz=False, ): self.dataset_type = dataset_type self.resize_to = resize_to self.queried_first = queried_first if self.dataset_type == "kinetics": self.depth_cache_root = os.path.join(data_root, "depth_cache") else: self.depth_cache_root = os.path.join(os.path.dirname(data_root), "depth_cache") os.makedirs(self.depth_cache_root, exist_ok=True) if self.dataset_type == "kinetics": all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl")) points_dataset = [] for pickle_path in all_paths: with open(pickle_path, "rb") as f: data = pickle.load(f) points_dataset = points_dataset + data self.points_dataset = points_dataset else: with open(data_root, "rb") as f: self.points_dataset = pickle.load(f) if self.dataset_type == "davis": self.video_names = list(self.points_dataset.keys()) logging.info("found %d unique videos in %s" % (len(self.points_dataset), data_root)) self.depth_estimator_name = depth_estimator_name self.depth_estimator_batch_size = depth_estimator_batch_size self.depth_estimator_device = depth_estimator_device self.stream_rerun_depth_viz = stream_rerun_depth_viz self.save_rerun_depth_viz = save_rerun_depth_viz # # Dummy call all items to generate rerun visualizations # self.stream_rerun_depth_viz = False # self.save_rerun_depth_viz = True # for i in tqdm(range(len(self.points_dataset))): # try: # self[i] # except Exception as e: # logging.error(f"Error processing video {i}: {e}") # logging.info(f"But we continue anyway") # continue # exit() def __getitem__(self, index): if self.dataset_type == "davis": video_name = self.video_names[index] else: video_name = index frames = self.points_dataset[video_name]["video"].copy() if isinstance(frames[0], bytes): # TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s. def decode(frame): byteio = io.BytesIO(frame) img = Image.open(byteio) return np.array(img) frames = np.array([decode(frame) for frame in frames]) target_points = self.points_dataset[video_name]["points"].copy() if self.resize_to is not None: frames = resize_video(frames, self.resize_to) target_points *= np.array([self.resize_to[1] - 1, self.resize_to[0] - 1]) else: target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1]) assert target_points[:, :, 0].min() >= 0 assert target_points[:, :, 0].max() <= frames.shape[2] - 1 assert target_points[:, :, 1].min() >= 0 assert target_points[:, :, 1].max() <= frames.shape[1] - 1 T, H, W, C = frames.shape N, T, D = target_points.shape target_occ = self.points_dataset[video_name]["occluded"].copy() if self.queried_first: converted = sample_queries_first(target_occ, target_points, frames) else: converted = sample_queries_strided(target_occ, target_points, frames) assert converted["target_points"].shape[1] == converted["query_points"].shape[1] trajs = (torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float()) # T, N, D rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float() visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute(1, 0) # T, N query_points_2d = torch.from_numpy(converted["query_points"])[0] # T, N # Let's estimate depths RIGHT HERE res = f"{H}x{W}" cached_file_zoedepth_nk = os.path.join(self.depth_cache_root, f"zoedepth_nk__{video_name}__{res}.npz") cached_file_moge = os.path.join(self.depth_cache_root, f"moge__{video_name}__{res}.npz") cached_file_megasam = os.path.join(self.depth_cache_root, f"megasam__{video_name}__{res}-v1.npz") if self.depth_estimator_name == "nodepth": depth = np.ones((T, H, W)) intrs = np.eye(3) * max(H, W) extrs = np.eye(4)[None].repeat(T, axis=0) elif self.depth_estimator_name == "zoedepth": depth = zoedepth_nk(rgbs, self.depth_estimator_batch_size, self.depth_estimator_device, cached_file_zoedepth_nk) _, intrs, _, _, _ = moge(rgbs, self.depth_estimator_batch_size, self.depth_estimator_device, cached_file_moge) extrs = np.eye(4)[None].repeat(T, axis=0) elif self.depth_estimator_name == "moge": depth, intrs, _, _, mask = moge(rgbs, self.depth_estimator_batch_size, self.depth_estimator_device, cached_file_moge) depth[~mask] = 0 extrs = np.eye(4)[None].repeat(T, axis=0) elif self.depth_estimator_name == "mogewithextrinsics": depth, intrs, extrs, _, mask = moge(rgbs, self.depth_estimator_batch_size, self.depth_estimator_device, cached_file_moge) depth[~mask] = 0 elif self.depth_estimator_name == "videodepthanything": raise NotImplementedError("videodepthanything is not implemented yet") elif self.depth_estimator_name == "megasam": try: depth, intrs, extrs = megasam( rgbs=rgbs, batch_size=self.depth_estimator_batch_size, device=self.depth_estimator_device, cached_file=cached_file_megasam, ) except Exception as e: logging.error(f"MegaSAM error for {video_name} ({rgbs.shape=}) (we will use moge depth instead): {e}") depth, intrs, extrs, _, mask = moge(rgbs, self.depth_estimator_batch_size, self.depth_estimator_device, cached_file_moge) depth[~mask] = 0 else: raise NotImplementedError depth = torch.from_numpy(depth).float() if intrs.ndim == 2: intrs = intrs[None].repeat(T, axis=0) intrs = torch.from_numpy(intrs).float() extrs_square = torch.from_numpy(extrs).float() extrs = extrs_square[:, :3, :] intrs_inv = torch.inverse(intrs) extrs_inv = torch.inverse(extrs_square) # Project trajectories to 3D trajs_depth = trajs.new_ones((T, N, 1)) * np.inf for t in range(T): # # V1: Not good enough, depths are jumping to the background near edges because of interpolation # trajs_depth[t] = bilinear_sample2d( # im=depth[t][None, None], # x=trajs[t, :, 0][None], # y=trajs[t, :, 1][None], # )[0].permute(1, 0).type(trajs_depth.dtype) # V2: Still not good, taking the closest pixel only (without interpolating) still has jumps at edges x_nearest = trajs[t, :, 0].round().long() y_nearest = trajs[t, :, 1].round().long() depth_nearest = depth[t].view(-1)[(y_nearest * W + x_nearest).view(-1)] depth_nearest = depth_nearest.view(1, -1).type(trajs_depth.dtype).permute(1, 0) trajs_depth[t] = depth_nearest # # V3: Taking the minimum depth value of the neighbors also fails when there are other things in front. # depth_pad = F.pad(depth[t][None, None], (1, 1, 1, 1), mode="replicate") # Pad to handle edges # depth_min = -F.max_pool2d(-depth_pad, kernel_size=9, stride=1) # Min pooling using negation # depth_min_sampled = depth_min[0, 0, trajs[t, :, 1].long(), trajs[t, :, 0].long()].type(trajs_depth.dtype) # trajs_depth[t] = depth_min_sampled[:, None] assert torch.all(torch.isfinite(trajs_depth)).item() trajs_camera = torch.einsum("Tij,TNj->TNi", intrs_inv, to_homogenous_torch(trajs)) * trajs_depth trajs_world = torch.einsum("Tij,TNj->TNi", extrs_inv, to_homogenous_torch(trajs_camera))[..., :3] trajs_3d = trajs_world trajs_w_z = torch.cat([trajs, trajs_depth], dim=2) # Project query points to 3D qp_t = query_points_2d[:, 0].float() qp_xyz_pixel = query_points_2d[:, 1:].float() qp_depth = qp_xyz_pixel.new_ones((N, 1)) * np.inf qp_xyz_world = qp_xyz_pixel.new_ones((N, 3)) * np.inf for t in range(T): qp_mask = qp_t == t if qp_mask.sum() == 0: continue # V2 depth interpolation x_nearest = qp_xyz_pixel[qp_mask, 0].round().long() y_nearest = qp_xyz_pixel[qp_mask, 1].round().long() depth_nearest = depth[t].view(-1)[(y_nearest * W + x_nearest).view(-1)] depth_nearest = depth_nearest.view(1, -1).type(trajs_depth.dtype).permute(1, 0) qp_depth[qp_mask] = depth_nearest qp_xyz_pixel_t = to_homogenous_torch(qp_xyz_pixel[qp_mask]) qp_xyz_camera_t = torch.einsum("ij,Nj->Ni", intrs_inv[t], qp_xyz_pixel_t) * qp_depth[qp_mask] qp_xyz_world_t = torch.einsum("ij,Nj->Ni", extrs_inv[t], to_homogenous_torch(qp_xyz_camera_t))[..., :3] qp_xyz_world[qp_mask] = qp_xyz_world_t assert torch.all(torch.isfinite(qp_depth)) assert torch.all(torch.isfinite(qp_xyz_world)) query_points_3d = torch.cat([qp_t[:, None], qp_xyz_world], dim=1) # Visualize the depth estimation in Rerun radii_scale = 0.1 streams = [] if self.stream_rerun_depth_viz: streams += [True] if self.save_rerun_depth_viz: streams += [False] for stream in streams: # depth_zoedepth = zoedepth_nk(rgbs, self.depth_estimator_batch_size, self.depth_estimator_device, # cached_file_zoedepth_nk) depth_moge, intrinsics_moge, w2c_moge, _, mask_moge = moge( rgbs=rgbs, batch_size=self.depth_estimator_batch_size, device=self.depth_estimator_device, cached_file=cached_file_moge, ) # TODO: But what intrinsics did Zoe really assume or use, if any? K = intrinsics_moge K_inv = np.linalg.inv(K) rr.init("TAPVid-2D Estimated Depths", recording_id="v0.1") if stream: rr.connect_tcp() rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True) rr.set_time_seconds("frame", 0) rr.log( "world/xyz", rr.Arrows3D( vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]], colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]], ), ) for t in range(T): rr.set_time_seconds("frame", t / 12) rgb = rgbs[t].permute(1, 2, 0).numpy() # Log the depth used for 3D tracking rr.log(f"{video_name}/image/depth_for_tracking", rr.Pinhole( image_from_camera=intrs[t].numpy(), width=W, height=H, )) rr.log(f"{video_name}/image/depth_for_tracking", rr.Transform3D( translation=np.linalg.inv(extrs_square[t].numpy())[:3, 3], mat3x3=np.linalg.inv(extrs_square[t].numpy())[:3, :3], )) rr.log(f"{video_name}/image/depth_for_tracking/depth", rr.DepthImage( image=depth[t].numpy(), point_fill_ratio=0.2, )) rr.log(f"{video_name}/image/depth_for_tracking/rgb", rr.Image(rgb)) # Log all other depth maps # d_zoe = depth_zoedepth[t, 0] d_moge = depth_moge[t] c2w_moge = np.linalg.inv(w2c_moge[t]) # for name, archetype in [ # ("depth-zoe", rr.DepthImage(d_zoe, point_fill_ratio=0.2)), # ("depth-moge", rr.DepthImage(d_moge, point_fill_ratio=0.2)), # ("depth-moge-with-extrinsics", rr.DepthImage(d_moge, point_fill_ratio=0.2)), # ]: # rr.log(f"{video_name}/image/{name}", rr.Pinhole(image_from_camera=K, width=W, height=H)) # rr.log(f"{video_name}/image/{name}/{name}", archetype) # if name == "depth-moge-with-extrinsics": # transform = rr.Transform3D(translation=c2w_moge[:3, 3], mat3x3=c2w_moge[:3, :3]) # rr.log(f"{video_name}/image/{name}", transform) # Convert depth map to 3D point cloud y, x = np.indices((H, W)) homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T for _name, _depth, _w2c in [ ("used_for_tracking", depth[t].numpy(), extrs_square[t]), # ("zoe", d_zoe, None), ("moge", d_moge, w2c_moge[t]), ("moge-with-extrinsics", d_moge, w2c_moge[t]), ]: depth_values = _depth.ravel() cam_coords = (K_inv @ homo_pixel_coords) * depth_values if _w2c is None: world_coords = cam_coords.T else: world_coords = from_homogeneous( np.einsum("ij,Nj->Ni", np.linalg.inv(_w2c), to_homogeneous(cam_coords.T))) valid_mask = depth_values > 0 world_coords = world_coords[valid_mask] rgb_colors = rgb.reshape(-1, 3)[valid_mask].astype(np.uint8) rr.log(f"{video_name}/pointcloud/{_name}", rr.Points3D(world_coords, colors=rgb_colors, radii=0.001)) def log_tracks( tracks: np.ndarray, visibles: np.ndarray, query_timestep: np.ndarray, colors: np.ndarray, track_names=None, entity_format_str="{}", log_points=True, points_radii=0.03 * radii_scale, invisible_color=[0., 0., 0.], log_line_strips=True, max_strip_length_past=6, max_strip_length_future=1, hide_invisible_strips=True, strips_radii=0.0027 * radii_scale, log_error_lines=False, error_lines_radii=0.0042 * radii_scale, error_lines_color=[1., 0., 0.], gt_for_error_lines=None, ) -> None: """ Log tracks to Rerun. Parameters: tracks: Shape (T, N, 3), the 3D trajectories of points. visibles: Shape (T, N), boolean visibility mask for each point at each timestep. query_timestep: Shape (T, N), the frame index after which the tracks start. colors: Shape (N, 4), RGBA colors for each point. entity_prefix: String prefix for entity hierarchy in Rerun. entity_suffix: String suffix for entity hierarchy in Rerun. """ T, N, _ = tracks.shape assert tracks.shape == (T, N, 3) assert visibles.shape == (T, N) assert query_timestep.shape == (N,) assert query_timestep.min() >= 0 assert query_timestep.max() < T assert colors.shape == (N, 4) for n in range(N): track_name = track_names[n] if track_names is not None else f"track-{n}" rr.log(entity_format_str.format(track_name, rr.Clear(recursive=True))) for t in range(query_timestep[n], T): rr.set_time_seconds("frame", t / 12) # Log the point (special handling for invisible points) if log_points: rr.log( entity_format_str.format(f"{track_name}/point"), rr.Points3D( positions=[tracks[t, n]], colors=[colors[n, :3]] if visibles[t, n] else [invisible_color], radii=points_radii, ), ) # Log line segments for visible tracks if log_line_strips and t > query_timestep[n]: strip_t_start = max(t - max_strip_length_past, query_timestep[n].item()) strip_t_end = min(t + max_strip_length_future, T - 1) if not hide_invisible_strips: strips = np.stack([ tracks[strip_t_start:strip_t_end, n], tracks[strip_t_start + 1:strip_t_end + 1, n], ], axis=-2) strips_visibility = visibles[strip_t_start + 1:strip_t_end + 1, n] strips_colors = np.where( strips_visibility[:, None], colors[None, n, :3], [invisible_color], ) else: point_sequence = tracks[strip_t_start:strip_t_end + 1, n] point_sequence_visible = point_sequence[visibles[strip_t_start:strip_t_end + 1, n]] strips = np.stack([point_sequence_visible[:-1], point_sequence_visible[1:]], axis=-2) strips_colors = colors[None, n, :3] rr.log( entity_format_str.format(f"{track_name}/line"), rr.LineStrips3D(strips=strips, colors=strips_colors, radii=strips_radii), ) if log_error_lines: assert gt_for_error_lines is not None strips = np.stack([ tracks[t, n], gt_for_error_lines[t, n], ], axis=-2) rr.log( entity_format_str.format(f"{track_name}/error"), rr.LineStrips3D(strips=strips, colors=error_lines_color, radii=error_lines_radii), ) # Log the tracks trajs_3d_np = trajs_3d.cpu().numpy() visibles_np = visibles.cpu().numpy() query_timestep_np = query_points_3d[:, 0].cpu().numpy().round().astype(int) cmap = matplotlib.colormaps["gist_rainbow"] norm = matplotlib.colors.Normalize(vmin=trajs_3d_np[..., 0].min(), vmax=trajs_3d_np[..., 0].max()) track_color = cmap(norm(trajs_3d_np[-1, :, 0])) # track_color = track_color * 0 + 1 # Just make all tracks white log_tracks( tracks=trajs_3d_np, visibles=visibles_np, query_timestep=query_timestep_np, colors=track_color, entity_format_str=f"{video_name}/tracks/{{}}", max_strip_length_future=0, ) if not stream: rr_rrd_path = os.path.join(self.depth_cache_root, f"rerun_viz__{video_name}.rrd") rr.save(rr_rrd_path) logging.info(f"Saved Rerun recording to: {os.path.abspath(rr_rrd_path)}") V = 1 rgbs = rgbs[None] trajs = trajs[None] trajs_w_z = trajs_w_z[None] trajs_3d = trajs_3d query_points_3d = query_points_3d visibles = visibles[None] depth = depth[None, :, None] feats = None intrs = intrs[None] extrs = extrs[None] assert rgbs.shape == (V, T, 3, H, W) assert depth.shape == (V, T, 1, H, W) assert feats is None assert intrs.shape == (V, T, 3, 3) assert extrs.shape == (V, T, 3, 4) assert trajs.shape == (V, T, N, 2) assert trajs_w_z.shape == (V, T, N, 3) assert visibles.shape == (V, T, N) assert trajs_3d.shape == (T, N, 3) assert query_points_3d.shape == (N, 4) # Normalize the scene to be similar to training scenes rot_x = R.from_euler('x', -90, degrees=True).as_matrix() rot_y = R.from_euler('y', 0, degrees=True).as_matrix() rot_z = R.from_euler('z', 0, degrees=True).as_matrix() rot = rot_z @ rot_y @ rot_x T_rot = torch.eye(4) T_rot[:3, :3] = torch.from_numpy(rot) ## V1: GT track-agnostic transformation # scale = 10 # translate_x = 0 # translate_y = -15 # translate_z = 2 # # T_scale_and_translate = torch.tensor([ # [scale, 0.0, 0.0, translate_x], # [0.0, scale, 0.0, translate_y], # [0.0, 0.0, scale, translate_z], # [0.0, 0.0, 0.0, 1.0], # ], dtype=torch.float32) ## V2: GT track-aware transformation # Rotate the 3D GT tracks first trajs_3d_homo = torch.cat([trajs_3d, torch.ones_like(trajs_3d[..., :1])], dim=-1) trajs_3d_rotated = torch.einsum('ij,TNj->TNi', T_rot, trajs_3d_homo)[..., :3] # Mask out non-visible points visible_mask = visibles[0] # (T, N) trajs_3d_visible = trajs_3d_rotated[visible_mask] # (V, 3) # Compute bbox over only visible points bbox_min = trajs_3d_visible.amin(dim=0) bbox_max = trajs_3d_visible.amax(dim=0) bbox_center = (bbox_min + bbox_max) / 2 bbox_size = bbox_max - bbox_min # Target bounds (half-extent of desired cube) target_bounds = torch.tensor([10.0, 10.0, 6.0]) scale = (target_bounds / bbox_size).min().item() translation = -bbox_center * scale rot = torch.from_numpy(rot) # Optional: clamp depth map if needed (max Z-depth defined in scaled space) logging.info(f"[datapoint_idx={index}] Scale={scale:.2f}, Translate={translation.tolist()}") # depth[depth > 50 / scale] = 50 / scale depth[depth > 20] = 20 # Apply to scene ( depth_trans, extrs_trans, query_points_3d_trans, trajs_3d_trans, trajs_w_z_trans ) = transform_scene(scale, rot, translation, depth, extrs, query_points_3d, trajs_3d, trajs_w_z) assert torch.allclose(trajs_w_z[..., :2], trajs_w_z_trans[..., :2]) gotit = True return Datapoint( video=rgbs, videodepth=depth_trans, feats=None, segmentation=torch.ones(T, 1, H, W).float(), trajectory=trajs_w_z_trans, trajectory_3d=trajs_3d_trans, visibility=visibles, valid=torch.ones((T, N)), seq_name=str(video_name), intrs=intrs, extrs=extrs_trans, query_points=query_points_2d, query_points_3d=query_points_3d_trans, ), gotit def __len__(self): return len(self.points_dataset) @torch.no_grad() def zoedepth_nk(rgbs, batch_size=2, device="cuda", cached_file=None): if cached_file is not None and os.path.exists(cached_file): return np.load(cached_file)["depth"] # needs timm==0.6.7, but megasam needs timm==1.0.15 model = torch.hub.load("isl-org/ZoeDepth", "ZoeD_NK", pretrained=True).to(device) model.eval() T, _, H, W = rgbs.shape depth = [] for i in range(0, T, batch_size): rgbs_i = rgbs[i:i + batch_size].to(device) / 255. depth_i = model.infer(rgbs_i).clamp(0.01, 65.0).cpu() depth.append(depth_i) depth = torch.cat(depth, dim=0).numpy()[:, 0] if cached_file is not None: np.savez(cached_file, depth=depth) del model torch.cuda.empty_cache() return depth def rigid_registration( p: np.ndarray, q: np.ndarray, w: np.ndarray = None, eps: float = 1e-12 ) -> Tuple[float, np.ndarray, np.ndarray]: from moge.utils.geometry_numpy import weighted_mean_numpy if w is None: w = np.ones(p.shape[0]) centroid_p = weighted_mean_numpy(p, w[:, None], axis=0) centroid_q = weighted_mean_numpy(q, w[:, None], axis=0) p_centered = p - centroid_p q_centered = q - centroid_q w = w / (np.sum(w) + eps) cov = (w[:, None] * p_centered).T @ q_centered U, S, Vh = np.linalg.svd(cov) R = Vh.T @ U.T if np.linalg.det(R) < 0: Vh[2, :] *= -1 R = Vh.T @ U.T scale = np.sum(S) / np.trace((w[:, None] * p_centered).T @ p_centered) t = centroid_q - scale * (centroid_p @ R.T) return scale, R, t def rigid_registration_ransac( p: np.ndarray, q: np.ndarray, w: np.ndarray = None, max_iters: int = 20, hypothetical_size: int = 10, inlier_thresh: float = 0.02 ) -> Tuple[Tuple[float, np.ndarray, np.ndarray], np.ndarray]: n = p.shape[0] if w is None: w = np.ones(p.shape[0]) best_score, best_inlines = 0., np.zeros(n, dtype=bool) best_solution = (np.array(1.), np.eye(3), np.zeros(3)) for _ in range(max_iters): maybe_inliers = np.random.choice(n, size=hypothetical_size, replace=False) try: s, R, t = rigid_registration(p[maybe_inliers], q[maybe_inliers], w[maybe_inliers]) except np.linalg.LinAlgError: continue transformed_p = s * p @ R.T + t errors = w * np.linalg.norm(transformed_p - q, axis=1) inliers = errors < inlier_thresh score = inlier_thresh * n - np.clip(errors, None, inlier_thresh).sum() if score > best_score: best_score, best_inlines = score, inliers best_solution = rigid_registration(p[inliers], q[inliers], w[inliers]) return best_solution, best_inlines def to_homogeneous(x): return np.concatenate([x, np.ones_like(x[..., :1])], axis=-1) def from_homogeneous(x, assert_homogeneous_part_is_equal_to_1=False, eps=0.001): if assert_homogeneous_part_is_equal_to_1: assert np.allclose(x[..., -1:], 1, atol=eps), f"Expected homogeneous part to be 1, got {x[..., -1:]}" return x[..., :-1] / x[..., -1:] def to_homogenous_torch(x): return torch.cat([x, torch.ones_like(x[..., :1])], axis=-1) @torch.no_grad() def moge(rgbs, batch_size=10, device="cuda", cached_file=None, intrinsics=None): if cached_file is not None and os.path.exists(cached_file): cached_data = np.load(cached_file) depths_with_normalized_scale = cached_data["depth"] points_in_world_space = cached_data["points"] w2c = cached_data["w2c"] intrinsics = cached_data["intrinsics"] mask = cached_data["mask"] return depths_with_normalized_scale, intrinsics, w2c, points_in_world_space, mask # git clone https://github.com/microsoft/MoGe.git ../moge # cd ../moge # git checkout dd158c0 sys.path.append("../moge") # TODO: Find a clean way to do this so that it is not hardcoded from moge.model import MoGeModel import utils3d model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device) T, _, H, W = rgbs.shape assert rgbs.shape == (T, 3, H, W) points = [] depth = [] mask = [] for rgb in rgbs: rgb = rgb.to(device) output = model.infer( image=rgb / 255, resolution_level=9, force_projection=True, apply_mask=True, fov_x=np.rad2deg(utils3d.intrinsics_to_fov(intrinsics)[0]) if intrinsics is not None else None, ) points.append(output["points"].cpu().numpy()) depth.append(output["depth"].cpu().numpy()) mask.append(output["mask"].cpu().numpy()) if intrinsics is None: intrinsics = output["intrinsics"].cpu().numpy() assert np.allclose(intrinsics, output["intrinsics"].cpu().numpy(), atol=0.01), "Intr. changed between frames" points = np.stack(points) depth = np.stack(depth) mask = np.stack(mask) intrinsics = np.diag([W, H, 1]) @ intrinsics # Assert we can reproduce the points from the depth maps already (should be enforced with force_projection=True) pixel_xy = np.stack(np.meshgrid(np.arange(W), np.arange(H)), axis=-1) pixel_xy_homo = to_homogeneous(pixel_xy) depthmap_camera_xyz = np.einsum('ij,HWj->HWi', np.linalg.inv(intrinsics), pixel_xy_homo) depthmap_camera_xyz = depthmap_camera_xyz[None, :, :, :] * depth[:, :, :, None] valid = mask & (depth > 0) assert np.allclose(points[valid], depthmap_camera_xyz[valid], atol=1, rtol=0.1) depths_with_normalized_scale = depth.copy() points_in_world_space = points.copy() w2c = np.eye(4)[None].repeat(T, axis=0) for t in range(1, T): valid_p = mask[t] & (depth[t] > 0) # & (depth[t] <= 4.20) # TODO: magic number here! valid_q = mask[t - 1] & (depth[t] > 0) # & (depth[t] <= 4.20) # TODO: magic number here! valid = valid_p & valid_q (scale, rotation, translation), inliers = rigid_registration_ransac( p=points[t][valid].reshape(-1, 3), q=points_in_world_space[t - 1][valid].reshape(-1, 3), w=(1 / depths_with_normalized_scale[t - 1][valid]).reshape(-1), max_iters=20, hypothetical_size=10, inlier_thresh=0.02 ) depths_with_normalized_scale[t] = scale * depths_with_normalized_scale[t] # Transforming points[t] -> points_in_world_space[t - 1] already tells us how to transform to the # world space since points_in_world_space[t - 1] had already been transformed to the world space points_in_world_space[t] = scale * points_in_world_space[t] @ rotation.T + translation # I prefer to use column vectors: Q = q.T, P = p.T # q = p @ R.T + t -> Q = R @ P + t.T # p = q @ R - t @ rotation -> P = R.T @ Q - R.T @ t.T w2c[t, :3, :3] = rotation.T w2c[t, :3, 3] = -rotation.T @ translation.T # Assert no nans assert not np.isnan(depths_with_normalized_scale).any() assert not np.isnan(w2c).any() assert np.allclose(w2c[:, 3, 3], 1) # Now let's make sure we can go from scale-normalized depth maps to the points in world space # Pixel --> Camera --> World pixel_xy = np.stack(np.meshgrid(np.arange(W), np.arange(H)), axis=-1) pixel_xy_homo = to_homogeneous(pixel_xy) depthmap_camera_xyz = np.einsum('ij,HWj->HWi', np.linalg.inv(intrinsics), pixel_xy_homo) depthmap_camera_xyz = depthmap_camera_xyz[None, :, :, :] * depths_with_normalized_scale[:, :, :, None] depthmap_camera_xyz_homo = to_homogeneous(depthmap_camera_xyz) depthmap_world_xyz_homo = np.einsum('Tij,THWj->THWi', np.linalg.inv(w2c), depthmap_camera_xyz_homo) depthmap_world_xyz = from_homogeneous(depthmap_world_xyz_homo) points_in_world_space_reproduced = depthmap_world_xyz valid = mask & (depths_with_normalized_scale > 0) assert np.allclose(points_in_world_space[valid], points_in_world_space_reproduced[valid], atol=0.1, rtol=0.1) if cached_file is not None: np.savez( cached_file, depth=depths_with_normalized_scale, points=points_in_world_space, w2c=w2c, intrinsics=intrinsics, mask=mask, ) return depths_with_normalized_scale, intrinsics, w2c, points_in_world_space, mask def megasam(rgbs: torch.Tensor, batch_size: int = 10, device: str = "cuda", cached_file: Optional[str] = None): if cached_file is not None and os.path.exists(cached_file): cached_data = np.load(cached_file) return ( cached_data["depths"].astype(np.float32), cached_data["intrinsics"].astype(np.float32), cached_data["extrinsics"].astype(np.float32), ) # else: # raise NotImplementedError("TMP ERR") T, C, H, W = rgbs.shape assert C == 3, "Expected shape (T, 3, H, W)" # Convert to NumPy format for MegaSAM (T, H, W, 3), uint8 [0, 255] rgbs_np = (rgbs.permute(0, 2, 3, 1).cpu().numpy()).astype(np.uint8) # git clone https://github.com/zbw001/TAPIP3D.git ../tapip3d # cd ../tapip3d # git checkout 8871375 sys.path.append("../tapip3d") from annotation.megasam import MegaSAMAnnotator megasam = MegaSAMAnnotator( script_path=Path("../tapip3d") / "third_party" / "megasam" / "inference.py", depth_model="moge", resolution=H * W ) megasam.to(device) depths, intrinsics, extrinsics = megasam.process_video( rgbs=rgbs_np, gt_intrinsics=None, return_raw_depths=False, ) if cached_file is not None: np.savez(cached_file, depths=depths, intrinsics=intrinsics, extrinsics=extrinsics) return depths, intrinsics, extrinsics ================================================ FILE: mvtracker/datasets/utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import dataclasses import json import pathlib from dataclasses import dataclass from typing import Any, Optional, List import numpy as np import png import torch from torch.nn import functional as F from torchvision.transforms import functional as TF from mvtracker.utils.basic import to_homogeneous, from_homogeneous @dataclass(eq=False) class Datapoint: """ Dataclass for storing video tracks data. """ video: torch.Tensor # B, S, C, H, W segmentation: torch.Tensor # B, S, 1, H, W # optional data videodepth: Optional[torch.Tensor] = None # B, S, 1, H, W videodepthconf: Optional[torch.Tensor] = None # B, S, 1, H, W feats: Optional[torch.Tensor] = None # B, S, C, H_strided, W_strided valid: Optional[torch.Tensor] = None # B, S, N seq_name: Optional[List[str]] = None # B intrs: Optional[torch.Tensor] = torch.eye(3).unsqueeze(0) # B, 3, 3 query_points: Optional[torch.Tensor] = None # TapVID evaluation format query_points_3d: Optional[torch.Tensor] = None # TapVID evaluation format trajectory: Optional[torch.Tensor] = None # B, S, N, 2 visibility: Optional[torch.Tensor] = None # B, S, N trajectory_3d: Optional[torch.Tensor] = None # B, S, 4, 4 trajectory_category: Optional[torch.Tensor] = None # B, S, 1 extrs: Optional[torch.Tensor] = None # B, S, 4, 4 track_upscaling_factor: Optional[float] = 1.0 novel_video: Optional[torch.Tensor] = None # B, S, C, H, W novel_intrs: Optional[torch.Tensor] = torch.eye(3).unsqueeze(0) # B, 3, 3 novel_extrs: Optional[torch.Tensor] = None # B, S, 4, 4 def collate_fn(batch): gotit = [gotit for _, gotit in batch] video = torch.stack([b.video for b, _ in batch], dim=0) videodepth = torch.stack([b.videodepth for b, _ in batch], dim=0) segmentation = torch.stack([b.segmentation for b, _ in batch], dim=0) seq_name = [b.seq_name for b, _ in batch] intrs = torch.stack([b.intrs for b, _ in batch], dim=0) videodepthconf = ( torch.stack([b.videodepthconf for b, _ in batch], dim=0) if batch[0][0].videodepthconf is not None else None ) feats = ( torch.stack([b.feats for b, _ in batch], dim=0) if batch[0][0].feats is not None else None ) trajectory = ( torch.stack([b.trajectory for b, _ in batch], dim=0) if batch[0][0].trajectory is not None else None ) valid = ( torch.stack([b.valid for b, _ in batch], dim=0) if batch[0][0].valid is not None else None ) visibility = ( torch.stack([b.visibility for b, _ in batch], dim=0) if batch[0][0].visibility is not None else None ) trajectory_3d = ( torch.stack([b.trajectory_3d for b, _ in batch], dim=0) if batch[0][0].trajectory_3d is not None else None ) extrs = ( torch.stack([b.extrs for b, _ in batch], dim=0) if batch[0][0].extrs is not None else None ) query_points = ( torch.stack([b.query_points for b, _ in batch], dim=0) if batch[0][0].query_points is not None else None ) query_points_3d = ( torch.stack([b.query_points_3d for b, _ in batch], dim=0) if batch[0][0].query_points_3d is not None else None ) track_upscaling_factor = batch[0][0].track_upscaling_factor novel_video = None novel_intrs = None novel_extrs = None if batch[0][0].novel_video is not None: novel_video = torch.stack([b.novel_video for b, _ in batch], dim=0) novel_intrs = torch.stack([b.novel_intrs for b, _ in batch], dim=0) novel_extrs = torch.stack([b.novel_extrs for b, _ in batch], dim=0) return ( Datapoint( video=video, videodepth=videodepth, videodepthconf=videodepthconf, feats=feats, segmentation=segmentation, trajectory=trajectory, trajectory_3d=trajectory_3d, visibility=visibility, valid=valid, seq_name=seq_name, intrs=intrs, extrs=extrs, query_points=query_points, query_points_3d=query_points_3d, track_upscaling_factor=track_upscaling_factor, novel_video=novel_video, novel_intrs=novel_intrs, novel_extrs=novel_extrs ), gotit, ) def try_to_cuda(t: Any) -> Any: """ Try to move the input variable `t` to a cuda device. Args: t: Input. Returns: t_cuda: `t` moved to a cuda device, if supported. """ try: t = t.float().cuda() except AttributeError: pass return t def dataclass_to_cuda_(obj): """ Move all contents of a dataclass to cuda inplace if supported. Args: batch: Input dataclass. Returns: batch_cuda: `batch` moved to a cuda device, if supported. """ for f in dataclasses.fields(obj): setattr(obj, f.name, try_to_cuda(getattr(obj, f.name))) return obj def read_json(filename: str) -> Any: with open(filename, "r") as fp: return json.load(fp) def read_tiff(filename: str) -> np.ndarray: import imageio img = imageio.v2.imread(pathlib.Path(filename).read_bytes(), format="tiff") if img.ndim == 2: img = img[:, :, None] return img def read_png(filename: str, rescale_range=None) -> np.ndarray: png_reader = png.Reader(bytes=pathlib.Path(filename).read_bytes()) width, height, pngdata, info = png_reader.read() del png_reader bitdepth = info["bitdepth"] if bitdepth == 8: dtype = np.uint8 elif bitdepth == 16: dtype = np.uint16 else: raise NotImplementedError(f"Unsupported bitdepth: {bitdepth}") plane_count = info["planes"] pngdata = np.vstack(list(map(dtype, pngdata))) if rescale_range is not None: minv, maxv = rescale_range pngdata = pngdata / 2 ** bitdepth * (maxv - minv) + minv return pngdata.reshape((height, width, plane_count)) def transform_scene( transformation_scale: float = 1.0, transformation_rotation: torch.Tensor = torch.eye(3, dtype=torch.float32), transformation_translation: torch.Tensor = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32), depth: torch.Tensor = None, # [V,T,1,H,W] extrs: torch.Tensor = None, # [V,T,3,4] world->cam query_points: torch.Tensor = None, # [N,4] (t, x, y, z) in world traj3d_world: torch.Tensor = None, # [T,N,3] traj2d_w_z: torch.Tensor = None, # [V,T,N,3] (x_px, y_px, z_cam) ): """ Make the world space `transformation_scale` larger, then rotate it by `transformation_rotation`, then translate it by `transformation_translation`. In other words, apply the following transformation: X_world' = transformation_translation + transformation_rotation @ (transformation_scale * X_world). Implemented as: - depth (z_cam) *= scale - extrinsics: scale translation by 'scale', then right-multiply by rigid inverse - query/world trajectories: scale then rigid - traj2d_w_z: only z scaled; (x,y) unchanged """ is_rot_orthonormal = torch.allclose( transformation_rotation @ transformation_rotation.T, torch.eye(3, dtype=transformation_rotation.dtype, device=transformation_rotation.device), atol=1e-3, ) assert is_rot_orthonormal, "The rotation matrix should be orthonormal." Rt = torch.eye(4, dtype=transformation_rotation.dtype, device=transformation_rotation.device) Rt[:3, :3] = transformation_rotation Rt[:3, 3] = transformation_translation # Transform depth if depth is not None: depth_trans = depth * transformation_scale else: depth_trans = None # Transform extrinsics if extrs is not None: n_views, n_frames, _, _ = extrs.shape assert extrs.shape == (n_views, n_frames, 3, 4) src_dtype = extrs.dtype extrs = extrs.type(Rt.dtype) extrs_trans_square = torch.eye(4, dtype=extrs.dtype, device=extrs.device).repeat(n_views, n_frames, 1, 1) extrs_trans_square[:, :, :3, :3] = extrs[:, :, :3, :3] extrs_trans_square[:, :, :3, 3] = extrs[:, :, :3, 3] * transformation_scale extrs_trans_square = torch.einsum('ABki,ij->ABkj', extrs_trans_square, torch.inverse(Rt)) extrs_trans = extrs_trans_square[..., :3, :] extrs_trans = extrs_trans.type(src_dtype) else: extrs_trans = None # Transform query points if query_points is not None: n_tracks = query_points.shape[0] assert query_points.shape == (n_tracks, 4) src_dtype = query_points.dtype query_points = query_points.type(Rt.dtype) query_points_xyz_scaled_homo = to_homogeneous(query_points[..., 1:4] * transformation_scale) query_points_xyz_trans_homo = torch.einsum('ij,Nj->Ni', Rt, query_points_xyz_scaled_homo) query_points_xyz_trans = from_homogeneous(query_points_xyz_trans_homo) query_points_trans = torch.cat([query_points[..., :1], query_points_xyz_trans], dim=-1) query_points_trans = query_points_trans.type(src_dtype) else: query_points_trans = None # Transform 3D trajectories if traj3d_world is not None: n_frames, n_tracks, _ = traj3d_world.shape assert traj3d_world.shape == (n_frames, n_tracks, 3) src_dtype = traj3d_world.dtype traj3d_world = traj3d_world.type(Rt.dtype) traj3d_world_scaled_homo = to_homogeneous(traj3d_world * transformation_scale) traj3d_world_trans_homo = torch.einsum('ij,TNj->TNi', Rt, traj3d_world_scaled_homo) traj3d_world_trans = from_homogeneous(traj3d_world_trans_homo) traj3d_world_trans = traj3d_world_trans.type(src_dtype) else: traj3d_world_trans = None # Transform 2D+depth trajectories if traj2d_w_z is not None: n_views, n_frames, n_tracks, _ = traj2d_w_z.shape assert traj2d_w_z.shape == (n_views, n_frames, n_tracks, 3) traj2d_w_z_trans = traj2d_w_z.clone() traj2d_w_z_trans[:, :, :, 2] *= transformation_scale else: traj2d_w_z_trans = None return depth_trans, extrs_trans, query_points_trans, traj3d_world_trans, traj2d_w_z_trans def add_camera_noise(intrs, extrs, noise_std_intr=0.01, noise_std_extr=0.001, rnd=np.random): """ Add small Gaussian noise to intrinsic and extrinsic camera parameters. Args: intrs (np.ndarray): (V, T, 3, 3) intrinsic matrices. extrs (np.ndarray): (V, T, 3, 4) extrinsic matrices. noise_std_intr (float): Standard deviation of intrinsic matrix noise. noise_std_extr (float): Standard deviation of extrinsic matrix noise. rnd (module): Random number generator (e.g., np.random or torch). Returns: intrs (same type as input): Noisy intrinsic matrices. extrs (same type as input): Noisy extrinsic matrices. """ V, T, _, _ = intrs.shape assert isinstance(intrs, np.ndarray) assert intrs.shape == (V, T, 3, 3) assert extrs.shape == (V, T, 3, 4) intrs, extrs = intrs.copy(), extrs.copy() intrs += rnd.normal(0, noise_std_intr, size=intrs.shape) extrs += rnd.normal(0, noise_std_extr, size=extrs.shape) return intrs, extrs def aug_depth(depth, grid=(8, 8), scale=(0.7, 1.3), shift=(-0.1, 0.1), gn_kernel=(7, 7), gn_sigma=(2.0, 2.0), generator=None): """ Augment depth for training. """ B, T, H, W = depth.shape msk = (depth != 0) # fallback to global generator if none is provided gen = generator if generator is not None else torch.default_generator # generate scale and shift maps H_s, W_s = grid scale_map = (torch.rand(B, T, H_s, W_s, device=depth.device, generator=gen) * (scale[1] - scale[0]) + scale[0]) shift_map = (torch.rand(B, T, H_s, W_s, device=depth.device, generator=gen) * (shift[1] - shift[0]) + shift[0]) # scale and shift the depth map scale_map = F.interpolate(scale_map, (H, W), mode='bilinear', align_corners=True) shift_map = F.interpolate(shift_map, (H, W), mode='bilinear', align_corners=True) # local scale and shift the depth depth[msk] = (depth[msk] * scale_map[msk]) + shift_map[msk] * (depth[msk].mean()) # gaussian blur depth = TF.gaussian_blur(depth, kernel_size=gn_kernel, sigma=gn_sigma) depth[~msk] = 0 return depth def align_umeyama(model, data, known_scale=False, yaw_only=False): mu_M = model.mean(0) mu_D = data.mean(0) model_zerocentered = model - mu_M data_zerocentered = data - mu_D n = np.shape(model)[0] # correlation C = 1.0 / n * np.dot(model_zerocentered.transpose(), data_zerocentered) sigma2 = 1.0 / n * np.multiply(data_zerocentered, data_zerocentered).sum() U_svd, D_svd, V_svd = np.linalg.linalg.svd(C) D_svd = np.diag(D_svd) V_svd = np.transpose(V_svd) S = np.eye(3) if np.linalg.det(U_svd) * np.linalg.det(V_svd) < 0: S[2, 2] = -1 if yaw_only: rot_C = np.dot(data_zerocentered.transpose(), model_zerocentered) theta = get_best_yaw(rot_C) R = rot_z(theta) else: R = np.dot(U_svd, np.dot(S, np.transpose(V_svd))) if known_scale: s = 1 else: s = 1.0 / sigma2 * np.trace(np.dot(D_svd, S)) t = mu_M - s * np.dot(R, mu_D) return s, R, t def get_camera_center(extr): R = extr[:, :3] t = extr[:, 3] return -R.T @ t def apply_sim3_to_extrinsics(vggt_extrinsics, s, R_align, t_align): aligned_extrinsics = [] R_inv = R_align.T t_inv = -R_inv @ t_align / s for extr in vggt_extrinsics: extr_h = np.eye(4) extr_h[:3, :4] = extr sim3_inv = np.eye(4) sim3_inv[:3, :3] = R_inv / s sim3_inv[:3, 3] = t_inv aligned = extr_h @ sim3_inv aligned_extrinsics.append(aligned[:3, :]) return aligned_extrinsics def get_best_yaw(C): """ maximize trace(Rz(theta) * C) """ assert C.shape == (3, 3) A = C[0, 1] - C[1, 0] B = C[0, 0] + C[1, 1] theta = np.pi / 2 - np.arctan2(B, A) return theta def rot_z(theta): R = rotation_matrix(theta, [0, 0, 1]) R = R[0:3, 0:3] return R ================================================ FILE: mvtracker/evaluation/__init__.py ================================================ ================================================ FILE: mvtracker/evaluation/evaluator_3dpt.py ================================================ import json import logging import os import re import time import warnings from collections import namedtuple from typing import Iterable from typing import Optional import imageio import matplotlib.cm as cm import numpy as np import rerun as rr import torch from sklearn.cluster import KMeans from threadpoolctl import threadpool_limits from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from mvtracker.datasets.utils import dataclass_to_cuda_ from mvtracker.evaluation.metrics import compute_tapvid_metrics_original, evaluate_predictions from mvtracker.models.core.model_utils import world_space_to_pixel_xy_and_camera_z, \ pixel_xy_and_camera_z_to_world_space, init_pointcloud_from_rgbd from mvtracker.utils.visualizer_mp4 import log_mp4_track_viz from mvtracker.utils.visualizer_rerun import log_pointclouds_to_rerun, log_tracks_to_rerun class NumpyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.ndarray): if obj.size == 1: return obj.item() return obj.tolist() if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): return float(obj) return json.JSONEncoder.default(self, obj) def kmeans_sample(pts, count): """ Given (N, 3) torch tensor of 3D points, return (count, 3) tensor of kmeans centers. """ if len(pts) <= count: return pts logging.info(f"Computing k-means (k={count}, N={len(pts)})...") start = time.time() with threadpool_limits(limits=1): pts_np = pts.detach().cpu().numpy() kmeans = KMeans(n_clusters=count, n_init='auto', random_state=0).fit(pts_np) duration = time.time() - start logging.info(f"K-means clustering completed in {duration:.2f} seconds.") centers = torch.tensor(kmeans.cluster_centers_, dtype=pts.dtype, device=pts.device) return centers def evaluate_3dpt( gt_tracks, gt_visibilities, pred_tracks, pred_visibilities, evaluation_setting, track_upscaling_factor, query_points=None, prefix="3dpt", verbose=True, add_per_track_results=True, ): n_frames, n_tracks, n_point_dim = gt_tracks.shape assert gt_tracks.shape == pred_tracks.shape assert gt_visibilities.shape == (n_frames, n_tracks) assert pred_visibilities.shape == (n_frames, n_tracks) if query_points is None: query_points_frame_id = gt_visibilities.argmax(axis=0) query_points_xyz = gt_tracks[query_points_frame_id, np.arange(gt_tracks.shape[1]), :] query_points = np.concatenate([query_points_frame_id[:, None], query_points_xyz], axis=-1) else: query_points_frame_id = query_points[:, 0].astype(int) query_points_xyz = query_points[:, 1:] if evaluation_setting == "kubric-multiview": assert n_point_dim == 3 distance_thresholds = [0.05, 0.1, 0.2, 0.4, 0.8] # The scale is non-metric survival_distance_threshold = 0.5 # 50 cm static_threshold = 0.01 # < 1 cm dynamic_threshold = 0.1 # > 10 cm very_dynamic_threshold = 2.0 # > 2 m elif evaluation_setting == "dexycb-multiview": assert n_point_dim == 3 distance_thresholds = [0.01, 0.02, 0.05, 0.1, 0.2] # 1 cm, 2 cm, 5 cm, 10 cm, 20 cm survival_distance_threshold = 0.1 # 10 cm static_threshold = 0.01 # < 1 cm dynamic_threshold = 0.1 # > 10 cm very_dynamic_threshold = 0.5 # > 50 cm elif evaluation_setting == "panoptic-multiview": assert n_point_dim == 3 distance_thresholds = [0.05, 0.10, 0.20, 0.40] # from 5 cm to 80 cm survival_distance_threshold = 1.0 # 1 m static_threshold = None dynamic_threshold = None very_dynamic_threshold = None elif evaluation_setting == "tapvid2d": assert n_point_dim == 2 distance_thresholds = [1, 2, 4, 8, 16] # pixels survival_distance_threshold = 50 static_threshold = None dynamic_threshold = None very_dynamic_threshold = None elif evaluation_setting == "2dpt_ablation": assert n_point_dim == 2 distance_thresholds = [1, 2, 4, 8, 16] # pixels survival_distance_threshold = 50 static_threshold = 1 dynamic_threshold = 1 very_dynamic_threshold = 50 else: raise NotImplementedError if verbose: logging.info(f"n_frames: {n_frames}, n_tracks: {n_tracks}") logging.info(f"GT TRACKS (min, max): {gt_tracks.min()}, {gt_tracks.max()}") logging.info(f"query_poits_xyz (min, max): {query_points_xyz.min()}, {query_points_xyz.max()}") df_model, df_model_per_track = evaluate_predictions( gt_tracks * track_upscaling_factor, gt_visibilities, pred_tracks * track_upscaling_factor, ~pred_visibilities, np.concatenate([query_points[:, 0:1], query_points[:, 1:] * track_upscaling_factor], axis=-1), distance_thresholds=distance_thresholds, survival_distance_threshold=survival_distance_threshold, static_threshold=static_threshold, dynamic_threshold=dynamic_threshold, very_dynamic_threshold=very_dynamic_threshold, ) if verbose: logging.info(f"DF Model:\n{df_model}") logging.info(f"DF Model:\n{df_model.loc[['average_pts_within_thresh', 'survival']]}") # Save to results_dict results_dict = {} # For dynamic points, report all metrics for point_type in ["dynamic-static-mean", "dynamic", "very_dynamic", "static", "any"]: if f'all_{point_type}' not in df_model.columns: continue for metric in sorted(df_model.index): results_dict[f'{prefix}/model__{metric}__{point_type}'] = df_model.loc[metric, f'all_{point_type}'] # For other point types, report only selected metrics for point_type in []: if f'all_{point_type}' not in df_model.columns: continue for metric in ["average_pts_within_thresh", "survival", "occlusion_accuracy", "average_jaccard"]: results_dict[f'{prefix}/model__{metric}__{point_type}'] = df_model.loc[metric, f'all_{point_type}'] for k in results_dict: results_dict[k] = results_dict[k].item() if verbose: logging.info(f"3DPT results:\n{results_dict}") if add_per_track_results: results_dict[f'{prefix}/model__per_track_results'] = df_model_per_track return results_dict class Evaluator: def __init__( self, rerun_viz_indices: Optional[Iterable[int]] = None, forward_pass_log_indices: Optional[Iterable[int]] = None, mp4_track_viz_indices: Optional[Iterable[int]] = (0, 3, 4, 5), ) -> None: """ Initializes the Evaluator. Parameters ---------- rerun_viz_indices : Optional[Iterable[int]] Indices of datapoints for which rerun 3D visualizations should be saved. If None, no rerun visualizations will be logged. forward_pass_log_indices : Optional[Iterable[int]] Indices of datapoints for which debug logs from the model's forward pass should be saved. If None, no forward pass debug logs will be generated. mp4_track_viz_indices : Optional[Iterable[int]] Indices of datapoints for which 2D trajectory visualizations (MP4 videos) should be saved. If None, MP4 visualizations will not be generated. """ self.rerun_viz_indices = rerun_viz_indices self.forward_pass_log_indices = forward_pass_log_indices self.mp4_track_viz_indices = mp4_track_viz_indices if self.rerun_viz_indices is None: self.rerun_viz_indices = [] if self.forward_pass_log_indices is None: self.forward_pass_log_indices = [] if self.mp4_track_viz_indices is None: self.mp4_track_viz_indices = [] @torch.no_grad() def evaluate_sequence( self, model, test_dataloader, dataset_name, log_dir, writer: Optional[SummaryWriter] = None, step: Optional[int] = 0, ): metrics = {} assert len(test_dataloader) > 0 total_fps = 0.0 count = 0 for datapoint_idx, datapoint in enumerate(tqdm(test_dataloader)): should_save_mp4_viz = datapoint_idx in self.mp4_track_viz_indices should_save_forward_pass_logs = datapoint_idx in self.forward_pass_log_indices should_save_rerun_viz = datapoint_idx in self.rerun_viz_indices # Hotfix for debugging: Load an edge-case datapoint directly from disk if False: # Batch 10060 datapoint = torch.load("logs/debug/ablation-E07/mvtracker-ptv3-512/crash_batch_step_010060.pt", map_location="cuda:0") # (datapoint.videodepth > 0).float().mean() --> 0 # Batch 8145 datapoint = torch.load("logs/ablation-E07/mvtracker-ptv3-512-2/crash_batch_step_008145.pt", map_location="cuda:0") datapoint.videodepth = datapoint.videodepth.clip(0.0, 1000.0) should_save_mp4_viz = True should_save_rerun_viz = True should_save_forward_pass_logs = False model.model.use_ptv3 = False if isinstance(datapoint, tuple) or isinstance(datapoint, list) and len(datapoint) == 2: datapoint, gotit = datapoint if not all(gotit): logging.warning("batch is None") continue if torch.cuda.is_available(): dataclass_to_cuda_(datapoint) device = torch.device("cuda") else: device = torch.device("cpu") # Per view data rgbs = datapoint.video depths = datapoint.videodepth depths_conf = datapoint.videodepthconf image_features = datapoint.feats intrs = datapoint.intrs extrs = datapoint.extrs gt_trajectories_2d_pixelspace_w_z_cameraspace = datapoint.trajectory gt_visibilities_per_view = datapoint.visibility query_points_2d = (datapoint.query_points.clone().float().to(device) if datapoint.query_points is not None else None) query_points_3d = (datapoint.query_points_3d.clone().float().to(device) if datapoint.query_points_3d is not None else None) # Non-per-view data gt_trajectories_3d_worldspace = datapoint.trajectory_3d valid_tracks_per_frame = datapoint.valid track_upscaling_factor = datapoint.track_upscaling_factor seq_name = datapoint.seq_name[0] # Novel view data novel_rgbs = datapoint.novel_video novel_intrs = datapoint.novel_intrs novel_extrs = datapoint.novel_extrs batch_size, num_views, num_frames, _, height, width = rgbs.shape # For generic datasets without labels, we will try sampling queries from depthmap points and around origin no_tracking_labels = False if query_points_2d is None and query_points_3d is None: no_tracking_labels = True assert batch_size == 1 assert gt_trajectories_2d_pixelspace_w_z_cameraspace is None assert gt_visibilities_per_view is None assert gt_trajectories_3d_worldspace is None assert valid_tracks_per_frame is None assert depths is not None assert depths_conf is not None # Config: (frame_idx, z_min, z_max, count) if "selfcap" in dataset_name: sampling_spec = [ (0, -0.1, 0.2, 1.8, 100, ""), (0, 0.2, 2.1, 1.8, 200, ""), # (0, 0.2, 2.1, 1.8, 200, "kmeans"), (36, 0.2, 2.1, 1.8, 200, ""), (120, 0.2, 2.1, 1.8, 200, ""), ] x0, y0, zmin, zmax, radius = 0.25, 0.7, -0.15, 3.6, 1.8 xyz, _ = init_pointcloud_from_rgbd( fmaps=depths_conf, depths=depths, intrs=intrs, extrs=extrs, stride=1, level=0, depth_interp_mode="N/A", ) x, y, z = xyz[..., 0], xyz[..., 1], xyz[..., 2] x -= x0 y -= y0 mask = (x ** 2 + y ** 2 < radius ** 2) & (z >= zmin) & (z <= zmax) mask = mask.reshape(batch_size, num_frames, num_views, height, width).permute(0, 2, 1, 3, 4) # depths[~mask[:, :, :, None, :, :]] = 0.0 # depths[depths_conf < 5] = 0.0 depths_conf[~mask[:, :, :, None, :, :]] = 2.0 elif "4d-dress" in dataset_name: sampling_spec = [ # (0, -10, +10, 10, 1500, ""), # (0, -10, +10, 10, 500, ""), (0, -10, +10, 10, 300, "kmeans"), # (72, -10, +10, 10, 500, "kmeans"), ] elif "hi4d" in dataset_name: sampling_spec = [ (0, -np.inf, +np.inf, np.inf, 1000, ""), ] else: sampling_spec = [ (0, -0.1, +4.2, 2.1, 1000, "kmeans"), ] depth_conf_threshold = 0.9 query_list = [] for t, zmin, zmax, radius, count, method in sampling_spec: if t >= num_frames: continue # skip invalid timestep dmap = depths[:, :, t:t + 1] conf = depths_conf[:, :, t:t + 1] xyz, c = init_pointcloud_from_rgbd( fmaps=conf, depths=dmap, intrs=intrs[:, :, t:t + 1], extrs=extrs[:, :, t:t + 1], stride=1, level=0, depth_interp_mode="N/A", ) xyz = xyz[0] # (N, 3) conf = c[0, :, 0] # (N,) valid = conf > depth_conf_threshold pts = xyz[valid] if pts.numel() == 0: continue x, y, z = pts[:, 0], pts[:, 1], pts[:, 2] mask = (x ** 2 + y ** 2 < radius ** 2) & (z >= zmin) & (z <= zmax) pts = pts[mask] if pts.numel() == 0: continue if len(pts) >= count: if method == "": pts = pts[torch.randperm(len(pts))[:count]] elif method == "kmeans": pts = kmeans_sample(pts, count) else: raise NotImplementedError t_col = torch.full((len(pts), 1), float(t), device=pts.device) query_list.append(torch.cat([t_col, pts], dim=1)) # Finalize query points query_points_3d = torch.cat(query_list, dim=0)[None] # (1, N, 4) # Dummy GT trajectory num_points = query_points_3d.shape[1] gt_trajectories_3d_worldspace = query_points_3d[:, None, :, 1:].repeat(1, num_frames, 1, 1) gt_trajectories_2d_pixelspace_w_z_cameraspace = torch.stack([ torch.cat(world_space_to_pixel_xy_and_camera_z( world_xyz=gt_trajectories_3d_worldspace[0], intrs=intrs[0, view_idx], extrs=extrs[0, view_idx], ), dim=-1) for view_idx in range(num_views) ], dim=0).unsqueeze(0) d = query_points_3d.device gt_visibilities_per_view = torch.ones((batch_size, num_views, num_frames, num_points), dtype=bool).to(d) valid_tracks_per_frame = torch.ones((batch_size, num_frames, num_points), dtype=bool).to(d) if no_tracking_labels and not any([should_save_mp4_viz, should_save_rerun_viz, should_save_forward_pass_logs]): continue # Assert shapes of per-view data num_points = gt_trajectories_2d_pixelspace_w_z_cameraspace.shape[3] assert depths is not None, "Depth is required for evaluation." assert rgbs.shape == (batch_size, num_views, num_frames, 3, height, width) assert depths.shape == (batch_size, num_views, num_frames, 1, height, width) assert depths_conf is None or depths_conf.shape == (batch_size, num_views, num_frames, 1, height, width) assert intrs.shape == (batch_size, num_views, num_frames, 3, 3) assert extrs.shape == (batch_size, num_views, num_frames, 3, 4) assert gt_trajectories_2d_pixelspace_w_z_cameraspace.shape == ( batch_size, num_views, num_frames, num_points, 3) assert gt_visibilities_per_view.shape == (batch_size, num_views, num_frames, num_points) # Assert shapes of non-per-view data assert query_points_3d.shape == (batch_size, num_points, 4) assert gt_trajectories_3d_worldspace.shape == (batch_size, num_frames, num_points, 3) assert valid_tracks_per_frame.shape == (batch_size, num_frames, num_points) # Dump the RGBs and depths to disk if should_save_rerun_viz: for v in range(num_views): rgb_path = os.path.join(log_dir, f"rgbs__{dataset_name}--seq-{datapoint_idx}__view-{v}.mp4") depth_path = os.path.join(log_dir, f"depths__{dataset_name}--seq-{datapoint_idx}__view-{v}.mp4") conf_path = os.path.join(log_dir, f"depth_confs__{dataset_name}--seq-{datapoint_idx}__view-{v}.mp4") # Precompute global min/max d_all = depths[0, v, :, 0].reshape(-1, height, width).cpu().numpy() d_min, d_max = d_all.min(), d_all.max() if depths_conf is not None: c_all = depths_conf[0, v, :, 0].reshape(-1, height, width).cpu().numpy() c_min, c_max = c_all.min(), c_all.max() # Colormaps depth_cmap = cm.get_cmap("turbo") conf_cmap = cm.get_cmap("inferno") rgb_video, depth_video, conf_video = [], [], [] for t in range(num_frames): rgb = (rgbs[0, v, t].permute(1, 2, 0).cpu().numpy()).astype(np.uint8) rgb_video.append(rgb) d = depths[0, v, t, 0].cpu().numpy() d_norm = (d - d_min) / (d_max - d_min + 1e-5) depth_color = (depth_cmap(d_norm)[..., :3] * 255).astype(np.uint8) depth_video.append(depth_color) if depths_conf is not None: c = depths_conf[0, v, t, 0].cpu().numpy() c_norm = (c - c_min) / (c_max - c_min + 1e-5) conf_color = (conf_cmap(c_norm)[..., :3] * 255).astype(np.uint8) conf_video.append(conf_color) if "selfcap-v1" in dataset_name: fps = 12 elif "4d-dress" in dataset_name or "egoexo4d" in dataset_name: fps = 30 else: fps = 12 imageio.mimsave(rgb_path, rgb_video, fps=fps) imageio.mimsave(depth_path, depth_video, fps=fps) if depths_conf is not None: imageio.mimsave(conf_path, conf_video, fps=fps) # Run the model fwd_kwargs = { "rgbs": rgbs, "depths": depths, "image_features": image_features, "query_points_3d": query_points_3d, "intrs": intrs, "extrs": extrs, "save_debug_logs": should_save_forward_pass_logs, "debug_logs_path": os.path.join( log_dir, f"forward_pass__eval_{dataset_name}_step-{step}_seq-{datapoint_idx}", ), "save_rerun_logs": should_save_rerun_viz, "save_rerun_logs_output_rrd_path": os.path.join( log_dir, f"rerun__{dataset_name}--seq-{datapoint_idx}--name-{seq_name}--fwd.rrd" ), } if "2dpt" in dataset_name: assert batch_size == 1 query_timestep = query_points_3d[0, :, 0].cpu().numpy().astype(int) query_points_view = gt_visibilities_per_view.argmax(dim=1)[0, query_timestep, torch.arange(num_points)] fwd_kwargs["query_points_view"] = query_points_view[None] start_time = time.time() if "shape_of_motion" in log_dir or "dynamic_3dgs" in log_dir: if "dynamic_3dgs" in log_dir: cached_output_path = os.path.join(log_dir, f"step-0_seq-{seq_name}_tracks.npz") else: cached_output_path = os.path.join(log_dir, f"step-{step}_seq-{seq_name}_tracks.npz") cached_output_path = re.sub(r"-novelviews\d+(_\d+)*", "", cached_output_path) assert os.path.exists(cached_output_path), cached_output_path cached_data = np.load(cached_output_path) if "dynamic_3dgs" in log_dir: results = { "traj_e": torch.from_numpy(cached_data["pred_trajectories_3d"]).to(device)[None], "vis_e": torch.from_numpy(cached_data["pred_visibilities_any_view"]).to(device).any(1), } else: results = { "traj_e": torch.from_numpy(cached_data["pred_trajectories_3d"]).to(device), "vis_e": torch.from_numpy(cached_data["pred_visibilities_any_view"]).to(device), } else: results = model(**fwd_kwargs) end_time = time.time() frames_processed = batch_size * num_frames elapsed = end_time - start_time fps = frames_processed / elapsed logging.info(f"[Datapoint {datapoint_idx}] FPS: {fps:.1f}") total_fps += fps count += 1 pred_trajectories = results["traj_e"] pred_visibilities = results["vis_e"] pred_trajectories_2d = results["traj2d_e"] if "traj2d_e" in results else None assert "strided" not in dataset_name, "Strided evaluation is not supported yet." # Determine the evaluation setting if "kubric" in dataset_name: evaluation_setting = "kubric-multiview" elif "panoptic-multiview" in dataset_name: evaluation_setting = "panoptic-multiview" elif "dex-ycb" in dataset_name: evaluation_setting = "dexycb-multiview" elif "tapvid2d" in dataset_name: evaluation_setting = "tapvid2d" elif no_tracking_labels: evaluation_setting = "no-tracking-labels" else: raise NotImplementedError # Invert the intrinsics and extrinsics matrices intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype) extrs_square = torch.eye(4).to(extrs.device)[None].repeat(batch_size, num_views, num_frames, 1, 1) extrs_square[:, :, :, :3, :] = extrs extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype) assert intrs_inv.shape == (batch_size, num_views, num_frames, 3, 3) assert extrs_inv.shape == (batch_size, num_views, num_frames, 4, 4) # Project the predictions to pixel space for visualization pred_trajectories_pixel_xy_camera_z_per_view = torch.stack([ torch.cat(world_space_to_pixel_xy_and_camera_z( world_xyz=pred_trajectories[0], intrs=intrs[0, view_idx], extrs=extrs[0, view_idx], ), dim=-1) for view_idx in range(num_views) ], dim=0) for view_idx in range(num_views): pred_trajectories_reproduced = pixel_xy_and_camera_z_to_world_space( pixel_xy=pred_trajectories_pixel_xy_camera_z_per_view[view_idx, :, :, :2], camera_z=pred_trajectories_pixel_xy_camera_z_per_view[view_idx, :, :, 2:], intrs_inv=intrs_inv[0, view_idx], extrs_inv=extrs_inv[0, view_idx], ) if not torch.allclose(pred_trajectories_reproduced, pred_trajectories, atol=1): warnings.warn(f"Reprojection of the predicted trajectories failed: " f"view_idx={view_idx}, " f"max_diff={torch.max(torch.abs(pred_trajectories_reproduced - pred_trajectories))}") pred_trajectories_pixel_xy_camera_z_per_view = pred_trajectories_pixel_xy_camera_z_per_view[None] # Compute 3D metrics gt_visibilities_any_view = gt_visibilities_per_view.any(dim=1) assert gt_visibilities_any_view.any(dim=1).all(), "All points should be visible in at least one view." per_track_results = None if evaluation_setting in ["kubric-multiview", "panoptic-multiview", "dexycb-multiview"]: eval_3dpt_results_dict = evaluate_3dpt( gt_tracks=gt_trajectories_3d_worldspace[0].cpu().numpy(), gt_visibilities=gt_visibilities_any_view[0].cpu().numpy(), query_points=query_points_3d[0].cpu().numpy(), pred_tracks=pred_trajectories[0].cpu().numpy(), pred_visibilities=pred_visibilities[0].cpu().numpy(), evaluation_setting=evaluation_setting, track_upscaling_factor=track_upscaling_factor, prefix=f"eval_{dataset_name}", add_per_track_results=should_save_rerun_viz, verbose=False, ) if should_save_rerun_viz: per_track_results = eval_3dpt_results_dict[f'eval_{dataset_name}/model__per_track_results'] del eval_3dpt_results_dict[f'eval_{dataset_name}/model__per_track_results'] metrics[datapoint_idx] = eval_3dpt_results_dict if "2dpt" in dataset_name: assert batch_size == 1 if pred_trajectories_2d is None: pred_trajectories_2d = pred_trajectories_pixel_xy_camera_z_per_view[:, :, :, :, :2] _rescale_to_256x256 = np.array([256, 256]) / np.array([width, height]) _metrics = {} for view_idx in range(num_views): track_mask = (query_points_view == view_idx).cpu().numpy() if track_mask.sum() == 0: continue _n_tracks = track_mask.sum() _gt_tracks = gt_trajectories_2d_pixelspace_w_z_cameraspace[0, view_idx, :, track_mask, :2] _gt_tracks = _gt_tracks.cpu().numpy() _gt_visibilities = gt_visibilities_per_view[0, view_idx, :, track_mask].cpu().bool().numpy() _query_t = query_timestep[track_mask] _query_xy = _gt_tracks[_query_t, np.arange(_n_tracks)] _query = np.concatenate([_query_t[:, None], _query_xy], axis=-1) _pred_tracks = pred_trajectories_2d[0, view_idx, :, track_mask].cpu().numpy() _pred_visibilities = np.zeros_like(_gt_visibilities) assert _gt_visibilities[_query_t, np.arange(_n_tracks)].all() eval_2dpt_results_dict = evaluate_3dpt( gt_tracks=_gt_tracks, gt_visibilities=_gt_visibilities, query_points=_query, pred_tracks=_pred_tracks, pred_visibilities=_pred_visibilities, evaluation_setting="2dpt_ablation", track_upscaling_factor=_rescale_to_256x256, prefix=f"eval_{dataset_name}", add_per_track_results=False, verbose=False, ) tapvid2d_original_metrics = compute_tapvid_metrics_original( query_points=np.concatenate([_query_t[:, None], _query_xy * _rescale_to_256x256], axis=-1), gt_occluded=~_gt_visibilities[None].transpose(0, 2, 1), gt_tracks=_gt_tracks[None].transpose(0, 2, 1, 3) * _rescale_to_256x256, pred_occluded=~_pred_visibilities[None].transpose(0, 2, 1), pred_tracks=_pred_tracks[None].transpose(0, 2, 1, 3) * _rescale_to_256x256, query_mode="first", ) tapvid2d_original_metrics = { f"eval_{dataset_name}/model__tapvid2d_{k}": (tapvid2d_original_metrics[k] * 100).round(2).item() for k in sorted(tapvid2d_original_metrics) } _metrics[view_idx] = {} _metrics[view_idx].update(eval_2dpt_results_dict) _metrics[view_idx].update(tapvid2d_original_metrics) _metrics[view_idx] = { k.replace("model__", "model__2dpt__"): v for k, v in _metrics[view_idx].items() if "jaccard" not in k and "occlusion" not in k } _metrics_avg = {} for k in _metrics[next(iter(_metrics.keys()))]: _metrics_avg[k] = np.mean([ _metrics[view_idx][k] for view_idx in _metrics if k in _metrics[view_idx] ]).round(2) metrics[datapoint_idx].update(_metrics_avg) for view_idx in _metrics: metrics[datapoint_idx].update({ f"{k}__view-{view_idx}": v for k, v in _metrics[view_idx].items() }) # Compute 2D metrics elif evaluation_setting in ["tapvid2d"]: assert num_views == 1 if pred_trajectories_2d is None: pred_trajectories_2d = pred_trajectories_pixel_xy_camera_z_per_view[:, :, :, :, :2] eval_2dpt_results_dict = evaluate_3dpt( gt_tracks=gt_trajectories_2d_pixelspace_w_z_cameraspace[0, 0, :, :, :2].cpu().numpy(), gt_visibilities=gt_visibilities_per_view[0, 0].cpu().bool().numpy(), query_points=query_points_2d[0].cpu().numpy(), pred_tracks=pred_trajectories_2d[0, 0].cpu().numpy(), pred_visibilities=pred_visibilities[0].cpu().numpy(), evaluation_setting=evaluation_setting, track_upscaling_factor=track_upscaling_factor, prefix=f"eval_{dataset_name}", add_per_track_results=should_save_rerun_viz, verbose=False, ) if should_save_rerun_viz: per_track_results = eval_2dpt_results_dict[f'eval_{dataset_name}/model__per_track_results'] del eval_2dpt_results_dict[f'eval_{dataset_name}/model__per_track_results'] metrics[datapoint_idx] = eval_2dpt_results_dict tapvid2d_original_metrics = compute_tapvid_metrics_original( query_points_2d[0].cpu().numpy(), torch.logical_not(gt_visibilities_per_view[:, 0].clone().permute(0, 2, 1)).cpu().numpy(), gt_trajectories_2d_pixelspace_w_z_cameraspace[:, 0, :, :, :2].clone().permute(0, 2, 1, 3).cpu().numpy(), torch.logical_not(pred_visibilities.clone().permute(0, 2, 1)).cpu().numpy(), pred_trajectories_2d[:, 0].permute(0, 2, 1, 3).cpu().numpy(), query_mode="first", ) tapvid2d_original_metrics = { f"eval_{dataset_name}/model__tapvid2d_{k}": (tapvid2d_original_metrics[k] * 100).round(2).item() for k in sorted(tapvid2d_original_metrics) } metrics[datapoint_idx].update(tapvid2d_original_metrics) elif evaluation_setting in ["no-tracking-labels"]: metrics[datapoint_idx] = {} np.savez( os.path.join(log_dir, f"step-{step}_seq-{seq_name}_tracks.npz"), gt_trajectories_2d=gt_trajectories_2d_pixelspace_w_z_cameraspace.cpu().numpy(), gt_trajectories_3d=gt_trajectories_3d_worldspace.cpu().numpy(), gt_visibilities_per_view=gt_visibilities_per_view.cpu().numpy(), gt_visibilities_any_view=gt_visibilities_any_view.cpu().numpy(), pred_trajectories_2d=pred_trajectories_pixel_xy_camera_z_per_view.cpu().numpy(), pred_trajectories_3d=pred_trajectories.cpu().numpy(), pred_visibilities_any_view=pred_visibilities.cpu().numpy(), query_points_2d=query_points_2d.cpu().numpy() if query_points_2d is not None else None, query_points_3d=query_points_3d.cpu().numpy(), track_upscaling_factor=track_upscaling_factor, ) # Visualize the results with rerun.io viz_fps = 30 if "panoptic" in dataset_name: viz_fps = 30 elif "dex" in dataset_name: viz_fps = 10 elif "kubric" in dataset_name: viz_fps = 12 if should_save_rerun_viz: # Log the visualizations to rerun if "mvtracker" in log_dir: method_id = 0 method_name = "MVTracker" elif "spatracker_mono" in log_dir: method_id = 1 method_name = "SpatialTrackerV1" elif "tapip3d" in log_dir: method_id = 2 method_name = "TAPIP3D" elif "spatracker_multi" in log_dir: method_id = 3 method_name = "Triplane" else: method_id = None method_name = "x" if "panoptic" in dataset_name: sphere_radius = 12 else: sphere_radius = 6.0 max_tracks = None if "dress" in dataset_name: max_tracks = 300 elif "panoptic" in dataset_name: max_tracks = 100 elif "kubric" in dataset_name or "dex-ycb" in dataset_name: max_tracks = 36 LogConfig = namedtuple("LogConfig", [ "suffix", "method_id", "max_tracks", "track_batch_size", "sphere_radius", "conf_thrs", "log_only_confident_pc", "memory_lightweight_logging" ]) log_configs = [ LogConfig( suffix="", method_id=None, max_tracks=None, track_batch_size=50, sphere_radius=None, conf_thrs=[1.0, 5.0], log_only_confident_pc=False, memory_lightweight_logging=False, ), LogConfig( suffix=".comparisons", method_id=method_id, max_tracks=100, track_batch_size=50, sphere_radius=None, conf_thrs=[1.0, 5.0], log_only_confident_pc=False, memory_lightweight_logging=True, ), LogConfig( suffix=".lightweight", method_id=None, max_tracks=max_tracks, track_batch_size=50, sphere_radius=sphere_radius, conf_thrs=[5.0], log_only_confident_pc=True, memory_lightweight_logging=True, ), LogConfig( suffix=".lightweight.comparisons", method_id=method_id, max_tracks=50, track_batch_size=50, sphere_radius=sphere_radius, conf_thrs=[5.0], log_only_confident_pc=True, memory_lightweight_logging=True, ), ] for cfg in log_configs: logfile_name = f"rerun__{dataset_name}--seq-{datapoint_idx}--name-{seq_name}--eval{cfg.suffix}.rrd" rr.init("3dpt", recording_id="v0.16") if cfg.method_id is None or cfg.method_id == 0: log_pointclouds_to_rerun( dataset_name=dataset_name, datapoint_idx=datapoint_idx, rgbs=rgbs, depths=depths, intrs=intrs, extrs=extrs, depths_conf=depths_conf, conf_thrs=cfg.conf_thrs, log_only_confident_pc=cfg.log_only_confident_pc, radii=-2.45, fps=viz_fps, bbox_crop=None, sphere_radius_crop=cfg.sphere_radius, sphere_center_crop=np.array([0, 0, 0]), log_rgb_image=not cfg.memory_lightweight_logging, log_depthmap_as_image_v1=False, log_depthmap_as_image_v2=False, log_camera_frustrum=True, log_rgb_pointcloud=True, ) log_tracks_to_rerun( dataset_name=dataset_name, datapoint_idx=datapoint_idx, predictor_name=method_name, gt_trajectories_3d_worldspace=None if no_tracking_labels else gt_trajectories_3d_worldspace, gt_visibilities_any_view=None if no_tracking_labels else gt_visibilities_any_view, query_points_3d=query_points_3d, pred_trajectories=pred_trajectories, pred_visibilities=pred_visibilities, per_track_results=per_track_results, radii_scale=1.0, fps=viz_fps, sphere_radius_crop=cfg.sphere_radius, sphere_center_crop=np.array([0, 0, 0]), log_per_interval_results=False, max_tracks_to_log=cfg.max_tracks, track_batch_size=cfg.track_batch_size, method_id=cfg.method_id, memory_lightweight_logging=cfg.memory_lightweight_logging, ) rr_rrd_path = os.path.join(log_dir, logfile_name) rr.save(rr_rrd_path) logging.info(f"Saved Rerun recording to: {rr_rrd_path}") # Visualize the results as mp4 if should_save_mp4_viz: log_mp4_track_viz( log_dir=log_dir, dataset_name=dataset_name, datapoint_idx=datapoint_idx, rgbs=rgbs, intrs=intrs, extrs=extrs, gt_trajectories=gt_trajectories_3d_worldspace, gt_visibilities=gt_visibilities_any_view, pred_trajectories=pred_trajectories, pred_visibilities=pred_visibilities, query_points_3d=query_points_3d, step=step, prefix="comparison__v4a-train__", max_tracks_to_visualize=36, max_individual_tracks_to_visualize=6, ) if novel_rgbs is not None: log_mp4_track_viz( log_dir=log_dir, dataset_name=dataset_name, datapoint_idx=datapoint_idx, rgbs=novel_rgbs, intrs=novel_intrs, extrs=novel_extrs, gt_trajectories=gt_trajectories_3d_worldspace, gt_visibilities=gt_visibilities_any_view, pred_trajectories=pred_trajectories, pred_visibilities=pred_visibilities, query_points_3d=query_points_3d, step=step, prefix="comparison__v4b-novel__", max_tracks_to_visualize=36, max_individual_tracks_to_visualize=0, ) metrics[datapoint_idx]["fps"] = fps try: params_total = sum(p.numel() for p in model.parameters()) params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) params_non_trainable = sum(p.numel() for p in model.parameters() if not p.requires_grad) metrics[datapoint_idx]["params_total"] = params_total metrics[datapoint_idx]["params_trainable"] = params_trainable metrics[datapoint_idx]["params_non_trainable"] = params_non_trainable except Exception as e: logging.info(f"Error calculating model parameters: {e}") # Compute average if count > 0: avg_fps = total_fps / count logging.info(f"\nAverage FPS across {count} datapoints: {avg_fps:.1f}") else: logging.warning("No datapoints were processed.") return metrics ================================================ FILE: mvtracker/evaluation/metrics.py ================================================ import logging import warnings from typing import Mapping import numpy as np import pandas as pd import torch def compute_metrics( query_points, gt_occluded, gt_tracks, pred_occluded, pred_tracks, distance_thresholds=[1, 2, 4, 8, 16], survival_distance_threshold=50, query_mode="first", ): n_batches, n_frames, n_points, n_point_dim = gt_tracks.shape # First, we compute the original TAP-Vid metrics tapvid_metrics = compute_tapvid_metrics(query_points, gt_occluded, gt_tracks, pred_occluded, pred_tracks, distance_thresholds, query_mode) # Compute distances only for visible points visible_mask = ~gt_occluded distances = torch.norm(pred_tracks - gt_tracks, dim=-1) distances[~visible_mask] = float('nan') distances[torch.arange(n_frames)[None, :, None] < query_points[:, :, 0].long()[:, None, :]] = float('nan') # Compute Median Trajectory Error (MTE) and Average Trajectory Error (ATE) for visible points mte_per_track = torch.nanmedian(distances, dim=1).values ate_per_track = torch.nanmean(distances, dim=1) assert torch.isnan(mte_per_track).sum() == 0 assert torch.isnan(ate_per_track).sum() == 0 # Compute Final Trajectory Error (FDE) for the last visible frame last_visible_idx = torch.argmax(visible_mask * np.arange(n_frames)[None, :, None], dim=1) fde_per_track = distances[torch.arange(n_batches)[:, None], last_visible_idx, torch.arange(n_points)] # Compute "Survival" rate for visible points tracking_failed = (distances > survival_distance_threshold) * visible_mask failure_index = tracking_failed.float().argmax(dim=1) failure_index[(~tracking_failed).all(dim=1)] = n_frames # If all points survived, survival is 1.0 survival_per_track = (failure_index - query_points[:, :, 0].long()) / (n_frames - query_points[:, :, 0].long()) assert mte_per_track.shape == ate_per_track.shape == survival_per_track.shape == fde_per_track.shape metrics = { 'mte_visible_per_track': mte_per_track, 'ate_visible_per_track': ate_per_track, 'fde_visible_per_track': fde_per_track, 'survival_per_track': survival_per_track, **tapvid_metrics, } return metrics def compute_tapvid_metrics( query_points, gt_occluded, gt_tracks, pred_occluded, pred_tracks, distance_thresholds, query_mode="first", ): """ Computes metrics from TAP-Vid (https://arxiv.org/abs/2211.03726) based on given ground truth and predictions. The computations are performed separately for each video in the batch. Parameters ---------- query_points : torch.Tensor Tensor of shape (n_batches, n_points, 3) representing the query points. gt_occluded : torch.Tensor Boolean tensor of shape (n_batches, n_frames, n_points) indicating if a point is occluded in the ground truth. gt_tracks : torch.Tensor Tensor of shape (n_batches, n_frames, n_points, n_point_dim) representing the ground truth tracks. pred_occluded : torch.Tensor Boolean tensor of shape (n_batches, n_frames, n_points) indicating if a point is occluded in the predictions. pred_tracks : torch.Tensor Tensor of shape (n_batches, n_frames, n_points, n_point_dim) representing the predicted tracks. query_mode : str, optional Either "first" or "strided", default is "first". Indicates how the query points are sampled. Returns ------- dict A dictionary containing: - 'occlusion_accuracy_per_track': Accuracy at predicting occlusion, per track. - 'pts_within_{x}_per_track' for x in [1, 2, 4, 8, 16]: Fraction of points predicted to be within the given pixel threshold, ignoring occlusion prediction, per track. - 'jaccard_{x}_per_track' for x in [1, 2, 4, 8, 16]: Jaccard metric for the given pixel threshold. Combines occlusion and point prediction accuracy, per track. - 'average_jaccard_per_track': Average Jaccard metric across thresholds, per track. - 'average_pts_within_thresh_per_track': Average fraction of points within threshold across thresholds, per track. """ metrics = {} # Check shapes. n_batches, n_frames, n_points, n_point_dim = gt_tracks.shape assert n_point_dim in [2, 3] assert query_points.shape == (n_batches, n_points, n_point_dim + 1) assert gt_occluded.shape == (n_batches, n_frames, n_points) assert gt_tracks.shape == (n_batches, n_frames, n_points, n_point_dim) assert pred_occluded.shape == (n_batches, n_frames, n_points) assert pred_tracks.shape == (n_batches, n_frames, n_points, n_point_dim) assert query_mode in ["first", "strided"] assert query_points.dtype == torch.float32 assert gt_occluded.dtype == torch.bool assert gt_tracks.dtype == torch.float32 assert pred_occluded.dtype == torch.bool assert pred_tracks.dtype == torch.float32 # Don't evaluate the query point. evaluation_points = torch.ones_like(gt_occluded, dtype=torch.bool) for batch_idx in range(n_batches): t = query_points[batch_idx, :, 0].long() evaluation_points[batch_idx, t, torch.arange(n_points)] = False # In first query mode, don't evaluate points before the query point. if query_mode == "first": t = query_points[:, :, 0].long() mask = torch.arange(n_frames).unsqueeze(-1) < t.unsqueeze(1) evaluation_points[mask] = False # Compute occlusion accuracy per track. occ_acc = ((pred_occluded == gt_occluded) & evaluation_points).float().sum(dim=1) / evaluation_points.sum(dim=1) metrics["occlusion_accuracy_per_track"] = occ_acc # Let's report the numbers separately for gt=0 and gt=1 numer0 = ((pred_occluded == gt_occluded) & (gt_occluded == 1) & evaluation_points).float().sum(dim=1) numer1 = ((pred_occluded == gt_occluded) & (gt_occluded == 0) & evaluation_points).float().sum(dim=1) denom0 = ((gt_occluded == 1) & evaluation_points).float().sum(dim=1) denom1 = ((gt_occluded == 0) & evaluation_points).float().sum(dim=1) occ_acc_for_vis0 = numer0 / denom0 occ_acc_for_vis1 = numer1 / denom1 metrics["occlusion_accuracy_for_vis0_per_track"] = occ_acc_for_vis0 metrics["occlusion_accuracy_for_vis1_per_track"] = occ_acc_for_vis1 # Compute position metrics per track. distances = torch.norm(pred_tracks - gt_tracks, dim=-1) thresholds = torch.tensor(distance_thresholds, device=distances.device) for thresh in thresholds: within_threshold = distances < thresh correct_positions = (within_threshold & ~gt_occluded & evaluation_points).float().sum(dim=1) visible_points = (~gt_occluded & evaluation_points).float().sum(dim=1) assert visible_points.min() > 0, "No visible points to evaluate. Make sure at least two timesteps were visible." metrics[f"pts_within_{thresh:.2f}_per_track"] = correct_positions / visible_points true_positives = (within_threshold & ~pred_occluded & ~gt_occluded & evaluation_points).float().sum(dim=1) gt_positives = (~gt_occluded & evaluation_points).float().sum(dim=1) false_positives = (~within_threshold & ~pred_occluded) | (~pred_occluded & gt_occluded) false_positives = (false_positives & evaluation_points).float().sum(dim=1) jaccard = true_positives / (gt_positives + false_positives) metrics[f"jaccard_{thresh:.2f}_per_track"] = jaccard metrics["average_jaccard_per_track"] = torch.stack([metrics[f"jaccard_{thresh:.2f}_per_track"] for thresh in thresholds], dim=-1).mean(dim=-1) metrics["average_pts_within_thresh_per_track"] = torch.stack([metrics[f"pts_within_{thresh:.2f}_per_track"] for thresh in thresholds], dim=-1).mean(dim=-1) # Assert no nans for k, v in metrics.items(): if k in ["occlusion_accuracy_for_vis0_per_track", "occlusion_accuracy_for_vis1_per_track"]: continue # They can have nans and will be handled later assert not torch.isnan(v).any(), f"NaN found in {k}" return metrics def compute_tapvid_metrics_original( query_points: np.ndarray, gt_occluded: np.ndarray, gt_tracks: np.ndarray, pred_occluded: np.ndarray, pred_tracks: np.ndarray, query_mode: str, ) -> Mapping[str, np.ndarray]: """Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.) See the TAP-Vid paper for details on the metric computation. All inputs are given in raster coordinates. The first three arguments should be the direct outputs of the reader: the 'query_points', 'occluded', and 'target_points'. The paper metrics assume these are scaled relative to 256x256 images. pred_occluded and pred_tracks are your algorithm's predictions. This function takes a batch of inputs, and computes metrics separately for each video. The metrics for the full benchmark are a simple mean of the metrics across the full set of videos. These numbers are between 0 and 1, but the paper multiplies them by 100 to ease reading. Args: query_points: The query points, an in the format [t, y, x]. Its size is [b, n, 3], where b is the batch size and n is the number of queries gt_occluded: A boolean array of shape [b, n, t], where t is the number of frames. True indicates that the point is occluded. gt_tracks: The target points, of shape [b, n, t, 2]. Each point is in the format [x, y] pred_occluded: A boolean array of predicted occlusions, in the same format as gt_occluded. pred_tracks: An array of track predictions from your algorithm, in the same format as gt_tracks. query_mode: Either 'first' or 'strided', depending on how queries are sampled. If 'first', we assume the prior knowledge that all points before the query point are occluded, and these are removed from the evaluation. Returns: A dict with the following keys: occlusion_accuracy: Accuracy at predicting occlusion. pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points predicted to be within the given pixel threshold, ignoring occlusion prediction. jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given threshold average_pts_within_thresh: average across pts_within_{x} average_jaccard: average across jaccard_{x} """ metrics = {} # Fixed bug is described in: # https://github.com/facebookresearch/co-tracker/issues/20 eye = np.eye(gt_tracks.shape[2], dtype=np.int32) if query_mode == "first": # evaluate frames after the query frame query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye elif query_mode == "strided": # evaluate all frames except the query frame query_frame_to_eval_frames = 1 - eye else: raise ValueError("Unknown query mode " + query_mode) query_frame = query_points[..., 0] query_frame = np.round(query_frame).astype(np.int32) evaluation_points = query_frame_to_eval_frames[query_frame] > 0 # Occlusion accuracy is simply how often the predicted occlusion equals the # ground truth. occ_acc = np.sum( np.equal(pred_occluded, gt_occluded) & evaluation_points, axis=(1, 2), ) / np.sum(evaluation_points) metrics["occlusion_accuracy"] = occ_acc # Next, convert the predictions and ground truth positions into pixel # coordinates. visible = np.logical_not(gt_occluded) pred_visible = np.logical_not(pred_occluded) all_frac_within = [] all_jaccard = [] for thresh in [1, 2, 4, 8, 16]: # True positives are points that are within the threshold and where both # the prediction and the ground truth are listed as visible. within_dist = np.sum( np.square(pred_tracks - gt_tracks), axis=-1, ) < np.square(thresh) is_correct = np.logical_and(within_dist, visible) # Compute the frac_within_threshold, which is the fraction of points # within the threshold among points that are visible in the ground truth, # ignoring whether they're predicted to be visible. count_correct = np.sum( is_correct & evaluation_points, axis=(1, 2), ) count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2)) frac_correct = count_correct / count_visible_points metrics["pts_within_" + str(thresh)] = frac_correct all_frac_within.append(frac_correct) true_positives = np.sum( is_correct & pred_visible & evaluation_points, axis=(1, 2) ) # The denominator of the jaccard metric is the true positives plus # false positives plus false negatives. However, note that true positives # plus false negatives is simply the number of points in the ground truth # which is easier to compute than trying to compute all three quantities. # Thus we just add the number of points in the ground truth to the number # of false positives. # # False positives are simply points that are predicted to be visible, # but the ground truth is not visible or too far from the prediction. gt_positives = np.sum(visible & evaluation_points, axis=(1, 2)) false_positives = (~visible) & pred_visible false_positives = false_positives | ((~within_dist) & pred_visible) false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2)) jaccard = true_positives / (gt_positives + false_positives) metrics["jaccard_" + str(thresh)] = jaccard all_jaccard.append(jaccard) metrics["average_jaccard"] = np.mean( np.stack(all_jaccard, axis=1), axis=1, ) metrics["average_pts_within_thresh"] = np.mean( np.stack(all_frac_within, axis=1), axis=1, ) return metrics def evaluate_predictions( gt_tracks, gt_visibilities, pred_tracks, pred_occluded, query_points=None, distance_thresholds=[0.01, 0.02, 0.04, 0.08, 0.16], # 1 cm, 2 cm, 4 cm, 8 cm, 16 cm survival_distance_threshold=0.5, # 50 cm static_threshold=0.01, # < 0.01 cm dynamic_threshold=0.1, # > 10 cm very_dynamic_threshold=2.0, # > 2 m ): n_frames, n_points, n_point_dim = gt_tracks.shape if query_points is None: warnings.warn("Query points are not provided. Using the first visible frame as query points.") query_points_t = np.argmax(gt_visibilities, axis=0) query_points_xyz = gt_tracks[query_points_t, np.arange(n_points)] query_points = np.concatenate([query_points_t[:, None], query_points_xyz], axis=-1) at_query_timestep_or_later = (np.arange(n_frames)[:, None] >= query_points[:, 0][None, :]) gt_visibilities = gt_visibilities.copy() * at_query_timestep_or_later movement = np.zeros(n_points) for point_idx in range(n_points): point_track = gt_tracks[gt_visibilities[:, point_idx], point_idx, :] movement[point_idx] = np.linalg.norm(point_track[1:] - point_track[:-1], axis=-1).sum() point_types = ["any"] static_points = None dynamic_points = None very_dynamic_points = None if static_threshold is not None: point_types += ["static"] static_points = movement < static_threshold if dynamic_threshold is not None: point_types += ["dynamic"] dynamic_points = movement > dynamic_threshold if very_dynamic_threshold is not None: point_types += ["very_dynamic"] very_dynamic_points = movement > very_dynamic_threshold mask_1 = gt_visibilities.sum(axis=0) >= 2 # At least two visible, the first one is a query results = {} results_per_track = {} for short_name, mask_a in [ ("all", mask_1), ]: for point_type in point_types: if point_type == "any": mask_b = np.ones_like(mask_a) elif point_type == "static": mask_b = static_points elif point_type == "dynamic": mask_b = dynamic_points elif point_type == "very_dynamic": mask_b = very_dynamic_points else: raise ValueError mask_ab = mask_a & mask_b short_name_ = f"{short_name}_{point_type}" if mask_ab.sum() == 0: logging.info(f"No points for {short_name_} (empty mask).") continue pred_tracks_ = pred_tracks[:, mask_ab, :][None] out_metrics_3d = compute_metrics( torch.from_numpy(query_points[mask_ab, :][None]).float(), torch.from_numpy(~gt_visibilities[:, mask_ab][None]), torch.from_numpy(gt_tracks[:, mask_ab, :][None]).float(), torch.from_numpy(pred_occluded[:, mask_ab][None]), torch.from_numpy(pred_tracks_).float(), distance_thresholds=distance_thresholds, survival_distance_threshold=survival_distance_threshold, query_mode="first", ) results[short_name_] = {} for k, v in out_metrics_3d.items(): assert "_per_track" in k results[short_name_][k.replace("_per_track", "")] = v.nanmean().item() * 100 results[short_name_]["n"] = mask_ab.sum() / n_points * 100 results[short_name_]["v"] = (gt_visibilities[:, mask_ab].sum() / mask_ab.sum() / n_frames) * 100 results_per_track[short_name_] = {} for k, v in out_metrics_3d.items(): assert v.ndim == 2 and v.shape[0] == 1 v = v[0] results_per_track[short_name_][k] = v.cpu().numpy() * 100 results_per_track[short_name_]["indices"] = np.where(mask_ab)[0] if "all_static" in results.keys() and "all_dynamic" in results.keys(): results["all_dynamic-static-mean"] = {} for k in results["all_static"].keys(): results["all_dynamic-static-mean"][k] = (results["all_dynamic"][k] + results["all_static"][k]) / 2 df = pd.DataFrame(results) df = df.round(2) df_per_track = pd.DataFrame(results_per_track) df_per_track = df_per_track.round(2) return df, df_per_track ================================================ FILE: mvtracker/models/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: mvtracker/models/core/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: mvtracker/models/core/copycat.py ================================================ import torch from torch import nn as nn class CopyCat(nn.Module): """ Dummy, no-movement baseline that always outputs the query points as the predicted points. """ def __init__(self): super().__init__() self.dummy_learnable_param = nn.Parameter(torch.zeros(1)) def forward( self, rgbs, depths, query_points, intrs, extrs, **kwargs, ): batch_size, num_views, num_frames, _, height, width = rgbs.shape _, num_points, _ = query_points.shape assert rgbs.shape == (batch_size, num_views, num_frames, 3, height, width) assert depths.shape == (batch_size, num_views, num_frames, 1, height, width) assert query_points.shape == (batch_size, num_points, 4) assert intrs.shape == (batch_size, num_views, num_frames, 3, 3) assert extrs.shape == (batch_size, num_views, num_frames, 3, 4) traj_e = query_points[:, None, :, 1:].repeat(1, num_frames, 1, 1) vis_e = query_points.new_ones((batch_size, num_frames, num_points)) results = { "traj_e": traj_e, "feat_init": None, "vis_e": vis_e, } return results ================================================ FILE: mvtracker/models/core/cotracker2/__init__.py ================================================ ================================================ FILE: mvtracker/models/core/cotracker2/blocks.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import collections from itertools import repeat from typing import Callable import torch import torch.nn as nn import torch.nn.functional as F # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) return parse def exists(val): return val is not None def default(val, d): return val if exists(val) else d to_2tuple = _ntuple(2) class Mlp(nn.Module): """MLP as used in Vision Transformer, MLP-Mixer and related networks""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class ResidualBlock(nn.Module): def __init__(self, in_planes, planes, norm_fn="group", stride=1): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=3, padding=1, stride=stride, padding_mode="zeros", ) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros") self.relu = nn.ReLU(inplace=True) num_groups = planes // 8 if norm_fn == "group": self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) if not stride == 1: self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) elif norm_fn == "batch": self.norm1 = nn.BatchNorm2d(planes) self.norm2 = nn.BatchNorm2d(planes) if not stride == 1: self.norm3 = nn.BatchNorm2d(planes) elif norm_fn == "instance": self.norm1 = nn.InstanceNorm2d(planes) self.norm2 = nn.InstanceNorm2d(planes) if not stride == 1: self.norm3 = nn.InstanceNorm2d(planes) elif norm_fn == "none": self.norm1 = nn.Sequential() self.norm2 = nn.Sequential() if not stride == 1: self.norm3 = nn.Sequential() if stride == 1: self.downsample = None else: self.downsample = nn.Sequential( nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 ) def forward(self, x): y = x y = self.relu(self.norm1(self.conv1(y))) y = self.relu(self.norm2(self.conv2(y))) if self.downsample is not None: x = self.downsample(x) return self.relu(x + y) class BasicEncoder(nn.Module): def __init__(self, input_dim=3, output_dim=128, stride=4): super(BasicEncoder, self).__init__() self.stride = stride self.norm_fn = "instance" self.in_planes = output_dim // 2 self.norm1 = nn.InstanceNorm2d(self.in_planes) self.norm2 = nn.InstanceNorm2d(output_dim * 2) self.conv1 = nn.Conv2d( input_dim, self.in_planes, kernel_size=7, stride=2, padding=3, padding_mode="zeros", ) self.relu1 = nn.ReLU(inplace=True) self.layer1 = self._make_layer(output_dim // 2, stride=1) self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2) self.layer3 = self._make_layer(output_dim, stride=2) self.layer4 = self._make_layer(output_dim, stride=2) self.conv2 = nn.Conv2d( output_dim * 3 + output_dim // 4, output_dim * 2, kernel_size=3, padding=1, padding_mode="zeros", ) self.relu2 = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.InstanceNorm2d)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def _make_layer(self, dim, stride=1): layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) layers = (layer1, layer2) self.in_planes = dim return nn.Sequential(*layers) def forward(self, x): _, _, H, W = x.shape x = self.conv1(x) x = self.norm1(x) x = self.relu1(x) a = self.layer1(x) b = self.layer2(a) c = self.layer3(b) d = self.layer4(c) def _bilinear_intepolate(x): return F.interpolate( x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True, ) a = _bilinear_intepolate(a) b = _bilinear_intepolate(b) c = _bilinear_intepolate(c) d = _bilinear_intepolate(d) x = self.conv2(torch.cat([a, b, c, d], dim=1)) x = self.norm2(x) x = self.relu2(x) x = self.conv3(x) return x class Attention(nn.Module): def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False): super().__init__() inner_dim = dim_head * num_heads context_dim = default(context_dim, query_dim) self.scale = dim_head ** -0.5 self.heads = num_heads self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias) self.to_out = nn.Linear(inner_dim, query_dim) def forward(self, x, context=None, attn_mask=None): B, N1, _ = x.shape h = self.heads q = self.to_q(x).reshape(B, N1, h, -1).permute(0, 2, 1, 3) context = default(context, x) k, v = self.to_kv(context).chunk(2, dim=-1) N2 = context.shape[1] k = k.reshape(B, N2, h, -1).permute(0, 2, 1, 3) v = v.reshape(B, N2, h, -1).permute(0, 2, 1, 3) sim = (q @ k.transpose(-2, -1)) * self.scale if attn_mask is not None: sim = sim.masked_fill(~attn_mask, float('-inf')) attn = sim.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N1, -1) return self.to_out(x) class FlashAttention(nn.Module): def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False): super().__init__() inner_dim = dim_head * num_heads context_dim = default(context_dim, query_dim) self.num_heads = num_heads self.dim_head = dim_head self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias) self.to_out = nn.Linear(inner_dim, query_dim) def forward(self, x, context=None, attn_mask=None): B, N1, _ = x.shape h = self.num_heads q = self.to_q(x).reshape(B, N1, h, self.dim_head).transpose(1, 2) context = default(context, x) k, v = self.to_kv(context).chunk(2, dim=-1) N2 = context.shape[1] k = k.reshape(B, N2, h, self.dim_head).transpose(1, 2) v = v.reshape(B, N2, h, self.dim_head).transpose(1, 2) x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) x = x.transpose(1, 2).reshape(B, N1, -1) return self.to_out(x) class AttnBlock(nn.Module): def __init__( self, hidden_size, num_heads, mlp_ratio=4.0, attn_class: Callable[..., nn.Module] = Attention, **block_kwargs, ): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp( in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0, ) def forward(self, x, attn_mask=None): x = x + self.attn(self.norm1(x), attn_mask=attn_mask) x = x + self.mlp(self.norm2(x)) return x class CrossAttnBlock(nn.Module): def __init__( self, hidden_size, context_dim, num_heads, mlp_ratio=4.0, attn_class: Callable[..., nn.Module] = Attention, **block_kwargs, ): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.norm_context = nn.LayerNorm(hidden_size) self.cross_attn = attn_class( query_dim=hidden_size, context_dim=context_dim, num_heads=num_heads, qkv_bias=True, **block_kwargs, ) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp( in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0, ) def forward(self, x, context, attn_mask=None): x = x + self.cross_attn(self.norm1(x), context=self.norm_context(context), attn_mask=attn_mask) x = x + self.mlp(self.norm2(x)) return x class EfficientUpdateFormer(nn.Module): """ Transformer model that updates track estimates. """ def __init__( self, space_depth=6, time_depth=6, input_dim=320, hidden_size=384, num_heads=8, output_dim=130, mlp_ratio=4.0, add_space_attn=True, num_virtual_tracks=64, attn_class: Callable[..., nn.Module] = Attention, linear_layer_for_vis_conf=False, ): super().__init__() self.out_channels = 2 self.num_heads = num_heads self.hidden_size = hidden_size self.add_space_attn = add_space_attn self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) self.linear_layer_for_vis_conf = linear_layer_for_vis_conf if self.linear_layer_for_vis_conf: self.flow_head = nn.Sequential( nn.Linear(hidden_size, output_dim, bias=True), nn.ReLU(inplace=True), nn.Linear(output_dim, output_dim, bias=True), nn.ReLU(inplace=True), nn.Linear(output_dim, output_dim - 2, bias=True) ) self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True) else: self.flow_head = nn.Sequential( nn.Linear(hidden_size, output_dim, bias=True), nn.ReLU(inplace=True), nn.Linear(output_dim, output_dim, bias=True), nn.ReLU(inplace=True), nn.Linear(output_dim, output_dim, bias=True) ) self.num_virtual_tracks = num_virtual_tracks self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) self.time_blocks = nn.ModuleList( [ AttnBlock( hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=attn_class, ) for _ in range(time_depth) ] ) if add_space_attn: self.space_virtual_blocks = nn.ModuleList( [ AttnBlock( hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=attn_class, ) for _ in range(space_depth) ] ) self.space_point2virtual_blocks = nn.ModuleList( [ CrossAttnBlock( hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=attn_class, ) for _ in range(space_depth) ] ) self.space_virtual2point_blocks = nn.ModuleList( [ CrossAttnBlock( hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=attn_class, ) for _ in range(space_depth) ] ) assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) self.initialize_weights() def initialize_weights(self): def xavier_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) def trunc_init(module): if isinstance(module, nn.Linear): torch.nn.init.trunc_normal_(module.weight, std=0.001) # Apply xavier to all except flow_head self.apply(xavier_init) # Then override flow_head with trunc_normal self.flow_head.apply(trunc_init) if self.linear_layer_for_vis_conf: self.vis_conf_head.apply(trunc_init) def forward(self, input_tensor, mask=None): tokens = self.input_transform(input_tensor) B, _, T, _ = tokens.shape virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) tokens = torch.cat([tokens, virtual_tokens], dim=1) _, N, _, _ = tokens.shape j = 0 for i in range(len(self.time_blocks)): time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C time_tokens = self.time_blocks[i](time_tokens) tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C if self.add_space_attn and ( i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0 ): space_tokens = ( tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) ) # B N T C -> (B T) N C point_tokens = space_tokens[:, : N - self.num_virtual_tracks] virtual_tokens = space_tokens[:, N - self.num_virtual_tracks:] virtual_tokens = self.space_virtual2point_blocks[j]( virtual_tokens, point_tokens, attn_mask=mask ) virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) point_tokens = self.space_point2virtual_blocks[j]( point_tokens, virtual_tokens, attn_mask=mask ) space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C j += 1 tokens = tokens[:, : N - self.num_virtual_tracks] flow = self.flow_head(tokens) if self.linear_layer_for_vis_conf: vis_conf = self.vis_conf_head(tokens) flow = torch.cat([flow, vis_conf], dim=-1) return flow ================================================ FILE: mvtracker/models/core/dpt/__init__.py ================================================ ================================================ FILE: mvtracker/models/core/dpt/base_model.py ================================================ import torch class BaseModel(torch.nn.Module): def load(self, path): """Load model from file. Args: path (str): file path """ parameters = torch.load(path, map_location=torch.device("cpu")) if "optimizer" in parameters: parameters = parameters["model"] self.load_state_dict(parameters) ================================================ FILE: mvtracker/models/core/dpt/blocks.py ================================================ import torch import torch.nn as nn from mvtracker.models.core.dpt.vit import ( _make_pretrained_vitb_rn50_384, _make_pretrained_vitl16_384, _make_pretrained_vitb16_384, _make_pretrained_vit_tiny ) def _make_encoder( backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore", enable_attention_hooks=False, ): if backbone == "vitl16_384": pretrained = _make_pretrained_vitl16_384( use_pretrained, hooks=hooks, use_readout=use_readout, enable_attention_hooks=enable_attention_hooks, ) scratch = _make_scratch( [256, 512, 1024, 1024], features, groups=groups, expand=expand ) # ViT-L/16 - 85.0% Top1 (backbone) elif backbone == "vitb_rn50_384": pretrained = _make_pretrained_vitb_rn50_384( use_pretrained, hooks=hooks, use_vit_only=use_vit_only, use_readout=use_readout, enable_attention_hooks=enable_attention_hooks, ) scratch = _make_scratch( [256, 512, 768, 768], features, groups=groups, expand=expand ) # ViT-H/16 - 85.0% Top1 (backbone) elif backbone == "vitb16_384": pretrained = _make_pretrained_vitb16_384( use_pretrained, hooks=hooks, use_readout=use_readout, enable_attention_hooks=enable_attention_hooks, ) scratch = _make_scratch( [96, 192, 384, 768], features, groups=groups, expand=expand ) # ViT-B/16 - 84.6% Top1 (backbone) elif backbone == "resnext101_wsl": pretrained = _make_pretrained_resnext101_wsl(use_pretrained) scratch = _make_scratch( [256, 512, 1024, 2048], features, groups=groups, expand=expand ) # efficientnet_lite3 elif backbone == "vit_tiny_r_s16_p8_384": pretrained = _make_pretrained_vit_tiny( use_pretrained, hooks=hooks, use_readout=use_readout, enable_attention_hooks=enable_attention_hooks, ) scratch = _make_scratch( [96, 192, 384, 768], features, groups=groups, expand=expand ) else: print(f"Backbone '{backbone}' not implemented") assert False return pretrained, scratch def _make_scratch(in_shape, out_shape, groups=1, expand=False): scratch = nn.Module() out_shape1 = out_shape out_shape2 = out_shape out_shape3 = out_shape out_shape4 = out_shape if expand == True: out_shape1 = out_shape out_shape2 = out_shape * 2 out_shape3 = out_shape * 4 out_shape4 = out_shape * 8 scratch.layer1_rn = nn.Conv2d( in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups, ) scratch.layer2_rn = nn.Conv2d( in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups, ) scratch.layer3_rn = nn.Conv2d( in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups, ) scratch.layer4_rn = nn.Conv2d( in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups, ) return scratch def _make_resnet_backbone(resnet): pretrained = nn.Module() pretrained.layer1 = nn.Sequential( resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 ) pretrained.layer2 = resnet.layer2 pretrained.layer3 = resnet.layer3 pretrained.layer4 = resnet.layer4 return pretrained def _make_pretrained_resnext101_wsl(use_pretrained): resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") return _make_resnet_backbone(resnet) class Interpolate(nn.Module): """Interpolation module.""" def __init__(self, scale_factor, mode, align_corners=False): """Init. Args: scale_factor (float): scaling mode (str): interpolation mode """ super(Interpolate, self).__init__() self.interp = nn.functional.interpolate self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners def forward(self, x): """Forward pass. Args: x (tensor): input Returns: tensor: interpolated data """ x = self.interp( x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners, ) return x class ResidualConvUnit(nn.Module): """Residual convolution module.""" def __init__(self, features): """Init. Args: features (int): number of features """ super().__init__() self.conv1 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=True ) self.conv2 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=True ) self.relu = nn.ReLU(inplace=True) def forward(self, x): """Forward pass. Args: x (tensor): input Returns: tensor: output """ out = self.relu(x) out = self.conv1(out) out = self.relu(out) out = self.conv2(out) return out + x class FeatureFusionBlock(nn.Module): """Feature fusion block.""" def __init__(self, features): """Init. Args: features (int): number of features """ super(FeatureFusionBlock, self).__init__() self.resConfUnit1 = ResidualConvUnit(features) self.resConfUnit2 = ResidualConvUnit(features) def forward(self, *xs): """Forward pass. Returns: tensor: output """ output = xs[0] if len(xs) == 2: output += self.resConfUnit1(xs[1]) output = self.resConfUnit2(output) output = nn.functional.interpolate( output, scale_factor=2, mode="bilinear", align_corners=True ) return output class ResidualConvUnit_custom(nn.Module): """Residual convolution module.""" def __init__(self, features, activation, bn): """Init. Args: features (int): number of features """ super().__init__() self.bn = bn self.groups = 1 self.conv1 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=not self.bn, groups=self.groups, ) self.conv2 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=not self.bn, groups=self.groups, ) if self.bn == True: self.bn1 = nn.BatchNorm2d(features) self.bn2 = nn.BatchNorm2d(features) self.activation = activation self.skip_add = nn.quantized.FloatFunctional() def forward(self, x): """Forward pass. Args: x (tensor): input Returns: tensor: output """ out = self.activation(x) out = self.conv1(out) if self.bn == True: out = self.bn1(out) out = self.activation(out) out = self.conv2(out) if self.bn == True: out = self.bn2(out) if self.groups > 1: out = self.conv_merge(out) return self.skip_add.add(out, x) # return out + x class FeatureFusionBlock_custom(nn.Module): """Feature fusion block.""" def __init__( self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, ): """Init. Args: features (int): number of features """ super(FeatureFusionBlock_custom, self).__init__() self.deconv = deconv self.align_corners = align_corners self.groups = 1 self.expand = expand out_features = features if self.expand == True: out_features = features // 2 self.out_conv = nn.Conv2d( features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1, ) self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) self.skip_add = nn.quantized.FloatFunctional() def forward(self, *xs): """Forward pass. Returns: tensor: output """ output = xs[0] if len(xs) == 2: res = self.resConfUnit1(xs[1]) output = self.skip_add.add(output, res) # output += res output = self.resConfUnit2(output) output = nn.functional.interpolate( output, scale_factor=2, mode="bilinear", align_corners=self.align_corners ) output = self.out_conv(output) return output ================================================ FILE: mvtracker/models/core/dpt/midas_net.py ================================================ """MidashNet: Network for monocular depth estimation trained by mixing several datasets. This file contains code that is adapted from https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py """ import torch import torch.nn as nn from mvtracker.models.core.dpt.base_model import BaseModel from mvtracker.models.core.dpt.blocks import FeatureFusionBlock, Interpolate, _make_encoder class MidasNet_large(BaseModel): """Network for monocular depth estimation.""" def __init__(self, path=None, features=256, non_negative=True): """Init. Args: path (str, optional): Path to saved model. Defaults to None. features (int, optional): Number of features. Defaults to 256. backbone (str, optional): Backbone network for encoder. Defaults to resnet50 """ print("Loading weights: ", path) super(MidasNet_large, self).__init__() use_pretrained = False if path is None else True self.pretrained, self.scratch = _make_encoder( backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained ) self.scratch.refinenet4 = FeatureFusionBlock(features) self.scratch.refinenet3 = FeatureFusionBlock(features) self.scratch.refinenet2 = FeatureFusionBlock(features) self.scratch.refinenet1 = FeatureFusionBlock(features) self.scratch.output_conv = nn.Sequential( nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), Interpolate(scale_factor=2, mode="bilinear"), nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(True), nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True) if non_negative else nn.Identity(), ) if path: self.load(path) def forward(self, x): """Forward pass. Args: x (tensor): input data (image) Returns: tensor: depth """ layer_1 = self.pretrained.layer1(x) layer_2 = self.pretrained.layer2(layer_1) layer_3 = self.pretrained.layer3(layer_2) layer_4 = self.pretrained.layer4(layer_3) layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) path_4 = self.scratch.refinenet4(layer_4_rn) path_3 = self.scratch.refinenet3(path_4, layer_3_rn) path_2 = self.scratch.refinenet2(path_3, layer_2_rn) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) out = self.scratch.output_conv(path_1) return torch.squeeze(out, dim=1) ================================================ FILE: mvtracker/models/core/dpt/models.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from mvtracker.models.core.dpt.base_model import BaseModel from mvtracker.models.core.dpt.blocks import ( FeatureFusionBlock_custom, Interpolate, _make_encoder, ) from mvtracker.models.core.dpt.vit import forward_vit def _make_fusion_block(features, use_bn): return FeatureFusionBlock_custom( features, nn.ReLU(False), deconv=False, bn=use_bn, expand=False, align_corners=True, ) class DPT(BaseModel): def __init__( self, head, features=256, backbone="vitb_rn50_384", readout="project", channels_last=False, use_bn=True, enable_attention_hooks=False, ): super(DPT, self).__init__() self.channels_last = channels_last hooks = { "vitb_rn50_384": [0, 1, 8, 11], "vitb16_384": [2, 5, 8, 11], "vitl16_384": [5, 11, 17, 23], "vit_tiny_r_s16_p8_384": [0, 1, 2, 3], } # Instantiate backbone and reassemble blocks self.pretrained, self.scratch = _make_encoder( backbone, features, False, # Set to true of you want to train from scratch, uses ImageNet weights groups=1, expand=False, exportable=False, hooks=hooks[backbone], use_readout=readout, enable_attention_hooks=enable_attention_hooks, ) self.scratch.refinenet1 = _make_fusion_block(features, use_bn) self.scratch.refinenet2 = _make_fusion_block(features, use_bn) self.scratch.refinenet3 = _make_fusion_block(features, use_bn) self.scratch.refinenet4 = _make_fusion_block(features, use_bn) self.scratch.output_conv = head self.proj_out = nn.Sequential( nn.Conv2d( 256 + 512 + 384 + 384, 256, kernel_size=3, padding=1, padding_mode="zeros", ), nn.BatchNorm2d(128 * 2), nn.ReLU(True), nn.Conv2d( 128 * 2, 128, kernel_size=3, padding=1, padding_mode="zeros", ) ) def forward(self, x, only_enc=False): if self.channels_last == True: x.contiguous(memory_format=torch.channels_last) if only_enc: layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) a = (layer_1) b = ( F.interpolate( layer_2, scale_factor=2, mode="bilinear", align_corners=True, ) ) c = ( F.interpolate( layer_3, scale_factor=8, mode="bilinear", align_corners=True, ) ) d = ( F.interpolate( layer_4, scale_factor=16, mode="bilinear", align_corners=True, ) ) x = self.proj_out(torch.cat([a, b, c, d], dim=1)) return x else: layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) path_4 = self.scratch.refinenet4(layer_4_rn) path_3 = self.scratch.refinenet3(path_4, layer_3_rn) path_2 = self.scratch.refinenet2(path_3, layer_2_rn) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) _, _, H_out, W_out = path_1.size() path_2_up = F.interpolate(path_2, size=(H_out, W_out), mode="bilinear", align_corners=True) path_3_up = F.interpolate(path_3, size=(H_out, W_out), mode="bilinear", align_corners=True) path_4_up = F.interpolate(path_4, size=(H_out, W_out), mode="bilinear", align_corners=True) out = self.scratch.output_conv(path_1 + path_2_up + path_3_up + path_4_up) return out class DPTDepthModel(DPT): def __init__( self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs ): features = kwargs["features"] if "features" in kwargs else 256 self.scale = scale self.shift = shift self.invert = invert head = nn.Sequential( nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), Interpolate(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(True), nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True) if non_negative else nn.Identity(), nn.Identity(), ) super().__init__(head, **kwargs) if path is not None: self.load(path) def forward(self, x): inv_depth = super().forward(x).squeeze(dim=1) if self.invert: depth = self.scale * inv_depth + self.shift depth[depth < 1e-8] = 1e-8 depth = 1.0 / depth return depth else: return inv_depth class DPTEncoder(DPT): def __init__( self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs ): features = kwargs["features"] if "features" in kwargs else 256 self.scale = scale self.shift = shift head = nn.Sequential( nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), ) super().__init__(head, **kwargs) if path is not None: self.load(path) def forward(self, x): features = super().forward(x, only_enc=True).squeeze(dim=1) return features class DPTSegmentationModel(DPT): def __init__(self, num_classes, path=None, **kwargs): features = kwargs["features"] if "features" in kwargs else 256 kwargs["use_bn"] = True head = nn.Sequential( nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(features), nn.ReLU(True), nn.Dropout(0.1, False), nn.Conv2d(features, num_classes, kernel_size=1), Interpolate(scale_factor=2, mode="bilinear", align_corners=True), ) super().__init__(head, **kwargs) self.auxlayer = nn.Sequential( nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(features), nn.ReLU(True), nn.Dropout(0.1, False), nn.Conv2d(features, num_classes, kernel_size=1), ) if path is not None: self.load(path) ================================================ FILE: mvtracker/models/core/dpt/transforms.py ================================================ import cv2 import math import numpy as np def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): """Rezise the sample to ensure the given size. Keeps aspect ratio. Args: sample (dict): sample size (tuple): image size Returns: tuple: new size """ shape = list(sample["disparity"].shape) if shape[0] >= size[0] and shape[1] >= size[1]: return sample scale = [0, 0] scale[0] = size[0] / shape[0] scale[1] = size[1] / shape[1] scale = max(scale) shape[0] = math.ceil(scale * shape[0]) shape[1] = math.ceil(scale * shape[1]) # resize sample["image"] = cv2.resize( sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method ) sample["disparity"] = cv2.resize( sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST ) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST, ) sample["mask"] = sample["mask"].astype(bool) return tuple(shape) class Resize(object): """Resize sample to given size (width, height).""" def __init__( self, width, height, resize_target=True, keep_aspect_ratio=False, ensure_multiple_of=1, resize_method="lower_bound", image_interpolation_method=cv2.INTER_AREA, ): """Init. Args: width (int): desired output width height (int): desired output height resize_target (bool, optional): True: Resize the full sample (image, mask, target). False: Resize image only. Defaults to True. keep_aspect_ratio (bool, optional): True: Keep the aspect ratio of the input sample. Output sample might not have the given width and height, and resize behaviour depends on the parameter 'resize_method'. Defaults to False. ensure_multiple_of (int, optional): Output width and height is constrained to be multiple of this parameter. Defaults to 1. resize_method (str, optional): "lower_bound": Output will be at least as large as the given size. "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) "minimal": Scale as least as possible. (Output size might be smaller than given size.) Defaults to "lower_bound". """ self.__width = width self.__height = height self.__resize_target = resize_target self.__keep_aspect_ratio = keep_aspect_ratio self.__multiple_of = ensure_multiple_of self.__resize_method = resize_method self.__image_interpolation_method = image_interpolation_method def constrain_to_multiple_of(self, x, min_val=0, max_val=None): y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) if max_val is not None and y > max_val: y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) if y < min_val: y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) return y def get_size(self, width, height): # determine new height and width scale_height = self.__height / height scale_width = self.__width / width if self.__keep_aspect_ratio: if self.__resize_method == "lower_bound": # scale such that output size is lower bound if scale_width > scale_height: # fit width scale_height = scale_width else: # fit height scale_width = scale_height elif self.__resize_method == "upper_bound": # scale such that output size is upper bound if scale_width < scale_height: # fit width scale_height = scale_width else: # fit height scale_width = scale_height elif self.__resize_method == "minimal": # scale as least as possbile if abs(1 - scale_width) < abs(1 - scale_height): # fit width scale_height = scale_width else: # fit height scale_width = scale_height else: raise ValueError( f"resize_method {self.__resize_method} not implemented" ) if self.__resize_method == "lower_bound": new_height = self.constrain_to_multiple_of( scale_height * height, min_val=self.__height ) new_width = self.constrain_to_multiple_of( scale_width * width, min_val=self.__width ) elif self.__resize_method == "upper_bound": new_height = self.constrain_to_multiple_of( scale_height * height, max_val=self.__height ) new_width = self.constrain_to_multiple_of( scale_width * width, max_val=self.__width ) elif self.__resize_method == "minimal": new_height = self.constrain_to_multiple_of(scale_height * height) new_width = self.constrain_to_multiple_of(scale_width * width) else: raise ValueError(f"resize_method {self.__resize_method} not implemented") return (new_width, new_height) def __call__(self, sample): width, height = self.get_size( sample["image"].shape[1], sample["image"].shape[0] ) # resize sample sample["image"] = cv2.resize( sample["image"], (width, height), interpolation=self.__image_interpolation_method, ) if self.__resize_target: if "disparity" in sample: sample["disparity"] = cv2.resize( sample["disparity"], (width, height), interpolation=cv2.INTER_NEAREST, ) if "depth" in sample: sample["depth"] = cv2.resize( sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST ) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST, ) sample["mask"] = sample["mask"].astype(bool) return sample class NormalizeImage(object): """Normlize image by given mean and std.""" def __init__(self, mean, std): self.__mean = mean self.__std = std def __call__(self, sample): sample["image"] = (sample["image"] - self.__mean) / self.__std return sample class PrepareForNet(object): """Prepare sample for usage as network input.""" def __init__(self): pass def __call__(self, sample): image = np.transpose(sample["image"], (2, 0, 1)) sample["image"] = np.ascontiguousarray(image).astype(np.float32) if "mask" in sample: sample["mask"] = sample["mask"].astype(np.float32) sample["mask"] = np.ascontiguousarray(sample["mask"]) if "disparity" in sample: disparity = sample["disparity"].astype(np.float32) sample["disparity"] = np.ascontiguousarray(disparity) if "depth" in sample: depth = sample["depth"].astype(np.float32) sample["depth"] = np.ascontiguousarray(depth) return sample ================================================ FILE: mvtracker/models/core/dpt/vit.py ================================================ import types import math import timm import torch import torch.nn as nn import torch.nn.functional as F activations = {} def get_activation(name): def hook(model, input, output): activations[name] = output return hook attention = {} def get_attention(name): def hook(module, input, output): x = input[0] B, N, C = x.shape qkv = ( module.qkv(x) .reshape(B, N, 3, module.num_heads, C // module.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = ( qkv[0], qkv[1], qkv[2], ) # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * module.scale attn = attn.softmax(dim=-1) # [:,:,1,1:] attention[name] = attn return hook def get_mean_attention_map(attn, token, shape): attn = attn[:, :, token, 1:] attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float() attn = torch.nn.functional.interpolate( attn, size=shape[2:], mode="bicubic", align_corners=False ).squeeze(0) all_attn = torch.mean(attn, 0) return all_attn class Slice(nn.Module): def __init__(self, start_index=1): super(Slice, self).__init__() self.start_index = start_index def forward(self, x): return x[:, self.start_index:] class AddReadout(nn.Module): def __init__(self, start_index=1): super(AddReadout, self).__init__() self.start_index = start_index def forward(self, x): if self.start_index == 2: readout = (x[:, 0] + x[:, 1]) / 2 else: readout = x[:, 0] return x[:, self.start_index:] + readout.unsqueeze(1) class ProjectReadout(nn.Module): def __init__(self, in_features, start_index=1): super(ProjectReadout, self).__init__() self.start_index = start_index self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) def forward(self, x): readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:]) features = torch.cat((x[:, self.start_index:], readout), -1) return self.project(features) class Transpose(nn.Module): def __init__(self, dim0, dim1): super(Transpose, self).__init__() self.dim0 = dim0 self.dim1 = dim1 def forward(self, x): x = x.transpose(self.dim0, self.dim1) return x def forward_vit(pretrained, x): b, c, h, w = x.shape glob = pretrained.model.forward_flex(x) layer_1 = pretrained.activations["1"] layer_2 = pretrained.activations["2"] layer_3 = pretrained.activations["3"] layer_4 = pretrained.activations["4"] layer_1 = pretrained.act_postprocess1[0:2](layer_1) layer_2 = pretrained.act_postprocess2[0:2](layer_2) layer_3 = pretrained.act_postprocess3[0:2](layer_3) layer_4 = pretrained.act_postprocess4[0:2](layer_4) unflatten = nn.Sequential( nn.Unflatten( 2, torch.Size( [ h // pretrained.model.patch_size[1], w // pretrained.model.patch_size[0], ] ), ) ) if layer_1.ndim == 3: layer_1 = unflatten(layer_1) if layer_2.ndim == 3: layer_2 = unflatten(layer_2) if layer_3.ndim == 3: layer_3 = unflatten(layer_3) if layer_4.ndim == 3: layer_4 = unflatten(layer_4) layer_1 = pretrained.act_postprocess1[3: len(pretrained.act_postprocess1)](layer_1) layer_2 = pretrained.act_postprocess2[3: len(pretrained.act_postprocess2)](layer_2) layer_3 = pretrained.act_postprocess3[3: len(pretrained.act_postprocess3)](layer_3) layer_4 = pretrained.act_postprocess4[3: len(pretrained.act_postprocess4)](layer_4) return layer_1, layer_2, layer_3, layer_4 def _resize_pos_embed(self, posemb, gs_h, gs_w): posemb_tok, posemb_grid = ( posemb[:, : self.start_index], posemb[0, self.start_index:], ) gs_old = int(math.sqrt(len(posemb_grid))) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) posemb = torch.cat([posemb_tok, posemb_grid], dim=1) return posemb def forward_flex(self, x): b, c, h, w = x.shape pos_embed = self._resize_pos_embed( self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] ) B = x.shape[0] if hasattr(self.patch_embed, "backbone"): x = self.patch_embed.backbone(x) if isinstance(x, (list, tuple)): x = x[-1] # last feature if backbone outputs list/tuple of features x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) if getattr(self, "dist_token", None) is not None: cls_tokens = self.cls_token.expand( B, -1, -1 ) # stole cls_tokens impl from Phil Wang, thanks dist_token = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_token, x), dim=1) else: cls_tokens = self.cls_token.expand( B, -1, -1 ) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + pos_embed x = self.pos_drop(x) for blk in self.blocks: x = blk(x) x = self.norm(x) return x def get_readout_oper(vit_features, features, use_readout, start_index=1): if use_readout == "ignore": readout_oper = [Slice(start_index)] * len(features) elif use_readout == "add": readout_oper = [AddReadout(start_index)] * len(features) elif use_readout == "project": readout_oper = [ ProjectReadout(vit_features, start_index) for out_feat in features ] else: assert ( False ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" return readout_oper def _make_vit_b16_backbone( model, features=[96, 192, 384, 768], size=[384, 384], hooks=[2, 5, 8, 11], vit_features=768, use_readout="ignore", start_index=1, enable_attention_hooks=False, ): pretrained = nn.Module() pretrained.model = model pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) pretrained.activations = activations if enable_attention_hooks: pretrained.model.blocks[hooks[0]].attn.register_forward_hook( get_attention("attn_1") ) pretrained.model.blocks[hooks[1]].attn.register_forward_hook( get_attention("attn_2") ) pretrained.model.blocks[hooks[2]].attn.register_forward_hook( get_attention("attn_3") ) pretrained.model.blocks[hooks[3]].attn.register_forward_hook( get_attention("attn_4") ) pretrained.attention = attention readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) # 32, 48, 136, 384 pretrained.act_postprocess1 = nn.Sequential( readout_oper[0], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[0], kernel_size=1, stride=1, padding=0, ), nn.ConvTranspose2d( in_channels=features[0], out_channels=features[0], kernel_size=4, stride=4, padding=0, bias=True, dilation=1, groups=1, ), ) pretrained.act_postprocess2 = nn.Sequential( readout_oper[1], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[1], kernel_size=1, stride=1, padding=0, ), nn.ConvTranspose2d( in_channels=features[1], out_channels=features[1], kernel_size=2, stride=2, padding=0, bias=True, dilation=1, groups=1, ), ) pretrained.act_postprocess3 = nn.Sequential( readout_oper[2], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[2], kernel_size=1, stride=1, padding=0, ), ) pretrained.act_postprocess4 = nn.Sequential( readout_oper[3], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[3], kernel_size=1, stride=1, padding=0, ), nn.Conv2d( in_channels=features[3], out_channels=features[3], kernel_size=3, stride=2, padding=1, ), ) pretrained.model.start_index = start_index pretrained.model.patch_size = [16, 16] # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) pretrained.model._resize_pos_embed = types.MethodType( _resize_pos_embed, pretrained.model ) return pretrained def _make_vit_b_rn50_backbone( model, features=[256, 512, 768, 768], size=[384, 384], hooks=[0, 1, 8, 11], vit_features=384, use_vit_only=False, use_readout="ignore", start_index=1, enable_attention_hooks=False, ): pretrained = nn.Module() pretrained.model = model pretrained.model.patch_size = [32, 32] ps = pretrained.model.patch_size[0] if use_vit_only == True: pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) else: pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( get_activation("1") ) pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( get_activation("2") ) pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) if enable_attention_hooks: pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1")) pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2")) pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3")) pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4")) pretrained.attention = attention pretrained.activations = activations readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) if use_vit_only == True: pretrained.act_postprocess1 = nn.Sequential( readout_oper[0], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])), nn.Conv2d( in_channels=vit_features, out_channels=features[0], kernel_size=1, stride=1, padding=0, ), nn.ConvTranspose2d( in_channels=features[0], out_channels=features[0], kernel_size=4, stride=4, padding=0, bias=True, dilation=1, groups=1, ), ) pretrained.act_postprocess2 = nn.Sequential( readout_oper[1], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])), nn.Conv2d( in_channels=vit_features, out_channels=features[1], kernel_size=1, stride=1, padding=0, ), nn.ConvTranspose2d( in_channels=features[1], out_channels=features[1], kernel_size=2, stride=2, padding=0, bias=True, dilation=1, groups=1, ), ) else: pretrained.act_postprocess1 = nn.Sequential( nn.Identity(), nn.Identity(), nn.Identity() ) pretrained.act_postprocess2 = nn.Sequential( nn.Identity(), nn.Identity(), nn.Identity() ) pretrained.act_postprocess3 = nn.Sequential( readout_oper[2], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])), nn.Conv2d( in_channels=vit_features, out_channels=features[2], kernel_size=1, stride=1, padding=0, ), ) pretrained.act_postprocess4 = nn.Sequential( readout_oper[3], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])), nn.Conv2d( in_channels=vit_features, out_channels=features[3], kernel_size=1, stride=1, padding=0, ), nn.Conv2d( in_channels=features[3], out_channels=features[3], kernel_size=3, stride=2, padding=1, ), ) pretrained.model.start_index = start_index pretrained.model.patch_size = [32, 32] # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model._resize_pos_embed = types.MethodType( _resize_pos_embed, pretrained.model ) return pretrained def _make_pretrained_vitb_rn50_384( pretrained, use_readout="ignore", hooks=None, use_vit_only=False, enable_attention_hooks=False, ): # model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) # model = timm.create_model("vit_tiny_r_s16_p8_384", pretrained=pretrained) model = timm.create_model("vit_small_r26_s32_384", pretrained=pretrained) hooks = [0, 1, 8, 11] if hooks == None else hooks return _make_vit_b_rn50_backbone( model, features=[128, 256, 384, 384], size=[384, 384], hooks=hooks, use_vit_only=use_vit_only, use_readout=use_readout, enable_attention_hooks=enable_attention_hooks, ) def _make_pretrained_vit_tiny( pretrained, use_readout="ignore", hooks=None, use_vit_only=False, enable_attention_hooks=False, ): # model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) model = timm.create_model("vit_tiny_r_s16_p8_384", pretrained=pretrained) import ipdb; ipdb.set_trace() hooks = [0, 1, 8, 11] if hooks == None else hooks return _make_vit_tiny_backbone( model, features=[256, 512, 768, 768], size=[384, 384], hooks=hooks, use_vit_only=use_vit_only, use_readout=use_readout, enable_attention_hooks=enable_attention_hooks, ) def _make_pretrained_vitl16_384( pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False ): model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) hooks = [5, 11, 17, 23] if hooks == None else hooks return _make_vit_b16_backbone( model, features=[256, 512, 1024, 1024], hooks=hooks, vit_features=1024, use_readout=use_readout, enable_attention_hooks=enable_attention_hooks, ) def _make_pretrained_vitb16_384( pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False ): model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks return _make_vit_b16_backbone( model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout, enable_attention_hooks=enable_attention_hooks, ) def _make_pretrained_deitb16_384( pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False ): model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks return _make_vit_b16_backbone( model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout, enable_attention_hooks=enable_attention_hooks, ) def _make_pretrained_deitb16_distil_384( pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False ): model = timm.create_model( "vit_deit_base_distilled_patch16_384", pretrained=pretrained ) hooks = [2, 5, 8, 11] if hooks == None else hooks return _make_vit_b16_backbone( model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout, start_index=2, enable_attention_hooks=enable_attention_hooks, ) ================================================ FILE: mvtracker/models/core/dynamic3dgs/LICENSE.md ================================================ ## Notes on license: The code in this repository (except in external.py) is licensed under the MIT licence. However, for this code to run it uses the cuda rasterizer code from [here](https://github.com/JonathonLuiten/diff-gaussian-rasterization-w-depth), as well as various code in [external.py](./external.py) which has been taken or adapted from [here](https://github.com/graphdeco-inria/gaussian-splatting). These are required for this project, and for these a much more restrictive license from Inria applies which can be found [here](https://github.com/graphdeco-inria/gaussian-splatting/blob/main/LICENSE.md). This requires express permission (licensing agreements) from Inria for use in any commercial application, but is otherwise freely freely distributed for research and experimentation. MIT License for the code in this repository where it applies (see above) is below: ## License: Copyright (c) 2023 Jonathon Luiten Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: mvtracker/models/core/dynamic3dgs/colormap.py ================================================ import numpy as np colormap = np.array([ # 0 , 0, 0, 0.5020, 0, 0, 0, 0.5020, 0, 0.5020, 0.5020, 0, 0, 0, 0.5020, 0.5020, 0, 0.5020, 0, 0.5020, 0.5020, # 0.5020, 0.5020, 0.5020, 0.2510, 0, 0, 0.7529, 0, 0, 0.2510, 0.5020, 0, 0.7529, 0.5020, 0, 0.2510, 0, 0.5020, 0.7529, 0, 0.5020, 0.2510, 0.5020, 0.5020, 0.7529, 0.5020, 0.5020, 0, 0.2510, 0, 0.5020, 0.2510, 0, 0, 0.7529, 0, 0.5020, 0.7529, 0, 0, 0.2510, 0.5020, 0.5020, 0.2510, 0.5020, 0, 0.7529, 0.5020, 0.5020, 0.7529, 0.5020, 0.2510, 0.2510, 0, 0.7529, 0.2510, 0, 0.2510, 0.7529, 0, 0.7529, 0.7529, 0, 0.2510, 0.2510, 0.5020, 0.7529, 0.2510, 0.5020, 0.2510, 0.7529, 0.5020, 0.7529, 0.7529, 0.5020, 0, 0, 0.2510, 0.5020, 0, 0.2510, 0, 0.5020, 0.2510, 0.5020, 0.5020, 0.2510, 0, 0, 0.7529, 0.5020, 0, 0.7529, 0, 0.5020, 0.7529, 0.5020, 0.5020, 0.7529, 0.2510, 0, 0.2510, 0.7529, 0, 0.2510, 0.2510, 0.5020, 0.2510, 0.7529, 0.5020, 0.2510, 0.2510, 0, 0.7529, 0.7529, 0, 0.7529, 0.2510, 0.5020, 0.7529, 0.7529, 0.5020, 0.7529, 0, 0.2510, 0.2510, 0.5020, 0.2510, 0.2510, 0, 0.7529, 0.2510, 0.5020, 0.7529, 0.2510, 0, 0.2510, 0.7529, 0.5020, 0.2510, 0.7529, 0, 0.7529, 0.7529, 0.5020, 0.7529, 0.7529, # 0.2510, 0.2510, 0.2510, 0.7529, 0.2510, 0.2510, 0.2510, 0.7529, 0.2510, 0.7529, 0.7529, 0.2510, 0.2510, 0.2510, 0.7529, 0.7529, 0.2510, 0.7529, 0.2510, 0.7529, 0.7529, # 0.7529, 0.7529, 0.7529, 0.1255, 0, 0, 0.6275, 0, 0, 0.1255, 0.5020, 0, 0.6275, 0.5020, 0, 0.1255, 0, 0.5020, 0.6275, 0, 0.5020, 0.1255, 0.5020, 0.5020, 0.6275, 0.5020, 0.5020, 0.3765, 0, 0, 0.8784, 0, 0, 0.3765, 0.5020, 0, 0.8784, 0.5020, 0, 0.3765, 0, 0.5020, 0.8784, 0, 0.5020, 0.3765, 0.5020, 0.5020, 0.8784, 0.5020, 0.5020, 0.1255, 0.2510, 0, 0.6275, 0.2510, 0, 0.1255, 0.7529, 0, 0.6275, 0.7529, 0, 0.1255, 0.2510, 0.5020, 0.6275, 0.2510, 0.5020, 0.1255, 0.7529, 0.5020, 0.6275, 0.7529, 0.5020, 0.3765, 0.2510, 0, 0.8784, 0.2510, 0, 0.3765, 0.7529, 0, 0.8784, 0.7529, 0, 0.3765, 0.2510, 0.5020, 0.8784, 0.2510, 0.5020, 0.3765, 0.7529, 0.5020, 0.8784, 0.7529, 0.5020, 0.1255, 0, 0.2510, 0.6275, 0, 0.2510, 0.1255, 0.5020, 0.2510, 0.6275, 0.5020, 0.2510, 0.1255, 0, 0.7529, 0.6275, 0, 0.7529, 0.1255, 0.5020, 0.7529, 0.6275, 0.5020, 0.7529, 0.3765, 0, 0.2510, 0.8784, 0, 0.2510, 0.3765, 0.5020, 0.2510, 0.8784, 0.5020, 0.2510, 0.3765, 0, 0.7529, 0.8784, 0, 0.7529, 0.3765, 0.5020, 0.7529, 0.8784, 0.5020, 0.7529, 0.1255, 0.2510, 0.2510, 0.6275, 0.2510, 0.2510, 0.1255, 0.7529, 0.2510, 0.6275, 0.7529, 0.2510, 0.1255, 0.2510, 0.7529, 0.6275, 0.2510, 0.7529, 0.1255, 0.7529, 0.7529, 0.6275, 0.7529, 0.7529, 0.3765, 0.2510, 0.2510, 0.8784, 0.2510, 0.2510, 0.3765, 0.7529, 0.2510, 0.8784, 0.7529, 0.2510, 0.3765, 0.2510, 0.7529, 0.8784, 0.2510, 0.7529, 0.3765, 0.7529, 0.7529, 0.8784, 0.7529, 0.7529, 0, 0.1255, 0, 0.5020, 0.1255, 0, 0, 0.6275, 0, 0.5020, 0.6275, 0, 0, 0.1255, 0.5020, 0.5020, 0.1255, 0.5020, 0, 0.6275, 0.5020, 0.5020, 0.6275, 0.5020, 0.2510, 0.1255, 0, 0.7529, 0.1255, 0, 0.2510, 0.6275, 0, 0.7529, 0.6275, 0, 0.2510, 0.1255, 0.5020, 0.7529, 0.1255, 0.5020, 0.2510, 0.6275, 0.5020, 0.7529, 0.6275, 0.5020, 0, 0.3765, 0, 0.5020, 0.3765, 0, 0, 0.8784, 0, 0.5020, 0.8784, 0, 0, 0.3765, 0.5020, 0.5020, 0.3765, 0.5020, 0, 0.8784, 0.5020, 0.5020, 0.8784, 0.5020, 0.2510, 0.3765, 0, 0.7529, 0.3765, 0, 0.2510, 0.8784, 0, 0.7529, 0.8784, 0, 0.2510, 0.3765, 0.5020, 0.7529, 0.3765, 0.5020, 0.2510, 0.8784, 0.5020, 0.7529, 0.8784, 0.5020, 0, 0.1255, 0.2510, 0.5020, 0.1255, 0.2510, 0, 0.6275, 0.2510, 0.5020, 0.6275, 0.2510, 0, 0.1255, 0.7529, 0.5020, 0.1255, 0.7529, 0, 0.6275, 0.7529, 0.5020, 0.6275, 0.7529, 0.2510, 0.1255, 0.2510, 0.7529, 0.1255, 0.2510, 0.2510, 0.6275, 0.2510, 0.7529, 0.6275, 0.2510, 0.2510, 0.1255, 0.7529, 0.7529, 0.1255, 0.7529, 0.2510, 0.6275, 0.7529, 0.7529, 0.6275, 0.7529, 0, 0.3765, 0.2510, 0.5020, 0.3765, 0.2510, 0, 0.8784, 0.2510, 0.5020, 0.8784, 0.2510, 0, 0.3765, 0.7529, 0.5020, 0.3765, 0.7529, 0, 0.8784, 0.7529, 0.5020, 0.8784, 0.7529, 0.2510, 0.3765, 0.2510, 0.7529, 0.3765, 0.2510, 0.2510, 0.8784, 0.2510, 0.7529, 0.8784, 0.2510, 0.2510, 0.3765, 0.7529, 0.7529, 0.3765, 0.7529, 0.2510, 0.8784, 0.7529, 0.7529, 0.8784, 0.7529, 0.1255, 0.1255, 0, 0.6275, 0.1255, 0, 0.1255, 0.6275, 0, 0.6275, 0.6275, 0, 0.1255, 0.1255, 0.5020, 0.6275, 0.1255, 0.5020, 0.1255, 0.6275, 0.5020, 0.6275, 0.6275, 0.5020, 0.3765, 0.1255, 0, 0.8784, 0.1255, 0, 0.3765, 0.6275, 0, 0.8784, 0.6275, 0, 0.3765, 0.1255, 0.5020, 0.8784, 0.1255, 0.5020, 0.3765, 0.6275, 0.5020, 0.8784, 0.6275, 0.5020, 0.1255, 0.3765, 0, 0.6275, 0.3765, 0, 0.1255, 0.8784, 0, 0.6275, 0.8784, 0, 0.1255, 0.3765, 0.5020, 0.6275, 0.3765, 0.5020, 0.1255, 0.8784, 0.5020, 0.6275, 0.8784, 0.5020, 0.3765, 0.3765, 0, 0.8784, 0.3765, 0, 0.3765, 0.8784, 0, 0.8784, 0.8784, 0, 0.3765, 0.3765, 0.5020, 0.8784, 0.3765, 0.5020, 0.3765, 0.8784, 0.5020, 0.8784, 0.8784, 0.5020, 0.1255, 0.1255, 0.2510, 0.6275, 0.1255, 0.2510, 0.1255, 0.6275, 0.2510, 0.6275, 0.6275, 0.2510, 0.1255, 0.1255, 0.7529, 0.6275, 0.1255, 0.7529, 0.1255, 0.6275, 0.7529, 0.6275, 0.6275, 0.7529, 0.3765, 0.1255, 0.2510, 0.8784, 0.1255, 0.2510, 0.3765, 0.6275, 0.2510, 0.8784, 0.6275, 0.2510, 0.3765, 0.1255, 0.7529, 0.8784, 0.1255, 0.7529, 0.3765, 0.6275, 0.7529, 0.8784, 0.6275, 0.7529, 0.1255, 0.3765, 0.2510, 0.6275, 0.3765, 0.2510, 0.1255, 0.8784, 0.2510, 0.6275, 0.8784, 0.2510, 0.1255, 0.3765, 0.7529, 0.6275, 0.3765, 0.7529, 0.1255, 0.8784, 0.7529, 0.6275, 0.8784, 0.7529, 0.3765, 0.3765, 0.2510, 0.8784, 0.3765, 0.2510, 0.3765, 0.8784, 0.2510, 0.8784, 0.8784, 0.2510, 0.3765, 0.3765, 0.7529, 0.8784, 0.3765, 0.7529, 0.3765, 0.8784, 0.7529, 0.8784, 0.8784, 0.7529, # 1.0, 1.0, 1.0, ]).reshape(-1, 3) ================================================ FILE: mvtracker/models/core/dynamic3dgs/export_depths_from_pretrained_checkpoint.py ================================================ import json import os from pathlib import Path import numpy as np import torch from PIL import Image from diff_gaussian_rasterization import GaussianRasterizer as Renderer from tqdm import tqdm from .helpers import setup_camera def load_scene_data(params_path, seg_as_col=False): """Load 3D scene data from file.""" params = dict(np.load(params_path, allow_pickle=True)) params = {k: torch.tensor(v).cuda().float() for k, v in params.items()} is_fg = params['seg_colors'][:, 0] > 0.5 scene_data = [] for t in range(len(params['means3D'])): rendervar = { 'means3D': params['means3D'][t], 'colors_precomp': params['rgb_colors'][t] if not seg_as_col else params['seg_colors'], 'rotations': torch.nn.functional.normalize(params['unnorm_rotations'][t]), 'opacities': torch.sigmoid(params['logit_opacities']), 'scales': torch.exp(params['log_scales']), 'means2D': torch.zeros_like(params['means3D'][0], device="cuda") } scene_data.append(rendervar) return scene_data, is_fg def render(w, h, k, w2c, timestep_data, near=0.01, far=100.0): """Render scene using Gaussian Rasterization.""" with torch.no_grad(): cam = setup_camera(w, h, k, w2c, near, far) im, _, depth = Renderer(raster_settings=cam)(**timestep_data) return im, depth def export_depth(scene_root, output_root, checkpoint_path): scene_data, is_fg = load_scene_data(os.path.join(checkpoint_path, "params.npz")) md_train = json.load(open(os.path.join(scene_root, "train_meta.json"), "r")) md_test = json.load(open(os.path.join(scene_root, "test_meta.json"), "r")) views = sorted(list(set(md_train["cam_id"][0]) | set(md_test["cam_id"][0]))) assert list(range(31)) == views, "We expect exactly 31 views: from 0 to 30." n_frames = len(md_train['fn']) n_views = len(views) # Check that the selected views are in the training set view_paths = [] for view_idx in views: view_path = scene_root / "ims" / f"{view_idx}" assert view_path.exists() view_paths.append(view_path) frame_paths = [sorted(view_path.glob("*.jpg")) for view_path in view_paths] assert all(len(frame_paths[v]) == n_frames for v in range(n_views)) assert len(scene_data) == n_frames # Load the camera parameters fx, fy, cx, cy, extrinsics = [], [], [], [], [] for view_idx in views: fx_current, fy_current, cx_current, cy_current, extrinsics_current = [], [], [], [], [] for t in range(n_frames): if view_idx in md_train['cam_id'][t]: md = md_train elif view_idx in md_test['cam_id'][t]: md = md_test else: raise ValueError(f"Camera {view_idx} not found in any of the meta files") view_idx_in_array = md['cam_id'][t].index(view_idx) k = md['k'][t][view_idx_in_array] w2c = np.array(md['w2c'][t][view_idx_in_array]) fx_current.append(k[0][0]) fy_current.append(k[1][1]) cx_current.append(k[0][2]) cy_current.append(k[1][2]) extrinsics_current.append(w2c) assert all(np.equal(fx_current[0], fx_current[t]).all() for t in range(1, n_frames)) assert all(np.equal(fy_current[0], fy_current[t]).all() for t in range(1, n_frames)) assert all(np.equal(cx_current[0], cx_current[t]).all() for t in range(1, n_frames)) assert all(np.equal(cy_current[0], cy_current[t]).all() for t in range(1, n_frames)) assert all(np.equal(extrinsics_current[0], extrinsics_current[t]).all() for t in range(1, n_frames)) fx.append(fx_current[0]) fy.append(fy_current[0]) cx.append(cx_current[0]) cy.append(cy_current[0]) extrinsics.append(extrinsics_current[0]) fx = torch.tensor(fx).float() fy = torch.tensor(fy).float() cx = torch.tensor(cx).float() cy = torch.tensor(cy).float() k = torch.eye(3).float()[None].repeat(n_views, 1, 1) k[:, 0, 0] = fx k[:, 1, 1] = fy k[:, 0, 2] = cx k[:, 1, 2] = cy extrinsics = torch.from_numpy(np.stack(extrinsics)).float() # Render and save the depths os.makedirs(output_root, exist_ok=True) rgbs = np.stack([ np.stack([ np.array(Image.open(frame_paths[v][t])) for t in range(n_frames) ]) for v in range(n_views) ]) h, w = rgbs.shape[2], rgbs.shape[3] for v, view_idx in enumerate(views): depths = [] for t in range(n_frames): im, depth = render(w, h, k[v].numpy(), extrinsics[v].numpy(), scene_data[t]) depths.append(depth.cpu().numpy()[0]) depths = np.stack(depths) np.save(output_root / f"depths_{view_idx:02d}.npy", depths) if __name__ == "__main__": print("Exporting depths from pretrained checkpoints") for sequence_name in tqdm(["basketball", "boxes", "football", "juggle", "softball", "tennis"]): scene_root = Path(f"./datasets/panoptic_d3dgs/{sequence_name}") output_path = Path(f"./datasets/panoptic_d3dgs/{sequence_name}/dynamic3dgs_depth") checkpoint_path = Path(f"./dynamic3dgs/output/pretrained/{sequence_name}") export_depth(scene_root, output_path, checkpoint_path) ================================================ FILE: mvtracker/models/core/dynamic3dgs/external.py ================================================ """ # Copyright (C) 2023, Inria # GRAPHDECO research group, https://team.inria.fr/graphdeco # All rights reserved. # # This software is free for non-commercial, research and evaluation use # under the terms of the LICENSE.md file found here: # https://github.com/graphdeco-inria/gaussian-splatting/blob/main/LICENSE.md # # For inquiries contact george.drettakis@inria.fr ####################################################################################################################### ##### NOTE: CODE IN THIS FILE IS NOT INCLUDED IN THE OVERALL PROJECT'S MIT LICENSE ##### ##### USE OF THIS CODE FOLLOWS THE COPYRIGHT NOTICE ABOVE ##### ####################################################################################################################### """ import torch import torch.nn.functional as func from math import exp from torch.autograd import Variable def build_rotation(q): norm = torch.sqrt(q[:, 0] * q[:, 0] + q[:, 1] * q[:, 1] + q[:, 2] * q[:, 2] + q[:, 3] * q[:, 3]) q = q / norm[:, None] rot = torch.zeros((q.size(0), 3, 3), device='cuda') r = q[:, 0] x = q[:, 1] y = q[:, 2] z = q[:, 3] rot[:, 0, 0] = 1 - 2 * (y * y + z * z) rot[:, 0, 1] = 2 * (x * y - r * z) rot[:, 0, 2] = 2 * (x * z + r * y) rot[:, 1, 0] = 2 * (x * y + r * z) rot[:, 1, 1] = 1 - 2 * (x * x + z * z) rot[:, 1, 2] = 2 * (y * z - r * x) rot[:, 2, 0] = 2 * (x * z - r * y) rot[:, 2, 1] = 2 * (y * z + r * x) rot[:, 2, 2] = 1 - 2 * (x * x + y * y) return rot def calc_mse(img1, img2): return ((img1 - img2) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) def calc_psnr(img1, img2): mse = ((img1 - img2) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) return 20 * torch.log10(1.0 / torch.sqrt(mse)) def gaussian(window_size, sigma): gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) return gauss / gauss.sum() def create_window(window_size, channel): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) return window def calc_ssim(img1, img2, window_size=11, size_average=True): channel = img1.size(-3) window = create_window(window_size, channel) if img1.is_cuda: window = window.cuda(img1.get_device()) window = window.type_as(img1) return _ssim(img1, img2, window, window_size, channel, size_average) def _ssim(img1, img2, window, window_size, channel, size_average=True): mu1 = func.conv2d(img1, window, padding=window_size // 2, groups=channel) mu2 = func.conv2d(img2, window, padding=window_size // 2, groups=channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 sigma1_sq = func.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq sigma2_sq = func.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq sigma12 = func.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 c1 = 0.01 ** 2 c2 = 0.03 ** 2 ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) if size_average: return ssim_map.mean() else: return ssim_map.mean(1).mean(1).mean(1) def accumulate_mean2d_gradient(variables): variables['means2D_gradient_accum'][variables['seen']] += torch.norm( variables['means2D'].grad[variables['seen'], :2], dim=-1) variables['denom'][variables['seen']] += 1 return variables def update_params_and_optimizer(new_params, params, optimizer): for k, v in new_params.items(): group = [x for x in optimizer.param_groups if x["name"] == k][0] stored_state = optimizer.state.get(group['params'][0], None) stored_state["exp_avg"] = torch.zeros_like(v) stored_state["exp_avg_sq"] = torch.zeros_like(v) del optimizer.state[group['params'][0]] group["params"][0] = torch.nn.Parameter(v.requires_grad_(True)) optimizer.state[group['params'][0]] = stored_state params[k] = group["params"][0] return params def cat_params_to_optimizer(new_params, params, optimizer): for k, v in new_params.items(): group = [g for g in optimizer.param_groups if g['name'] == k][0] stored_state = optimizer.state.get(group['params'][0], None) if stored_state is not None: stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(v)), dim=0) stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(v)), dim=0) del optimizer.state[group['params'][0]] group["params"][0] = torch.nn.Parameter(torch.cat((group["params"][0], v), dim=0).requires_grad_(True)) optimizer.state[group['params'][0]] = stored_state params[k] = group["params"][0] else: group["params"][0] = torch.nn.Parameter(torch.cat((group["params"][0], v), dim=0).requires_grad_(True)) params[k] = group["params"][0] return params def remove_points(to_remove, params, variables, optimizer): to_keep = ~to_remove keys = [k for k in params.keys() if k not in ['cam_m', 'cam_c']] for k in keys: group = [g for g in optimizer.param_groups if g['name'] == k][0] stored_state = optimizer.state.get(group['params'][0], None) if stored_state is not None: stored_state["exp_avg"] = stored_state["exp_avg"][to_keep] stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][to_keep] del optimizer.state[group['params'][0]] group["params"][0] = torch.nn.Parameter((group["params"][0][to_keep].requires_grad_(True))) optimizer.state[group['params'][0]] = stored_state params[k] = group["params"][0] else: group["params"][0] = torch.nn.Parameter(group["params"][0][to_keep].requires_grad_(True)) params[k] = group["params"][0] variables['means2D_gradient_accum'] = variables['means2D_gradient_accum'][to_keep] variables['denom'] = variables['denom'][to_keep] variables['max_2D_radius'] = variables['max_2D_radius'][to_keep] return params, variables def inverse_sigmoid(x): return torch.log(x / (1 - x)) def densify(params, variables, optimizer, i): if i <= 5000: variables = accumulate_mean2d_gradient(variables) grad_thresh = 0.0002 if (i >= 500) and (i % 100 == 0): grads = variables['means2D_gradient_accum'] / variables['denom'] grads[grads.isnan()] = 0.0 to_clone = torch.logical_and(grads >= grad_thresh, ( torch.max(torch.exp(params['log_scales']), dim=1).values <= 0.01 * variables['scene_radius'])) new_params = {k: v[to_clone] for k, v in params.items() if k not in ['cam_m', 'cam_c']} params = cat_params_to_optimizer(new_params, params, optimizer) num_pts = params['means3D'].shape[0] padded_grad = torch.zeros(num_pts, device="cuda") padded_grad[:grads.shape[0]] = grads to_split = torch.logical_and(padded_grad >= grad_thresh, torch.max(torch.exp(params['log_scales']), dim=1).values > 0.01 * variables[ 'scene_radius']) n = 2 # number to split into new_params = {k: v[to_split].repeat(n, 1) for k, v in params.items() if k not in ['cam_m', 'cam_c']} stds = torch.exp(params['log_scales'])[to_split].repeat(n, 1) means = torch.zeros((stds.size(0), 3), device="cuda") samples = torch.normal(mean=means, std=stds) rots = build_rotation(params['unnorm_rotations'][to_split]).repeat(n, 1, 1) new_params['means3D'] += torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) new_params['log_scales'] = torch.log(torch.exp(new_params['log_scales']) / (0.8 * n)) params = cat_params_to_optimizer(new_params, params, optimizer) num_pts = params['means3D'].shape[0] variables['means2D_gradient_accum'] = torch.zeros(num_pts, device="cuda") variables['denom'] = torch.zeros(num_pts, device="cuda") variables['max_2D_radius'] = torch.zeros(num_pts, device="cuda") to_remove = torch.cat((to_split, torch.zeros(n * to_split.sum(), dtype=torch.bool, device="cuda"))) params, variables = remove_points(to_remove, params, variables, optimizer) remove_threshold = 0.25 if i == 5000 else 0.005 to_remove = (torch.sigmoid(params['logit_opacities']) < remove_threshold).squeeze() if i >= 3000: big_points_ws = torch.exp(params['log_scales']).max(dim=1).values > 0.1 * variables['scene_radius'] to_remove = torch.logical_or(to_remove, big_points_ws) params, variables = remove_points(to_remove, params, variables, optimizer) torch.cuda.empty_cache() if i > 0 and i % 3000 == 0: new_params = {'logit_opacities': inverse_sigmoid(torch.ones_like(params['logit_opacities']) * 0.01)} params = update_params_and_optimizer(new_params, params, optimizer) return params, variables ================================================ FILE: mvtracker/models/core/dynamic3dgs/helpers.py ================================================ import os import numpy as np import open3d as o3d import torch from diff_gaussian_rasterization import GaussianRasterizationSettings as Camera def setup_camera(w, h, k, w2c, near=0.01, far=100): fx, fy, cx, cy = k[0][0], k[1][1], k[0][2], k[1][2] w2c = torch.tensor(w2c).cuda().float() cam_center = torch.inverse(w2c)[:3, 3] w2c = w2c.unsqueeze(0).transpose(1, 2) opengl_proj = torch.tensor([[2 * fx / w, 0.0, -(w - 2 * cx) / w, 0.0], [0.0, 2 * fy / h, -(h - 2 * cy) / h, 0.0], [0.0, 0.0, far / (far - near), -(far * near) / (far - near)], [0.0, 0.0, 1.0, 0.0]]).cuda().float().unsqueeze(0).transpose(1, 2) full_proj = w2c.bmm(opengl_proj) cam = Camera( image_height=h, image_width=w, tanfovx=w / (2 * fx), tanfovy=h / (2 * fy), bg=torch.tensor([0, 0, 0], dtype=torch.float32, device="cuda"), scale_modifier=1.0, viewmatrix=w2c, projmatrix=full_proj, sh_degree=0, campos=cam_center, prefiltered=False ) return cam def params2rendervar(params): rendervar = { 'means3D': params['means3D'], 'colors_precomp': params['rgb_colors'], 'rotations': torch.nn.functional.normalize(params['unnorm_rotations']), 'opacities': torch.sigmoid(params['logit_opacities']), 'scales': torch.exp(params['log_scales']), 'means2D': torch.zeros_like(params['means3D'], requires_grad=True, device="cuda") + 0 } return rendervar def l1_loss_v1(x, y): return torch.abs((x - y)).mean() def l1_loss_v2(x, y): return (torch.abs(x - y).sum(-1)).mean() def weighted_l2_loss_v1(x, y, w): return torch.sqrt(((x - y) ** 2) * w + 1e-20).mean() def weighted_l2_loss_v2(x, y, w): return torch.sqrt(((x - y) ** 2).sum(-1) * w + 1e-20).mean() def quat_mult(q1, q2): w1, x1, y1, z1 = q1.T w2, x2, y2, z2 = q2.T w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 return torch.stack([w, x, y, z]).T def o3d_knn(pts, num_knn): indices = [] sq_dists = [] pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(np.ascontiguousarray(pts, np.float64)) pcd_tree = o3d.geometry.KDTreeFlann(pcd) for p in pcd.points: [_, i, d] = pcd_tree.search_knn_vector_3d(p, num_knn + 1) indices.append(i[1:]) sq_dists.append(d[1:]) return np.array(sq_dists), np.array(indices) def params2cpu(params, is_initial_timestep): if is_initial_timestep: res = {k: v.detach().cpu().contiguous().numpy() for k, v in params.items()} else: res = {k: v.detach().cpu().contiguous().numpy() for k, v in params.items() if k in ['means3D', 'rgb_colors', 'unnorm_rotations']} return res def save_params(output_params, seq, exp): to_save = {} for k in output_params[0].keys(): if k in output_params[1].keys(): to_save[k] = np.stack([params[k] for params in output_params]) else: to_save[k] = output_params[0][k] os.makedirs(f"./output/{exp}/{seq}", exist_ok=True) np.savez(f"./output/{exp}/{seq}/params", **to_save) ================================================ FILE: mvtracker/models/core/dynamic3dgs/merge_tapvid3d_per_camera_annotations.py ================================================ import json import os import warnings from pathlib import Path import matplotlib import numpy as np import rerun as rr import torch from PIL import Image from diff_gaussian_rasterization import GaussianRasterizer as Renderer from tqdm import tqdm from .helpers import setup_camera from .visualize import log_tracks_to_rerun def to_homogeneous(x): return np.concatenate([x, np.ones_like(x[..., :1])], axis=-1) def from_homogeneous(x, assert_homogeneous_part_is_equal_to_1=False, eps=0.001): if assert_homogeneous_part_is_equal_to_1: assert np.allclose(x[..., -1:], 1, atol=eps), f"Expected homogeneous part to be 1, got {x[..., -1:]}" return x[..., :-1] / x[..., -1:] def load_scene_data(params_path, seg_as_col=False): """Load 3D scene data from file.""" params = dict(np.load(params_path, allow_pickle=True)) params = {k: torch.tensor(v).cuda().float() for k, v in params.items()} is_fg = params['seg_colors'][:, 0] > 0.5 scene_data = [] for t in range(len(params['means3D'])): rendervar = { 'means3D': params['means3D'][t], 'colors_precomp': params['rgb_colors'][t] if not seg_as_col else params['seg_colors'], 'rotations': torch.nn.functional.normalize(params['unnorm_rotations'][t]), 'opacities': torch.sigmoid(params['logit_opacities']), 'scales': torch.exp(params['log_scales']), 'means2D': torch.zeros_like(params['means3D'][0], device="cuda") } scene_data.append(rendervar) return scene_data, is_fg def render(h, w, k, w2c, timestep_data, near=0.01, far=100.0): """Render scene using Gaussian Rasterization.""" with torch.no_grad(): cam = setup_camera(w, h, k, w2c, near, far) im, _, depth = Renderer(raster_settings=cam)(**timestep_data) return im, depth def merge_annotations( scene_root, checkpoint_path, tapvid3d_annotation_paths, nearest_neighbor_distance_threshold_for_visibility=0.015, skip_if_output_already_exists=False, assert_query_points_project_to_trajectories_in_tapvid3d_annotation=False, rerun_logging=False, rerun_stream_only=False, rerun_views_to_viz=(27, 16, 1), rerun_log_rgb=True, rerun_log_d3dgs_rgb=False, rerun_log_d3dgs_depth=False, rerun_log_d3dgs_point_cloud=True, rerun_log_tracks=True, rerun_log_n_skip_t=1, ): output_annotation_path = scene_root / "tapvid3d_annotations.npz" if skip_if_output_already_exists and output_annotation_path.exists(): print(f"Output file {output_annotation_path} already exists, skipping.") return scene_data, is_fg = load_scene_data(os.path.join(checkpoint_path, "params.npz")) md_train = json.load(open(os.path.join(scene_root, "train_meta.json"), "r")) md_test = json.load(open(os.path.join(scene_root, "test_meta.json"), "r")) views = sorted(list(set(md_train["cam_id"][0]) | set(md_test["cam_id"][0]))) assert list(range(31)) == views, "We expect exactly 31 views: from 0 to 30." n_frames = len(md_train['fn']) n_views = len(views) # Check that the selected views are in the training set view_paths = [] for view_idx in views: view_path = scene_root / "ims" / f"{view_idx}" assert view_path.exists() view_paths.append(view_path) frame_paths = [sorted(view_path.glob("*.jpg")) for view_path in view_paths] assert all(len(frame_paths[v]) == n_frames for v in range(n_views)) assert len(scene_data) == n_frames # Load the camera parameters fx, fy, cx, cy, extrinsics = [], [], [], [], [] for view_idx in views: fx_current, fy_current, cx_current, cy_current, extrinsics_current = [], [], [], [], [] for t in range(n_frames): if view_idx in md_train['cam_id'][t]: md = md_train elif view_idx in md_test['cam_id'][t]: md = md_test else: raise ValueError(f"Camera {view_idx} not found in any of the meta files") view_idx_in_array = md['cam_id'][t].index(view_idx) k = md['k'][t][view_idx_in_array] w2c = np.array(md['w2c'][t][view_idx_in_array]) fx_current.append(k[0][0]) fy_current.append(k[1][1]) cx_current.append(k[0][2]) cy_current.append(k[1][2]) extrinsics_current.append(w2c) assert all(np.equal(fx_current[0], fx_current[t]).all() for t in range(1, n_frames)) assert all(np.equal(fy_current[0], fy_current[t]).all() for t in range(1, n_frames)) assert all(np.equal(cx_current[0], cx_current[t]).all() for t in range(1, n_frames)) assert all(np.equal(cy_current[0], cy_current[t]).all() for t in range(1, n_frames)) assert all(np.equal(extrinsics_current[0], extrinsics_current[t]).all() for t in range(1, n_frames)) fx.append(fx_current[0]) fy.append(fy_current[0]) cx.append(cx_current[0]) cy.append(cy_current[0]) extrinsics.append(extrinsics_current[0]) k = np.eye(3).astype(np.float64)[None].repeat(n_views, 0) k[:, 0, 0] = fx k[:, 1, 1] = fy k[:, 0, 2] = cx k[:, 1, 2] = cy extrinsics = np.stack(extrinsics).astype(np.float64) k_inv = np.linalg.inv(k) extrinsics_inv = np.linalg.inv(extrinsics) # Render imgs and depths rgbs = np.stack([ np.stack([ np.array(Image.open(frame_paths[v][t])) for t in range(n_frames) ]) for v in range(n_views) ]) h, w = rgbs.shape[2], rgbs.shape[3] d3dgs_rgbs = [] d3dgs_depths = [] for v, view_idx in enumerate(views): for t in range(n_frames): im, depth = render(h, w, k[v], extrinsics[v], scene_data[t]) d3dgs_rgbs.append(im.cpu().numpy().transpose(1, 2, 0)) d3dgs_depths.append(depth.cpu().numpy()[0]) d3dgs_rgbs = np.stack(d3dgs_rgbs).reshape(n_views, n_frames, h, w, 3) d3dgs_depths = np.stack(d3dgs_depths).reshape(n_views, n_frames, h, w) assert rgbs.shape == (n_views, n_frames, h, w, 3) assert d3dgs_rgbs.shape == (n_views, n_frames, h, w, 3) assert d3dgs_depths.shape == (n_views, n_frames, h, w) # Merge TAP-Vid3D annotations merged_trajectories = [] merged_trajectories_pixelspace = [] merged_per_view_visibilities = [] merged_query_points_3d = [] for tapvid3d_annotation_path in tqdm(tapvid3d_annotation_paths): annotation = np.load(tapvid3d_annotation_path) queries_xyt = annotation["queries_xyt"] tracks_XYZ = annotation["tracks_XYZ"] visibility = annotation["visibility"] fx_fy_cx_cy = annotation["fx_fy_cx_cy"] images_jpeg_bytes = annotation["images_jpeg_bytes"] _, cam_id = os.path.basename(tapvid3d_annotation_path)[:-4].split("_") cam_id = int(cam_id) assert cam_id == views.index(cam_id) n_tracks, _ = queries_xyt.shape assert cam_id in views assert queries_xyt.shape == (n_tracks, 3) assert fx_fy_cx_cy.shape == (4,) assert images_jpeg_bytes.shape == (n_frames,) assert tracks_XYZ.shape == (n_frames, n_tracks, 3) assert visibility.shape == (n_frames, n_tracks) assert np.allclose(fx_fy_cx_cy, [fx[cam_id], fy[cam_id], cx[cam_id], cy[cam_id]]) # Project the tracks to the world space cam_coords_homo = to_homogeneous(tracks_XYZ) world_coords_homo = np.einsum("ij,SNj->SNi", extrinsics_inv[cam_id], cam_coords_homo) world_coords = from_homogeneous(world_coords_homo, assert_homogeneous_part_is_equal_to_1=True) # Project query points to 3D to verify we can reproduce the camera space points qp_t = queries_xyt[:, 2].astype(np.int32) qp_xy_pixel = queries_xyt[:, :2].astype(np.float32) qp_depth = np.ones((n_tracks, 1), dtype=np.float32) * np.inf qp_xyz_camera = np.ones((n_tracks, 3), dtype=np.float32) * np.inf qp_xyz_world = np.ones((n_tracks, 3), dtype=np.float32) * np.inf for t in range(n_frames): qp_mask = qp_t == t if qp_mask.sum() == 0: continue # V2 depth interpolation x_nearest = qp_xy_pixel[qp_mask, 0].round().astype(np.int32).clip(0, w - 1) y_nearest = qp_xy_pixel[qp_mask, 1].round().astype(np.int32).clip(0, h - 1) depth_nearest = d3dgs_depths[cam_id, t].reshape(-1)[ (y_nearest * w + x_nearest).reshape(-1)] depth_nearest = depth_nearest.reshape(-1, 1) qp_depth[qp_mask] = depth_nearest qp_xyz_pixel_t = np.concatenate([qp_xy_pixel[qp_mask], np.ones_like(qp_xy_pixel[qp_mask][..., :1])], axis=1) qp_xyz_camera_t = np.einsum("ij,Nj->Ni", k_inv[cam_id], qp_xyz_pixel_t) * qp_depth[qp_mask] qp_xyz_world_t = np.einsum("ij,Nj->Ni", extrinsics_inv[cam_id], np.concatenate([qp_xyz_camera_t, np.ones_like(qp_xyz_camera_t[..., :1])], axis=1))[:, :3] qp_xyz_camera[qp_mask] = qp_xyz_camera_t qp_xyz_world[qp_mask] = qp_xyz_world_t assert np.all(np.isfinite(qp_depth)) assert np.all(np.isfinite(qp_xyz_camera)) assert np.all(np.isfinite(qp_xyz_world)) # Verify that the query points are close to the tracks in the world space qp_projection_diff = np.linalg.norm( qp_xyz_camera - tracks_XYZ[queries_xyt[:, 2].astype(np.int32), np.arange(n_tracks)], axis=1) repro1 = np.percentile(qp_projection_diff, 80) < 1 repro2 = qp_projection_diff.mean() < 0.1 if not repro1 or not repro2: warnings.warn(f"Projecting query points to match tracks in camera space failed. " f"Differences: max={qp_projection_diff.max():0.3f}, " f"mean={qp_projection_diff.mean():0.3f}, " f"median={np.percentile(qp_projection_diff, 50):0.3f}, " f"p80={np.percentile(qp_projection_diff, 80):0.3f}") if assert_query_points_project_to_trajectories_in_tapvid3d_annotation: assert repro1 assert repro2 # Verify that the projected tracks are close to the query points in pixel space cam_coords_per_view = from_homogeneous(np.einsum("Vij,SNj->VSNi", extrinsics, world_coords_homo), True) pixel_coords_per_view = from_homogeneous(np.einsum("Vij,VSNj->VSNi", k, cam_coords_per_view)) diff = np.linalg.norm(qp_xy_pixel - pixel_coords_per_view[cam_id][qp_t, np.arange(n_tracks)], axis=-1) repro3 = np.percentile(diff, 80) < 0.1 # The xy pixel query from queries_xyz in the raw labels sometimes doesn't match the tracks_XYZ in camera space. # In the merged labels, we will not use the queries_xyz, but just directly work with the tracks_XYZ and their # projections (where pixel-space projections are needed). if not repro3: warnings.warn(f"Projecting tracks to pixel space to match query points failed. " f"Max diff: {diff.max()}. Mean diff: {diff.mean()}. Median diff: {np.percentile(diff, 50)}. " f"Percentile 80: {np.percentile(diff, 80)}.") if assert_query_points_project_to_trajectories_in_tapvid3d_annotation: assert repro3 # import matplotlib.pyplot as plt # plt.imshow(rgbs[v, qp_t[0]]) # plt.scatter(qp_xy_pixel[0, 0], qp_xy_pixel[0, 1], color="red") # plt.scatter(pixel_coords_per_view[cam_id][qp_t[0], 0, 0], pixel_coords_per_view[cam_id][qp_t[0], 0, 1], color="green") # plt.show() # Compute the distance from the trajectories to their nearest depthmap neighbors depthmap_nearest_neighbor_distance = np.ones((n_views, n_frames, n_tracks), dtype=np.float32) * np.inf k_inv_torch = torch.from_numpy(k_inv).cuda() extrinsics_inv_torch = torch.from_numpy(extrinsics_inv).cuda() pixel_coords_per_view_round_torch = torch.from_numpy(pixel_coords_per_view.round().astype(int)).cuda() world_coords_torch = torch.from_numpy(world_coords).cuda() for v, view_idx in enumerate(views): for t in range(n_frames): # Project depths to world space # Pixel --> Camera --> World pixel_xy = torch.stack(torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy"), dim=-1).cuda() pixel_xy = pixel_xy.type(k_inv_torch.dtype) pixel_xy_homo = torch.cat([pixel_xy, torch.ones_like(pixel_xy[..., :1])], dim=-1) depthmap_camera_xyz = torch.einsum("ij,hwj->hwi", k_inv_torch[v], pixel_xy_homo) depthmap_camera_xyz *= torch.tensor(d3dgs_depths[v, t], device="cuda", dtype=torch.float32)[..., None] depthmap_camera_xyz_homo = torch.cat( [depthmap_camera_xyz, torch.ones_like(depthmap_camera_xyz[..., :1])], dim=-1) depthmap_world_xyz_homo = torch.einsum("ij,hwj->hwi", extrinsics_inv_torch[v], depthmap_camera_xyz_homo) depthmap_world_xyz = depthmap_world_xyz_homo[..., :-1] / depthmap_world_xyz_homo[..., -1:] radius = 3 xmin = (pixel_coords_per_view_round_torch[v, t, :, 0] - radius).clip(min=0, max=w - 1 - 2 * radius) ymin = (pixel_coords_per_view_round_torch[v, t, :, 1] - radius).clip(min=0, max=h - 1 - 2 * radius) offsets = torch.arange(0, 2 * radius + 1, device="cuda") x_offsets, y_offsets = torch.meshgrid(offsets, offsets, indexing="ij") x_offsets = x_offsets.reshape(-1) y_offsets = y_offsets.reshape(-1) x_indices = (xmin[:, None] + x_offsets[None, :]).long() y_indices = (ymin[:, None] + y_offsets[None, :]).long() neighbors = depthmap_world_xyz[y_indices, x_indices] nearest_dist = torch.linalg.norm(neighbors - world_coords_torch[t][:, None, :], dim=-1).min(dim=-1)[0] depthmap_nearest_neighbor_distance[v, t, :] = nearest_dist.cpu().numpy() assert not np.isinf(depthmap_nearest_neighbor_distance).any() # Compute whether the projected trajectory is within the HxW frame of a view within_frame = ((pixel_coords_per_view[..., 0] >= 0) & (pixel_coords_per_view[..., 0] < w) & (pixel_coords_per_view[..., 1] >= 0) & (pixel_coords_per_view[..., 1] < h)) # If nearest neighbor in depth is less than X cm away, consider the point as visible in that view # Furthermore if the projected pixel space location is out of the frame, the point is not visible per_view_visibility = depthmap_nearest_neighbor_distance <= nearest_neighbor_distance_threshold_for_visibility per_view_visibility = per_view_visibility & within_frame valid_tracks_mask = (per_view_visibility[cam_id] == visibility).mean(0) > 0.7 valid_tracks_indices = np.where(valid_tracks_mask)[0] assert (per_view_visibility[cam_id] == visibility)[:, valid_tracks_mask].mean() > 0.8 query_points_3d_t = np.max(np.stack([qp_t, per_view_visibility[cam_id].argmax(0)], axis=1), axis=1) query_points_3d_xyz = world_coords[query_points_3d_t, np.arange(n_tracks)] query_points_3d = np.concatenate([query_points_3d_t[:, None], query_points_3d_xyz[:, :]], axis=1) merged_trajectories.append(world_coords[:, valid_tracks_indices, :]) merged_trajectories_pixelspace.append(pixel_coords_per_view[:, :, valid_tracks_indices, :]) merged_per_view_visibilities.append(per_view_visibility[:, :, valid_tracks_indices]) merged_query_points_3d.append(query_points_3d[valid_tracks_indices]) # print(f"VERBOSE LOGS: varying the distance threshold for cam_id={cam_id}") # for d in [0.001, 0.005, 0.01, 0.011, 0.012, 0.013, 0.014, 0.015, 0.016, 0.017, 0.019, # 0.020, 0.021, 0.022, 0.023, 0.024, 0.025, 0.026, 0.027, 0.028, 0.029, 0.030, 0.035, 0.04, 0.05]: # per_view_visibility = (depthmap_nearest_neighbor_distance <= d) & within_frame # print(f" --> dist={d:0.3f} " # f"v1={per_view_visibility[cam_id].mean() * 100:.1f} " # f"v2={visibility.mean() * 100:.1f} " # f"acc={(per_view_visibility[cam_id] == visibility).mean() * 100:.1f}") # per_view_visibility = depthmap_nearest_neighbor_distance <= nearest_neighbor_distance_threshold_for_visibility # per_view_visibility = per_view_visibility & within_frame # print(f"dist={nearest_neighbor_distance_threshold_for_visibility:0.3f} " # f"v1={per_view_visibility[cam_id].mean() * 100:.1f} " # f"v2={visibility.mean() * 100:.1f} " # f"acc={(per_view_visibility[cam_id] == visibility).mean() * 100:.1f}") # # if cam_id != 16: # continue # # rr.init("reconstruction", recording_id="v0.1") # rr.connect_tcp() # rr.log("/", rr.ViewCoordinates.LEFT_HAND_Y_DOWN, static=True) # rr.set_time_seconds("frame", 0) # rr.log("world/xyz", rr.Arrows3D(vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]], # colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]])) # # rr.log(f"debug/qp_xyz_camera", # rr.Points3D(world_coords[queries_xyt[:, 2].astype(np.int32), np.arange(n_tracks)], # colors=np.ones_like(qp_xyz_camera) * [0, 1, 0], radii=0.01)) # rr.log(f"debug/qp_xyz_camera_reproj", # rr.Points3D(qp_xyz_world, colors=np.ones_like(qp_xyz_camera) * [0, 0, 1], radii=0.01)) # strips = np.stack([world_coords[queries_xyt[:, 2].astype(np.int32), np.arange(n_tracks)], qp_xyz_world], axis=1) # rr.log("debug/qp_xyz_error_line", rr.LineStrips3D(strips=strips, colors=np.array([1., 0, 0]), radii=0.003)) # # seq = os.path.basename(scene_root) # for t in range(0, n_frames, rerun_log_n_skip_t): # for v in rerun_views_to_viz: # rr.set_time_seconds("frame", t / 30) # depth_values = d3dgs_depths[v, t].ravel() # valid_mask = depth_values > 0 # y, x = np.indices((h, w)) # homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T # cam_coords = (k_inv[v] @ homo_pixel_coords) * depth_values # cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1])))) # world_coords_ = (extrinsics_inv[v] @ cam_coords)[:3].T # world_coords_ = world_coords_[valid_mask] # rgb_colors = rgbs[v, t].reshape(-1, 3)[valid_mask].astype(np.uint8) # rr.log(f"{seq}/dyn-3dgs-point-cloud/view-{v}", # rr.Points3D(world_coords_, colors=rgb_colors, radii=0.004)) # cmap = matplotlib.colormaps["gist_rainbow"] # norm = matplotlib.colors.Normalize(vmin=world_coords[..., 0].min(), vmax=world_coords[..., 0].max()) # track_colors = cmap(norm(world_coords[-1, :, 0])) # log_tracks_to_rerun( # tracks=world_coords, # visibles=visibility, # query_timestep=np.zeros(n_tracks, dtype=np.int32), # colors=track_colors, # track_names=[f"track-{i:02d}" for i in range(n_tracks)], # entity_format_str=f"debug/tapvid3d-tracks-visGT/{{}}", # invisible_color=[0.3, 0.3, 0.3], # ) # log_tracks_to_rerun( # tracks=world_coords, # visibles=per_view_visibility[views.index(16)], # query_timestep=np.zeros(n_tracks, dtype=np.int32), # colors=track_colors, # track_names=[f"track-{i:02d}" for i in range(n_tracks)], # entity_format_str=f"debug/tapvid3d-tracks-vis16-v2/{{}}", # invisible_color=[0.3, 0.3, 0.3], # ) # log_tracks_to_rerun( # tracks=world_coords, # visibles=per_view_visibility[views.index(27)], # query_timestep=np.zeros(n_tracks, dtype=np.int32), # colors=track_colors, # track_names=[f"track-{i:02d}" for i in range(n_tracks)], # entity_format_str=f"debug/tapvid3d-tracks-vis27/{{}}", # invisible_color=[0.3, 0.3, 0.3], # ) # exit() merged_trajectories = np.concatenate(merged_trajectories, axis=1) merged_trajectories_pixelspace = np.concatenate(merged_trajectories_pixelspace, axis=2) merged_per_view_visibilities = np.concatenate(merged_per_view_visibilities, axis=2) merged_query_points_3d = np.concatenate(merged_query_points_3d, axis=0) # Remove duplicates from the merged trajectories from sklearn.cluster import DBSCAN flat_trajectories = merged_trajectories.transpose(1, 0, 2).reshape(-1, n_frames * 3) dbscan = DBSCAN(eps=0.01, min_samples=1, metric='euclidean') labels = dbscan.fit_predict(flat_trajectories) _, unique_indices = np.unique(labels, return_index=True) unique_indices = np.sort(unique_indices) merged_trajectories = merged_trajectories[:, unique_indices, :] merged_trajectories_pixelspace = merged_trajectories_pixelspace[:, :, unique_indices, :] merged_per_view_visibilities = merged_per_view_visibilities[:, :, unique_indices] merged_query_points_3d = merged_query_points_3d[unique_indices, :] n_tracks = merged_trajectories.shape[1] assert merged_trajectories.shape == (n_frames, n_tracks, 3) assert merged_trajectories_pixelspace.shape == (n_views, n_frames, n_tracks, 2) assert merged_per_view_visibilities.shape == (n_views, n_frames, n_tracks) assert merged_query_points_3d.shape == (n_tracks, 4) # Shuffle the tracks np.random.seed(72) track_perm = np.random.permutation(n_tracks) shuffled_trajectories = merged_trajectories[:, track_perm, :] shuffled_trajectories_pixelspace = merged_trajectories_pixelspace[:, :, track_perm, :] shuffled_per_view_visibilities = merged_per_view_visibilities[:, :, track_perm] shuffled_query_points_3d = merged_query_points_3d[track_perm, :] # Save the merged annotations np.savez( output_annotation_path, trajectories=shuffled_trajectories, trajectories_pixelspace=shuffled_trajectories_pixelspace, per_view_visibilities=shuffled_per_view_visibilities, query_points_3d=shuffled_query_points_3d, intrinsics=k, extrinsics=extrinsics, ) print(f"Saved merged annotations to {output_annotation_path}") if rerun_logging: rr.init("reconstruction", recording_id="v0.1") if rerun_stream_only: rr.connect_tcp() rr.set_time_seconds("frame", 0) rr.log("/", rr.ViewCoordinates.LEFT_HAND_Y_DOWN, static=True) rr.log("world/xyz", rr.Arrows3D(vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]], colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]])) seq = os.path.basename(scene_root) for t in range(0, n_frames, rerun_log_n_skip_t): for v in rerun_views_to_viz: rr.set_time_seconds("frame", t / 30) if rerun_log_rgb: rr.log(f"{seq}/rgb/view-{views[v]}/rgb", rr.Image(rgbs[v, t])) rr.log(f"{seq}/rgb/view-{views[v]}", rr.Pinhole(image_from_camera=k[v], width=w, height=h)) rr.log(f"{seq}/rgb/view-{views[v]}", rr.Transform3D(translation=extrinsics_inv[v, :3, 3], mat3x3=extrinsics_inv[v, :3, :3])) if rerun_log_d3dgs_rgb: rr.log(f"{seq}/dyn-3dgs-rgb/view-{views[v]}/rgb", rr.Image(d3dgs_rgbs[v, t])) rr.log(f"{seq}/dyn-3dgs-rgb/view-{views[v]}", rr.Pinhole(image_from_camera=k[v], width=w, height=h)) rr.log(f"{seq}/dyn-3dgs-rgb/view-{views[v]}", rr.Transform3D(translation=extrinsics_inv[v, :3, 3], mat3x3=extrinsics_inv[v, :3, :3])) if rerun_log_d3dgs_depth: rr.log(f"{seq}/dyn-3dgs-depth/view-{views[v]}/depth", rr.DepthImage(d3dgs_depths[v, t], point_fill_ratio=0.2)) rr.log(f"{seq}/dyn-3dgs-depth/view-{views[v]}", rr.Pinhole(image_from_camera=k[v], width=w, height=h)) rr.log(f"{seq}/dyn-3dgs-depth/view-{views[v]}", rr.Transform3D(translation=extrinsics_inv[v, :3, 3], mat3x3=extrinsics_inv[v, :3, :3])) if rerun_log_d3dgs_point_cloud: depth_values = d3dgs_depths[v, t].ravel() valid_mask = depth_values > 0 y, x = np.indices((h, w)) homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T cam_coords = (k_inv[v] @ homo_pixel_coords) * depth_values cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1])))) world_coords = (extrinsics_inv[v] @ cam_coords)[:3].T world_coords = world_coords[valid_mask] rgb_colors = rgbs[v, t].reshape(-1, 3)[valid_mask].astype(np.uint8) rr.log(f"{seq}/dyn-3dgs-point-cloud/view-{v}", rr.Points3D(world_coords, colors=rgb_colors, radii=0.004)) if rerun_log_tracks: raw_tracks = np.stack([data['means3D'][is_fg][::200].contiguous().cpu().numpy() for data in scene_data]) n_tracks_raw = raw_tracks.shape[1] cmap = matplotlib.colormaps["gist_rainbow"] norm = matplotlib.colors.Normalize(vmin=raw_tracks[..., 0].min(), vmax=raw_tracks[..., 0].max()) track_colors = cmap(norm(raw_tracks[-1, :, 0])) log_tracks_to_rerun( tracks=raw_tracks, visibles=np.ones((n_frames, n_tracks_raw), dtype=bool), query_timestep=np.zeros(n_tracks_raw, dtype=np.int32), colors=track_colors, track_names=[f"track-{i:02d}" for i in range(n_tracks_raw)], entity_format_str=f"{seq}/dyn-3dgs-raw-tracks/{{}}", invisible_color=[0.3, 0.3, 0.3], ) cmap = matplotlib.colormaps["gist_rainbow"] norm = matplotlib.colors.Normalize(vmin=shuffled_trajectories[..., 0].min(), vmax=shuffled_trajectories[..., 0].max()) track_colors = cmap(norm(shuffled_trajectories[-1, :, 0])) batch_size = 50 max_tracks = 500 for v in rerun_views_to_viz: for tracks_batch_start in range(0, max_tracks, batch_size): tracks_batch_end = min(tracks_batch_start + batch_size, n_tracks) log_tracks_to_rerun( tracks=shuffled_trajectories[:, tracks_batch_start:tracks_batch_end], visibles=shuffled_per_view_visibilities[v, :, tracks_batch_start:tracks_batch_end], query_timestep=shuffled_query_points_3d[:, 0][tracks_batch_start:tracks_batch_end].astype(int), colors=track_colors[tracks_batch_start:tracks_batch_end], track_names=[f"track-{i:02d}" for i in range(tracks_batch_start, tracks_batch_end)], entity_format_str=f"{seq}/tapvid3d-tracks/view-{v}-visiblity/{tracks_batch_start}-{tracks_batch_end}/{{}}", invisible_color=[0.3, 0.3, 0.3], ) if not rerun_stream_only: rr_rrd_path = scene_root / "rerun_tapvid3d_labels.rrd" rr.save(rr_rrd_path) print(f"Saved Rerun recording to: {os.path.abspath(rr_rrd_path)}") if __name__ == "__main__": print("Merging TAP-Vid3D per-camera annotations.") for sequence_name in tqdm(["basketball", "boxes", "football", "juggle", "softball", "tennis"]): scene_root = Path(f"./datasets/panoptic_d3dgs/{sequence_name}") checkpoint_path = Path(f"./dynamic3dgs/output/pretrained/{sequence_name}") tapvid3d_annotation_paths = list(Path(f"./datasets/tapvid3d_dataset/pstudio").glob(f"{sequence_name}_*.npz")) merge_annotations( scene_root, checkpoint_path, tapvid3d_annotation_paths, skip_if_output_already_exists=True, rerun_logging=True ) ================================================ FILE: mvtracker/models/core/dynamic3dgs/metadata_dexycb.py ================================================ import json import os from collections import defaultdict import numpy as np # Configurable parameters BASE_PATH = "." IMAGE_WIDTH = 640 IMAGE_HEIGHT = 480 SELECTED_CAMS = [0, 1, 2, 3] OUTPUT_NAME = "0123_metadata" # Filter sequences sequences = [f for f in os.listdir(BASE_PATH) if f.startswith("2020")] print(sequences) for sequence in sequences: sequence_path = os.path.join(BASE_PATH, sequence) view_folders = [ f for f in os.listdir(sequence_path) if f.startswith("view_") and f[-2:].isdigit() ] if not view_folders: continue example_view_path = os.path.join(sequence_path, view_folders[0]) frame_files = [ fname for fname in os.listdir(example_view_path) if fname.endswith(".png") and fname[:-4].isdigit() ] num_timesteps = len(frame_files) print(f"{sequence}: Found {num_timesteps} frames in {view_folders[0]}") combined_data = defaultdict( lambda: defaultdict( lambda: {"cam_id": 0, "w": 0, "h": 0, "k": [], "w2c": [], "fn": []} ) ) for time_step in range(num_timesteps): for view_folder in view_folders: view_folder_path = os.path.join(sequence_path, view_folder) if not os.path.exists(view_folder_path): print(f"Skipping {view_folder_path}") continue cam_id = int(view_folder[-2:]) if SELECTED_CAMS != [] and cam_id not in SELECTED_CAMS: continue data_path = os.path.join(view_folder_path, "intrinsics_extrinsics.npz") if not os.path.exists(data_path): print(f"Missing intrinsics_extrinsics.npz in {view_folder_path}") continue data = np.load(data_path) k = data["intrinsics"][:3, :3] w2c = data["extrinsics"][:3, :] w2c = np.vstack([w2c, np.array([0, 0, 0, 1])]) frame_name = f"{cam_id}/{str(time_step).zfill(5)}.png" cam_info = combined_data[time_step][str(cam_id)] cam_info["cam_id"] = cam_id cam_info["w"] = IMAGE_WIDTH cam_info["h"] = IMAGE_HEIGHT cam_info["k"] = k.tolist() cam_info["w2c"] = w2c.tolist() cam_info["fn"] = frame_name output_path = os.path.join(sequence_path, "metadata.json") with open(output_path, "w") as f: json.dump(dict(combined_data), f, indent=4) print(f"Saved metadata for {sequence}") ================================================ FILE: mvtracker/models/core/dynamic3dgs/metadata_kubric.py ================================================ import json import os from collections import defaultdict import kornia import numpy as np import torch BASE_PATH = "." IMAGE_WIDTH = 512 IMAGE_HEIGHT = 512 NUM_TIMESTEPS = 24 SELECTED_CAMS = [0, 1, 2, 3] OUTPUT_NAME = "0123_metadata" # Filter valid sequences sequences = [f for f in os.listdir(BASE_PATH)] for sequence in sequences: sequence_path = os.path.join(BASE_PATH, sequence) view_folders = [ f for f in os.listdir(sequence_path) if f.startswith("view_") and f[-1:].isdigit() ] combined_data = defaultdict( lambda: defaultdict( lambda: { "cam_id": 0, "w": 0, "h": 0, "k": [], "w2c": [], "fn": [], "sensor_width": 0, "focal_length": 0, } ) ) if not view_folders: continue first_valid_view = None for vf in view_folders: cam_id = int(vf[-1:]) if not SELECTED_CAMS or cam_id in SELECTED_CAMS: first_valid_view = vf break if first_valid_view is None: continue example_path = os.path.join(sequence_path, first_valid_view) all_frames = [ f for f in os.listdir(example_path) if f.endswith(".png") and f[:-4].isdigit() and f.startswith("rgba") ] num_timesteps = len(all_frames) for time_step in range(NUM_TIMESTEPS): for view_folder in view_folders: print(f"Processing {sequence}/{view_folder}, time step {time_step}") view_folder_path = os.path.join(sequence_path, view_folder) if not os.path.exists(view_folder_path): continue cam_id = int(view_folder[-1:]) if SELECTED_CAMS != [] and cam_id not in SELECTED_CAMS: continue with open(os.path.join(view_folder_path, "metadata.json"), "r") as f: data = json.load(f) cam_data = data["camera"] k = cam_data["K"] quaternions = torch.tensor(cam_data["quaternions"]) positions = torch.tensor(cam_data["positions"]) rot_matrices = kornia.geometry.quaternion_to_rotation_matrix(quaternions) ext_inv = torch.eye(4).repeat(NUM_TIMESTEPS, 1, 1) ext_inv[:, :3, :3] = rot_matrices ext_inv[:, :3, 3] = positions ext = ext_inv.inverse()[:, :3, :] ext = np.diag([1, -1, -1]) @ ext.numpy() w2c = ext[0].tolist() w2c.append([0, 0, 0, 1]) intrinsics = ( np.diag([IMAGE_WIDTH, IMAGE_HEIGHT, 1]) @ np.array(k) @ np.diag([1, -1, -1]) ) frame_name = f"{cam_id}/{str(time_step).zfill(5)}.png" cam_info = combined_data[time_step][str(cam_id)] cam_info["cam_id"] = cam_id cam_info["w"] = IMAGE_WIDTH cam_info["h"] = IMAGE_HEIGHT cam_info["k"] = intrinsics.tolist() cam_info["w2c"] = w2c cam_info["fn"] = frame_name cam_info["sensor_width"] = cam_data["sensor_width"] cam_info["focal_length"] = cam_data["focal_length"] output_path = os.path.join(sequence_path, f"{OUTPUT_NAME}.json") with open(output_path, "w") as f: json.dump(dict(combined_data), f, indent=4) print(f"Saved metadata for {sequence}") ================================================ FILE: mvtracker/models/core/dynamic3dgs/reorganize_dexycb.py ================================================ import os source_roots = [f for f in os.listdir(".") if f.startswith("2020")] import os import shutil source_roots = [f for f in os.listdir(".") if f.startswith("2020")] for source_root in source_roots: target_root = source_root ims_target = os.path.join(target_root, "ims") seg_target = os.path.join(target_root, "seg") depths_target = os.path.join(target_root, "depths") for target in [ims_target, seg_target, depths_target]: os.makedirs(target, exist_ok=True) for i in range(8): # view_00 to view_07 view_folder = os.path.join(source_root, f"view_{i:02d}") ims_source = os.path.join(view_folder, "rgb") ims_dest = os.path.join(ims_target, str(i)) if os.path.exists(ims_source): shutil.copytree(ims_source, ims_dest, dirs_exist_ok=True) mask_source = os.path.join(view_folder, "mask") seg_dest = os.path.join(seg_target, str(i)) if os.path.exists(mask_source): shutil.copytree(mask_source, seg_dest, dirs_exist_ok=True) depth_source = os.path.join(view_folder, "depth") depth_dest = os.path.join(depths_target, str(i)) if os.path.exists(depth_source): shutil.copytree(depth_source, depth_dest, dirs_exist_ok=True) print("Copying complete!") ================================================ FILE: mvtracker/models/core/dynamic3dgs/test.py ================================================ import json import os import numpy as np import torch import torchvision from PIL import Image from diff_gaussian_rasterization import GaussianRasterizer as Renderer from tqdm import tqdm from external import calc_psnr, calc_ssim from helpers import setup_camera TEST_CAMS = [0, 10, 15, 30] def load_saved_params(seq, exp): """Load saved parameters for testing.""" params_path = f"./output/{exp}/{seq}/params.npz" params = np.load(params_path) params = {k: torch.tensor(v).cuda().float() for k, v in params.items()} return params def prepare_test_dataset(t, md, seq, exclude_cam_ids): """Prepare dataset for the given timestep, excluding specific camera IDs.""" dataset = [] used_cam_ids = [] for c in range(len(md["fn"][t])): cam_id = md["cam_id"][t][c] # if cam_id in exclude_cam_ids: # continue # ONLY USE THE SPECIFIC CAMS if cam_id not in TEST_CAMS: continue w, h, k, w2c = md["w"], md["h"], md["k"][t][c], md["w2c"][t][c] cam = setup_camera(w, h, k, w2c, near=1.0, far=100) fn = md["fn"][t][c] im_path = f"./data/{seq}/ims/{fn}" im = np.array(Image.open(im_path)) / 255.0 im = torch.tensor(im).float().cuda().permute(2, 0, 1) dataset.append({"cam": cam, "im": im, "id": cam_id}) used_cam_ids.append(cam_id) return dataset, used_cam_ids def render_image(cam, rendervar): """Render an image using the given camera and render variables.""" with torch.no_grad(): im, _, _ = Renderer(raster_settings=cam)(**rendervar) return im def test(seq, exp, exclude_cam_ids=[]): """Test saved parameters on a dataset and report metrics.""" print(f"Testing sequence: {seq}, experiment: {exp}") # Load metadata and saved parameters md = json.load(open(f"./data/{seq}/test_meta.json", "r")) # metadata params = load_saved_params(seq, exp) # Prepare output paths render_path = f"./output/{exp}/{seq}/renders" results_path = f"./output/{exp}_metrics_test.csv" os.makedirs(render_path, exist_ok=True) if not os.path.exists(results_path): with open(results_path, "w") as f: f.write("Sequence,Experiment,Timestep,Camera ID,PSNR,SSIM\n") num_timesteps = len(md["fn"]) psnrs, ssims = [], [] used_cameras = [] for t in tqdm(range(num_timesteps), desc="Testing timesteps"): dataset, used_cam_ids = prepare_test_dataset(t, md, seq, exclude_cam_ids) used_cameras.extend(used_cam_ids) rendervar = { "means3D": params["means3D"][t], "colors_precomp": params["rgb_colors"][t], "rotations": torch.nn.functional.normalize(params["unnorm_rotations"][t]), "opacities": torch.sigmoid(params["logit_opacities"]), "scales": torch.exp(params["log_scales"]), "means2D": torch.zeros_like(params["means3D"][t], device="cuda"), } for camera in dataset: im_rendered = render_image(camera["cam"], rendervar) gt = camera["im"] # Save rendered and ground truth images idx = camera["id"] torchvision.utils.save_image( im_rendered, f"{render_path}/t{t:03d}_c{idx:02d}_rendered.png" ) torchvision.utils.save_image( gt, f"{render_path}/t{t:03d}_c{idx:02d}_gt.png" ) # Compute metrics psnr_val = calc_psnr(im_rendered, gt).mean().item() ssim_val = calc_ssim(im_rendered, gt).mean().item() psnrs.append(psnr_val) ssims.append(ssim_val) # Save metrics with open(results_path, "a") as f: f.write(f"{seq},{exp},{t},{idx},{psnr_val:.4f},{ssim_val:.4f}\n") print(f"Used cameras: {sorted(set(used_cameras))}") print(f"Average PSNR: {np.mean(psnrs):.4f}, Average SSIM: {np.mean(ssims):.4f}") if __name__ == "__main__": exp_name = "testing_init_pt" training_cam_ids = [1, 4, 7, 11, 17, 20, 23, 26, 29] # Cameras used during training # for sequence in ["basketball", "boxes", "football"]: for sequence in ["basketball"]: test(sequence, exp_name, exclude_cam_ids=training_cam_ids) ================================================ FILE: mvtracker/models/core/dynamic3dgs/track_2d.py ================================================ import json import os import numpy as np import torch from diff_gaussian_rasterization import GaussianRasterizer as Renderer from tqdm import tqdm from external import build_rotation from helpers import setup_camera REMOVE_BACKGROUND = False w, h = 640, 360 near, far = 0.01, 100.0 def gaussian_influence(point, gaussians): """ Computes the most influential Gaussian for a given 3D point. Args: point (torch.Tensor): 3D point (shape: [3]). gaussians (dict): Dictionary containing: - "means3D": [N, 3] Gaussian means. - "scales": [N, 3] Gaussian scales. - "opacities": [N, 1] Gaussian opacities. - "rotations": [N, 4] Gaussian quaternion rotations. Returns: int: Index of the most influential Gaussian. """ # print(f"Query point: {point}") means = gaussians["means3D"] # [N, 3] scales = gaussians["scales"] # [N, 3] opacities = gaussians["opacities"] # [N, 1] rotations = gaussians["rotations"] # [N, 4] sigmoid_opacities = opacities.squeeze() diff = point - means # [N, 3] R = build_rotation(rotations) # [N, 3, 3] S = torch.diag_embed(scales) # [N, 3, 3] cov = R @ S @ S.transpose(-1, -2) @ R.transpose(-1, -2) # [N, 3, 3] try: cov_inv = torch.inverse(cov) # [N, 3, 3] diff = diff.unsqueeze(1) # [N, 1, 3] # -1/2 * (x - mu)^T * cov^-1 * (x - mu) mahalanobis = ( -0.5 * torch.matmul( diff, torch.matmul(cov_inv, diff.transpose(-1, -2)) ).squeeze() ) # [N] # Gaussian influences influences = sigmoid_opacities * torch.exp(mahalanobis) # [N] most_influential_idx = torch.argmax(influences).item() return most_influential_idx except RuntimeError as e: print(f"Error in computation: {e}") return -1 def render_depth(timestep_data, w2c, k): """ Renders a depth map using the Gaussian parameters. Args: timestep_data (dict): Scene data for the specific timestep. Returns: torch.Tensor: Depth map. """ with torch.no_grad(): cam = setup_camera(w, h, k, w2c, near, far) ( im, _, depth, ) = Renderer(raster_settings=cam)(**timestep_data) if depth.dim() == 3 and depth.size(0) == 1: # Shape (1, H, W) depth = depth.squeeze(0) return depth def load_scene_data(seq, exp, seg_as_col=False): params = dict(np.load(f"./output/{exp}/{seq}/params.npz")) params = {k: torch.tensor(v).cuda().float() for k, v in params.items()} is_fg = params["seg_colors"][:, 0] > 0.5 scene_data = [] for t in range(len(params["means3D"])): rendervar = { "means3D": params["means3D"][t], "colors_precomp": params["rgb_colors"][t] if not seg_as_col else params["seg_colors"], "rotations": torch.nn.functional.normalize(params["unnorm_rotations"][t]), "opacities": torch.sigmoid(params["logit_opacities"]), "scales": torch.exp(params["log_scales"]), "means2D": torch.zeros_like(params["means3D"][0], device="cuda"), } if REMOVE_BACKGROUND: rendervar = {k: v[is_fg] for k, v in rendervar.items()} scene_data.append(rendervar) if REMOVE_BACKGROUND: is_fg = is_fg[is_fg] return ( scene_data, is_fg, ) def unproject_2d_to_3d(query_pt, depth_map, intrinsics): """ Unproject a 2D point to 3D. """ x, y = query_pt z = depth_map[y, x] fx, fy = intrinsics[0, 0], intrinsics[1, 1] cx, cy = intrinsics[0, 2], intrinsics[1, 2] X = (x - cx) * z / fx Y = (y - cy) * z / fy Z = z return torch.tensor([X, Y, Z], dtype=torch.float32).cuda() def load_camera_params(dataset_path, seq, cam_id_g): cam_params = f"{dataset_path}/{seq}/merged_by_timestamp.json" with open(cam_params, "r") as f: cam_params = json.load(f) for timestamp, cameras in cam_params.items(): for cam_id, cam_data in cameras.items(): if int(cam_id) == int(cam_id_g): return np.array(cam_data["w2c"]), np.array(cam_data["k"]) return None, None def c2w_convert(point_3d, w2c): point_3d_h = np.append(point_3d.cpu().numpy(), 1).reshape(4, 1) c2w = np.linalg.inv(w2c) point_cam = c2w @ point_3d_h return torch.tensor(point_cam[:3].flatten(), dtype=torch.float32).cuda() def w2c_convert(point_3d_h, w2c): point_3d = np.append(point_3d_h.cpu().numpy(), 1).reshape(4, 1) point_cam = w2c @ point_3d return torch.tensor(point_cam[:3].flatten(), dtype=torch.float32).cuda() def track_query_point(scene_data, query_point, depth_map, w2c, k, t_given=0): """ Tracks the 3D trajectory of a 2D query point across all frames. Args: scene_data (list): Scene data for all frames. query_point (tuple): Initial 2D query point (x, y). intrinsics (torch.Tensor): Camera intrinsics. t_start (int): Starting frame index. Returns: list: A list of 3D points (numpy arrays) across all timestamps. """ trajectory = [] opacities = [] point_3d = unproject_2d_to_3d(query_point, depth_map, k) point_3d_gaussian = c2w_convert(point_3d, w2c) gaussians = scene_data[t_given] gaussian_idx = gaussian_influence(point_3d_gaussian, gaussians) for t in range(0, len(scene_data)): gaussians = scene_data[t] gaussian = {k: v[gaussian_idx] for k, v in gaussians.items()} point_3d_gaussian = gaussian["means3D"] point_3d = w2c_convert(point_3d_gaussian, w2c) trajectory.append(point_3d) opacities.append(gaussian["opacities"]) return trajectory if __name__ == "__main__": exp = "exp_init_1-7-14-20" exp = "exp_merged_cleaned_pt_1-7-14-20" tapvid3d_dir = "./datasets/tapvid3d_dataset/pstudio" dataset_path = "./datasets/panoptic_d3dgs" # read the .npz files under directory npz_files = [ f for f in os.listdir(tapvid3d_dir) if f.endswith(".npz") and "basketball" in f and "_1." in f ] file_avg_distances = {} # for each .npz file, it has following naming: {seq}_{cam_id}.npz for npz_file in tqdm(npz_files): seq, cam_id = npz_file.split(".")[0].split("_") # load tapvid3d gt_file = f"{tapvid3d_dir}/{npz_file}" print(f"Loading {gt_file}") data = np.load(gt_file) print(data.files) queries_xyt = data["queries_xyt"] print("quries_xyt:", queries_xyt) gt_trajectories = data["tracks_XYZ"] trajectories = [] for query in tqdm(queries_xyt): # round to nearest integer q_x = round(query[0]) q_y = round(query[1]) query_point = (q_x, q_y) t_given = int(query[2]) - 1 # Load the scene data scene_data, _ = load_scene_data(seq, exp) w2c, k = load_camera_params(dataset_path, seq, cam_id) depth_map = render_depth(scene_data[t_given], w2c, k) # Track the query point across all timestamps trajectory = track_query_point( scene_data, query_point, depth_map, w2c, k, t_given=t_given ) trajectories.append(torch.stack(trajectory).cpu().numpy()) # save the trajectories # np.savez( # "{exp}_{seq}_{cam_id}_trajectories.npz", # trajectories=trajectories.cpu().numpy(), # ) # print(f"Trajectories for {seq}_{cam_id} saved.") distances = [] for i, query in enumerate(queries_xyt): t_given = int(query[2]) gt_traj = gt_trajectories[ :, i ] # Extract ground truth trajectory for this query exp_traj = trajectories[i] # Our computed trajectory # Compute Euclidean distances for each timestamp per_frame_distances = np.linalg.norm(gt_traj - exp_traj, axis=1) avg_distance = np.mean(per_frame_distances) sum_distance = np.sum(per_frame_distances) distances.append(avg_distance) print(f"avg distance for {npz_file}: {np.mean(distances)}") file_avg_distances[npz_file] = np.mean(distances) print("Average distances per file:") print(file_avg_distances) print("Overall average distance:", np.mean(list(file_avg_distances.values()))) ================================================ FILE: mvtracker/models/core/dynamic3dgs/track_3d.py ================================================ import os import cv2 import numpy as np import torch from tqdm import tqdm from external import build_rotation REMOVE_BACKGROUND = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") w, h = 512, 512 near, far = 0.01, 100.0 from mvtracker.evaluation.evaluator_3dpt import evaluate_3dpt def load_scene_data(seq, exp, seg_as_col=False): params = dict(np.load(f"./output/{exp}/{seq}/params.npz")) params = {k: torch.tensor(v, device=device).float() for k, v in params.items()} is_fg = params["seg_colors"][:, 0] > 0.5 scene_data = [] for t in range(len(params["means3D"])): rendervar = { "means3D": params["means3D"][t], "colors_precomp": params["rgb_colors"][t] if not seg_as_col else params["seg_colors"], "rotations": params["unnorm_rotations"][t], "opacities": torch.sigmoid(params["logit_opacities"]), "scales": torch.exp(params["log_scales"]), "means2D": torch.zeros_like(params["means3D"][0], device=device), } if REMOVE_BACKGROUND: rendervar = {k: v[is_fg] for k, v in rendervar.items()} scene_data.append(rendervar) if REMOVE_BACKGROUND: is_fg = is_fg[is_fg] return scene_data, is_fg def load_depth_maps(dataset_path, seq, cam_ids): depth_maps = {} for cam_id in cam_ids: depth_dir = f"{dataset_path}/{seq}/depths/{cam_id}/" depth_maps[cam_id] = [] for frame_idx in sorted(os.listdir(depth_dir)): depth_path = os.path.join(depth_dir, frame_idx) depth_map = ( cv2.imread(depth_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 1000.0 ) depth_maps[cam_id].append(torch.tensor(depth_map, device=device)) depth_maps[cam_id] = torch.stack(depth_maps[cam_id]) return depth_maps def preload_camera_data(dataset_path, seq, cam_ids): cam_params_path = f"{dataset_path}/{seq}/metadata.json" with open(cam_params_path, "r") as f: cam_params = json.load(f) preloaded_cameras = {} for cam_id in cam_ids: for timestamp, cameras in cam_params.items(): if str(cam_id) in cameras: preloaded_cameras[cam_id] = ( torch.tensor( cameras[str(cam_id)]["w2c"], dtype=torch.float32 ).cuda(), torch.tensor(cameras[str(cam_id)]["k"], dtype=torch.float32).cuda(), ) break # We only need one instance per camera return preloaded_cameras def gaussian_influence(point, gaussians): """ Computes the most influential Gaussian for a given 3D point. Args: point (torch.Tensor): 3D point (shape: [3]). gaussians (dict): Dictionary containing: - "means3D": [N, 3] Gaussian means. - "scales": [N, 3] Gaussian scales. - "opacities": [N, 1] Gaussian opacities. - "rotations": [N, 4] Gaussian quaternion rotations. Returns: int: Index of the most influential Gaussian. """ # print(f"Query point: {point}") means = gaussians["means3D"] # [N, 3] scales = gaussians["scales"] # [N, 3] opacities = gaussians["opacities"] # [N, 1] rotations = gaussians["rotations"] # [N, 4] sigmoid_opacities = opacities.squeeze() diff = point - means # [N, 3] R = build_rotation(rotations) # [N, 3, 3] S = torch.diag_embed(scales) # [N, 3, 3] cov = R @ S @ S.transpose(-1, -2) @ R.transpose(-1, -2) # [N, 3, 3] try: cov_inv = torch.inverse(cov) # [N, 3, 3] diff = diff.unsqueeze(1) # [N, 1, 3] # -1/2 * (x - mu)^T * cov^-1 * (x - mu) mahalanobis = ( -0.5 * torch.matmul( diff, torch.matmul(cov_inv, diff.transpose(-1, -2)) ).squeeze() ) # [N] # Gaussian influences influences = sigmoid_opacities * torch.exp(mahalanobis) # [N] most_influential_idx = torch.argmax(influences).item() print("Most influnce:", influences[most_influential_idx]) return most_influential_idx, influences[most_influential_idx] except RuntimeError as e: print(f"Error in computation: {e}") return -1 def get_visibilities( point_3d, cam_ids, t, depth_maps, preloaded_cameras, th=0.02, ): visibilities = [] for cam_id in cam_ids: if cam_id not in preloaded_cameras: continue w2c, intrinsics = preloaded_cameras[cam_id] point_cam = torch.matmul( w2c, torch.cat([point_3d, torch.tensor([1.0], device=point_3d.device)]) )[:3] X, Y, Z = point_cam if Z <= 0: continue x = int((X * intrinsics[0, 0]) / Z + intrinsics[0, 2]) y = int((Y * intrinsics[1, 1]) / Z + intrinsics[1, 2]) if not ( 0 <= x < depth_maps[cam_id].shape[2] and 0 <= y < depth_maps[cam_id].shape[1] ): continue depth_at_pixel = depth_maps[cam_id][t, y, x] depth_diff = Z - depth_at_pixel visibilities.append(0 <= depth_diff <= th) return visibilities def track_query_point( scene_data, query_point, cam_ids, t_given, depth_maps, preloaded_cameras, threshold=0.02, ): """ Tracks the 3D trajectory of a 3D query point across all frames. Args: scene_data (list): Scene data for all frames. query_point (tuple): Initial 2D query point (x, y). intrinsics (torch.Tensor): Camera intrinsics. t_start (int): Starting frame index. Returns: list: A list of 3D points (numpy arrays) across all timestamps. """ trajectory = [] visibilities = [] gaussians = scene_data[t_given] gaussian_idx, influence = gaussian_influence(query_point, gaussians) for t in range(0, len(scene_data)): gaussians = scene_data[t] gaussian = {k: v[gaussian_idx] for k, v in gaussians.items()} point_3d_gaussian = gaussian["means3D"] trajectory.append(point_3d_gaussian) visibility = get_visibilities( point_3d_gaussian, cam_ids, t, depth_maps, preloaded_cameras, threshold ) visibilities.append(torch.tensor(visibility)) # print ratio of visibilities for each cam: visibity has shape n_frames * cam # print("Visibility ratio for each camera:") # print(np.array(visibility).sum(axis=0) / len(visibility)) return trajectory, visibilities if __name__ == "__main__": exp = "exp_use_duster_views_0123" sequences = [ "20200709-subject-01__20200709_141754", "20200813-subject-02__20200813_145653", "20200903-subject-04__20200903_104428", "20200820-subject-03__20200820_135841", "20200908-subject-05__20200908_144409", "20200918-subject-06__20200918_114117", "20200928-subject-07__20200928_144906", "20201002-subject-08__20201002_110227", "20201015-subject-09__20201015_144721", "20201022-subject-10__20201022_112651", ] dataset_path = "./datasets/dex_formatted/neus_nsubsample-3" remove_hand = False use_duster = True cleaned_duster = False views = "0123" tracks_path = f"seed-000072_remove-hand-{remove_hand}_tracks-384_use-duster-depths-{use_duster}_clean-duster-depths-{cleaned_duster}_views-{views}_duster-views-{views}.npz" # sequences = ["basketball"] # cam_ids = [27, 16, 14, 8] # cam_ids = [0, 1, 2, 3, 4, 5, 6, 7] cam_ids = [0, 1, 2, 3] for seq in sequences: merged_path = f"{dataset_path}/{seq}/{tracks_path}" # Load scene data scene_data, is_fg = load_scene_data(seq, exp, s=1) # scene_data = [] print("Scene data loaded.") depth_maps = load_depth_maps(dataset_path, seq, cam_ids) preloaded_cameras = preload_camera_data(dataset_path, seq, cam_ids) load_tapvid3d = np.load(merged_path) query_points = load_tapvid3d["query_points_3d"] predictions_file = f"./output/{exp}/{seq}/predictions.npz" if True: THRESHOLD = 0.02 predictions = [] visibilities = [] for i, query_point in tqdm(enumerate(query_points), desc="Query points"): # print("Query point:", query_point) given_time = query_point[0] # to int # query_point = query_point.astype(int) given_time = int(given_time) qp = query_point[1:] # convert it to Torch tensor # torch.tensor([X, Y, Z], dtype=torch.float32).cuda() qp = torch.tensor(qp, dtype=torch.float32).cuda() trajectory, visiblity_d = track_query_point( scene_data, qp, cam_ids, given_time, depth_maps, preloaded_cameras, THRESHOLD, ) # trajectory = trajectory.cpu().numpy() predictions.append(torch.stack(trajectory).cpu().numpy()) visibilities.append(torch.stack(visiblity_d).cpu().numpy()) # pred shape is: n_queries, n_frames, 3 # convert it to n_frames, n_queries, 3 predictions = np.array(predictions) predictions = np.transpose(predictions, (1, 0, 2)) visibilities = np.array(visibilities) visibilities = np.transpose(visibilities, (2, 1, 0)) preds_file = f"./output/{exp}/{seq}/predictions_threshold_{THRESHOLD}.npz" np.savez( preds_file, predictions=predictions, visibilities=visibilities, ) print(f"Results saved for threshold {THRESHOLD} at: {preds_file}") # Load the ground truth query_points = load_tapvid3d["query_points_3d"] query_points = query_points[None, ...] # batch * num tracks * 4 gt_visibilities = load_tapvid3d["per_view_visibilities"] gt_visibilities = gt_visibilities[ None, ... ] # batch * view * num frames * num tracks # convert all of them to false gt_tracks = load_tapvid3d["trajectories"] gt_tracks = gt_tracks[None, ...] # batch * num frames * num tracks * 3 # pred_visibilities = visibilities[None, ...] # pred_visibilities_t = visibilities_i[None, ...] pred_tracks = predictions[None, ...] # print all dimensions for debugging print("query_points:", query_points.shape) print("gt_occluded:", gt_visibilities.shape) print("gt_tracks:", gt_tracks.shape) print("pred_occluded:", gt_visibilities.shape) print("pred_tracks:", pred_tracks.shape) gt_visibilities_any_view = gt_visibilities.any(axis=1) pred_visibilities = visibilities[None, ...] pred_visibilities_any_view = pred_visibilities.any(axis=1) print("EXP: ", exp) print("SEQ: ", seq) print("Evaluating ... ") metrics_2 = evaluate_3dpt( gt_tracks[0], gt_visibilities_any_view[0], pred_tracks[0], pred_visibilities_any_view[0], evaluation_setting="dex-ycb-multiview", query_points=query_points[0], track_upscaling_factor=1, verbose=True, ) # Save evaluation results results_file = f"./output/{exp}/{seq}/results_threshold_{THRESHOLD}.txt" with open(results_file, "w") as f: f.write(f"Exp: {exp}\n") f.write(f"Seq: {seq}\n") f.write(f"Threshold: {THRESHOLD}\n") f.write(str(metrics_2)) f.write("\n") print(f"Results saved at: {results_file}") print("Done.") ================================================ FILE: mvtracker/models/core/dynamic3dgs/train.py ================================================ import copy import json import os from random import randint import numpy as np import torch from PIL import Image from diff_gaussian_rasterization import GaussianRasterizer as Renderer from tqdm import tqdm from external import calc_ssim, calc_psnr, build_rotation, densify, update_params_and_optimizer from helpers import setup_camera, l1_loss_v1, l1_loss_v2, weighted_l2_loss_v1, weighted_l2_loss_v2, quat_mult, \ o3d_knn, params2rendervar, params2cpu, save_params def get_dataset(t, md, seq): dataset = [] for c in range(len(md['fn'][t])): w, h, k, w2c = md['w'], md['h'], md['k'][t][c], md['w2c'][t][c] cam = setup_camera(w, h, k, w2c, near=1.0, far=100) fn = md['fn'][t][c] im = np.array(copy.deepcopy(Image.open(f"./data/{seq}/ims/{fn}"))) im = torch.tensor(im).float().cuda().permute(2, 0, 1) / 255 seg = np.array(copy.deepcopy(Image.open(f"./data/{seq}/seg/{fn.replace('.jpg', '.png')}"))).astype(np.float32) seg = torch.tensor(seg).float().cuda() seg_col = torch.stack((seg, torch.zeros_like(seg), 1 - seg)) dataset.append({'cam': cam, 'im': im, 'seg': seg_col, 'id': c}) return dataset def get_batch(todo_dataset, dataset): if not todo_dataset: todo_dataset = dataset.copy() curr_data = todo_dataset.pop(randint(0, len(todo_dataset) - 1)) return curr_data def initialize_params(seq, md): init_pt_cld = np.load(f"./data/{seq}/init_pt_cld.npz")["data"] seg = init_pt_cld[:, 6] max_cams = 50 sq_dist, _ = o3d_knn(init_pt_cld[:, :3], 3) mean3_sq_dist = sq_dist.mean(-1).clip(min=0.0000001) params = { 'means3D': init_pt_cld[:, :3], 'rgb_colors': init_pt_cld[:, 3:6], 'seg_colors': np.stack((seg, np.zeros_like(seg), 1 - seg), -1), 'unnorm_rotations': np.tile([1, 0, 0, 0], (seg.shape[0], 1)), 'logit_opacities': np.zeros((seg.shape[0], 1)), 'log_scales': np.tile(np.log(np.sqrt(mean3_sq_dist))[..., None], (1, 3)), 'cam_m': np.zeros((max_cams, 3)), 'cam_c': np.zeros((max_cams, 3)), } params = {k: torch.nn.Parameter(torch.tensor(v).cuda().float().contiguous().requires_grad_(True)) for k, v in params.items()} cam_centers = np.linalg.inv(md['w2c'][0])[:, :3, 3] # Get scene radius scene_radius = 1.1 * np.max(np.linalg.norm(cam_centers - np.mean(cam_centers, 0)[None], axis=-1)) variables = {'max_2D_radius': torch.zeros(params['means3D'].shape[0]).cuda().float(), 'scene_radius': scene_radius, 'means2D_gradient_accum': torch.zeros(params['means3D'].shape[0]).cuda().float(), 'denom': torch.zeros(params['means3D'].shape[0]).cuda().float()} return params, variables def initialize_optimizer(params, variables): lrs = { 'means3D': 0.00016 * variables['scene_radius'], 'rgb_colors': 0.0025, 'seg_colors': 0.0, 'unnorm_rotations': 0.001, 'logit_opacities': 0.05, 'log_scales': 0.001, 'cam_m': 1e-4, 'cam_c': 1e-4, } param_groups = [{'params': [v], 'name': k, 'lr': lrs[k]} for k, v in params.items()] return torch.optim.Adam(param_groups, lr=0.0, eps=1e-15) def get_loss(params, curr_data, variables, is_initial_timestep): losses = {} rendervar = params2rendervar(params) rendervar['means2D'].retain_grad() im, radius, _, = Renderer(raster_settings=curr_data['cam'])(**rendervar) curr_id = curr_data['id'] im = torch.exp(params['cam_m'][curr_id])[:, None, None] * im + params['cam_c'][curr_id][:, None, None] losses['im'] = 0.8 * l1_loss_v1(im, curr_data['im']) + 0.2 * (1.0 - calc_ssim(im, curr_data['im'])) variables['means2D'] = rendervar['means2D'] # Gradient only accum from colour render for densification segrendervar = params2rendervar(params) segrendervar['colors_precomp'] = params['seg_colors'] seg, _, _, = Renderer(raster_settings=curr_data['cam'])(**segrendervar) losses['seg'] = 0.8 * l1_loss_v1(seg, curr_data['seg']) + 0.2 * (1.0 - calc_ssim(seg, curr_data['seg'])) if not is_initial_timestep: is_fg = (params['seg_colors'][:, 0] > 0.5).detach() fg_pts = rendervar['means3D'][is_fg] fg_rot = rendervar['rotations'][is_fg] rel_rot = quat_mult(fg_rot, variables["prev_inv_rot_fg"]) rot = build_rotation(rel_rot) neighbor_pts = fg_pts[variables["neighbor_indices"]] curr_offset = neighbor_pts - fg_pts[:, None] curr_offset_in_prev_coord = (rot.transpose(2, 1)[:, None] @ curr_offset[:, :, :, None]).squeeze(-1) losses['rigid'] = weighted_l2_loss_v2(curr_offset_in_prev_coord, variables["prev_offset"], variables["neighbor_weight"]) losses['rot'] = weighted_l2_loss_v2(rel_rot[variables["neighbor_indices"]], rel_rot[:, None], variables["neighbor_weight"]) curr_offset_mag = torch.sqrt((curr_offset ** 2).sum(-1) + 1e-20) losses['iso'] = weighted_l2_loss_v1(curr_offset_mag, variables["neighbor_dist"], variables["neighbor_weight"]) losses['floor'] = torch.clamp(fg_pts[:, 1], min=0).mean() bg_pts = rendervar['means3D'][~is_fg] bg_rot = rendervar['rotations'][~is_fg] losses['bg'] = l1_loss_v2(bg_pts, variables["init_bg_pts"]) + l1_loss_v2(bg_rot, variables["init_bg_rot"]) losses['soft_col_cons'] = l1_loss_v2(params['rgb_colors'], variables["prev_col"]) loss_weights = {'im': 1.0, 'seg': 3.0, 'rigid': 4.0, 'rot': 4.0, 'iso': 2.0, 'floor': 2.0, 'bg': 20.0, 'soft_col_cons': 0.01} loss = sum([loss_weights[k] * v for k, v in losses.items()]) seen = radius > 0 variables['max_2D_radius'][seen] = torch.max(radius[seen], variables['max_2D_radius'][seen]) variables['seen'] = seen return loss, variables def initialize_per_timestep(params, variables, optimizer): pts = params['means3D'] rot = torch.nn.functional.normalize(params['unnorm_rotations']) new_pts = pts + (pts - variables["prev_pts"]) new_rot = torch.nn.functional.normalize(rot + (rot - variables["prev_rot"])) is_fg = params['seg_colors'][:, 0] > 0.5 prev_inv_rot_fg = rot[is_fg] prev_inv_rot_fg[:, 1:] = -1 * prev_inv_rot_fg[:, 1:] fg_pts = pts[is_fg] prev_offset = fg_pts[variables["neighbor_indices"]] - fg_pts[:, None] variables['prev_inv_rot_fg'] = prev_inv_rot_fg.detach() variables['prev_offset'] = prev_offset.detach() variables["prev_col"] = params['rgb_colors'].detach() variables["prev_pts"] = pts.detach() variables["prev_rot"] = rot.detach() new_params = {'means3D': new_pts, 'unnorm_rotations': new_rot} params = update_params_and_optimizer(new_params, params, optimizer) return params, variables def initialize_post_first_timestep(params, variables, optimizer, num_knn=20): is_fg = params['seg_colors'][:, 0] > 0.5 init_fg_pts = params['means3D'][is_fg] init_bg_pts = params['means3D'][~is_fg] init_bg_rot = torch.nn.functional.normalize(params['unnorm_rotations'][~is_fg]) neighbor_sq_dist, neighbor_indices = o3d_knn(init_fg_pts.detach().cpu().numpy(), num_knn) neighbor_weight = np.exp(-2000 * neighbor_sq_dist) neighbor_dist = np.sqrt(neighbor_sq_dist) variables["neighbor_indices"] = torch.tensor(neighbor_indices).cuda().long().contiguous() variables["neighbor_weight"] = torch.tensor(neighbor_weight).cuda().float().contiguous() variables["neighbor_dist"] = torch.tensor(neighbor_dist).cuda().float().contiguous() variables["init_bg_pts"] = init_bg_pts.detach() variables["init_bg_rot"] = init_bg_rot.detach() variables["prev_pts"] = params['means3D'].detach() variables["prev_rot"] = torch.nn.functional.normalize(params['unnorm_rotations']).detach() params_to_fix = ['logit_opacities', 'log_scales', 'cam_m', 'cam_c'] for param_group in optimizer.param_groups: if param_group["name"] in params_to_fix: param_group['lr'] = 0.0 return variables def report_progress(params, data, i, progress_bar, every_i=100): if i % every_i == 0: im, _, _, = Renderer(raster_settings=data['cam'])(**params2rendervar(params)) curr_id = data['id'] im = torch.exp(params['cam_m'][curr_id])[:, None, None] * im + params['cam_c'][curr_id][:, None, None] psnr = calc_psnr(im, data['im']).mean() progress_bar.set_postfix({"train img 0 PSNR": f"{psnr:.{7}f}"}) progress_bar.update(every_i) def train(seq, exp): if os.path.exists(f"./output/{exp}/{seq}"): print(f"Experiment '{exp}' for sequence '{seq}' already exists. Exiting.") return md = json.load(open(f"./data/{seq}/train_meta.json", 'r')) # metadata num_timesteps = len(md['fn']) params, variables = initialize_params(seq, md) optimizer = initialize_optimizer(params, variables) output_params = [] for t in range(num_timesteps): dataset = get_dataset(t, md, seq) todo_dataset = [] is_initial_timestep = (t == 0) if not is_initial_timestep: params, variables = initialize_per_timestep(params, variables, optimizer) num_iter_per_timestep = 10000 if is_initial_timestep else 2000 progress_bar = tqdm(range(num_iter_per_timestep), desc=f"timestep {t}") for i in range(num_iter_per_timestep): curr_data = get_batch(todo_dataset, dataset) loss, variables = get_loss(params, curr_data, variables, is_initial_timestep) loss.backward() with torch.no_grad(): report_progress(params, dataset[0], i, progress_bar) if is_initial_timestep: params, variables = densify(params, variables, optimizer, i) optimizer.step() optimizer.zero_grad(set_to_none=True) progress_bar.close() output_params.append(params2cpu(params, is_initial_timestep)) if is_initial_timestep: variables = initialize_post_first_timestep(params, variables, optimizer) save_params(output_params, seq, exp) if __name__ == "__main__": exp_name = "exp1" for sequence in ["basketball", "boxes", "football", "juggle", "softball", "tennis"]: train(sequence, exp_name) torch.cuda.empty_cache() ================================================ FILE: mvtracker/models/core/dynamic3dgs/visualize.py ================================================ import json import os from pathlib import Path import matplotlib import numpy as np import rerun as rr import torch from PIL import Image from diff_gaussian_rasterization import GaussianRasterizer as Renderer from .helpers import setup_camera RENDER_MODE = 'color' # 'color', 'depth' or 'centers' # RENDER_MODE = 'depth' # 'color', 'depth' or 'centers' # RENDER_MODE = 'centers' # 'color', 'depth' or 'centers' REMOVE_BACKGROUND = False # False or True # REMOVE_BACKGROUND = True # False or True FORCE_LOOP = False # False or True # FORCE_LOOP = True # False or True w, h = 640, 360 near, far = 0.01, 100.0 traj_frac = 200 # 0.5% of points # VIEWS = [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29] VIEWS = [1, 14] log_rgb = True log_d3dgs_rgb = False log_d3dgs_depth = False log_d3dgs_point_cloud = True log_tracks = True log_n_skip_view = 1 log_n_skip_t = 1 def load_scene_data(params_path, seg_as_col=False): """Load 3D scene data from file.""" params = dict(np.load(params_path, allow_pickle=True)) params = {k: torch.tensor(v).cuda().float() for k, v in params.items()} is_fg = params['seg_colors'][:, 0] > 0.5 scene_data = [] for t in range(len(params['means3D'])): rendervar = { 'means3D': params['means3D'][t], 'colors_precomp': params['rgb_colors'][t] if not seg_as_col else params['seg_colors'], 'rotations': torch.nn.functional.normalize(params['unnorm_rotations'][t]), 'opacities': torch.sigmoid(params['logit_opacities']), 'scales': torch.exp(params['log_scales']), 'means2D': torch.zeros_like(params['means3D'][0], device="cuda") } if REMOVE_BACKGROUND: rendervar = {k: v[is_fg] for k, v in rendervar.items()} scene_data.append(rendervar) if REMOVE_BACKGROUND: is_fg = is_fg[is_fg] return scene_data, is_fg def render(w2c, k, timestep_data): """Render scene using Gaussian Rasterization.""" with torch.no_grad(): cam = setup_camera(w, h, k, w2c, near, far) im, _, depth = Renderer(raster_settings=cam)(**timestep_data) return im, depth def log_tracks_to_rerun( tracks: np.ndarray, visibles: np.ndarray, query_timestep: np.ndarray, colors: np.ndarray, track_names=None, entity_format_str="{}", log_points=True, points_radii=0.01, invisible_color=[0., 0., 0.], log_line_strips=True, max_strip_length_past=30, max_strip_length_future=1, strips_radii=0.001, log_error_lines=False, error_lines_radii=0.0042, error_lines_color=[1., 0., 0.], gt_for_error_lines=None, fps=30, ) -> None: """ Log tracks to Rerun. Parameters: tracks: Shape (T, N, 3), the 3D trajectories of points. visibles: Shape (T, N), boolean visibility mask for each point at each timestep. query_timestep: Shape (T, N), the frame index after which the tracks start. colors: Shape (N, 4), RGBA colors for each point. entity_prefix: String prefix for entity hierarchy in Rerun. entity_suffix: String suffix for entity hierarchy in Rerun. """ T, N, _ = tracks.shape assert tracks.shape == (T, N, 3) assert visibles.shape == (T, N) assert query_timestep.shape == (N,) assert query_timestep.min() >= 0 assert query_timestep.max() < T assert colors.shape == (N, 4) for n in range(N): track_name = track_names[n] if track_names is not None else f"track-{n}" rr.log(entity_format_str.format(track_name), rr.Clear(recursive=True)) for t in range(query_timestep[n], T): rr.set_time_seconds("frame", t / fps) # Log the point (special handling for invisible points) if log_points: rr.log( entity_format_str.format(f"{track_name}/point"), rr.Points3D( positions=[tracks[t, n]], colors=[colors[n, :3]] if visibles[t, n] else [invisible_color], radii=points_radii, ), ) # Log line segments for visible tracks if log_line_strips and t > query_timestep[n]: strip_t_start = max(t - max_strip_length_past, query_timestep[n].item()) strip_t_end = min(t + max_strip_length_future, T - 1) strips = np.stack([ tracks[strip_t_start:strip_t_end, n], tracks[strip_t_start + 1:strip_t_end + 1, n], ], axis=-2) strips_visibility = visibles[strip_t_start + 1:strip_t_end + 1, n] strips_colors = np.where( strips_visibility[:, None], colors[None, n, :3], [invisible_color], ) rr.log( entity_format_str.format(f"{track_name}/line"), rr.LineStrips3D(strips=strips, colors=strips_colors, radii=strips_radii), ) if log_error_lines: assert gt_for_error_lines is not None strips = np.stack([ tracks[t, n], gt_for_error_lines[t, n], ], axis=-2) rr.log( entity_format_str.format(f"{track_name}/error"), rr.LineStrips3D(strips=strips, colors=error_lines_color, radii=error_lines_radii), ) def visualize(seq, exp): """Visualize 3D Gaussian Splatting using Rerun.""" scene_root = Path(f"../datasets/panoptic_d3dgs/{seq}") output_root = Path(f"./output/{exp}/{seq}") scene_data, is_fg = load_scene_data(os.path.join(output_root, "params.npz")) md = json.load(open(os.path.join(scene_root, "train_meta.json"), "r")) n_frames = len(md['fn']) n_views = len(VIEWS) # Check that the selected views are in the training set view_paths = [] for view_idx in VIEWS: view_path = scene_root / "ims" / f"{view_idx}" assert view_idx in md["cam_id"][0], f"Camera {view_idx} is not in the training set" assert view_path.exists() view_paths.append(view_path) frame_paths = [sorted(view_path.glob("*.jpg")) for view_path in view_paths] assert all(len(frame_paths[v]) == n_frames for v in range(len(VIEWS))) assert len(scene_data) == n_frames # Create the output directory views_selection_str = '-'.join(str(v) for v in VIEWS) output_path = scene_root / f'dynamic3dgs-views-{views_selection_str}' os.makedirs(output_path, exist_ok=True) # Load the camera parameters fx, fy, cx, cy, extrinsics = [], [], [], [], [] for view_idx in VIEWS: fx_current, fy_current, cx_current, cy_current, extrinsics_current = [], [], [], [], [] for t in range(n_frames): view_idx_in_array = md['cam_id'][t].index(view_idx) k = md['k'][t][view_idx_in_array] w2c = np.array(md['w2c'][t][view_idx_in_array]) fx_current.append(k[0][0]) fy_current.append(k[1][1]) cx_current.append(k[0][2]) cy_current.append(k[1][2]) extrinsics_current.append(w2c) assert all(np.equal(fx_current[0], fx_current[t]).all() for t in range(1, n_frames)) assert all(np.equal(fy_current[0], fy_current[t]).all() for t in range(1, n_frames)) assert all(np.equal(cx_current[0], cx_current[t]).all() for t in range(1, n_frames)) assert all(np.equal(cy_current[0], cy_current[t]).all() for t in range(1, n_frames)) assert all(np.equal(extrinsics_current[0], extrinsics_current[t]).all() for t in range(1, n_frames)) fx.append(fx_current[0]) fy.append(fy_current[0]) cx.append(cx_current[0]) cy.append(cy_current[0]) extrinsics.append(extrinsics_current[0]) fx = torch.tensor(fx).float() fy = torch.tensor(fy).float() cx = torch.tensor(cx).float() cy = torch.tensor(cy).float() k = torch.eye(3).float()[None].repeat(n_views, 1, 1) k[:, 0, 0] = fx k[:, 1, 1] = fy k[:, 0, 2] = cx k[:, 1, 2] = cy extrinsics = torch.from_numpy(np.stack(extrinsics)).float() k_inv = torch.inverse(k) extrinsics_inv = torch.inverse(extrinsics) # Render the depths rgbs = np.stack([ np.stack([ np.array(Image.open(frame_paths[v][t])) for t in range(n_frames) ]) for v in range(n_views) ]) h, w = rgbs.shape[2], rgbs.shape[3] d3dgs_rgbs = [] d3dgs_depths = [] for v, view_idx in enumerate(VIEWS): for t in range(n_frames): im, depth = render(extrinsics[v].numpy(), k[v].numpy(), scene_data[t]) d3dgs_rgbs.append(im.cpu().numpy().transpose(1, 2, 0)) d3dgs_depths.append(depth.cpu().numpy()[0]) d3dgs_rgbs = np.stack(d3dgs_rgbs).reshape(n_views, n_frames, h, w, 3) d3dgs_depths = np.stack(d3dgs_depths).reshape(n_views, n_frames, h, w) assert rgbs.shape == (n_views, n_frames, h, w, 3) assert d3dgs_rgbs.shape == (n_views, n_frames, h, w, 3) assert d3dgs_depths.shape == (n_views, n_frames, h, w) gt_tracks = np.stack([data['means3D'][is_fg][::traj_frac].contiguous().cpu().numpy() for data in scene_data]) n_tracks = gt_tracks.shape[1] gt_vis = np.ones((n_frames, n_tracks), dtype=bool) query_timestep = gt_vis.argmin(0) assert gt_tracks.shape == (n_frames, n_tracks, 3) assert gt_vis.shape == (n_frames, n_tracks) cmap = matplotlib.colormaps["gist_rainbow"] norm = matplotlib.colors.Normalize(vmin=gt_tracks[..., 0].min(), vmax=gt_tracks[..., 0].max()) track_colors = cmap(norm(gt_tracks[-1, :, 0])) assert track_colors.shape == (n_tracks, 4) rr.init("reconstruction", recording_id="v0.1") rr.connect_tcp() rr.set_time_seconds("frame", 0) rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True) rr.log("world/xyz", rr.Arrows3D(vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]], colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]])) for t in range(0, n_frames, log_n_skip_t): for v in range(0, n_views, log_n_skip_view): rr.set_time_seconds("frame", t / 30) if log_rgb: rr.log(f"{seq}/rgb/view-{VIEWS[v]}/rgb", rr.Image(rgbs[v, t])) rr.log(f"{seq}/rgb/view-{VIEWS[v]}", rr.Pinhole(image_from_camera=k[v].numpy(), width=w, height=h)) rr.log(f"{seq}/rgb/view-{VIEWS[v]}", rr.Transform3D(translation=extrinsics_inv[v, :3, 3].numpy(), mat3x3=extrinsics_inv[v, :3, :3].numpy())) if log_d3dgs_rgb: rr.log(f"{seq}/dyn-3dgs-rgb/view-{VIEWS[v]}/rgb", rr.Image(d3dgs_rgbs[v, t])) rr.log(f"{seq}/dyn-3dgs-rgb/view-{VIEWS[v]}", rr.Pinhole(image_from_camera=k[v].numpy(), width=w, height=h)) rr.log(f"{seq}/dyn-3dgs-rgb/view-{VIEWS[v]}", rr.Transform3D(translation=extrinsics_inv[v, :3, 3].numpy(), mat3x3=extrinsics_inv[v, :3, :3].numpy())) if log_d3dgs_depth: rr.log(f"{seq}/dyn-3dgs-depth/view-{VIEWS[v]}/depth", rr.DepthImage(d3dgs_depths[v, t], point_fill_ratio=0.2)) rr.log(f"{seq}/dyn-3dgs-depth/view-{VIEWS[v]}", rr.Pinhole(image_from_camera=k[v].numpy(), width=w, height=h)) rr.log(f"{seq}/dyn-3dgs-depth/view-{VIEWS[v]}", rr.Transform3D(translation=extrinsics_inv[v, :3, 3].numpy(), mat3x3=extrinsics_inv[v, :3, :3].numpy())) if log_d3dgs_point_cloud: y, x = np.indices((h, w)) homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T depth_values = d3dgs_depths[v, t].ravel() cam_coords = (k_inv[v] @ homo_pixel_coords) * depth_values cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1])))) world_coords = (extrinsics_inv[v] @ cam_coords)[:3].T valid_mask = depth_values > 0 world_coords = world_coords[valid_mask] rgb_colors = rgbs[v, t].reshape(-1, 3)[valid_mask].astype(np.uint8) rr.log(f"{seq}/dyn-3dgs-point-cloud/view-{v}", rr.Points3D(world_coords, colors=rgb_colors, radii=0.01)) if log_tracks: for tracks_batch_start in range(0, n_tracks, 100): tracks_batch_end = min(tracks_batch_start + 100, n_tracks) log_tracks_to_rerun( tracks=gt_tracks[:, tracks_batch_start:tracks_batch_end], visibles=gt_vis[:, tracks_batch_start:tracks_batch_end], query_timestep=query_timestep[tracks_batch_start:tracks_batch_end], colors=track_colors[tracks_batch_start:tracks_batch_end], track_names=[f"track-{i:02d}" for i in range(tracks_batch_start, tracks_batch_end)], entity_format_str=f"{seq}/dyn-3dgs-tracks/{tracks_batch_start}-{tracks_batch_end}/{{}}", invisible_color=[0.3, 0.3, 0.3], ) print("Done with visualization.") if __name__ == "__main__": exp_name = "pretrained" for sequence in ["basketball", "boxes", "football", "juggle", "softball", "tennis"]: visualize(sequence, exp_name) ================================================ FILE: mvtracker/models/core/embeddings.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import numpy as np import torch def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ if isinstance(grid_size, tuple): grid_size_h, grid_size_w = grid_size else: grid_size_h = grid_size_w = grid_size grid_h = np.arange(grid_size_h, dtype=np.float32) grid_w = np.arange(grid_size_w, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate( [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 ) return pos_embed def get_3d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 3 == 0 # use half of dimensions to encode grid_h B, S, N, _ = grid.shape gridx = grid[..., 0].view(B * S * N).detach().cpu().numpy() gridy = grid[..., 1].view(B * S * N).detach().cpu().numpy() gridz = grid[..., 2].view(B * S * N).detach().cpu().numpy() emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridx) # (N, D/3) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridy) # (N, D/3) emb_z = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridz) # (N, D/3) emb = np.concatenate([emb_h, emb_w, emb_z], axis=1) # (N, D) emb = torch.from_numpy(emb).to(grid.device) return emb.view(B, S, N, embed_dim) def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ if isinstance(grid_size, tuple): grid_size_h, grid_size_w = grid_size else: grid_size_h = grid_size_w = grid_size grid_h = np.arange(grid_size_h, dtype=np.float32) grid_w = np.arange(grid_size_w, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate( [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 ) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 omega = 1.0 / 10000 ** omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb def get_2d_embedding(xy, C, cat_coords=True): B, N, D = xy.shape assert D == 2 x = xy[:, :, 0:1] y = xy[:, :, 1:2] div_term = ( torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) ).reshape(1, 1, int(C / 2)) pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) pe_x[:, :, 0::2] = torch.sin(x * div_term) pe_x[:, :, 1::2] = torch.cos(x * div_term) pe_y[:, :, 0::2] = torch.sin(y * div_term) pe_y[:, :, 1::2] = torch.cos(y * div_term) pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3 if cat_coords: pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3 return pe def get_3d_embedding(xyz, C, cat_coords=True): B, N, D = xyz.shape assert D == 3 x = xyz[:, :, 0:1] y = xyz[:, :, 1:2] z = xyz[:, :, 2:3] div_term = ( torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C) ).reshape(1, 1, int(C / 2)) pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) pe_x[:, :, 0::2] = torch.sin(x * div_term) pe_x[:, :, 1::2] = torch.cos(x * div_term) pe_y[:, :, 0::2] = torch.sin(y * div_term) pe_y[:, :, 1::2] = torch.cos(y * div_term) pe_z[:, :, 0::2] = torch.sin(z * div_term) pe_z[:, :, 1::2] = torch.cos(z * div_term) pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3 if cat_coords: pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3 return pe def get_4d_embedding(xyzw, C, cat_coords=True): B, N, D = xyzw.shape assert D == 4 x = xyzw[:, :, 0:1] y = xyzw[:, :, 1:2] z = xyzw[:, :, 2:3] w = xyzw[:, :, 3:4] div_term = ( torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C) ).reshape(1, 1, int(C / 2)) pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) pe_x[:, :, 0::2] = torch.sin(x * div_term) pe_x[:, :, 1::2] = torch.cos(x * div_term) pe_y[:, :, 0::2] = torch.sin(y * div_term) pe_y[:, :, 1::2] = torch.cos(y * div_term) pe_z[:, :, 0::2] = torch.sin(z * div_term) pe_z[:, :, 1::2] = torch.cos(z * div_term) pe_w[:, :, 0::2] = torch.sin(w * div_term) pe_w[:, :, 1::2] = torch.cos(w * div_term) pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3 if cat_coords: pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3 return pe import torch.nn as nn class Embedder_Fourier(nn.Module): def __init__(self, input_dim, max_freq_log2, N_freqs, log_sampling=True, include_input=True, periodic_fns=(torch.sin, torch.cos)): ''' :param input_dim: dimension of input to be embedded :param max_freq_log2: log2 of max freq; min freq is 1 by default :param N_freqs: number of frequency bands :param log_sampling: if True, frequency bands are linerly sampled in log-space :param include_input: if True, raw input is included in the embedding :param periodic_fns: periodic functions used to embed input ''' super(Embedder_Fourier, self).__init__() self.input_dim = input_dim self.include_input = include_input self.periodic_fns = periodic_fns self.out_dim = 0 if self.include_input: self.out_dim += self.input_dim self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns) if log_sampling: self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs) else: self.freq_bands = torch.linspace( 2. ** 0., 2. ** max_freq_log2, N_freqs) self.freq_bands = self.freq_bands.numpy().tolist() def forward(self, input: torch.Tensor, rescale: float = 1.0): ''' :param input: tensor of shape [..., self.input_dim] :return: tensor of shape [..., self.out_dim] ''' assert (input.shape[-1] == self.input_dim) out = [] if self.include_input: out.append(input / rescale) for i in range(len(self.freq_bands)): freq = self.freq_bands[i] for p_fn in self.periodic_fns: out.append(p_fn(input.float() * freq).type_as(input)) out = torch.cat(out, dim=-1) assert not input.isnan().any(), f"Found NaN in input" assert not out.isnan().any(), f"Found NaN in output" assert (out.shape[-1] == self.out_dim) return out ================================================ FILE: mvtracker/models/core/loftr/__init__.py ================================================ from .transformer import LocalFeatureTransformer ================================================ FILE: mvtracker/models/core/loftr/linear_attention.py ================================================ """ Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py """ import torch from torch.nn import Module, Dropout def elu_feature_map(x): return torch.nn.functional.elu(x) + 1 class LinearAttention(Module): def __init__(self, eps=1e-6): super().__init__() self.feature_map = elu_feature_map self.eps = eps def forward(self, queries, keys, values, q_mask=None, kv_mask=None): """ Multi-Head linear attention proposed in "Transformers are RNNs" Args: queries: [N, L, H, D] keys: [N, S, H, D] values: [N, S, H, D] q_mask: [N, L] kv_mask: [N, S] Returns: queried_values: (N, L, H, D) """ Q = self.feature_map(queries) K = self.feature_map(keys) # set padded position to zero if q_mask is not None: Q = Q * q_mask[:, :, None, None] if kv_mask is not None: K = K * kv_mask[:, :, None, None] values = values * kv_mask[:, :, None, None] v_length = values.size(1) values = values / v_length # prevent fp16 overflow KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length return queried_values.contiguous() class FullAttention(Module): def __init__(self, use_dropout=False, attention_dropout=0.1): super().__init__() self.use_dropout = use_dropout self.dropout = Dropout(attention_dropout) def forward(self, queries, keys, values, q_mask=None, kv_mask=None): """ Multi-head scaled dot-product attention, a.k.a full attention. Args: queries: [N, L, H, D] keys: [N, S, H, D] values: [N, S, H, D] q_mask: [N, L] kv_mask: [N, S] Returns: queried_values: (N, L, H, D) """ # Compute the unnormalized attention and apply the masks QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) if kv_mask is not None: QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) # Compute the attention and the weighted average softmax_temp = 1. / queries.size(3) ** .5 # sqrt(D) A = torch.softmax(softmax_temp * QK, dim=2) if self.use_dropout: A = self.dropout(A) queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) return queried_values.contiguous() ================================================ FILE: mvtracker/models/core/loftr/transformer.py ================================================ ''' modified from https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py ''' import copy import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import Module, Dropout def elu_feature_map(x): return torch.nn.functional.elu(x) + 1 class FullAttention(Module): def __init__(self, use_dropout=False, attention_dropout=0.1): super().__init__() self.use_dropout = use_dropout self.dropout = Dropout(attention_dropout) def forward(self, queries, keys, values, q_mask=None, kv_mask=None): """ Multi-head scaled dot-product attention, a.k.a full attention. Args: queries: [N, L, H, D] keys: [N, S, H, D] values: [N, S, H, D] q_mask: [N, L] kv_mask: [N, S] Returns: queried_values: (N, L, H, D) """ # Compute the unnormalized attention and apply the masks # QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) # if kv_mask is not None: # QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float(-1e12)) # softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) # A = torch.softmax(softmax_temp * QK, dim=2) # if self.use_dropout: # A = self.dropout(A) # queried_values_ = torch.einsum("nlsh,nshd->nlhd", A, values) # Compute the attention and the weighted average input_args = [x.half().contiguous() for x in [queries.permute(0, 2, 1, 3), keys.permute(0, 2, 1, 3), values.permute(0, 2, 1, 3)]] queried_values = F.scaled_dot_product_attention(*input_args).permute(0, 2, 1, 3).float() # type: ignore return queried_values.contiguous() class TransformerEncoderLayer(nn.Module): def __init__(self, d_model, nhead, ): super(TransformerEncoderLayer, self).__init__() self.dim = d_model // nhead self.nhead = nhead # multi-head attention self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, d_model, bias=False) self.v_proj = nn.Linear(d_model, d_model, bias=False) self.attention = FullAttention() self.merge = nn.Linear(d_model, d_model, bias=False) # feed-forward network self.mlp = nn.Sequential( nn.Linear(d_model * 2, d_model * 2, bias=False), nn.ReLU(True), nn.Linear(d_model * 2, d_model, bias=False), ) # norm and dropout self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) def forward(self, x, source, x_mask=None, source_mask=None): """ Args: x (torch.Tensor): [N, L, C] source (torch.Tensor): [N, S, C] x_mask (torch.Tensor): [N, L] (optional) source_mask (torch.Tensor): [N, S] (optional) """ bs = x.size(0) query, key, value = x, source, source # multi-head attention query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] message = self.merge(message.view(bs, -1, self.nhead * self.dim)) # [N, L, C] message = self.norm1(message) # feed-forward network message = self.mlp(torch.cat([x, message], dim=2)) message = self.norm2(message) return x + message class LocalFeatureTransformer(nn.Module): """A Local Feature Transformer module.""" def __init__(self, config): super(LocalFeatureTransformer, self).__init__() self.config = config self.d_model = config['d_model'] self.nhead = config['nhead'] self.layer_names = config['layer_names'] encoder_layer = TransformerEncoderLayer(config['d_model'], config['nhead']) self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, feat0, feat1, mask0=None, mask1=None): """ Args: feat0 (torch.Tensor): [N, L, C] feat1 (torch.Tensor): [N, S, C] mask0 (torch.Tensor): [N, L] (optional) mask1 (torch.Tensor): [N, S] (optional) """ assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" for layer, name in zip(self.layers, self.layer_names): if name == 'self': feat0 = layer(feat0, feat0, mask0, mask0) feat1 = layer(feat1, feat1, mask1, mask1) elif name == 'cross': feat0 = layer(feat0, feat1, mask0, mask1) feat1 = layer(feat1, feat0, mask1, mask0) else: raise KeyError return feat0, feat1 ================================================ FILE: mvtracker/models/core/losses.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn.functional as F from mvtracker.models.core.model_utils import reduce_masked_mean EPS = 1e-6 sigma = 3 x_grid = torch.arange(-7, 8, 1) y_grid = torch.arange(-7, 8, 1) x_grid, y_grid = torch.meshgrid(x_grid, y_grid, indexing="ij") gridxy = torch.stack([x_grid, y_grid], dim=-1).float() gs_kernel = torch.exp(-torch.sum(gridxy ** 2, dim=-1) / (2 * sigma ** 2)) def balanced_ce_loss(pred, gt, valid=None): total_balanced_loss = 0.0 for j in range(len(gt)): B, S, N = gt[j].shape # pred and gt are the same shape for (a, b) in zip(pred[j].size(), gt[j].size()): assert a == b # some shape mismatch! # if valid is not None: for (a, b) in zip(pred[j].size(), valid[j].size()): assert a == b # some shape mismatch! pos = (gt[j] > 0.95).float() neg = (gt[j] < 0.05).float() label = pos * 2.0 - 1.0 a = -label * pred[j] b = F.relu(a) loss = b + torch.log(torch.exp(-b) + torch.exp(a - b)) pos_loss = reduce_masked_mean(loss, pos * valid[j]) neg_loss = reduce_masked_mean(loss, neg * valid[j]) balanced_loss = pos_loss + neg_loss total_balanced_loss += balanced_loss return total_balanced_loss def sequence_loss_3d(flow_preds, flow_gt, vis, valids, gamma=0.8, dmin=0.1, dmax=65, Dz=128): """Loss function defined over sequence of flow predictions with z component post-processing""" total_flow_loss = 0.0 J = len(flow_gt) for j in range(J): B, S, N, D = flow_gt[j].shape assert D == 3 B, S1, N = vis[j].shape B, S2, N = valids[j].shape assert S == S1 assert S == S2 n_predictions = len(flow_preds[j]) flow_loss = 0.0 for i in range(n_predictions): i_weight = gamma ** (n_predictions - i - 1) flow_pred = flow_preds[j][i] flow_gt_j = flow_gt[j].clone() flow_pred[..., 2] = (flow_pred[..., 2] - dmin) / (dmax - dmin) * Dz flow_gt_j[..., 2] = (flow_gt_j[..., 2] - dmin) / (dmax - dmin) * Dz i_loss = (flow_pred - flow_gt_j).abs() # B, S, N, 3 i_loss = torch.mean(i_loss, dim=3) # B, S, N flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j]) flow_loss = flow_loss / n_predictions total_flow_loss += flow_loss / float(J) return total_flow_loss ================================================ FILE: mvtracker/models/core/model_utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import warnings from typing import Tuple, Optional import torch from easydict import EasyDict as edict from torch.nn import functional as F from mvtracker.utils.basic import to_homogeneous, from_homogeneous EPS = 1e-6 def smart_cat(tensor1, tensor2, dim): if tensor1 is None: return tensor2 return torch.cat([tensor1, tensor2], dim=dim) def normalize_single(d): # d is a whatever shape torch tensor dmin = torch.min(d) dmax = torch.max(d) d = (d - dmin) / (EPS + (dmax - dmin)) return d def normalize(d): # d is B x whatever. normalize within each element of the batch out = torch.zeros(d.size()) if d.is_cuda: out = out.cuda() B = list(d.size())[0] for b in list(range(B)): out[b] = normalize_single(d[b]) return out def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"): # returns a meshgrid sized B x Y x X grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device)) grid_y = torch.reshape(grid_y, [1, Y, 1]) grid_y = grid_y.repeat(B, 1, X) grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device)) grid_x = torch.reshape(grid_x, [1, 1, X]) grid_x = grid_x.repeat(B, Y, 1) if stack: # note we stack in xy order # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) grid = torch.stack([grid_x, grid_y], dim=-1) return grid else: return grid_y, grid_x def reduce_masked_mean(x, mask, dim=None, keepdim=False): # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting # returns shape-1 # axis can be a list of axes for (a, b) in zip(x.size(), mask.size()): assert a == b # some shape mismatch! prod = x * mask if dim is None: numer = torch.sum(prod) denom = EPS + torch.sum(mask) else: numer = torch.sum(prod, dim=dim, keepdim=keepdim) denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim) mean = numer / denom return mean def bilinear_sample2d(im, x, y, return_inbounds=False): # x and y are each B, N # output is B, C, N if len(im.shape) == 5: B, N, C, H, W = list(im.shape) else: B, C, H, W = list(im.shape) N = list(x.shape)[1] x = x.float() y = y.float() H_f = torch.tensor(H, dtype=torch.float32) W_f = torch.tensor(W, dtype=torch.float32) # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte() y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte() inbounds = (x_valid & y_valid).float() inbounds = inbounds.reshape( B, N ) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1) return output, inbounds return output # B, C, N def procrustes_analysis(X0, X1, Weight): # [B,N,3] # translation t0 = X0.mean(dim=1, keepdim=True) t1 = X1.mean(dim=1, keepdim=True) X0c = X0 - t0 X1c = X1 - t1 # scale # s0 = (X0c**2).sum(dim=-1).mean().sqrt() # s1 = (X1c**2).sum(dim=-1).mean().sqrt() # X0cs = X0c/s0 # X1cs = X1c/s1 # rotation (use double for SVD, float loses precision) U, _, V = (X0c.t() @ X1c).double().svd(some=True) R = (U @ V.t()).float() if R.det() < 0: R[2] *= -1 # align X1 to X0: X1to0 = (X1-t1)/@R.t()+t0 se3 = edict(t0=t0[0], t1=t1[0], R=R) return se3 def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): r"""Sample a tensor using bilinear interpolation `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at coordinates :attr:`coords` using bilinear interpolation. It is the same as `torch.nn.functional.grid_sample()` but with a different coordinate convention. The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where :math:`B` is the batch size, :math:`C` is the number of channels, :math:`H` is the height of the image, and :math:`W` is the width of the image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note that in this case the order of the components is slightly different from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. If `align_corners` is `True`, the coordinate :math:`x` is assumed to be in the range :math:`[0,W-1]`, with 0 corresponding to the center of the left-most image pixel :math:`W-1` to the center of the right-most pixel. If `align_corners` is `False`, the coordinate :math:`x` is assumed to be in the range :math:`[0,W]`, with 0 corresponding to the left edge of the left-most pixel :math:`W` to the right edge of the right-most pixel. Similar conventions apply to the :math:`y` for the range :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range :math:`[0,T-1]` and :math:`[0,T]`. Args: input (Tensor): batch of input images. coords (Tensor): batch of coordinates. align_corners (bool, optional): Coordinate convention. Defaults to `True`. padding_mode (str, optional): Padding mode. Defaults to `"border"`. Returns: Tensor: sampled points. """ sizes = input.shape[2:] assert len(sizes) in [2, 3] if len(sizes) == 3: # t x y -> x y t to match dimensions T H W in grid_sample coords = coords[..., [1, 2, 0]] if align_corners: coords = coords * torch.tensor( [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device ) else: coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device) coords -= 1 return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) def sample_features4d(input, coords): r"""Sample spatial features `sample_features4d(input, coords)` samples the spatial features :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. The field is sampled at coordinates :attr:`coords` using bilinear interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, 3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the same convention as :func:`bilinear_sampler` with `align_corners=True`. The output tensor has one feature per point, and has shape :math:`(B, R, C)`. Args: input (Tensor): spatial features. coords (Tensor): points. Returns: Tensor: sampled features. """ B, _, _, _ = input.shape # B R 2 -> B R 1 2 coords = coords.unsqueeze(2) # B C R 1 feats = bilinear_sampler(input, coords) return feats.permute(0, 2, 1, 3).view( B, -1, feats.shape[1] * feats.shape[3] ) # B C R 1 -> B R C def sample_features5d(input, coords): r"""Sample spatio-temporal features `sample_features5d(input, coords)` works in the same way as :func:`sample_features4d` but for spatio-temporal features and points: :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i, x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`. Args: input (Tensor): spatio-temporal features. coords (Tensor): spatio-temporal points. Returns: Tensor: sampled features. """ B, T, _, _, _ = input.shape # B T C H W -> B C T H W input = input.permute(0, 2, 1, 3, 4) # B R1 R2 3 -> B R1 R2 1 3 coords = coords.unsqueeze(3) # B C R1 R2 1 feats = bilinear_sampler(input, coords) return feats.permute(0, 2, 3, 1, 4).view( B, feats.shape[2], feats.shape[3], feats.shape[1] ) # B C R1 R2 1 -> B R1 R2 C def pixel_xy_and_camera_z_to_world_space(pixel_xy, camera_z, intrs_inv, extrs_inv): num_frames, num_points, _ = pixel_xy.shape assert pixel_xy.shape == (num_frames, num_points, 2) assert camera_z.shape == (num_frames, num_points, 1) assert intrs_inv.shape == (num_frames, 3, 3) assert extrs_inv.shape == (num_frames, 4, 4) pixel_xy_homo = torch.cat([pixel_xy, pixel_xy.new_ones(pixel_xy[..., :1].shape)], -1) camera_xyz = torch.einsum('Aij,ABj->ABi', intrs_inv, pixel_xy_homo) * camera_z camera_xyz_homo = torch.cat([camera_xyz, camera_xyz.new_ones(camera_xyz[..., :1].shape)], -1) world_xyz_homo = torch.einsum('Aij,ABj->ABi', extrs_inv, camera_xyz_homo) if not torch.allclose( world_xyz_homo[..., -1], world_xyz_homo.new_ones(world_xyz_homo[..., -1].shape), atol=0.1, ): warnings.warn(f"pixel_xy_and_camera_z_to_world_space found some homo coordinates not close to 1: " f"the homo values are in {world_xyz_homo[..., -1].min()} – {world_xyz_homo[..., -1].max()}") world_xyz = world_xyz_homo[..., :-1] assert world_xyz.shape == (num_frames, num_points, 3) return world_xyz def world_space_to_pixel_xy_and_camera_z(world_xyz, intrs, extrs): num_frames, num_points, _ = world_xyz.shape assert world_xyz.shape == (num_frames, num_points, 3) assert intrs.shape == (num_frames, 3, 3) assert extrs.shape == (num_frames, 3, 4) world_xyz_homo = torch.cat([world_xyz, world_xyz.new_ones(world_xyz[..., :1].shape)], -1) camera_xyz = torch.einsum('Aij,ABj->ABi', extrs, world_xyz_homo) camera_z = camera_xyz[..., -1:] pixel_xy_homo = torch.einsum('Aij,ABj->ABi', intrs, camera_xyz) pixel_xy = pixel_xy_homo[..., :2] / pixel_xy_homo[..., -1:] assert pixel_xy.shape == (num_frames, num_points, 2) assert camera_z.shape == (num_frames, num_points, 1) return pixel_xy, camera_z def get_points_on_a_grid( size: int, extent: Tuple[float, ...], center: Optional[Tuple[float, ...]] = None, device: Optional[torch.device] = torch.device("cpu"), ): r"""Get a grid of points covering a rectangular region `get_points_on_a_grid(size, extent)` generates a :attr:`size` by :attr:`size` grid fo points distributed to cover a rectangular area specified by `extent`. The `extent` is a pair of integer :math:`(H,W)` specifying the height and width of the rectangle. Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)` specifying the vertical and horizontal center coordinates. The center defaults to the middle of the extent. Points are distributed uniformly within the rectangle leaving a margin :math:`m=W/64` from the border. It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of points :math:`P_{ij}=(x_i, y_i)` where .. math:: P_{ij} = \left( c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~ c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i \right) Points are returned in row-major order. Args: size (int): grid size. extent (tuple): height and with of the grid extent. center (tuple, optional): grid center. device (str, optional): Defaults to `"cpu"`. Returns: Tensor: grid. """ if size == 1: return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None] if center is None: center = [extent[0] / 2, extent[1] / 2] margin = extent[1] / 64 range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin) range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin) grid_y, grid_x = torch.meshgrid( torch.linspace(*range_y, size, device=device), torch.linspace(*range_x, size, device=device), indexing="ij", ) return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2) def init_pointcloud_from_rgbd( fmaps: torch.Tensor, depths: torch.Tensor, intrs: torch.Tensor, extrs: torch.Tensor, stride=4, level=0, depth_interp_mode='nearest', return_validity_mask=False, ): B, V, S, C, H, W = fmaps.shape assert fmaps.shape == (B, V, S, C, H, W) assert depths.shape == (B, V, S, 1, H, W) assert intrs.shape == (B, V, S, 3, 3) assert extrs.shape == (B, V, S, 3, 4) # Pool the fmaps and depths to the desired pyramid level fmaps = fmaps.reshape(B * V * S, C, H, W) depths = depths.reshape(B * V * S, 1, H, W) for i in range(level): fmaps = F.avg_pool2d(fmaps, 2, stride=2) if depth_interp_mode == 'avg': depths = F.avg_pool2d(depths, 2, stride=2) elif depth_interp_mode == 'nearest': depths = F.interpolate(depths, scale_factor=0.5, mode='nearest') else: raise NotImplementedError H = H // 2 ** level W = W // 2 ** level fmaps = fmaps.reshape(B, V, S, C, H, W) depths = depths.reshape(B, V, S, 1, H, W) stride = stride * 2 ** level # Invert intrinsics and extrinsics intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype) extrs_square = torch.eye(4).to(extrs.device)[None].repeat(B, V, S, 1, 1) extrs_square[:, :, :, :3, :] = extrs extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype) assert intrs_inv.shape == (B, V, S, 3, 3) assert extrs_inv.shape == (B, V, S, 4, 4) # Pixel --> Camera --> World pixel_xy = torch.stack(torch.meshgrid( (torch.arange(0, H) + 0.5) * stride - 0.5, (torch.arange(0, W) + 0.5) * stride - 0.5, indexing="ij", )[::-1], dim=-1) pixel_xy = pixel_xy.to(device=fmaps.device, dtype=fmaps.dtype) pixel_xy_homo = to_homogeneous(pixel_xy) depthmap_camera_xyz = torch.einsum('BVSij,HWj->BVSHWi', intrs_inv, pixel_xy_homo) depthmap_camera_xyz = depthmap_camera_xyz * depths[..., 0, :, :, None] depthmap_camera_xyz_homo = to_homogeneous(depthmap_camera_xyz) depthmap_world_xyz_homo = torch.einsum('BVSij,BVSHWj->BVSHWi', extrs_inv, depthmap_camera_xyz_homo) depthmap_world_xyz = from_homogeneous(depthmap_world_xyz_homo) pointcloud_xyz = depthmap_world_xyz.permute(0, 2, 1, 3, 4, 5).reshape(B * S, V * H * W, 3) pointcloud_fvec = fmaps.permute(0, 2, 1, 4, 5, 3).reshape(B * S, V * H * W, C) if return_validity_mask: pointcloud_valid_mask = depths.permute(0, 2, 1, 3, 4, 5).reshape(B * S, V * H * W) > 0 return pointcloud_xyz, pointcloud_fvec, pointcloud_valid_mask return pointcloud_xyz, pointcloud_fvec def save_pointcloud_to_ply(filename, points, colors, edges=None): with open(filename, 'w') as ply_file: ply_file.write("ply\nformat ascii 1.0\n") ply_file.write(f"element vertex {len(points)}\n") ply_file.write("property float x\nproperty float y\nproperty float z\n") ply_file.write("property uchar red\nproperty uchar green\nproperty uchar blue\n") if edges is not None: ply_file.write(f"element edge {len(edges)}\n") ply_file.write("property int vertex1\nproperty int vertex2\n") ply_file.write("end_header\n") # Write vertices (points with colors) for point, color in zip(points, colors): ply_file.write(f"{point[0]} {point[1]} {point[2]} {color[0]} {color[1]} {color[2]}\n") # Write edges (if provided) if edges is not None: for edge in edges: ply_file.write(f"{edge[0]} {edge[1]}\n") ================================================ FILE: mvtracker/models/core/monocular_baselines.py ================================================ import logging import sys import warnings from typing import Tuple import numpy as np import torch import torch.nn.functional as F from torch import nn as nn from mvtracker.datasets.utils import transform_scene from mvtracker.models.core.model_utils import bilinear_sample2d, pixel_xy_and_camera_z_to_world_space from mvtracker.utils.visualizer_mp4 import Visualizer class CoTrackerOfflineWrapper(nn.Module): def __init__(self, model_name="cotracker3_offline", grid_size=10): super(CoTrackerOfflineWrapper, self).__init__() self.grid_size = grid_size self.cotracker = torch.hub.load("facebookresearch/co-tracker", model_name) def forward(self, rgbs, queries, **kwargs): T, _, H, W = rgbs.shape N, _ = queries.shape assert rgbs.shape == (T, 3, H, W) assert queries.shape == (N, 3) # Forward pass: https://github.com/facebookresearch/co-tracker/blob/82e02e8029753ad4ef13cf06be7f4fc5facdda4d/cotracker/predictor.py#L36 pred_tracks, pred_visibility = self.cotracker( video=rgbs[None].float(), queries=queries[None].float(), grid_size=self.grid_size, ) return {"traj_2d": pred_tracks[0], "vis": pred_visibility[0]} class CoTrackerOnlineWrapper(nn.Module): def __init__(self, model_name="cotracker3_online", grid_size=10): super(CoTrackerOnlineWrapper, self).__init__() self.grid_size = grid_size self.cotracker = torch.hub.load("facebookresearch/co-tracker", model_name) def forward(self, rgbs, queries, **kwargs): T, _, H, W = rgbs.shape N, _ = queries.shape assert rgbs.shape == (T, 3, H, W) assert queries.shape == (N, 3) # Forward pass: https://github.com/facebookresearch/co-tracker/blob/82e02e8029753ad4ef13cf06be7f4fc5facdda4d/cotracker/predictor.py#L230 self.cotracker( video_chunk=rgbs[None].float(), queries=queries[None].float(), grid_size=self.grid_size, is_first_step=True, ) for t in range(0, T - self.cotracker.step, self.cotracker.step): pred_tracks, pred_visibility = self.cotracker(video_chunk=rgbs[None, t: t + self.cotracker.step * 2]) return {"traj_2d": pred_tracks[0], "vis": pred_visibility[0]} class SpaTrackerV2Wrapper(nn.Module): """ Environment setup: ```bash git clone https://github.com/henry123-boy/SpaTrackerV2.git ../spatialtrackerv2 cd ../spatialtrackerv2 git checkout 1673230 git submodule update --init --recursive pip install pycolmap==3.11.1 pip install git+https://github.com/EasternJournalist/utils3d.git#egg=utils3d pip install pyceres==2.4 # Update the threshold for weighted_procrustes_torch from 1e-3 to 5e-3 sed -i 's/(torch.det(R) - 1).abs().max() < 1e-3/(torch.det(R) - 1).abs().max() < 5e-3/' ./models/SpaTrackV2/models/tracker3D/spatrack_modules/utils.py # Verify the change: this should print a line with 5e-3 cat ./models/SpaTrackV2/models/tracker3D/spatrack_modules/utils.py | grep "(torch.det(R) - 1).abs().max()" ``` """ def __init__( self, model_type="offline", vo_points=756, ): super(SpaTrackerV2Wrapper, self).__init__() sys.path.append("../spatialtrackerv2/") from models.SpaTrackV2.models.predictor import Predictor if model_type == "offline": self.model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline") elif model_type == "online": self.model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online") else: raise ValueError(f"Unknown model_type: {model_type}") self.model.spatrack.track_num = vo_points # the track_num is the number of points in the grid self.model.eval() self.model.to("cuda") def forward(self, rgbs, depths, queries, queries_xyz_worldspace, intrs, extrs, **kwargs): T, _, H, W = rgbs.shape N, _ = queries.shape assert rgbs.shape == (T, 3, H, W) assert depths.shape == (T, 1, H, W) assert intrs.shape == (T, 3, 3) assert extrs.shape == (T, 3, 4) assert queries.shape == (N, 3) assert queries_xyz_worldspace.shape == (N, 4) extrs_square = torch.eye(4).to(extrs.device)[None].repeat(T, 1, 1) extrs_square[:, :3, :] = extrs # Transform the extrinsics so that the camera is in the origin, and later revert. transform = extrs_square[0] transform_inv = torch.inverse(transform) extrs, queries_xyz_worldspace = extrs.clone(), queries_xyz_worldspace.clone() ( _, extrs, queries_xyz_worldspace, _, _ ) = transform_scene(1, transform[:3, :3], transform[:3, 3], None, extrs[None], queries_xyz_worldspace) extrs = extrs[0] extrs_square[:, :3, :] = extrs # Check if the camera is fixed extrs_delta = torch.linalg.norm(extrs - extrs[0], dim=(1, 2)) fixed_cam = (extrs_delta < 1e-3).all().item() # Run inference extrs_inv = torch.inverse(extrs_square) with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): ( c2w_traj, intrs, point_map, conf_depth, track3d_pred, track2d_pred, vis_pred, conf_pred, video ) = self.model.forward(rgbs.cpu(), depth=depths.squeeze(1).cpu().numpy(), intrs=intrs.cpu(), extrs=extrs_inv.cpu().numpy(), queries=queries.cpu().numpy(), queries_3d=queries_xyz_worldspace.cpu().numpy(), fps=1, full_point=True, iters_track=4, query_no_BA=True, fixed_cam=fixed_cam, stage=1, unc_metric=None, support_frame=T - 1, replace_ratio=0.2) trajectories_3d = ( torch.einsum("tij,tnj->tni", c2w_traj[:, :3, :3].to(track3d_pred.device), track3d_pred[:, :, :3]) + c2w_traj[:, :3, 3][:, None, :].to(track3d_pred.device) ) ( _, _, _, trajectories_3d, _ ) = transform_scene(1, transform_inv[:3, :3], transform_inv[:3, 3], None, None, None, trajectories_3d, None) visibilities = vis_pred.squeeze(2) assert trajectories_3d.shape == (T, N, 3) assert visibilities.shape == (T, N) return {"traj_2d": None, "traj_3d_worldspace": trajectories_3d, "vis": visibilities} class LocoTrackWrapper(nn.Module): """ Environment setup: ```sh git clone https://github.com/cvlab-kaist/locotrack ../locotrack cd ../locotrack find ./locotrack_pytorch -type f -name "*.py" -exec sed -i 's/\bimport models\b/import locotrack_pytorch.models/g' {} \; find ./locotrack_pytorch -type f -name "*.py" -exec sed -i 's/\bfrom models\b/from locotrack_pytorch.models/g' {} \; cd ../spatialtracker ``` """ def __init__(self, model_size="base"): super(LocoTrackWrapper, self).__init__() sys.path.append("../locotrack") from locotrack_pytorch.models.locotrack_model import load_model self.model = load_model(model_size=model_size).cuda() self.model.eval() def forward(self, rgbs, queries, **kwargs): T, _, H, W = rgbs.shape N, _ = queries.shape assert (H, W) == (256, 256), f"LocoTrack only supports (256, 256) images, but got ({H}, {W})" assert rgbs.shape == (T, 3, H, W) assert queries.shape == (N, 3) # Forward pass: https://github.com/cvlab-kaist/locotrack/blob/6f3f9cad46b06c3de9c38fbf21006271056baf45/locotrack_pytorch/models/locotrack_model.py#L323 video = (rgbs.permute(0, 2, 3, 1)[None] / 255.0) * 2 - 1 queries_tyx = torch.stack([queries[:, 0], queries[:, 2], queries[:, 1]], dim=1)[None] # queries_tyx = queries_tyx / torch.tensor([1, H, W], dtype=queries_tyx.dtype, device=queries_tyx.device) with torch.no_grad(): output = self.model(video=video, query_points=queries_tyx) pred_occ = torch.sigmoid(output['occlusion']) if 'expected_dist' in output: pred_occ = 1 - (1 - pred_occ) * (1 - torch.sigmoid(output['expected_dist'])) pred_occ = (pred_occ > 0.5)[0] trajectories_2d = output['tracks'][0].permute(1, 0, 2) # trajectories_2d[..., 0] *= W # trajectories_2d[..., 1] *= H visibilities = ~pred_occ.permute(1, 0) if torch.isnan(trajectories_2d).any(): warnings.warn( f"Found {torch.isnan(trajectories_2d).sum()}/{trajectories_2d.numel()} NaN values in trajectories_2d. Setting them to 0.") trajectories_2d[trajectories_2d.isnan()] = 0 if torch.isnan(visibilities).any(): warnings.warn( f"Found {torch.isnan(visibilities).sum()}/{visibilities.numel()} NaN values in visibilities. Setting them to 1.") visibilities[visibilities.isnan()] = 1 return {"traj_2d": trajectories_2d, "vis": visibilities} class TAPTRWrapper(nn.Module): pass class TAPIRWrapper(nn.Module): pass class PIPSWrapper(nn.Module): pass class PIPSPlusPlusWrapper(nn.Module): pass class SceneTrackerWrapper(nn.Module): """ Environment setup: ```sh wget --directory-prefix=checkpoints https://huggingface.co/wwcreator/SceneTracker/resolve/main/scenetracker-odyssey-200k.pth git clone https://github.com/wwsource/SceneTracker.git ../scenetracker python eval.py experiment_path=logs/scenetracker model=scenetracker ``` """ def __init__( self, ckpt="checkpoints/scenetracker-odyssey-200k.pth", return_2d_track=False, ): super(SceneTrackerWrapper, self).__init__() sys.path.append("../scenetracker/") from model.model_scenetracker import SceneTracker model = SceneTracker() pre_replace_list = [['module.', '']] checkpoint = torch.load(ckpt) for l in pre_replace_list: checkpoint = {k.replace(l[0], l[1]): v for k, v in checkpoint.items()} model.load_state_dict(checkpoint, strict=True) model.eval().cuda() self.return_2d_track = return_2d_track self.model = model def forward(self, rgbs, depths, queries_with_z, **kwargs): T, _, H, W = rgbs.shape N, _ = queries_with_z.shape assert rgbs.shape == (T, 3, H, W) assert depths.shape == (T, 1, H, W) assert queries_with_z.shape == (N, 4) trajs_uv_e, trajs_z_e, _, _ = self.model.infer( self.model, input_list=[ rgbs[None].float(), depths[None].float(), queries_with_z[None].float(), ], iters=4, is_train=False, ) trajectories_2d = trajs_uv_e[0].type(queries_with_z.dtype) trajectories_z = trajs_z_e[0].type(queries_with_z.dtype) visibilities = torch.zeros_like(trajectories_2d[..., 0], dtype=torch.bool) if self.return_2d_track: return {"traj_2d": trajectories_2d, "vis": visibilities} else: return {"traj_2d": trajectories_2d, "traj_z": trajectories_z, "vis": visibilities} class DELTAWrapper(nn.Module): """ Environment setup: ```sh mkdir -p ./checkpoints/ gdown --fuzzy https://drive.google.com/file/d/18d5M3nl3AxbG4ZkT7wssvMXZXbmXrnjz/view?usp=sharing -O ./checkpoints/ # 3D ckpt gdown --fuzzy https://drive.google.com/file/d/1S_T7DzqBXMtr0voRC_XUGn1VTnPk_7Rm/view?usp=sharing -O ./checkpoints/ # 2D ckpt git clone --recursive https://github.com/snap-research/DELTA_densetrack3d ../delta pip install jaxtyping python eval.py experiment_path=logs/delta model=delta ``` """ def __init__( self, ckpt="checkpoints/densetrack3d.pth", upsample_factor=4, grid_size=20, return_2d_track=False, ): super(DELTAWrapper, self).__init__() self.grid_size = grid_size self.return_2d_track = return_2d_track sys.path.append("../delta") from densetrack3d.models.densetrack3d.densetrack3d import DenseTrack3D from densetrack3d.models.predictor.predictor import Predictor3D model = DenseTrack3D( stride=4, window_len=16, add_space_attn=True, num_virtual_tracks=64, model_resolution=(384, 512), upsample_factor=upsample_factor ) with open(ckpt, "rb") as f: state_dict = torch.load(f, map_location="cpu") if "model" in state_dict: state_dict = state_dict["model"] model.load_state_dict(state_dict, strict=False) predictor = Predictor3D(model=model) predictor = predictor.eval().cuda() self.model = model self.predictor = predictor def forward(self, rgbs, depths, queries, **kwargs): T, _, H, W = rgbs.shape N, _ = queries.shape assert rgbs.shape == (T, 3, H, W) assert depths.shape == (T, 1, H, W) assert queries.shape == (N, 3) out_dict = self.predictor( rgbs[None], depths[None], queries=queries[None], segm_mask=None, grid_size=self.grid_size, grid_query_frame=0, backward_tracking=False, predefined_intrs=None ) trajectories_2d = out_dict["trajs_uv"][0] trajectories_z = out_dict["trajs_depth"][0] trajectories_3d = out_dict["trajs_3d_dict"]["coords"][0] visibilities = out_dict["vis"][0] if self.return_2d_track: return {"traj_2d": trajectories_2d, "vis": visibilities} else: return {"traj_2d": trajectories_2d, "traj_z": trajectories_z, "vis": visibilities} class TAPIP3DWrapper(nn.Module): """ Environment setup: ```sh wget --directory-prefix=checkpoints https://huggingface.co/zbww/tapip3d/resolve/main/tapip3d_final.pth git clone git@github.com:zbw001/TAPIP3D.git ../tapip3d cd ../tapip3d git checkout 9359ae236f16a58a103dc1c55ad1919360dc6f8b cd third_party/pointops2 LIBRARY_PATH=$CONDA_PREFIX/lib:$LIBRARY_PATH python setup.py install cd ../.. """ def __init__( self, ckpt="checkpoints/tapip3d_final.pth", num_iters=6, grid_size=8, resolution_factor=2, transform_to_camera_space=False, ): super(TAPIP3DWrapper, self).__init__() self.num_iters = num_iters self.support_grid_size = grid_size self.resolution_factor = resolution_factor self.transform_to_camera_space = transform_to_camera_space sys.path.append("../tapip3d") from utils.inference_utils import load_model self.model = load_model(ckpt) self.model.cuda() inference_res = ( int(self.model.image_size[0] * np.sqrt(self.resolution_factor)), int(self.model.image_size[1] * np.sqrt(self.resolution_factor)), ) self.model.set_image_size(inference_res) def forward(self, rgbs, depths, intrs, extrs, queries_xyz_worldspace, **kwargs): T, _, H, W = rgbs.shape N, _ = queries_xyz_worldspace.shape assert rgbs.shape == (T, 3, H, W) assert depths.shape == (T, 1, H, W) assert intrs.shape == (T, 3, 3) assert extrs.shape == (T, 3, 4) assert queries_xyz_worldspace.shape == (N, 4) extrs_square = torch.eye(4).to(extrs.device)[None].repeat(T, 1, 1) extrs_square[:, :3, :] = extrs # Transform the extrinsics (and query points) so that # the camera is in the origin, and later revert. # But it's about the same performance either way. if self.transform_to_camera_space: T = extrs_square[0] T_inv = torch.inverse(T) extrs = extrs.clone() ( _, extrs, queries_xyz_worldspace, _, _ ) = transform_scene(1, T[:3, :3], T[:3, 3], None, extrs[None], queries_xyz_worldspace, None, None) extrs = extrs[0] extrs_square[:, :3, :] = extrs # Run inference with torch.autocast("cuda", dtype=torch.bfloat16): trajectories_3d, visibilities = TAPIP3DWrapper.inference( model=self.model, video=rgbs / 255.0, depths=depths.squeeze(1), intrinsics=intrs, extrinsics=extrs_square, query_point=queries_xyz_worldspace, num_iters=self.num_iters, grid_size=self.support_grid_size, ) if self.transform_to_camera_space: ( _, _, _, trajectories_3d, _ ) = transform_scene(1, T_inv[:3, :3], T_inv[:3, 3], None, None, None, trajectories_3d, None) if N == 1: trajectories_3d = trajectories_3d.unsqueeze(1) visibilities = visibilities.unsqueeze(1) assert trajectories_3d.shape == (T, N, 3) assert visibilities.shape == (T, N) return {"traj_2d": None, "traj_3d_worldspace": trajectories_3d.clone(), "vis": visibilities.clone()} @staticmethod @torch.no_grad() def inference( *, model: torch.nn.Module, video: torch.Tensor, depths: torch.Tensor, intrinsics: torch.Tensor, extrinsics: torch.Tensor, query_point: torch.Tensor, num_iters: int = 6, grid_size: int = 8, bidrectional: bool = True, vis_threshold=None, ) -> Tuple[torch.Tensor, torch.Tensor]: from utils.inference_utils import _inference_with_grid from einops import repeat _depths = depths.clone() _depths = _depths[_depths > 0].reshape(-1) q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values iqr = q75 - q25 _depth_roi = torch.tensor( [1e-7, (q75 + 1.5 * iqr).item()], dtype=torch.float32, device=video.device ) T, C, H, W = video.shape assert depths.shape == (T, H, W) N = query_point.shape[0] model.set_image_size((H, W)) preds, _ = _inference_with_grid( model=model, video=video[None], depths=depths[None], intrinsics=intrinsics[None], extrinsics=extrinsics[None], query_point=query_point[None], num_iters=num_iters, depth_roi=_depth_roi, grid_size=grid_size ) if bidrectional and not model.bidirectional and (query_point[..., 0] > 0).any(): preds_backward, _ = _inference_with_grid( model=model, video=video[None].flip(dims=(1,)), depths=depths[None].flip(dims=(1,)), intrinsics=intrinsics[None].flip(dims=(1,)), extrinsics=extrinsics[None].flip(dims=(1,)), query_point=torch.cat([T - 1 - query_point[..., :1], query_point[..., 1:]], dim=-1)[None], num_iters=num_iters, depth_roi=_depth_roi, grid_size=grid_size, ) preds.coords = torch.where( repeat(torch.arange(T, device=video.device), 't -> b t n 3', b=1, n=N) < repeat( query_point[..., 0][None], 'b n -> b t n 3', t=T, n=N), preds_backward.coords.flip(dims=(1,)), preds.coords ) preds.visibs = torch.where( repeat(torch.arange(T, device=video.device), 't -> b t n', b=1, n=N) < repeat( query_point[..., 0][None], 'b n -> b t n', t=T, n=N), preds_backward.visibs.flip(dims=(1,)), preds.visibs ) coords, visib_logits = preds.coords, preds.visibs visibs = torch.sigmoid(visib_logits) if vis_threshold is not None: visibs = visibs >= vis_threshold return coords.squeeze(), visibs.squeeze() class MonocularToMultiViewAdapter(nn.Module): def __init__(self, model, **kwargs): super(MonocularToMultiViewAdapter, self).__init__() self.model = model def forward( self, rgbs, depths, query_points, intrs, extrs, save_debug_logs=False, debug_logs_path="", query_points_view=None, **kwargs, ): batch_size, num_views, num_frames, _, height, width = rgbs.shape _, num_points, _ = query_points.shape assert rgbs.shape == (batch_size, num_views, num_frames, 3, height, width) assert depths.shape == (batch_size, num_views, num_frames, 1, height, width) assert query_points.shape == (batch_size, num_points, 4) assert intrs.shape == (batch_size, num_views, num_frames, 3, 3) assert extrs.shape == (batch_size, num_views, num_frames, 3, 4) # Project the queries to each view query_points_t = query_points[:, :, :1].long() query_points_xyz_worldspace = query_points[:, :, 1:] query_points_xy_pixelspace_per_view = query_points.new_zeros((batch_size, num_views, num_points, 2)) query_points_z_cameraspace_per_view = query_points.new_zeros((batch_size, num_views, num_points, 1)) for batch_idx in range(batch_size): for t in query_points_t[batch_idx].unique(): query_points_t_mask = query_points_t[batch_idx].squeeze(-1) == t point_3d_world = query_points_xyz_worldspace[batch_idx][query_points_t_mask] # World to camera space point_4d_world_homo = torch.cat( [point_3d_world, point_3d_world.new_ones(point_3d_world[..., :1].shape)], -1) point_3d_camera = torch.einsum('Aij,Bj->ABi', extrs[batch_idx, :, t, :, :], point_4d_world_homo[:, :]) # Camera to pixel space point_2d_pixel_homo = torch.einsum('Aij,ABj->ABi', intrs[batch_idx, :, t, :, :], point_3d_camera[:, :]) point_2d_pixel = point_2d_pixel_homo[..., :2] / point_2d_pixel_homo[..., 2:] query_points_xy_pixelspace_per_view[batch_idx, :, query_points_t_mask] = point_2d_pixel query_points_z_cameraspace_per_view[batch_idx, :, query_points_t_mask] = point_3d_camera[..., -1:] # Estimate occlusion mask in each view based on depth maps query_points_depth_in_view = query_points.new_zeros((batch_size, num_views, num_points, 1)) for batch_idx in range(batch_size): for view_idx in range(num_views): for t in query_points_t[batch_idx].unique(): query_points_t_mask = query_points_t[batch_idx].squeeze(-1) == t interpolated_depth = bilinear_sample2d( im=depths[batch_idx, view_idx, t][None], x=query_points_xy_pixelspace_per_view[batch_idx, view_idx, query_points_t_mask, 0][None], y=query_points_xy_pixelspace_per_view[batch_idx, view_idx, query_points_t_mask, 1][None], )[0].permute(1, 0).type(query_points.dtype) query_points_depth_in_view[batch_idx, view_idx, query_points_t_mask] = interpolated_depth query_points_depth_in_view_masked = query_points_depth_in_view.clone() query_points_outside_of_view_box = ( (query_points_xy_pixelspace_per_view[..., 0] < 0) | (query_points_xy_pixelspace_per_view[..., 0] >= width) | (query_points_xy_pixelspace_per_view[..., 1] < 0) | (query_points_xy_pixelspace_per_view[..., 1] >= height) | (query_points_z_cameraspace_per_view[..., 0] < 0) ) if query_points_outside_of_view_box.all(1).any(): warnings.warn(f"There are some query points that are outside of the frame of every view: " f"{query_points_xy_pixelspace_per_view[query_points_outside_of_view_box.all(1)[:, None, :].repeat(1, num_views, 1)].reshape(num_views, -1, 2).permute(1, 0, 2)}") query_points_depth_in_view_masked[query_points_outside_of_view_box] = -1e4 query_points_best_visibility_view = ( query_points_depth_in_view_masked - query_points_z_cameraspace_per_view).argmax(1) query_points_best_visibility_view = query_points_best_visibility_view.squeeze(-1) if query_points_view is not None: query_points_best_visibility_view = query_points_view logging.info(f"Using the provided query_points_view instead of the estimated one") assert batch_size == 1, "Batch size > 1 is not supported yet" batch_idx = 0 # Call the 2D tracker for each view traj_e_per_view = {} vis_e_per_view = {} for view_idx in range(num_views): track_mask = query_points_best_visibility_view[batch_idx] == view_idx if track_mask.sum() == 0: continue view_rgbs = rgbs[batch_idx, view_idx] view_depths = depths[batch_idx, view_idx] view_intrs = intrs[batch_idx, view_idx] view_extrs = extrs[batch_idx, view_idx] view_query_points = torch.concat([ query_points_t[batch_idx, :, :][track_mask], query_points_xy_pixelspace_per_view[batch_idx, view_idx, :, :][track_mask], ], dim=-1) view_query_points_with_z = torch.concat([ query_points_t[batch_idx, :, :][track_mask], query_points_xy_pixelspace_per_view[batch_idx, view_idx, :, :][track_mask], query_points_z_cameraspace_per_view[batch_idx, view_idx, :][track_mask], ], dim=-1) view_query_points_xyz_worldspace = torch.concat([ query_points_t[batch_idx, :, :][track_mask], query_points_xyz_worldspace[batch_idx, :][track_mask], ], dim=-1) results = self.model( rgbs=view_rgbs, depths=view_depths, intrs=view_intrs, extrs=view_extrs, queries=view_query_points, queries_with_z=view_query_points_with_z, queries_xyz_worldspace=view_query_points_xyz_worldspace, ) view_traj_e = results["traj_2d"] view_vis_e = results["vis"] if save_debug_logs and view_traj_e is not None: visualizer = Visualizer( save_dir=debug_logs_path, pad_value=16, fps=12, show_first_frame=0, tracks_leave_trace=3, ) visualizer.visualize( video=view_rgbs[None].cpu(), tracks=view_traj_e[None].cpu(), visibility=view_vis_e[None].cpu(), filename=f"view_{view_idx}.mp4", query_frame=query_points_t[batch_idx, :, 0][track_mask][None], save_video=True, ) # Project the trajectories to the world space if "traj_3d_worldspace" in results: view_traj_e = results["traj_3d_worldspace"] else: if "traj_z" in results: view_camera_z = results["traj_z"] else: view_camera_z = bilinear_sampler(view_depths, view_traj_e.reshape(num_frames, -1, 1, 2))[:, 0, :, :] view_intrs = intrs[batch_idx, view_idx] view_extrs = extrs[batch_idx, view_idx] intrs_inv = torch.inverse(view_intrs.float()) view_extrs_square = torch.eye(4).to(view_extrs.device)[None].repeat(num_frames, 1, 1) view_extrs_square[:, :3, :] = view_extrs extrs_inv = torch.inverse(view_extrs_square.float()) view_traj_e = pixel_xy_and_camera_z_to_world_space( pixel_xy=view_traj_e[..., :].float(), camera_z=view_camera_z.float(), intrs_inv=intrs_inv, extrs_inv=extrs_inv, ) # Set the trajectory to (0,0,0) for the timesteps before the query timestep for point_idx, t in enumerate(query_points_t[batch_idx, :, :].squeeze(-1)[track_mask]): view_traj_e[:t, point_idx, :] = 0.0 traj_e_per_view[view_idx] = view_traj_e[None] vis_e_per_view[view_idx] = view_vis_e[None] # Merging the results from all views views_to_keep = list(traj_e_per_view.keys()) traj_e = torch.cat([traj_e_per_view[view_idx] for view_idx in views_to_keep], dim=2) vis_e = torch.cat([vis_e_per_view[view_idx] for view_idx in views_to_keep], dim=2) # Sort the traj_e and vis_e based on the original indices, since concatenating the results from all views # will first put the results from the first view, then the results from the second view, and so on. # But we want to keep the trajectories order to match the original query points order. sort_inds = [] for view_idx in views_to_keep: track_mask = query_points_best_visibility_view[batch_idx] == view_idx if track_mask.sum() == 0: continue global_indices = torch.nonzero(track_mask).squeeze(-1) sort_inds += [global_indices] sort_inds = torch.cat(sort_inds, dim=0) inv_sort_inds = torch.argsort(sort_inds, dim=0) # Use the inv_sort_inds to sort the traj_e and vis_e traj_e = traj_e[:, :, inv_sort_inds] vis_e = vis_e[:, :, inv_sort_inds] # Save to results results = {"traj_e": traj_e, "vis_e": vis_e} return results # From https://github.com/facebookresearch/co-tracker/blob/82e02e8029753ad4ef13cf06be7f4fc5facdda4d/cotracker/models/core/model_utils.py#L286 def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): r"""Sample a tensor using bilinear interpolation `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at coordinates :attr:`coords` using bilinear interpolation. It is the same as `torch.nn.functional.grid_sample()` but with a different coordinate convention. The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where :math:`B` is the batch size, :math:`C` is the number of channels, :math:`H` is the height of the image, and :math:`W` is the width of the image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note that in this case the order of the components is slightly different from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. If `align_corners` is `True`, the coordinate :math:`x` is assumed to be in the range :math:`[0,W-1]`, with 0 corresponding to the center of the left-most image pixel :math:`W-1` to the center of the right-most pixel. If `align_corners` is `False`, the coordinate :math:`x` is assumed to be in the range :math:`[0,W]`, with 0 corresponding to the left edge of the left-most pixel :math:`W` to the right edge of the right-most pixel. Similar conventions apply to the :math:`y` for the range :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range :math:`[0,T-1]` and :math:`[0,T]`. Args: input (Tensor): batch of input images. coords (Tensor): batch of coordinates. align_corners (bool, optional): Coordinate convention. Defaults to `True`. padding_mode (str, optional): Padding mode. Defaults to `"border"`. Returns: Tensor: sampled points. """ sizes = input.shape[2:] assert len(sizes) in [2, 3] if len(sizes) == 3: # t x y -> x y t to match dimensions T H W in grid_sample coords = coords[..., [1, 2, 0]] if align_corners: coords = coords * torch.tensor( [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device ) else: coords = coords * torch.tensor( [2 / size for size in reversed(sizes)], device=coords.device ) coords -= 1 return F.grid_sample( input, coords, align_corners=align_corners, padding_mode=padding_mode ) ================================================ FILE: mvtracker/models/core/mvtracker/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: mvtracker/models/core/mvtracker/mvtracker.py ================================================ import logging import os from collections import defaultdict from typing import Optional, Callable import numpy as np import pandas as pd import torch from einops import rearrange from torch import nn as nn from mvtracker.datasets.utils import transform_scene from mvtracker.models.core.cotracker2.blocks import Attention, FlashAttention from mvtracker.models.core.cotracker2.blocks import EfficientUpdateFormer from mvtracker.models.core.embeddings import ( get_3d_sincos_pos_embed_from_grid, get_1d_sincos_pos_embed_from_grid, get_3d_embedding, ) from mvtracker.models.core.model_utils import smart_cat, init_pointcloud_from_rgbd, save_pointcloud_to_ply from mvtracker.models.core.spatracker.blocks import BasicEncoder from mvtracker.utils.basic import time_now # ---------- KNN backends ---------- def _knn_pointops(k: int, xyz_ref: torch.Tensor, xyz_query: torch.Tensor): """ Efficient batched KNN using pointops library. This is slightly faster than torch.cdist + torch.topk and uses less memory: Example:: Benchmarking KNN with different methods (HALF_PRECISION=True): torch.cdist+torch.topk | Avg Time: 0.008380 s | Peak Memory: 1151.19 MB (min: 1151.19, max: 1151.19) pointops.knn_query | Avg Time: 0.007477 s | Peak Memory: 47.22 MB (min: 47.22, max: 47.22) Benchmarking KNN with different methods (HALF_PRECISION=False): torch.cdist+torch.topk | Avg Time: 0.014090 s | Peak Memory: 2249.88 MB (min: 2249.88, max: 2249.88) pointops.knn_query | Avg Time: 0.007368 s | Peak Memory: 43.62 MB (min: 43.62, max: 43.62) Args: xyz_ref (Tensor): (B, N, 3) xyz_query (Tensor): (B, M, 3) Returns: Tuple[Tensor, Tensor]: - dist (Tensor): (B, M, k) - idx (Tensor): (B, M, k) int32 — indices into dimension N """ # Fallback if tensors are not on CUDA if not xyz_ref.is_cuda: return _knn_torch(k, xyz_ref, xyz_query) from pointops import knn_query B, N, _ = xyz_ref.shape _, M, _ = xyz_query.shape orig_dtype = xyz_ref.dtype xyz_ref_flat = xyz_ref.contiguous().view(B * N, 3).to(torch.float32) xyz_query_flat = xyz_query.contiguous().view(B * M, 3).to(torch.float32) offset = torch.arange(1, B + 1, device=xyz_ref.device) * N new_offset = torch.arange(1, B + 1, device=xyz_query.device) * M idx, dists = knn_query(k, xyz_ref_flat, offset, xyz_query_flat, new_offset) # Remap global indices to local per-batch idx = idx.view(B, M, k) idx = idx - (torch.arange(B, device=idx.device).view(B, 1, 1) * N) dists = dists.view(B, M, k).to(orig_dtype) return dists, idx def _knn_torch(k: int, xyz_ref: torch.Tensor, xyz_query: torch.Tensor): """Fallback KNN using torch.cdist + topk.""" dists = torch.cdist(xyz_query, xyz_ref, p=2) # (B, M, N) sorted_dists, indices = torch.topk(dists, k, dim=-1, largest=False, sorted=True) return sorted_dists, indices # Select backend once (safe if pointops missing). try: import importlib importlib.import_module("pointops") knn = _knn_pointops except Exception: logging.warning("pointops not found, falling back to slower KNN implementation.") knn = _knn_torch class MVTracker(nn.Module): def __init__( self, sliding_window_len=12, stride=4, normalize_scene_in_fwd_pass=False, fmaps_dim=128, add_space_attn=True, num_heads=6, hidden_size=384, space_depth=6, time_depth=6, num_virtual_tracks=64, use_flash_attention=True, corr_n_groups=1, corr_n_levels=4, corr_neighbors=16, corr_add_neighbor_offset=True, corr_add_neighbor_xyz=False, corr_filter_invalid_depth=False, ): super().__init__() self.S = sliding_window_len self.stride = stride self.normalize_scene_in_fwd_pass = normalize_scene_in_fwd_pass self.latent_dim = fmaps_dim self.flow_embed_dim = 64 self.b_latent_dim = self.latent_dim // 3 self.corr_n_groups = corr_n_groups self.corr_n_levels = corr_n_levels self.corr_neighbors = corr_neighbors self.corr_pos_emb_size = 0 self.corr_add_neighbor_offset = corr_add_neighbor_offset self.corr_add_neighbor_xyz = corr_add_neighbor_xyz self.corr_filter_invalid_depth = corr_filter_invalid_depth self.add_space_attn = add_space_attn self.updateformer_input_dim = ( # The positional encoding of the 3D flow from t=i to t=0 + (self.flow_embed_dim + 1) * 3 # The correlation features (LRR) for the three planes (xy, yz, xz), concatenated + self.corr_neighbors * self.corr_n_levels * (self.corr_n_groups + 3 * self.corr_add_neighbor_offset + 3 * self.corr_add_neighbor_xyz + self.corr_pos_emb_size) # The features of the tracked points, one for each of the three planes + self.latent_dim # The visibility mask + 1 # The whether-the-point-is-tracked mask + 1 ) # Feature encoder self.fnet = BasicEncoder( input_dim=3, output_dim=self.latent_dim, norm_fn="instance", dropout=0, stride=self.stride, Embed3D=False, ) # Transformer for iterative updates self.updateformer_hidden_size = hidden_size self.updateformer = EfficientUpdateFormer( space_depth=space_depth, time_depth=time_depth, input_dim=self.updateformer_input_dim, hidden_size=hidden_size, num_heads=num_heads, output_dim=3 + self.latent_dim, mlp_ratio=4.0, add_space_attn=add_space_attn, num_virtual_tracks=num_virtual_tracks, attn_class=FlashAttention if use_flash_attention else Attention, linear_layer_for_vis_conf=False, ) # Feature update + visibility self.ffeats_norm = nn.GroupNorm(1, self.latent_dim) self.ffeats_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) self.stats_pyramid = None self.stats_depth = None def fnet_fwd(self, rgbs_normalized, image_features=None): b, v, t, _, h, w = rgbs_normalized.shape rgbs_normalized = rgbs_normalized.reshape(-1, 3, h, w) return self.fnet(rgbs_normalized) def init_stats(self): self.stats_pyramid = defaultdict(list) self.stats_depth = [] def consume_stats(self): # Per-pyramid-level summary of neighbor distances level_to_norms = defaultdict(list) for (level, _), norm_lists in self.stats_pyramid.items(): level_to_norms[level].extend(norm_lists) level_summary = [] for level, norm_lists in level_to_norms.items(): norms = np.concatenate(norm_lists).astype(float) stats = pd.Series(norms).describe(percentiles=[.25, .5, .75]) level_summary.append({ "level": level, "count": int(stats["count"]), "mean": round(float(stats["mean"] * 100), 1), "std": round(float(stats["std"] * 100), 1), "min": round(float(stats["min"] * 100), 1), "25%": round(float(stats["25%"] * 100), 1), "50%": round(float(stats["50%"] * 100), 1), "75%": round(float(stats["75%"] * 100), 1), "max": round(float(stats["max"] * 100), 1), }) df_level_summary = pd.DataFrame(level_summary).sort_values("level") logging.info(f"Neighbor distances across pyramid levels:\n{df_level_summary}") # Per-pyramid-level and per-iteration summary of neighbor distances summary = [] for (level, it), norm_lists in self.stats_pyramid.items(): norms = np.concatenate(norm_lists).astype(float) stats = pd.Series(norms).describe(percentiles=[.25, .5, .75]) summary.append({ "level": level, "iteration": it, "count": int(stats["count"]), "mean": round(float(stats["mean"] * 100), 1), "std": round(float(stats["std"] * 100), 1), "min": round(float(stats["min"] * 100), 1), "25%": round(float(stats["25%"] * 100), 1), "50%": round(float(stats["50%"] * 100), 1), "75%": round(float(stats["75%"] * 100), 1), "max": round(float(stats["max"] * 100), 1), }) df_summary = pd.DataFrame(summary).sort_values(["level", "iteration"]) logging.info(f"Neighbor distances across pyramid levels and iterations (in cm):\n{(df_summary)}") # Valid vs invalid depth stats depth_stats = pd.Series(self.stats_depth).describe(percentiles=[.25, .5, .75]).astype(float).round(1) logging.info(f"Depth stats (valid vs invalid):\n{depth_stats}") self.stats_pyramid = None self.stats_depth = None def forward_iteration( self, fmaps, depths, intrs, extrs, coords_init, vis_init, track_mask, iters=4, feat_init=None, save_debug_logs=False, debug_logs_path="", debug_logs_prefix="", debug_logs_window_idx=None, save_rerun_logs: bool = False, rerun_fmap_coloring_fn: Optional[Callable] = None, ): B, V, S, D, H, W = fmaps.shape N = coords_init.shape[2] device = fmaps.device if coords_init.shape[1] < S: coords = torch.cat([coords_init, coords_init[:, -1].repeat(1, S - coords_init.shape[1], 1, 1)], dim=1) vis_init = torch.cat([vis_init, vis_init[:, -1].repeat(1, S - vis_init.shape[1], 1, 1)], dim=1) else: coords = coords_init.clone() if track_mask.shape[1] < S: track_mask = torch.cat([ track_mask, torch.zeros_like(track_mask[:, 0]).repeat(1, S - track_mask.shape[1], 1, 1), ], dim=1) assert B == 1 assert D == self.latent_dim assert fmaps.shape == (B, V, S, D, H, W) assert depths.shape == (B, V, S, 1, H, W) assert intrs.shape == (B, V, S, 3, 3) assert extrs.shape == (B, V, S, 3, 4) assert coords.shape == (B, S, N, 3) assert vis_init.shape == (B, S, N, 1) assert track_mask.shape == (B, S, N, 1) assert feat_init is None or feat_init.shape == (B, S, N, self.latent_dim) assert track_mask.any(1).all(), "All points should be requested to be tracked at least for one frame" intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype) extrs_square = torch.eye(4).to(extrs.device)[None].repeat(B, V, S, 1, 1) extrs_square[:, :, :, :3, :] = extrs extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype) assert intrs_inv.shape == (B, V, S, 3, 3) assert extrs_square.shape == (B, V, S, 4, 4) assert extrs_inv.shape == (B, V, S, 4, 4) fcorr_fns = {} for lvl in range(self.corr_n_levels): pc = init_pointcloud_from_rgbd( fmaps=fmaps, depths=depths, intrs=intrs, extrs=extrs, stride=self.stride, level=lvl, return_validity_mask=self.corr_filter_invalid_depth or save_rerun_logs, ) if self.corr_filter_invalid_depth or save_rerun_logs: pc_xyz, pc_fvec, pc_valid = pc else: pc_xyz, pc_fvec = pc pc_valid = None fcorr_fns[lvl] = PointcloudCorrBlock( k=self.corr_neighbors, groups=self.corr_n_groups, xyz=pc_xyz, fvec=pc_fvec, filter_invalid=self.corr_filter_invalid_depth, valid=pc_valid, corr_add_neighbor_offset=self.corr_add_neighbor_offset, corr_add_neighbor_xyz=self.corr_add_neighbor_xyz, rerun_fmap_coloring_fn=rerun_fmap_coloring_fn, ) # Positional/time embeddings (keep shapes identical to before) embed_dim = self.updateformer_input_dim if embed_dim % 6 != 0: embed_dim += 6 - (embed_dim % 6) pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, coords[:, 0:1]).float()[:, 0].permute(0, 2, 1) if embed_dim > self.updateformer_input_dim: pos_embed = pos_embed[:, :self.updateformer_input_dim, :] pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1) times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1) / S embed_dim = self.updateformer_input_dim if embed_dim % 2 != 0: embed_dim += 2 - (embed_dim % 2) times_embed = ( torch.from_numpy(get_1d_sincos_pos_embed_from_grid(embed_dim, times_[0]))[None] .repeat(B, 1, 1) .float() .to(device) ) if embed_dim > self.updateformer_input_dim: times_embed = times_embed[:, :, :self.updateformer_input_dim] coord_predictions = [] ffeats = feat_init.clone() track_mask_and_vis = torch.cat([track_mask, vis_init], dim=3).permute(0, 2, 1, 3).reshape(B * N, S, 2) for it in range(iters): coords = coords.detach() # Sample correlation features around each point fcorrs = [] for lvl in range(self.corr_n_levels): fcorr_fn = fcorr_fns[lvl] fcorrs_level = ( fcorr_fn .corr_sample( targets=ffeats.reshape(B * S, N, self.latent_dim), coords_world_xyz=coords.reshape(B * S, N, 3), save_debug_logs=False, debug_logs_path=debug_logs_path, debug_logs_prefix=debug_logs_prefix + f"__iter_{it}__pyramid_level_{lvl}", save_rerun_logs=save_rerun_logs, ) .reshape(B, S, N, -1) ) fcorrs.append(fcorrs_level) if self.stats_pyramid is not None: self.stats_pyramid[(lvl, it)] += [ np.linalg.norm(fcorrs_level.reshape(-1, 4)[:, 1:].detach().cpu().numpy(), axis=-1) ] fcorrs = torch.cat(fcorrs, dim=-1) LRR = fcorrs.shape[3] fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR) # Flow embedding flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 3) flows_ = get_3d_embedding(flows_, self.flow_embed_dim, cat_coords=True) ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) transformer_input = torch.cat([flows_, fcorrs_, ffeats_, track_mask_and_vis], dim=2) assert transformer_input.shape[-1] == pos_embed.shape[-1] x = transformer_input + pos_embed + times_embed x = rearrange(x, "(b n) t d -> b n t d", b=B) delta = self.updateformer(x) delta = rearrange(delta, " b n t d -> (b n) t d") d_coord = delta[:, :, :3].reshape(B, N, S, 3).permute(0, 2, 1, 3) d_feats = delta[:, :, 3:self.latent_dim + 3] d_feats = self.ffeats_norm(d_feats.view(-1, self.latent_dim)) d_feats = self.ffeats_updater(d_feats).view(B, N, S, self.latent_dim).permute(0, 2, 1, 3) coords = coords + d_coord ffeats = ffeats + d_feats if torch.isnan(coords).any(): logging.error("Got NaN values in coords, perhaps the training exploded") import ipdb ipdb.set_trace() coord_predictions.append(coords.clone()) vis_e = self.vis_predictor(ffeats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) return coord_predictions, vis_e, feat_init def forward( self, rgbs, depths, query_points, intrs, extrs, iters=4, image_features=None, is_train=False, save_debug_logs=False, debug_logs_path="", save_rerun_logs: bool = False, save_rerun_logs_output_rrd_path: Optional[str] = None, **kwargs, ): device = extrs.device if save_debug_logs: if kwargs: logging.info(f"Unused kwargs: {kwargs.keys()}") batch_size, num_views, num_frames, _, height, width = rgbs.shape _, num_points, _ = query_points.shape logging.info(f"FWD pass: {num_views=} {num_frames=} {num_points=} " f"{height=} {width=} {iters=} {num_points=} {rgbs.dtype=}") # I made a video tutorial here if it is easier to follow: https://www.youtube.com/watch?v=dQw4w9WgXcQ assert rgbs.shape == (batch_size, num_views, num_frames, 3, height, width) assert depths.shape == (batch_size, num_views, num_frames, 1, height, width) assert query_points.shape == (batch_size, num_points, 4) assert intrs.shape == (batch_size, num_views, num_frames, 3, 3) assert extrs.shape == (batch_size, num_views, num_frames, 3, 4) if save_debug_logs: os.makedirs(debug_logs_path, exist_ok=True) if save_rerun_logs: assert save_rerun_logs_output_rrd_path is not None import rerun as rr rr.init("3dpt", recording_id="v0.16") rr.set_time_seconds("frame", 0) if self.stats_depth is not None: self.stats_depth += [(depths == 0).float().mean().item() * 100] # Scene normalization (optional): Rigid transformation to center first camera and rescale the scene like VGGT qp_range_before = np.stack([ query_points[0, :, 1:].min(0).values.cpu().numpy().round(2), query_points[0, :, 1:].max(0).values.cpu().numpy().round(2), ]) if self.normalize_scene_in_fwd_pass: assert batch_size == 1, "VGGT normalization assumes batch size 1" max_depth = 24 _d = depths.clone() _d[_d < max_depth] = max_depth T_scale, T_rot, T_translation = compute_vggt_scene_normalization_transform( _d[0], extrs[0].to(_d.device), intrs[0].to(_d.device) ) T_scale_inv = 1 / T_scale T_rot_inv = T_rot.transpose(0, 1) T_translation_inv = -T_translation @ T_rot_inv query_points, extrs = query_points[0], extrs[0] # Remove batch dimension extrs, query_points, _, _ = transform_scene(T, extrs, query_points, None, None) query_points, extrs = query_points[None], extrs[None] # Add batch dimension qp_range_after = np.stack([ query_points[0, :, 1:].min(0).values.cpu().numpy().round(2), query_points[0, :, 1:].max(0).values.cpu().numpy().round(2), ]) if save_debug_logs: logging.info(f"Query points range before normalization:\n{qp_range_before}") logging.info(f"Query points range after normalization: \n{qp_range_after}") self.is_train = is_train # Unpack the query points query_points_t = query_points[:, :, :1].long() query_points_xyz_worldspace = query_points[:, :, 1:] # Invert intrinsics and extrinsics intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype) extrs_square = torch.eye(4).to(extrs.device)[None].repeat(batch_size, num_views, num_frames, 1, 1) extrs_square[:, :, :, :3, :] = extrs extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype) # Interpolate the rgbs and depthmaps to the stride of the SpaTracker strided_height = height // self.stride strided_width = width // self.stride # Filter the points that never appear during 1 - T assert batch_size == 1, "Batch size > 1 is not supported yet" query_points_t = query_points_t.squeeze(0).squeeze(-1) # BN1 --> N ind_array = torch.arange(num_frames, device=query_points.device) ind_array = ind_array[None, :, None].repeat(batch_size, 1, num_points) track_mask = (ind_array >= query_points_t[None, None, :]).unsqueeze(-1) # TODO: >= or >? # Prepare the initial coordinates and visibility coords_init = query_points_xyz_worldspace.unsqueeze(1).repeat(1, self.S, 1, 1) vis_init = query_points.new_ones((batch_size, self.S, num_points, 1)) * 10 # Sort the queries via their first appeared time _, sort_inds = torch.sort(query_points_t, dim=0, descending=False) inv_sort_inds = torch.argsort(sort_inds, dim=0) assert torch.allclose(query_points_t, query_points_t[sort_inds][inv_sort_inds]) query_points_t_ = query_points_t[sort_inds] query_points_xyz_worldspace_ = query_points_xyz_worldspace[..., sort_inds, :] coords_init_ = coords_init[..., sort_inds, :].clone() vis_init_ = vis_init[:, :, sort_inds].clone() track_mask_ = track_mask[:, :, sort_inds].clone() # Delete the unsorted variables (for safety) del coords_init, vis_init, query_points_t, query_points, query_points_xyz_worldspace, track_mask # Placeholders for the results (for the sorted points) traj_e_ = coords_init_.new_zeros((batch_size, num_frames, num_points, 3)) vis_e_ = coords_init_.new_zeros((batch_size, num_frames, num_points)) w_idx_start = query_points_t_.min() p_idx_start = 0 vis_predictions = [] coord_predictions = [] p_idx_end_list = [] fmaps_seq, depths_seq, feat_init, rerun_fmap_coloring_fn = None, None, None, None while w_idx_start < num_frames - self.S // 2: curr_wind_points = torch.nonzero(query_points_t_ < w_idx_start + self.S) assert curr_wind_points.shape[0] > 0 p_idx_end = curr_wind_points[-1].item() + 1 p_idx_end_list.append(p_idx_end) intrs_seq = intrs[:, :, w_idx_start:w_idx_start + self.S] extrs_seq = extrs[:, :, w_idx_start:w_idx_start + self.S] # Compute fmaps and interpolated depth on a rolling basis # to reduce peak GPU memory consumption, but don't recompute # for the overlapping part of a window if fmaps_seq is None: assert depths_seq is None new_seq_t0 = w_idx_start else: fmaps_seq = fmaps_seq[:, :, self.S // 2:] depths_seq = depths_seq[:, :, self.S // 2:] new_seq_t0 = w_idx_start + self.S // 2 new_seq_t1 = w_idx_start + self.S _depths_seq_new = nn.functional.interpolate( input=depths[:, :, new_seq_t0:new_seq_t1].to(device).reshape(-1, 1, height, width), scale_factor=1.0 / self.stride, mode="nearest", ).reshape(batch_size, num_views, -1, 1, strided_height, strided_width) depths_seq = smart_cat(depths_seq, _depths_seq_new, dim=2) _fmaps_seq_new = self.fnet_fwd( (2 * (rgbs[:, :, new_seq_t0: new_seq_t1].to(device) / 255.0) - 1.0), image_features, ) _fmaps_seq_new = nn.functional.interpolate( input=_fmaps_seq_new, size=(strided_height, strided_width), mode="bilinear", ).reshape(batch_size, num_views, -1, self.latent_dim, strided_height, strided_width) fmaps_seq = smart_cat(fmaps_seq, _fmaps_seq_new, dim=2) if save_rerun_logs and rerun_fmap_coloring_fn is None: valid_depths_mask = depths_seq.detach().cpu().squeeze(3) > 0 fvec_flat = fmaps_seq.detach().cpu().permute(0, 1, 2, 4, 5, 3)[valid_depths_mask].numpy() from sklearn.decomposition import PCA reducer = PCA(n_components=3) reducer.fit(fvec_flat) fvec_reduced = reducer.transform(fvec_flat) reducer_min = fvec_reduced.min(axis=0) reducer_max = fvec_reduced.max(axis=0) def fvec_to_rgb(fvec): input_shape = fvec.shape assert input_shape[-1] == self.latent_dim fvec_reduced = reducer.transform(fvec.reshape(-1, self.latent_dim)) fvec_reduced = np.clip(fvec_reduced, reducer_min[None, :], reducer_max[None, :]) fvec_reduced_rescaled = (fvec_reduced - reducer_min) / (reducer_max - reducer_min) fvec_reduced_rgb = (fvec_reduced_rescaled * 255).astype(int) fvec_reduced_rgb = fvec_reduced_rgb.reshape(input_shape[:-1] + (3,)) return fvec_reduced_rgb rerun_fmap_coloring_fn = fvec_to_rgb S_local = fmaps_seq.shape[2] if S_local < self.S: diff = self.S - S_local fmaps_seq = torch.cat([fmaps_seq, fmaps_seq[:, :, -1:].repeat(1, 1, diff, 1, 1, 1)], 2) depths_seq = torch.cat([depths_seq, depths_seq[:, :, -1:].repeat(1, 1, diff, 1, 1, 1)], 2) intrs_seq = torch.cat([intrs_seq, intrs_seq[:, :, -1:].repeat(1, 1, diff, 1, 1)], 2) extrs_seq = torch.cat([extrs_seq, extrs_seq[:, :, -1:].repeat(1, 1, diff, 1, 1)], 2) # Compute the feature vector initialization for the new query points if p_idx_end - p_idx_start > 0: rgbd_xyz, rgbd_fvec = init_pointcloud_from_rgbd( fmaps=_fmaps_seq_new, depths=_depths_seq_new, intrs=intrs[:, :, new_seq_t0:new_seq_t1], extrs=extrs[:, :, new_seq_t0:new_seq_t1], stride=self.stride, ) new_num_frames = _fmaps_seq_new.shape[2] rgbd_xyz = rgbd_xyz.reshape(batch_size, new_num_frames, num_views, strided_height * strided_width, 3) rgbd_fvec = rgbd_fvec.reshape(batch_size, new_num_frames, num_views, strided_height * strided_width, self.latent_dim) _feat_init_new = torch.zeros(batch_size, p_idx_end - p_idx_start, self.latent_dim, device=_fmaps_seq_new.device, dtype=_fmaps_seq_new.dtype) assert batch_size == 1 assert ((query_points_t_[p_idx_start:p_idx_end] > new_seq_t0) | (query_points_t_[p_idx_start:p_idx_end] < new_seq_t1)).all() batch_idx = 0 for t in range(new_seq_t0, new_seq_t1): query_mask = query_points_t_[p_idx_start:p_idx_end] == t if query_mask.sum() == 0: continue query_points_world = query_points_xyz_worldspace_[batch_idx, p_idx_start:p_idx_end][query_mask] rgbd_xyz_current = rgbd_xyz[batch_idx, t - new_seq_t0].reshape(-1, 3) # Combine views for frame rgbd_fvec_current = rgbd_fvec[batch_idx, t - new_seq_t0].reshape(-1, self.latent_dim) k = 1 neighbor_dists, neighbor_indices = knn(k, rgbd_xyz_current[None], query_points_world[None]) assert k == 1, "If k > 1, the code below should be modified to handle multiple neighbors -- how to combine the features of multiple neighbors?" neighbor_xyz = rgbd_xyz_current[neighbor_indices[0, :, 0]] neighbor_fvec = rgbd_fvec_current[neighbor_indices[0, :, 0]] _feat_init_new[batch_idx, query_mask] = neighbor_fvec feat_init = smart_cat(feat_init, _feat_init_new.repeat(1, self.S, 1, 1), dim=2) # Update the initial coordinates and visibility for non-first windows if p_idx_start > 0: last_coords = coords[-1][:, self.S // 2:].clone() # Take the predicted coords from the last window coords_init_[:, : self.S // 2, :p_idx_start] = last_coords coords_init_[:, self.S // 2:, :p_idx_start] = last_coords[:, -1].repeat(1, self.S // 2, 1, 1) last_vis = vis[:, self.S // 2:][..., None] vis_init_[:, : self.S // 2, :p_idx_start] = last_vis vis_init_[:, self.S // 2:, :p_idx_start] = last_vis[:, -1].repeat(1, self.S // 2, 1, 1) track_mask_current = track_mask_[:, w_idx_start: w_idx_start + self.S, :p_idx_end] if S_local < self.S: track_mask_current = torch.cat([ track_mask_current, track_mask_current[:, -1:].repeat(1, self.S - S_local, 1, 1), ], 1) coords, vis, _ = self.forward_iteration( fmaps=fmaps_seq, depths=depths_seq, intrs=intrs_seq, extrs=extrs_seq, coords_init=coords_init_[:, :, :p_idx_end], feat_init=feat_init[:, :, :p_idx_end], vis_init=vis_init_[:, :, :p_idx_end], track_mask=track_mask_current, iters=iters, save_debug_logs=save_debug_logs, debug_logs_path=debug_logs_path, debug_logs_prefix=f"__widx-{w_idx_start}_pidx-{p_idx_start}-{p_idx_end}", debug_logs_window_idx=w_idx_start, save_rerun_logs=save_rerun_logs, rerun_fmap_coloring_fn=rerun_fmap_coloring_fn, ) if is_train: coord_predictions.append([ coord[:, :S_local] if not self.normalize_scene_in_fwd_pass else transform_scene(T_scale_inv, T_rot_inv, T_translation_inv, None, None, None, coord[:, :S_local][0], None)[2][None] for coord in coords ]) vis_predictions.append(vis[:, :S_local]) traj_e_[:, w_idx_start:w_idx_start + self.S, :p_idx_end] = coords[-1][:, :S_local] vis_e_[:, w_idx_start:w_idx_start + self.S, :p_idx_end] = torch.sigmoid(vis[:, :S_local]) track_mask_[:, : w_idx_start + self.S, :p_idx_end] = 0.0 w_idx_start = w_idx_start + self.S // 2 p_idx_start = p_idx_end if save_debug_logs: import gpustat torch.cuda.empty_cache() logging.info(f"Forward pass GPU usage: {gpustat.new_query()}") if save_rerun_logs: import rerun as rr rr.save(save_rerun_logs_output_rrd_path) logging.info(f"Saved Rerun recording to: {os.path.abspath(save_rerun_logs_output_rrd_path)}.") traj_e = traj_e_[:, :, inv_sort_inds] vis_e = vis_e_[:, :, inv_sort_inds] # Un-normalize the scene if self.normalize_scene_in_fwd_pass: traj_e = transform_scene(T_scale_inv, T_rot_inv, T_translation_inv, None, None, None, traj_e[0], None)[2][None] results = { "traj_e": traj_e, "feat_init": feat_init, "vis_e": vis_e, } if self.is_train: results["train_data"] = { "vis_predictions": vis_predictions, "coord_predictions": coord_predictions, "attn_predictions": None, "p_idx_end_list": p_idx_end_list, "sort_inds": sort_inds, "Rigid_ln_total": None, } return results def compute_vggt_scene_normalization_transform(depths, extrs, intrs): V, T, _, H, W = depths.shape device = depths.device extrs_square = torch.eye(4, device=device)[None, None].repeat(V, T, 1, 1) extrs_square[:, :, :3, :] = extrs extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype) intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype) y, x = torch.meshgrid( torch.arange(H, device=device), torch.arange(W, device=device), indexing="ij" ) homog = torch.stack([x, y, torch.ones_like(x)], dim=-1).float().reshape(-1, 3) homog = homog[None].expand(V, -1, -1).type(depths.dtype) cam_points = torch.einsum("vij,vnj->vni", intrs_inv[:, 0], homog) * depths[:, 0].reshape(V, -1, 1) cam_points_h = torch.cat([cam_points, torch.ones_like(cam_points[..., :1])], dim=-1) world_points_h = torch.einsum("vij,vnj->vni", extrs_inv[:, 0], cam_points_h) world_points_in_first = torch.einsum("ij,vnj->vni", extrs[0, 0], world_points_h) mask = (depths[:, 0] > 0).reshape(V, -1) valid_points = world_points_in_first[mask] avg_dist = valid_points.norm(dim=1).mean() scale = 1.0 / avg_dist rot = extrs[0, 0, :3, :3] translation = extrs[0, 0, :3, 3] * scale return scale, rot, translation class PointcloudCorrBlock: def __init__( self, k: int, groups, xyz: torch.Tensor, fvec: torch.Tensor, corr_add_neighbor_offset: bool, corr_add_neighbor_xyz: bool, filter_invalid: bool = False, valid: Optional[torch.Tensor] = None, rerun_fmap_coloring_fn: Optional[Callable] = None, ): self.B, self.N, self.C = fvec.shape assert xyz.shape == (self.B, self.N, 3) assert fvec.shape == (self.B, self.N, self.C) assert k <= self.N, "k should be less than or equal to N" assert groups <= self.C, "number of correlation groups should not be larger than the number of channels" assert self.C % groups == 0, "number of channels must be divisible by the number of groups (for convenience)" assert not filter_invalid or valid is not None self.k = k self.groups = groups self.xyz = xyz self.fvec = fvec self.corr_add_neighbor_offset = corr_add_neighbor_offset self.corr_add_neighbor_xyz = corr_add_neighbor_xyz self.filter_invalid = filter_invalid self.valid = valid self.rerun_fmap_coloring_fn = rerun_fmap_coloring_fn def corr_sample( self, targets: torch.Tensor, coords_world_xyz: torch.Tensor, save_debug_logs=False, debug_logs_path=".", debug_logs_prefix="corr", save_rerun_logs=False, ): # Check inputs _, M, _ = targets.shape assert targets.shape == (self.B, M, self.C) assert coords_world_xyz.shape == (self.B, M, 3) # Find neighbors for each of the N target points if not self.filter_invalid: neighbor_dists, neighbor_indices = knn(self.k, self.xyz, coords_world_xyz) else: neighbor_dists = [] neighbor_indices = [] for xyz_i, valid_i, coords_world_xyz_i in zip(self.xyz, self.valid, coords_world_xyz): xyz_i = xyz_i[valid_i] neighbor_dists_i, neighbor_indices_i = knn(self.k, xyz_i[None], coords_world_xyz_i[None]) neighbor_dists.append(neighbor_dists_i) neighbor_indices.append(neighbor_indices_i) neighbor_dists = torch.cat(neighbor_dists) neighbor_indices = torch.cat(neighbor_indices) batch_idx = torch.arange(self.B, device=self.xyz.device)[:, None, None] neighbor_xyz = self.xyz[batch_idx, neighbor_indices] neighbor_fvec = self.fvec[batch_idx, neighbor_indices] # Compute the local correlations targets_grouped = targets.view(self.B, M, self.groups, -1) neighbor_fvec_grouped = neighbor_fvec.view(self.B, M, self.k, self.groups, -1) corrs = torch.einsum('BMGc,BMKGc->BMKG', targets_grouped, neighbor_fvec_grouped) corrs = corrs / ((self.C / self.groups) ** 0.5) output = corrs # Append the distance/direction features to the correlation neighbor_offset_in_world_xyz = neighbor_xyz - coords_world_xyz[..., None, :] if self.corr_add_neighbor_offset: output = torch.cat([corrs, neighbor_offset_in_world_xyz], -1) # Append the neighbor xyz to the correlation if self.corr_add_neighbor_xyz: output = torch.cat([output, neighbor_xyz], -1) if save_debug_logs: from sklearn.decomposition import PCA fvec_flat = self.fvec.reshape(-1, self.C).detach().cpu().numpy() reducer = PCA(n_components=3) reducer.fit(fvec_flat) fvec_reduced = reducer.transform(fvec_flat) reducer_min = fvec_reduced.min(axis=0) reducer_max = fvec_reduced.max(axis=0) def fvec_to_rgb(fvec): fvec_reduced = reducer.transform(fvec) fvec_reduced_rescaled = (fvec_reduced - reducer_min) / (reducer_max - reducer_min) fvec_reduced_rgb = (fvec_reduced_rescaled * 255).astype(int) return fvec_reduced_rgb for b in [0, self.B - 1]: # Save all points xyz = self.xyz[b].detach().cpu().numpy() xyz_colors = fvec_to_rgb(self.fvec[b].detach().cpu().numpy()) save_pointcloud_to_ply(os.path.join(debug_logs_path, f"{time_now()}{debug_logs_prefix}_all_b{b}.ply"), xyz, xyz_colors) for n in range(3): neighbors = neighbor_xyz[b, n].detach().cpu().numpy() neighbors_colors = fvec_to_rgb(neighbor_fvec[b, n].detach().cpu().numpy()) save_pointcloud_to_ply( os.path.join(debug_logs_path, f"{time_now()}{debug_logs_prefix}_neighbors_b{b}_n{n}.ply"), neighbors, neighbors_colors) for n in range(3): neighbors = neighbor_xyz[b, n].detach().cpu().numpy() neighbors_colors = fvec_to_rgb(neighbor_fvec[b, n].detach().cpu().numpy()) query_point = coords_world_xyz[b, n].detach().cpu().numpy() query_point_color = fvec_to_rgb(targets[b, n].detach().cpu().numpy().reshape(1, -1)) combined_points = np.vstack([query_point, neighbors]) combined_colors = np.vstack([query_point_color, neighbors_colors]) query_point_index = 0 neighbor_indices = np.arange(1, len(neighbors) + 1) edges = np.array([[query_point_index, i] for i in neighbor_indices]) save_pointcloud_to_ply(os.path.join(debug_logs_path, f"{time_now()}{debug_logs_prefix}_query_b{b}_n{n}_with_edges.ply"), combined_points, combined_colors, edges=edges) # Visualize the results with rerun.io if save_rerun_logs: import rerun as rr import re assert self.C > 1 rerun_fps = 30 log_feature_maps = True log_knn_neighbors = False knn_line_coloring = "static" knn_neighbors_to_log = 6 logging.info(f"rerun for {debug_logs_prefix} started") ## Mask out target scene area # xyz = self.xyz.detach().cpu().numpy() # bbox = np.array([[-4, 4], [-3, 3.7], [1.2, 5.2]]) # Softball bbox # mask = ( # (xyz[..., 0] > bbox[0, 0]) # & (xyz[..., 0] < bbox[0, 1]) # & (xyz[..., 1] > bbox[1, 0]) # & (xyz[..., 1] < bbox[1, 1]) # & (xyz[..., 2] > bbox[2, 0]) # & (xyz[..., 2] < bbox[2, 1]) # ) xyz = self.xyz.detach().cpu().numpy() mask = np.ones_like(xyz[..., 0]).astype(bool) if self.valid is not None: mask = self.valid.detach().cpu().numpy() # PCA-based feature coloring if self.rerun_fmap_coloring_fn is None: fvec_flat = self.fvec.detach().cpu().numpy()[mask] from sklearn.decomposition import PCA reducer = PCA(n_components=3) reducer.fit(fvec_flat) fvec_reduced = reducer.transform(fvec_flat) reducer_min = fvec_reduced.min(axis=0) reducer_max = fvec_reduced.max(axis=0) def fvec_to_rgb(fvec): input_shape = fvec.shape assert input_shape[-1] == self.C fvec_reduced = reducer.transform(fvec.reshape(-1, self.C)) fvec_reduced = np.clip(fvec_reduced, reducer_min[None, :], reducer_max[None, :]) fvec_reduced_rescaled = (fvec_reduced - reducer_min) / (reducer_max - reducer_min) fvec_reduced_rgb = (fvec_reduced_rescaled * 255).astype(int) fvec_reduced_rgb = fvec_reduced_rgb.reshape(input_shape[:-1] + (3,)) return fvec_reduced_rgb self.rerun_fmap_coloring_fn = fvec_to_rgb fvec_colors = self.rerun_fmap_coloring_fn(self.fvec.detach().cpu().numpy()) targets_colors = self.rerun_fmap_coloring_fn(targets.detach().cpu().numpy()) neighbor_fvec_colors = self.rerun_fmap_coloring_fn(neighbor_fvec.detach().cpu().numpy()) import re pattern = r'__widx-(\d+)_pidx-(\d+)-(\d+)__iter_(\d+)__pyramid_level_(\d+)' match = re.search(pattern, debug_logs_prefix) assert match t_start = int(match.group(1)) pidx_start = int(match.group(2)) pidx_end = int(match.group(3)) iteration = int(match.group(4)) pyramid_level = int(match.group(5)) # # Log fmaps as images for the pipeline figure # import os # from PIL import Image # png_outdir = os.path.join(debug_logs_path, "feature_maps_pngs_2") # os.makedirs(png_outdir, exist_ok=True) # if pyramid_level == 0 and iteration == 0: # for b in range(self.B): # t = t_start + b # for v in range(8): # fvec_rgb_uint8 = fvec_colors[b].reshape(8, 96, 128, 3)[v].astype(np.uint8) # fname = f"fmap__view{v:02d}__frame{t:05d}.png" # fpath = os.path.join(png_outdir, fname) # Image.fromarray(fvec_rgb_uint8).save(fpath) # Log feature map points # if log_feature_maps and pyramid_level in [0, 1, 2, 3] and iteration == 0: if log_feature_maps and pyramid_level in [0] and iteration == 0: if t_start > 0: bs = range(self.B) else: bs = range(self.B // 2, self.B) for b in bs: rr.set_time_seconds("frame", (t_start + b) / rerun_fps) rr.log(f"fmaps/pyramid-{pyramid_level}", rr.Points3D( xyz[b][mask[b]], colors=fvec_colors[b][mask[b]], radii=0.042, # radii=-2.53, )) # Log neighbors if log_knn_neighbors and pyramid_level in [0, 1, 2, 3] and iteration in [0]: for b in range(self.B): rr.set_time_seconds("frame", (t_start + b) / rerun_fps) for n in range(min(neighbor_xyz.shape[1], knn_neighbors_to_log)): # Iterate over queries prefix = f"knn/track-{n:03d}/iter-{iteration}/pyramid-{pyramid_level}" rr.log(f"{prefix}/queries", rr.Points3D( coords_world_xyz[b, n].cpu().numpy(), colors=targets_colors[b, n], radii=0.072, # radii=-9.0, )) rr.log(f"{prefix}/neighbors", rr.Points3D( neighbor_xyz[b, n].cpu().numpy(), colors=neighbor_fvec_colors[b, n], radii=0.054, # radii=-5.0, )) if knn_line_coloring == "correlation": # Compute correlation strength for line coloring corr_strength = corrs[b, n,].squeeze(-1).cpu().numpy() corr_strength_normalized = (corr_strength / corr_strength.max()) * 1.0 + 0.0 line_colors = (corr_strength_normalized[:, None] * np.array([9, 208, 239])).astype(int) line_colors = np.hstack([line_colors, np.full((line_colors.shape[0], 1), 204)]) # RGBA 80% elif knn_line_coloring == "static": # Make the lines sun flower yellow (241, 196, 15) line_colors = np.array([241, 196, 15])[None].repeat(self.k, 0).astype(int) # Draw edges between query and its neighbors strips = np.stack([ coords_world_xyz[b, n].cpu().numpy()[None].repeat(neighbor_xyz.shape[2], axis=0), neighbor_xyz[b, n].cpu().numpy(), ], axis=-2) rr.log(f"{prefix}/arrows", rr.Arrows3D( origins=strips[:, 0], vectors=strips[:, 1] - strips[:, 0], colors=line_colors, radii=0.016, # radii=-1.2, )) logging.info(f"rerun for {debug_logs_prefix} done") return output ================================================ FILE: mvtracker/models/core/ptv3/__init__.py ================================================ ================================================ FILE: mvtracker/models/core/ptv3/model.py ================================================ """ Point Transformer - V3 Mode1 Pointcept detached version Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) Please cite our work if the code is helpful to you. """ import sys from collections import OrderedDict from functools import partial import math import spconv.pytorch as spconv import torch import torch.nn as nn import torch_scatter from addict import Dict from timm.models.layers import DropPath try: import flash_attn except ImportError: flash_attn = None from .serialization import encode @torch.inference_mode() def offset2bincount(offset): return torch.diff( offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long) ) @torch.inference_mode() def offset2batch(offset): bincount = offset2bincount(offset) return torch.arange( len(bincount), device=offset.device, dtype=torch.long ).repeat_interleave(bincount) @torch.inference_mode() def batch2offset(batch): return torch.cumsum(batch.bincount(), dim=0).long() class Point(Dict): """ Point Structure of Pointcept A Point (point cloud) in Pointcept is a dictionary that contains various properties of a batched point cloud. The property with the following names have a specific definition as follows: - "coord": original coordinate of point cloud; - "grid_coord": grid coordinate for specific grid size (related to GridSampling); Point also support the following optional attributes: - "offset": if not exist, initialized as batch size is 1; - "batch": if not exist, initialized as batch size is 1; - "feat": feature of point cloud, default input of model; - "grid_size": Grid size of point cloud (related to GridSampling); (related to Serialization) - "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range; - "serialized_code": a list of serialization codes; - "serialized_order": a list of serialization order determined by code; - "serialized_inverse": a list of inverse mapping determined by code; (related to Sparsify: SpConv) - "sparse_shape": Sparse shape for Sparse Conv Tensor; - "sparse_conv_feat": SparseConvTensor init with information provide by Point; """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # If one of "offset" or "batch" do not exist, generate by the existing one if "batch" not in self.keys() and "offset" in self.keys(): self["batch"] = offset2batch(self.offset) elif "offset" not in self.keys() and "batch" in self.keys(): self["offset"] = batch2offset(self.batch) def serialization(self, order="z", depth=None, shuffle_orders=False): """ Point Cloud Serialization relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] """ assert "batch" in self.keys() if "grid_coord" not in self.keys(): # if you don't want to operate GridSampling in data augmentation, # please add the following augmentation into your pipline: # dict(type="Copy", keys_dict={"grid_size": 0.01}), # (adjust `grid_size` to what your want) assert {"grid_size", "coord"}.issubset(self.keys()) self["grid_coord"] = torch.div( self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" ).int() if depth is None: # Adaptive measure the depth of serialization cube (length = 2 ^ depth) depth = int(self.grid_coord.max()).bit_length() self["serialized_depth"] = depth # Maximum bit length for serialization code is 63 (int64) assert depth * 3 + len(self.offset).bit_length() <= 63 # Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position. # Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3 # cube with a grid size of 0.01 meter. We consider it is enough for the current stage. # We can unlock the limitation by optimizing the z-order encoding function if necessary. assert depth <= 16 # The serialization codes are arranged as following structures: # [Order1 ([n]), # Order2 ([n]), # ... # OrderN ([n])] (k, n) code = [ encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order ] code = torch.stack(code) order = torch.argsort(code) inverse = torch.zeros_like(order).scatter_( dim=1, index=order, src=torch.arange(0, code.shape[1], device=order.device).repeat( code.shape[0], 1 ), ) if shuffle_orders: perm = torch.randperm(code.shape[0]) code = code[perm] order = order[perm] inverse = inverse[perm] self["serialized_code"] = code self["serialized_order"] = order self["serialized_inverse"] = inverse def sparsify(self, pad=96): """ Point Cloud Sparsification Point cloud is sparse, here we use "sparsify" to specifically refer to preparing "spconv.SparseConvTensor" for SpConv. relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] pad: padding sparse for sparse shape. """ assert {"feat", "batch"}.issubset(self.keys()) if "grid_coord" not in self.keys(): # if you don't want to operate GridSampling in data augmentation, # please add the following augmentation into your pipline: # dict(type="Copy", keys_dict={"grid_size": 0.01}), # (adjust `grid_size` to what your want) assert {"grid_size", "coord"}.issubset(self.keys()) self["grid_coord"] = torch.div( self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" ).int() if "sparse_shape" in self.keys(): sparse_shape = self.sparse_shape else: sparse_shape = torch.add( torch.max(self.grid_coord, dim=0).values, pad ).tolist() sparse_conv_feat = spconv.SparseConvTensor( features=self.feat, indices=torch.cat( [self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1 ).contiguous(), spatial_shape=sparse_shape, batch_size=self.batch[-1].tolist() + 1, ) self["sparse_shape"] = sparse_shape self["sparse_conv_feat"] = sparse_conv_feat class PointModule(nn.Module): r"""PointModule placeholder, all module subclass from this will take Point in PointSequential. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) class PointSequential(PointModule): r"""A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in. """ def __init__(self, *args, **kwargs): super().__init__() if len(args) == 1 and isinstance(args[0], OrderedDict): for key, module in args[0].items(): self.add_module(key, module) else: for idx, module in enumerate(args): self.add_module(str(idx), module) for name, module in kwargs.items(): if sys.version_info < (3, 6): raise ValueError("kwargs only supported in py36+") if name in self._modules: raise ValueError("name exists.") self.add_module(name, module) def __getitem__(self, idx): if not (-len(self) <= idx < len(self)): raise IndexError("index {} is out of range".format(idx)) if idx < 0: idx += len(self) it = iter(self._modules.values()) for i in range(idx): next(it) return next(it) def __len__(self): return len(self._modules) def add(self, module, name=None): if name is None: name = str(len(self._modules)) if name in self._modules: raise KeyError("name exists") self.add_module(name, module) def forward(self, input): for k, module in self._modules.items(): # Point module if isinstance(module, PointModule): input = module(input) # Spconv module elif spconv.modules.is_spconv_module(module): if isinstance(input, Point): input.sparse_conv_feat = module(input.sparse_conv_feat) input.feat = input.sparse_conv_feat.features else: input = module(input) # PyTorch module else: if isinstance(input, Point): input.feat = module(input.feat) if "sparse_conv_feat" in input.keys(): input.sparse_conv_feat = input.sparse_conv_feat.replace_feature( input.feat ) elif isinstance(input, spconv.SparseConvTensor): if input.indices.shape[0] != 0: input = input.replace_feature(module(input.features)) else: input = module(input) return input class PDNorm(PointModule): def __init__( self, num_features, norm_layer, context_channels=256, conditions=("ScanNet", "S3DIS", "Structured3D"), decouple=True, adaptive=False, ): super().__init__() self.conditions = conditions self.decouple = decouple self.adaptive = adaptive if self.decouple: self.norm = nn.ModuleList([norm_layer(num_features) for _ in conditions]) else: self.norm = norm_layer if self.adaptive: self.modulation = nn.Sequential( nn.SiLU(), nn.Linear(context_channels, 2 * num_features, bias=True) ) def forward(self, point): assert {"feat", "condition"}.issubset(point.keys()) if isinstance(point.condition, str): condition = point.condition else: condition = point.condition[0] if self.decouple: assert condition in self.conditions norm = self.norm[self.conditions.index(condition)] else: norm = self.norm point.feat = norm(point.feat) if self.adaptive: assert "context" in point.keys() shift, scale = self.modulation(point.context).chunk(2, dim=1) point.feat = point.feat * (1.0 + scale) + shift return point class RPE(torch.nn.Module): def __init__(self, patch_size, num_heads): super().__init__() self.patch_size = patch_size self.num_heads = num_heads self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2) self.rpe_num = 2 * self.pos_bnd + 1 self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads)) torch.nn.init.trunc_normal_(self.rpe_table, std=0.02) def forward(self, coord): idx = ( coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd + self.pos_bnd # relative position to positive index + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride ) out = self.rpe_table.index_select(0, idx.reshape(-1)) out = out.view(idx.shape + (-1,)).sum(3) out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K) return out class SerializedAttention(PointModule): def __init__( self, channels, num_heads, patch_size, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, order_index=0, enable_rpe=False, enable_flash=True, upcast_attention=True, upcast_softmax=True, ): super().__init__() assert channels % num_heads == 0 self.channels = channels self.num_heads = num_heads self.scale = qk_scale or (channels // num_heads) ** -0.5 self.order_index = order_index self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.enable_rpe = enable_rpe self.enable_flash = enable_flash if enable_flash: assert ( enable_rpe is False ), "Set enable_rpe to False when enable Flash Attention" assert ( upcast_attention is False ), "Set upcast_attention to False when enable Flash Attention" assert ( upcast_softmax is False ), "Set upcast_softmax to False when enable Flash Attention" assert flash_attn is not None, "Make sure flash_attn is installed." self.patch_size = patch_size self.attn_drop = attn_drop else: # when disable flash attention, we still don't want to use mask # consequently, patch size will auto set to the # min number of patch_size_max and number of points self.patch_size_max = patch_size self.patch_size = 0 self.attn_drop = torch.nn.Dropout(attn_drop) self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias) self.proj = torch.nn.Linear(channels, channels) self.proj_drop = torch.nn.Dropout(proj_drop) self.softmax = torch.nn.Softmax(dim=-1) self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None @torch.no_grad() def get_rel_pos(self, point, order): K = self.patch_size rel_pos_key = f"rel_pos_{self.order_index}" if rel_pos_key not in point.keys(): grid_coord = point.grid_coord[order] grid_coord = grid_coord.reshape(-1, K, 3) point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1) return point[rel_pos_key] @torch.no_grad() def get_padding_and_inverse(self, point): pad_key = "pad" unpad_key = "unpad" cu_seqlens_key = "cu_seqlens_key" if ( pad_key not in point.keys() or unpad_key not in point.keys() or cu_seqlens_key not in point.keys() ): offset = point.offset bincount = offset2bincount(offset) bincount_pad = ( torch.div( bincount + self.patch_size - 1, self.patch_size, rounding_mode="trunc", ) * self.patch_size ) # only pad point when num of points larger than patch_size mask_pad = bincount > self.patch_size bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad _offset = nn.functional.pad(offset, (1, 0)) _offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0)) pad = torch.arange(_offset_pad[-1], device=offset.device) unpad = torch.arange(_offset[-1], device=offset.device) cu_seqlens = [] for i in range(len(offset)): unpad[_offset[i]: _offset[i + 1]] += _offset_pad[i] - _offset[i] if bincount[i] != bincount_pad[i]: pad[ _offset_pad[i + 1] - self.patch_size + (bincount[i] % self.patch_size): _offset_pad[i + 1] ] = pad[ _offset_pad[i + 1] - 2 * self.patch_size + (bincount[i] % self.patch_size): _offset_pad[i + 1] - self.patch_size ] pad[_offset_pad[i]: _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i] cu_seqlens.append( torch.arange( _offset_pad[i], _offset_pad[i + 1], step=self.patch_size, dtype=torch.int32, device=offset.device, ) ) point[pad_key] = pad point[unpad_key] = unpad point[cu_seqlens_key] = nn.functional.pad( torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1] ) return point[pad_key], point[unpad_key], point[cu_seqlens_key] def forward(self, point): if not self.enable_flash: self.patch_size = min( offset2bincount(point.offset).min().tolist(), self.patch_size_max ) H = self.num_heads K = self.patch_size C = self.channels pad, unpad, cu_seqlens = self.get_padding_and_inverse(point) order = point.serialized_order[self.order_index][pad] inverse = unpad[point.serialized_inverse[self.order_index]] # padding and reshape feat and batch for serialized point patch qkv = self.qkv(point.feat)[order] if not self.enable_flash: # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C') q, k, v = ( qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0) ) # attn if self.upcast_attention: q = q.float() k = k.float() attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K) if self.enable_rpe: attn = attn + self.rpe(self.get_rel_pos(point, order)) if self.upcast_softmax: attn = attn.float() attn = self.softmax(attn) attn = self.attn_drop(attn).to(qkv.dtype) feat = (attn @ v).transpose(1, 2).reshape(-1, C) else: feat = flash_attn.flash_attn_varlen_qkvpacked_func( qkv.half().reshape(-1, 3, H, C // H), cu_seqlens, max_seqlen=self.patch_size, dropout_p=self.attn_drop if self.training else 0, softmax_scale=self.scale, ).reshape(-1, C) feat = feat.to(qkv.dtype) feat = feat[inverse] # ffn feat = self.proj(feat) feat = self.proj_drop(feat) point.feat = feat return point class MLP(nn.Module): def __init__( self, in_channels, hidden_channels=None, out_channels=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_channels = out_channels or in_channels hidden_channels = hidden_channels or in_channels self.fc1 = nn.Linear(in_channels, hidden_channels) self.act = act_layer() self.fc2 = nn.Linear(hidden_channels, out_channels) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Block(PointModule): def __init__( self, channels, num_heads, patch_size=48, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, act_layer=nn.GELU, pre_norm=True, order_index=0, cpe_indice_key=None, enable_rpe=False, enable_flash=True, upcast_attention=True, upcast_softmax=True, ): super().__init__() self.channels = channels self.pre_norm = pre_norm self.cpe = PointSequential( spconv.SubMConv3d( channels, channels, kernel_size=3, bias=True, indice_key=cpe_indice_key, ), nn.Linear(channels, channels), norm_layer(channels), ) self.norm1 = PointSequential(norm_layer(channels)) self.attn = SerializedAttention( channels=channels, patch_size=patch_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop, order_index=order_index, enable_rpe=enable_rpe, enable_flash=enable_flash, upcast_attention=upcast_attention, upcast_softmax=upcast_softmax, ) self.norm2 = PointSequential(norm_layer(channels)) self.mlp = PointSequential( MLP( in_channels=channels, hidden_channels=int(channels * mlp_ratio), out_channels=channels, act_layer=act_layer, drop=proj_drop, ) ) self.drop_path = PointSequential( DropPath(drop_path) if drop_path > 0.0 else nn.Identity() ) def forward(self, point: Point): shortcut = point.feat point = self.cpe(point) point.feat = shortcut + point.feat shortcut = point.feat if self.pre_norm: point = self.norm1(point) point = self.drop_path(self.attn(point)) point.feat = shortcut + point.feat if not self.pre_norm: point = self.norm1(point) shortcut = point.feat if self.pre_norm: point = self.norm2(point) point = self.drop_path(self.mlp(point)) point.feat = shortcut + point.feat if not self.pre_norm: point = self.norm2(point) point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) return point class SerializedPooling(PointModule): def __init__( self, in_channels, out_channels, stride=2, norm_layer=None, act_layer=None, reduce="max", shuffle_orders=True, traceable=True, # record parent and cluster ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8 # TODO: add support to grid pool (any stride) self.stride = stride assert reduce in ["sum", "mean", "min", "max"] self.reduce = reduce self.shuffle_orders = shuffle_orders self.traceable = traceable self.proj = nn.Linear(in_channels, out_channels) if norm_layer is not None: self.norm = PointSequential(norm_layer(out_channels)) if act_layer is not None: self.act = PointSequential(act_layer()) def forward(self, point: Point): pooling_depth = (math.ceil(self.stride) - 1).bit_length() if pooling_depth > point.serialized_depth: pooling_depth = 0 assert { "serialized_code", "serialized_order", "serialized_inverse", "serialized_depth", }.issubset( point.keys() ), "Run point.serialization() point cloud before SerializedPooling" code = point.serialized_code >> pooling_depth * 3 code_, cluster, counts = torch.unique( code[0], sorted=True, return_inverse=True, return_counts=True, ) # indices of point sorted by cluster, for torch_scatter.segment_csr _, indices = torch.sort(cluster) # index pointer for sorted point, for torch_scatter.segment_csr idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) # head_indices of each cluster, for reduce attr e.g. code, batch head_indices = indices[idx_ptr[:-1]] # generate down code, order, inverse code = code[:, head_indices] order = torch.argsort(code) inverse = torch.zeros_like(order).scatter_( dim=1, index=order, src=torch.arange(0, code.shape[1], device=order.device).repeat( code.shape[0], 1 ), ) if self.shuffle_orders: perm = torch.randperm(code.shape[0]) code = code[perm] order = order[perm] inverse = inverse[perm] # collect information point_dict = Dict( feat=torch_scatter.segment_csr( self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce ), coord=torch_scatter.segment_csr( point.coord[indices], idx_ptr, reduce="mean" ), grid_coord=point.grid_coord[head_indices] >> pooling_depth, serialized_code=code, serialized_order=order, serialized_inverse=inverse, serialized_depth=point.serialized_depth - pooling_depth, batch=point.batch[head_indices], ) if "condition" in point.keys(): point_dict["condition"] = point.condition if "context" in point.keys(): point_dict["context"] = point.context if self.traceable: point_dict["pooling_inverse"] = cluster point_dict["pooling_parent"] = point point = Point(point_dict) if self.norm is not None: point = self.norm(point) if self.act is not None: point = self.act(point) point.sparsify() return point class SerializedUnpooling(PointModule): def __init__( self, in_channels, skip_channels, out_channels, norm_layer=None, act_layer=None, traceable=False, # record parent and cluster ): super().__init__() self.proj = PointSequential(nn.Linear(in_channels, out_channels)) self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels)) if norm_layer is not None: self.proj.add(norm_layer(out_channels)) self.proj_skip.add(norm_layer(out_channels)) if act_layer is not None: self.proj.add(act_layer()) self.proj_skip.add(act_layer()) self.traceable = traceable def forward(self, point): assert "pooling_parent" in point.keys() assert "pooling_inverse" in point.keys() parent = point.pop("pooling_parent") inverse = point.pop("pooling_inverse") point = self.proj(point) parent = self.proj_skip(parent) parent.feat = parent.feat + point.feat[inverse] if self.traceable: parent["unpooling_parent"] = point return parent class Embedding(PointModule): def __init__( self, in_channels, embed_channels, norm_layer=None, act_layer=None, ): super().__init__() self.in_channels = in_channels self.embed_channels = embed_channels # TODO: check remove spconv self.stem = PointSequential( conv=spconv.SubMConv3d( in_channels, embed_channels, kernel_size=5, padding=1, bias=False, indice_key="stem", ) ) if norm_layer is not None: self.stem.add(norm_layer(embed_channels), name="norm") if act_layer is not None: self.stem.add(act_layer(), name="act") def forward(self, point: Point): point = self.stem(point) return point class PointTransformerV3(PointModule): def __init__( self, in_channels=6, order=("z", "z-trans", "hilbert", "hilbert-trans"), stride=(2, 2, 2, 2), enc_depths=(2, 2, 2, 6, 2), enc_channels=(32, 64, 128, 256, 512), enc_num_head=(2, 4, 8, 16, 32), enc_patch_size=(1024, 1024, 1024, 1024, 1024), dec_depths=(2, 2, 2, 2), dec_channels=(64, 64, 128, 256), dec_num_head=(4, 4, 8, 16), dec_patch_size=(1024, 1024, 1024, 1024), mlp_ratio=4, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, drop_path=0.3, pre_norm=True, shuffle_orders=True, enable_rpe=False, enable_flash=True, upcast_attention=False, upcast_softmax=False, cls_mode=False, pdnorm_bn=False, pdnorm_ln=False, pdnorm_decouple=True, pdnorm_adaptive=False, pdnorm_affine=True, pdnorm_conditions=("ScanNet", "S3DIS", "Structured3D"), ): super().__init__() self.num_stages = len(enc_depths) self.order = [order] if isinstance(order, str) else order self.cls_mode = cls_mode self.shuffle_orders = shuffle_orders assert self.num_stages == len(stride) + 1 assert self.num_stages == len(enc_depths) assert self.num_stages == len(enc_channels) assert self.num_stages == len(enc_num_head) assert self.num_stages == len(enc_patch_size) assert self.cls_mode or self.num_stages == len(dec_depths) + 1 assert self.cls_mode or self.num_stages == len(dec_channels) + 1 assert self.cls_mode or self.num_stages == len(dec_num_head) + 1 assert self.cls_mode or self.num_stages == len(dec_patch_size) + 1 # norm layers if pdnorm_bn: bn_layer = partial( PDNorm, norm_layer=partial( nn.BatchNorm1d, eps=1e-3, momentum=0.01, affine=pdnorm_affine ), conditions=pdnorm_conditions, decouple=pdnorm_decouple, adaptive=pdnorm_adaptive, ) else: bn_layer = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) if pdnorm_ln: ln_layer = partial( PDNorm, norm_layer=partial(nn.LayerNorm, elementwise_affine=pdnorm_affine), conditions=pdnorm_conditions, decouple=pdnorm_decouple, adaptive=pdnorm_adaptive, ) else: ln_layer = nn.LayerNorm # activation layers act_layer = nn.GELU self.embedding = Embedding( in_channels=in_channels, embed_channels=enc_channels[0], norm_layer=bn_layer, act_layer=act_layer, ) # encoder enc_drop_path = [ x.item() for x in torch.linspace(0, drop_path, sum(enc_depths)) ] self.enc = PointSequential() for s in range(self.num_stages): enc_drop_path_ = enc_drop_path[ sum(enc_depths[:s]): sum(enc_depths[: s + 1]) ] enc = PointSequential() if s > 0: enc.add( SerializedPooling( in_channels=enc_channels[s - 1], out_channels=enc_channels[s], stride=stride[s - 1], norm_layer=bn_layer, act_layer=act_layer, ), name="down", ) for i in range(enc_depths[s]): enc.add( Block( channels=enc_channels[s], num_heads=enc_num_head[s], patch_size=enc_patch_size[s], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop, drop_path=enc_drop_path_[i], norm_layer=ln_layer, act_layer=act_layer, pre_norm=pre_norm, order_index=i % len(self.order), cpe_indice_key=f"stage{s}", enable_rpe=enable_rpe, enable_flash=enable_flash, upcast_attention=upcast_attention, upcast_softmax=upcast_softmax, ), name=f"block{i}", ) if len(enc) != 0: self.enc.add(module=enc, name=f"enc{s}") # decoder if not self.cls_mode: dec_drop_path = [ x.item() for x in torch.linspace(0, drop_path, sum(dec_depths)) ] self.dec = PointSequential() dec_channels = list(dec_channels) + [enc_channels[-1]] for s in reversed(range(self.num_stages - 1)): dec_drop_path_ = dec_drop_path[ sum(dec_depths[:s]): sum(dec_depths[: s + 1]) ] dec_drop_path_.reverse() dec = PointSequential() dec.add( SerializedUnpooling( in_channels=dec_channels[s + 1], skip_channels=enc_channels[s], out_channels=dec_channels[s], norm_layer=bn_layer, act_layer=act_layer, ), name="up", ) for i in range(dec_depths[s]): dec.add( Block( channels=dec_channels[s], num_heads=dec_num_head[s], patch_size=dec_patch_size[s], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop, drop_path=dec_drop_path_[i], norm_layer=ln_layer, act_layer=act_layer, pre_norm=pre_norm, order_index=i % len(self.order), cpe_indice_key=f"stage{s}", enable_rpe=enable_rpe, enable_flash=enable_flash, upcast_attention=upcast_attention, upcast_softmax=upcast_softmax, ), name=f"block{i}", ) self.dec.add(module=dec, name=f"dec{s}") def forward(self, data_dict): """ A data_dict is a dictionary containing properties of a batched point cloud. It should contain the following properties for PTv3: 1. "feat": feature of point cloud 2. "grid_coord": discrete coordinate after grid sampling (voxelization) or "coord" + "grid_size" 3. "offset" or "batch": https://github.com/Pointcept/Pointcept?tab=readme-ov-file#offset """ point = Point(data_dict) point.serialization(order=self.order, shuffle_orders=self.shuffle_orders) point.sparsify() point = self.embedding(point) point = self.enc(point) if not self.cls_mode: point = self.dec(point) return point ================================================ FILE: mvtracker/models/core/ptv3/serialization/__init__.py ================================================ from .default import ( encode, decode, z_order_encode, z_order_decode, hilbert_encode, hilbert_decode, ) ================================================ FILE: mvtracker/models/core/ptv3/serialization/default.py ================================================ import torch from .hilbert import decode as hilbert_decode_ from .hilbert import encode as hilbert_encode_ from .z_order import key2xyz as z_order_decode_ from .z_order import xyz2key as z_order_encode_ @torch.inference_mode() def encode(grid_coord, batch=None, depth=16, order="z"): assert order in {"z", "z-trans", "hilbert", "hilbert-trans"} if order == "z": code = z_order_encode(grid_coord, depth=depth) elif order == "z-trans": code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth) elif order == "hilbert": code = hilbert_encode(grid_coord, depth=depth) elif order == "hilbert-trans": code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth) else: raise NotImplementedError if batch is not None: batch = batch.long() code = batch << depth * 3 | code return code @torch.inference_mode() def decode(code, depth=16, order="z"): assert order in {"z", "hilbert"} batch = code >> depth * 3 code = code & ((1 << depth * 3) - 1) if order == "z": grid_coord = z_order_decode(code, depth=depth) elif order == "hilbert": grid_coord = hilbert_decode(code, depth=depth) else: raise NotImplementedError return grid_coord, batch def z_order_encode(grid_coord: torch.Tensor, depth: int = 16): x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long() # we block the support to batch, maintain batched code in Point class code = z_order_encode_(x, y, z, b=None, depth=depth) return code def z_order_decode(code: torch.Tensor, depth): x, y, z = z_order_decode_(code, depth=depth) grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3) return grid_coord def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16): return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth) def hilbert_decode(code: torch.Tensor, depth: int = 16): return hilbert_decode_(code, num_dims=3, num_bits=depth) ================================================ FILE: mvtracker/models/core/ptv3/serialization/hilbert.py ================================================ """ Hilbert Order Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Kaixin Xu Please cite our work if the code is helpful to you. """ import torch def right_shift(binary, k=1, axis=-1): """Right shift an array of binary values. Parameters: ----------- binary: An ndarray of binary values. k: The number of bits to shift. Default 1. axis: The axis along which to shift. Default -1. Returns: -------- Returns an ndarray with zero prepended and the ends truncated, along whatever axis was specified.""" # If we're shifting the whole thing, just return zeros. if binary.shape[axis] <= k: return torch.zeros_like(binary) # Determine the padding pattern. # padding = [(0,0)] * len(binary.shape) # padding[axis] = (k,0) # Determine the slicing pattern to eliminate just the last one. slicing = [slice(None)] * len(binary.shape) slicing[axis] = slice(None, -k) shifted = torch.nn.functional.pad( binary[tuple(slicing)], (k, 0), mode="constant", value=0 ) return shifted def binary2gray(binary, axis=-1): """Convert an array of binary values into Gray codes. This uses the classic X ^ (X >> 1) trick to compute the Gray code. Parameters: ----------- binary: An ndarray of binary values. axis: The axis along which to compute the gray code. Default=-1. Returns: -------- Returns an ndarray of Gray codes. """ shifted = right_shift(binary, axis=axis) # Do the X ^ (X >> 1) trick. gray = torch.logical_xor(binary, shifted) return gray def gray2binary(gray, axis=-1): """Convert an array of Gray codes back into binary values. Parameters: ----------- gray: An ndarray of gray codes. axis: The axis along which to perform Gray decoding. Default=-1. Returns: -------- Returns an ndarray of binary values. """ # Loop the log2(bits) number of times necessary, with shift and xor. shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1) while shift > 0: gray = torch.logical_xor(gray, right_shift(gray, shift)) shift = torch.div(shift, 2, rounding_mode="floor") return gray def encode(locs, num_dims, num_bits): """Decode an array of locations in a hypercube into a Hilbert integer. This is a vectorized-ish version of the Hilbert curve implementation by John Skilling as described in: Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. Params: ------- locs - An ndarray of locations in a hypercube of num_dims dimensions, in which each dimension runs from 0 to 2**num_bits-1. The shape can be arbitrary, as long as the last dimension of the same has size num_dims. num_dims - The dimensionality of the hypercube. Integer. num_bits - The number of bits for each dimension. Integer. Returns: -------- The output is an ndarray of uint64 integers with the same shape as the input, excluding the last dimension, which needs to be num_dims. """ # Keep around the original shape for later. orig_shape = locs.shape bitpack_mask = 1 << torch.arange(0, 8).to(locs.device) bitpack_mask_rev = bitpack_mask.flip(-1) if orig_shape[-1] != num_dims: raise ValueError( """ The shape of locs was surprising in that the last dimension was of size %d, but num_dims=%d. These need to be equal. """ % (orig_shape[-1], num_dims) ) if num_dims * num_bits > 63: raise ValueError( """ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded into a int64. Are you sure you need that many points on your Hilbert curve? """ % (num_dims, num_bits, num_dims * num_bits) ) # Treat the location integers as 64-bit unsigned and then split them up into # a sequence of uint8s. Preserve the association by dimension. locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1) # Now turn these into bits and truncate to num_bits. gray = ( locs_uint8.unsqueeze(-1) .bitwise_and(bitpack_mask_rev) .ne(0) .byte() .flatten(-2, -1)[..., -num_bits:] ) # Run the decoding process the other way. # Iterate forwards through the bits. for bit in range(0, num_bits): # Iterate forwards through the dimensions. for dim in range(0, num_dims): # Identify which ones have this bit active. mask = gray[:, dim, bit] # Where this bit is on, invert the 0 dimension for lower bits. gray[:, 0, bit + 1:] = torch.logical_xor( gray[:, 0, bit + 1:], mask[:, None] ) # Where the bit is off, exchange the lower bits with the 0 dimension. to_flip = torch.logical_and( torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1), torch.logical_xor(gray[:, 0, bit + 1:], gray[:, dim, bit + 1:]), ) gray[:, dim, bit + 1:] = torch.logical_xor( gray[:, dim, bit + 1:], to_flip ) gray[:, 0, bit + 1:] = torch.logical_xor(gray[:, 0, bit + 1:], to_flip) # Now flatten out. gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims)) # Convert Gray back to binary. hh_bin = gray2binary(gray) # Pad back out to 64 bits. extra_dims = 64 - num_bits * num_dims padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0) # Convert binary values into uint8s. hh_uint8 = ( (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask) .sum(2) .squeeze() .type(torch.uint8) ) # Convert uint8s into uint64s. hh_uint64 = hh_uint8.view(torch.int64).squeeze() return hh_uint64 def decode(hilberts, num_dims, num_bits): """Decode an array of Hilbert integers into locations in a hypercube. This is a vectorized-ish version of the Hilbert curve implementation by John Skilling as described in: Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. Params: ------- hilberts - An ndarray of Hilbert integers. Must be an integer dtype and cannot have fewer bits than num_dims * num_bits. num_dims - The dimensionality of the hypercube. Integer. num_bits - The number of bits for each dimension. Integer. Returns: -------- The output is an ndarray of unsigned integers with the same shape as hilberts but with an additional dimension of size num_dims. """ if num_dims * num_bits > 64: raise ValueError( """ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded into a uint64. Are you sure you need that many points on your Hilbert curve? """ % (num_dims, num_bits) ) # Handle the case where we got handed a naked integer. hilberts = torch.atleast_1d(hilberts) # Keep around the shape for later. orig_shape = hilberts.shape bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device) bitpack_mask_rev = bitpack_mask.flip(-1) # Treat each of the hilberts as a s equence of eight uint8. # This treats all of the inputs as uint64 and makes things uniform. hh_uint8 = ( hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1) ) # Turn these lists of uints into lists of bits and then truncate to the size # we actually need for using Skilling's procedure. hh_bits = ( hh_uint8.unsqueeze(-1) .bitwise_and(bitpack_mask_rev) .ne(0) .byte() .flatten(-2, -1)[:, -num_dims * num_bits:] ) # Take the sequence of bits and Gray-code it. gray = binary2gray(hh_bits) # There has got to be a better way to do this. # I could index them differently, but the eventual packbits likes it this way. gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2) # Iterate backwards through the bits. for bit in range(num_bits - 1, -1, -1): # Iterate backwards through the dimensions. for dim in range(num_dims - 1, -1, -1): # Identify which ones have this bit active. mask = gray[:, dim, bit] # Where this bit is on, invert the 0 dimension for lower bits. gray[:, 0, bit + 1:] = torch.logical_xor( gray[:, 0, bit + 1:], mask[:, None] ) # Where the bit is off, exchange the lower bits with the 0 dimension. to_flip = torch.logical_and( torch.logical_not(mask[:, None]), torch.logical_xor(gray[:, 0, bit + 1:], gray[:, dim, bit + 1:]), ) gray[:, dim, bit + 1:] = torch.logical_xor( gray[:, dim, bit + 1:], to_flip ) gray[:, 0, bit + 1:] = torch.logical_xor(gray[:, 0, bit + 1:], to_flip) # Pad back out to 64 bits. extra_dims = 64 - num_bits padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0) # Now chop these up into blocks of 8. locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8)) # Take those blocks and turn them unto uint8s. # from IPython import embed; embed() locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8) # Finally, treat these as uint64s. flat_locs = locs_uint8.view(torch.int64) # Return them in the expected shape. return flat_locs.reshape((*orig_shape, num_dims)) ================================================ FILE: mvtracker/models/core/ptv3/serialization/z_order.py ================================================ # -------------------------------------------------------- # Octree-based Sparse Convolutional Neural Networks # Copyright (c) 2022 Peng-Shuai Wang # Licensed under The MIT License [see LICENSE for details] # Written by Peng-Shuai Wang # -------------------------------------------------------- from typing import Optional, Union import torch class KeyLUT: def __init__(self): r256 = torch.arange(256, dtype=torch.int64) r512 = torch.arange(512, dtype=torch.int64) zero = torch.zeros(256, dtype=torch.int64) device = torch.device("cpu") self._encode = { device: ( self.xyz2key(r256, zero, zero, 8), self.xyz2key(zero, r256, zero, 8), self.xyz2key(zero, zero, r256, 8), ) } self._decode = {device: self.key2xyz(r512, 9)} def encode_lut(self, device=torch.device("cpu")): if device not in self._encode: cpu = torch.device("cpu") self._encode[device] = tuple(e.to(device) for e in self._encode[cpu]) return self._encode[device] def decode_lut(self, device=torch.device("cpu")): if device not in self._decode: cpu = torch.device("cpu") self._decode[device] = tuple(e.to(device) for e in self._decode[cpu]) return self._decode[device] def xyz2key(self, x, y, z, depth): key = torch.zeros_like(x) for i in range(depth): mask = 1 << i key = ( key | ((x & mask) << (2 * i + 2)) | ((y & mask) << (2 * i + 1)) | ((z & mask) << (2 * i + 0)) ) return key def key2xyz(self, key, depth): x = torch.zeros_like(key) y = torch.zeros_like(key) z = torch.zeros_like(key) for i in range(depth): x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2)) y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1)) z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0)) return x, y, z _key_lut = KeyLUT() def xyz2key( x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, b: Optional[Union[torch.Tensor, int]] = None, depth: int = 16, ): r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys based on pre-computed look up tables. The speed of this function is much faster than the method based on for-loop. Args: x (torch.Tensor): The x coordinate. y (torch.Tensor): The y coordinate. z (torch.Tensor): The z coordinate. b (torch.Tensor or int): The batch index of the coordinates, and should be smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`. depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). """ EX, EY, EZ = _key_lut.encode_lut(x.device) x, y, z = x.long(), y.long(), z.long() mask = 255 if depth > 8 else (1 << depth) - 1 key = EX[x & mask] | EY[y & mask] | EZ[z & mask] if depth > 8: mask = (1 << (depth - 8)) - 1 key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask] key = key16 << 24 | key if b is not None: b = b.long() key = b << 48 | key return key def key2xyz(key: torch.Tensor, depth: int = 16): r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates and the batch index based on pre-computed look up tables. Args: key (torch.Tensor): The shuffled key. depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). """ DX, DY, DZ = _key_lut.decode_lut(key.device) x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key) b = key >> 48 key = key & ((1 << 48) - 1) n = (depth + 2) // 3 for i in range(n): k = key >> (i * 9) & 511 x = x | (DX[k] << (i * 3)) y = y | (DY[k] << (i * 3)) z = z | (DZ[k] << (i * 3)) return x, y, z, b ================================================ FILE: mvtracker/models/core/shape-of-motion/.gitignore ================================================ *.pth *.npy *.mp4 outputs/ work_dirs/ *__pycache__* .vscode/ .envrc .bak/ datasets/ preproc/checkpoints preproc/checkpoints/ ================================================ FILE: mvtracker/models/core/shape-of-motion/.gitmodules ================================================ [submodule "preproc/tapnet"] path = preproc/tapnet url = https://github.com/google-deepmind/tapnet.git [submodule "preproc/DROID-SLAM"] path = preproc/DROID-SLAM url = https://github.com/princeton-vl/DROID-SLAM.git [submodule "preproc/UniDepth"] path = preproc/UniDepth url = https://github.com/lpiccinelli-eth/UniDepth.git ================================================ FILE: mvtracker/models/core/shape-of-motion/LICENSE ================================================ MIT License Copyright (c) 2024 Vickie Ye Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: mvtracker/models/core/shape-of-motion/README.md ================================================ # Shape of Motion: 4D Reconstruction from a Single Video **[Project Page](https://shape-of-motion.github.io/) | [Arxiv](https://arxiv.org/abs/2407.13764)** [Qianqian Wang](https://qianqianwang68.github.io/)1,2*, [Vickie Ye](https://people.eecs.berkeley.edu/~vye/)1\*, [Hang Gao](https://hangg7.com/)1\*, [Jake Austin](https://www.linkedin.com/in/jakeaustin4701)1, [Zhengqi Li](https://zhengqili.github.io/)2, [Angjoo Kanazawa](https://people.eecs.berkeley.edu/~kanazawa/)1 1UC Berkeley   2Google Research \* Equal Contribution ## Installation ``` git clone --recurse-submodules https://github.com/vye16/shape-of-motion cd shape-of-motion/ conda create -n som python=3.10 conda activate som ``` Update `requirements.txt` with correct CUDA version for PyTorch and cuUML, i.e., replacing `cu122` and `cu12` with your CUDA version. ``` pip install -r requirements.txt pip install git+https://github.com/nerfstudio-project/gsplat.git ``` ## Usage ### Preprocessing We depend on the third-party libraries in `preproc` to generate depth maps, object masks, camera estimates, and 2D tracks. Please follow the guide in the [preprocessing README](./preproc/README.md). ### Fitting to a Video ```python python run_training.py \ --work-dir \ data:davis \ --data.seq-name horsejump-low ``` ## Evaluation on iPhone Dataset First, download our processed iPhone dataset from [this](https://drive.google.com/drive/folders/1xJaFS_3027crk7u36cue7BseAX80abRe?usp=sharing) link. To train on a sequence, e.g., *paper-windmill*, run: ```python python run_training.py \ --work-dir \ --port \ data:iphone \ --data.data-dir ``` After optimization, the numerical result can be evaluated via: ``` PYTHONPATH='.' python scripts/evaluate_iphone.py \ --data_dir \ --result_dir \ --seq_names paper-windmill ``` ## Citation ``` @inproceedings{som2024, title = {Shape of Motion: 4D Reconstruction from a Single Video}, author = {Wang, Qianqian and Ye, Vickie and Gao, Hang and Austin, Jake and Li, Zhengqi and Kanazawa, Angjoo}, journal = {arXiv preprint arXiv:2407.13764}, year = {2024} } ``` ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/__init__.py ================================================ ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/configs.py ================================================ from dataclasses import dataclass @dataclass class FGLRConfig: means: float = 1.6e-4 opacities: float = 1e-2 scales: float = 5e-3 quats: float = 1e-3 colors: float = 1e-2 motion_coefs: float = 1e-2 @dataclass class BGLRConfig: means: float = 1.6e-4 opacities: float = 5e-2 scales: float = 5e-3 quats: float = 1e-3 colors: float = 1e-2 @dataclass class MotionLRConfig: rots: float = 1.6e-4 transls: float = 1.6e-4 @dataclass class SceneLRConfig: fg: FGLRConfig bg: BGLRConfig motion_bases: MotionLRConfig @dataclass class LossesConfig: w_rgb: float = 1.0 w_depth_reg: float = 0.5 w_depth_const: float = 0.1 w_depth_grad: float = 1 w_track: float = 2.0 w_mask: float = 1.0 w_smooth_bases: float = 0.1 w_smooth_tracks: float = 2.0 w_scale_var: float = 0.01 w_z_accel: float = 1.0 @dataclass class OptimizerConfig: max_steps: int = 5000 ## Adaptive gaussian control warmup_steps: int = 200 control_every: int = 100 reset_opacity_every_n_controls: int = 30 stop_control_by_screen_steps: int = 4000 stop_control_steps: int = 4000 ### Densify. densify_xys_grad_threshold: float = 0.0002 densify_scale_threshold: float = 0.01 densify_screen_threshold: float = 0.05 stop_densify_steps: int = 15000 ### Cull. cull_opacity_threshold: float = 0.1 cull_scale_threshold: float = 0.5 cull_screen_threshold: float = 0.15 ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/data/__init__.py ================================================ from dataclasses import asdict, replace from torch.utils.data import Dataset from .base_dataset import BaseDataset from .panoptic_dataset import PanopticDataConfig, PanopticStudioDatasetSoM from .casual_dataset import CasualDataset, CustomDataConfig, DavisDataConfig from .iphone_dataset import ( iPhoneDataConfig, iPhoneDataset, iPhoneDatasetKeypointView, iPhoneDatasetVideoView, ) def get_train_val_datasets( data_cfg: iPhoneDataConfig | DavisDataConfig | CustomDataConfig, load_val: bool ) -> tuple[BaseDataset, Dataset | None, Dataset | None, Dataset | None]: train_video_view = None val_img_dataset = None val_kpt_dataset = None if isinstance(data_cfg, iPhoneDataConfig): train_dataset = iPhoneDataset(**asdict(data_cfg)) train_video_view = iPhoneDatasetVideoView(train_dataset) if load_val: val_img_dataset = ( iPhoneDataset( **asdict(replace(data_cfg, split="val", load_from_cache=True)) ) if train_dataset.has_validation else None ) val_kpt_dataset = iPhoneDatasetKeypointView(train_dataset) elif isinstance(data_cfg, DavisDataConfig) or isinstance( data_cfg, CustomDataConfig ): train_dataset = CasualDataset(**asdict(data_cfg)) elif isinstance(data_cfg, PanopticDataConfig): train_dataset = PanopticStudioDatasetSoM(**asdict(data_cfg)) print("PANOPTIC IS LOADED.") else: raise ValueError(f"Unknown data config: {data_cfg}") return train_dataset, train_video_view, val_img_dataset, val_kpt_dataset ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/data/base_dataset.py ================================================ from abc import abstractmethod import torch from torch.utils.data import Dataset, default_collate class BaseDataset(Dataset): @property @abstractmethod def num_frames(self) -> int: ... @property def keyframe_idcs(self) -> torch.Tensor: return torch.arange(self.num_frames) @abstractmethod def get_w2cs(self) -> torch.Tensor: ... @abstractmethod def get_Ks(self) -> torch.Tensor: ... @abstractmethod def get_image(self, index: int) -> torch.Tensor: ... @abstractmethod def get_depth(self, index: int) -> torch.Tensor: ... @abstractmethod def get_mask(self, index: int) -> torch.Tensor: ... def get_img_wh(self) -> tuple[int, int]: ... @abstractmethod def get_tracks_3d( self, num_samples: int, **kwargs ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Returns 3D tracks: coordinates (N, T, 3), visibles (N, T), invisibles (N, T), confidences (N, T), colors (N, 3) """ ... @abstractmethod def get_bkgd_points( self, num_samples: int, **kwargs ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Returns background points: coordinates (N, 3), normals (N, 3), colors (N, 3) """ ... # @staticmethod # def train_collate_fn(batch): # collated = {} # for k in batch[0]: # if k not in [ # "query_tracks_2d", # "target_ts", # "target_w2cs", # "target_Ks", # "target_tracks_2d", # "target_visibles", # "target_track_depths", # "target_invisibles", # "target_confidences", # ]: # collated[k] = default_collate([sample[k] for sample in batch]) # else: # collated[k] = [sample[k] for sample in batch] # return collated @staticmethod def train_collate_fn(batch): """ Collate function that correctly batches data when each sample consists of multiple views. """ # Step 1: Transpose the batch to group by views # If batch contains 4 views per sample, `batch` is a list of lists: [ [view_1, view_2, view_3, view_4], [view_1, view_2, view_3, view_4], ... ] # We want to group all view_1's together, all view_2's together, etc. num_views = len(batch[0]) # Assumes each sample has the same number of views batch_per_view = list(zip(*batch)) # Transposes list-of-lists structure collated_views = [] # Step 2: Collate each view separately for view_batch in batch_per_view: collated = {} for k in view_batch[0]: # Iterate over keys in the dictionary if k not in [ "query_tracks_2d", "target_ts", "target_w2cs", "target_Ks", "target_tracks_2d", "target_visibles", "target_track_depths", "target_invisibles", "target_confidences", ]: collated[k] = default_collate([sample[k] for sample in view_batch]) else: collated[k] = [sample[k] for sample in view_batch] # Keep list format collated_views.append(collated) return collated_views # List of collated dictionaries, one per view ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/data/casual_dataset.py ================================================ import os from dataclasses import dataclass from functools import partial from typing import Literal, cast import cv2 import imageio import numpy as np import torch import torch.nn.functional as F import tyro from loguru import logger as guru from roma import roma from tqdm import tqdm from flow3d.data.base_dataset import BaseDataset from flow3d.data.utils import ( UINT16_MAX, SceneNormDict, get_tracks_3d_for_query_frame, median_filter_2d, normal_from_depth_image, normalize_coords, parse_tapir_track_info, ) from flow3d.transforms import rt_to_mat4 @dataclass class DavisDataConfig: seq_name: str root_dir: str start: int = 0 end: int = -1 res: str = "480p" image_type: str = "JPEGImages" mask_type: str = "Annotations" depth_type: Literal[ "aligned_depth_anything", "aligned_depth_anything_v2", "depth_anything", "depth_anything_v2", "unidepth_disp", ] = "aligned_depth_anything" camera_type: Literal["droid_recon"] = "droid_recon" track_2d_type: Literal["bootstapir", "tapir"] = "bootstapir" mask_erosion_radius: int = 3 scene_norm_dict: tyro.conf.Suppress[SceneNormDict | None] = None num_targets_per_frame: int = 4 load_from_cache: bool = False @dataclass class CustomDataConfig: seq_name: str root_dir: str start: int = 0 end: int = -1 res: str = "" image_type: str = "images" mask_type: str = "masks" depth_type: Literal[ "aligned_depth_anything", "aligned_depth_anything_v2", "depth_anything", "depth_anything_v2", "unidepth_disp", ] = "aligned_depth_anything" camera_type: Literal["droid_recon"] = "droid_recon" track_2d_type: Literal["bootstapir", "tapir"] = "bootstapir" mask_erosion_radius: int = 7 scene_norm_dict: tyro.conf.Suppress[SceneNormDict | None] = None num_targets_per_frame: int = 4 load_from_cache: bool = False class CasualDataset(BaseDataset): def __init__( self, seq_name: str, root_dir: str, start: int = 0, end: int = -1, res: str = "480p", image_type: str = "JPEGImages", mask_type: str = "Annotations", depth_type: Literal[ "aligned_depth_anything", "aligned_depth_anything_v2", "depth_anything", "depth_anything_v2", "unidepth_disp", ] = "aligned_depth_anything", camera_type: Literal["droid_recon"] = "droid_recon", track_2d_type: Literal["bootstapir", "tapir"] = "bootstapir", mask_erosion_radius: int = 3, scene_norm_dict: SceneNormDict | None = None, num_targets_per_frame: int = 4, load_from_cache: bool = False, **_, ): super().__init__() self.seq_name = seq_name self.root_dir = root_dir self.res = res self.depth_type = depth_type self.num_targets_per_frame = num_targets_per_frame self.load_from_cache = load_from_cache self.has_validation = False self.mask_erosion_radius = mask_erosion_radius self.img_dir = f"{root_dir}/{image_type}/{res}/{seq_name}" self.img_ext = os.path.splitext(os.listdir(self.img_dir)[0])[1] self.depth_dir = f"{root_dir}/{depth_type}/{res}/{seq_name}" self.mask_dir = f"{root_dir}/{mask_type}/{res}/{seq_name}" self.tracks_dir = f"{root_dir}/{track_2d_type}/{res}/{seq_name}" self.cache_dir = f"{root_dir}/flow3d_preprocessed/{res}/{seq_name}" # self.cache_dir = f"datasets/davis/flow3d_preprocessed/{res}/{seq_name}" frame_names = [os.path.splitext(p)[0] for p in sorted(os.listdir(self.img_dir))] if end == -1: end = len(frame_names) self.start = start self.end = end self.frame_names = frame_names[start:end] self.imgs: list[torch.Tensor | None] = [None for _ in self.frame_names] self.depths: list[torch.Tensor | None] = [None for _ in self.frame_names] self.masks: list[torch.Tensor | None] = [None for _ in self.frame_names] # load cameras if camera_type == "droid_recon": img = self.get_image(0) H, W = img.shape[:2] w2cs, Ks, tstamps = load_cameras( f"{root_dir}/{camera_type}/{seq_name}.npy", H, W ) else: raise ValueError(f"Unknown camera type: {camera_type}") assert ( len(frame_names) == len(w2cs) == len(Ks) ), f"{len(frame_names)}, {len(w2cs)}, {len(Ks)}" self.w2cs = w2cs[start:end] self.Ks = Ks[start:end] tmask = (tstamps >= start) & (tstamps < end) self._keyframe_idcs = tstamps[tmask] - start self.scale = 1 if scene_norm_dict is None: cached_scene_norm_dict_path = os.path.join( self.cache_dir, "scene_norm_dict.pth" ) if os.path.exists(cached_scene_norm_dict_path) and self.load_from_cache: guru.info("loading cached scene norm dict...") scene_norm_dict = torch.load( os.path.join(self.cache_dir, "scene_norm_dict.pth") ) else: tracks_3d = self.get_tracks_3d(5000, step=self.num_frames // 10)[0] scale, transfm = compute_scene_norm(tracks_3d, self.w2cs) scene_norm_dict = SceneNormDict(scale=scale, transfm=transfm) os.makedirs(self.cache_dir, exist_ok=True) torch.save(scene_norm_dict, cached_scene_norm_dict_path) # transform cameras self.scene_norm_dict = cast(SceneNormDict, scene_norm_dict) self.scale = self.scene_norm_dict["scale"] transform = self.scene_norm_dict["transfm"] guru.info(f"scene norm {self.scale=}, {transform=}") self.w2cs = torch.einsum("nij,jk->nik", self.w2cs, torch.linalg.inv(transform)) self.w2cs[:, :3, 3] /= self.scale @property def num_frames(self) -> int: return len(self.frame_names) @property def keyframe_idcs(self) -> torch.Tensor: return self._keyframe_idcs def __len__(self): return len(self.frame_names) def get_w2cs(self) -> torch.Tensor: return self.w2cs def get_Ks(self) -> torch.Tensor: return self.Ks def get_img_wh(self) -> tuple[int, int]: return self.get_image(0).shape[1::-1] def get_image(self, index) -> torch.Tensor: if self.imgs[index] is None: self.imgs[index] = self.load_image(index) img = cast(torch.Tensor, self.imgs[index]) return img def get_mask(self, index) -> torch.Tensor: if self.masks[index] is None: self.masks[index] = self.load_mask(index) mask = cast(torch.Tensor, self.masks[index]) return mask def get_depth(self, index) -> torch.Tensor: if self.depths[index] is None: self.depths[index] = self.load_depth(index) return self.depths[index] / self.scale def load_image(self, index) -> torch.Tensor: path = f"{self.img_dir}/{self.frame_names[index]}{self.img_ext}" return torch.from_numpy(imageio.imread(path)).float() / 255.0 def load_mask(self, index) -> torch.Tensor: path = f"{self.mask_dir}/{self.frame_names[index]}.png" r = self.mask_erosion_radius mask = imageio.imread(path) fg_mask = mask.reshape((*mask.shape[:2], -1)).max(axis=-1) > 0 bg_mask = ~fg_mask fg_mask_erode = cv2.erode( fg_mask.astype(np.uint8), np.ones((r, r), np.uint8), iterations=1 ) bg_mask_erode = cv2.erode( bg_mask.astype(np.uint8), np.ones((r, r), np.uint8), iterations=1 ) out_mask = np.zeros_like(fg_mask, dtype=np.float32) out_mask[bg_mask_erode > 0] = -1 out_mask[fg_mask_erode > 0] = 1 return torch.from_numpy(out_mask).float() def load_depth(self, index) -> torch.Tensor: path = f"{self.depth_dir}/{self.frame_names[index]}.npy" disp = np.load(path) depth = 1.0 / np.clip(disp, a_min=1e-6, a_max=1e6) depth = torch.from_numpy(depth).float() depth = median_filter_2d(depth[None, None], 11, 1)[0, 0] return depth def load_target_tracks( self, query_index: int, target_indices: list[int], dim: int = 1 ): """ tracks are 2d, occs and uncertainties :param dim (int), default 1: dimension to stack the time axis return (N, T, 4) if dim=1, (T, N, 4) if dim=0 """ q_name = self.frame_names[query_index] all_tracks = [] for ti in target_indices: t_name = self.frame_names[ti] path = f"{self.tracks_dir}/{q_name}_{t_name}.npy" tracks = np.load(path).astype(np.float32) all_tracks.append(tracks) return torch.from_numpy(np.stack(all_tracks, axis=dim)) def get_tracks_3d( self, num_samples: int, start: int = 0, end: int = -1, step: int = 1, **kwargs ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: num_frames = self.num_frames if end < 0: end = num_frames + 1 + end query_idcs = list(range(start, end, step)) target_idcs = list(range(start, end, step)) masks = torch.stack([self.get_mask(i) for i in target_idcs], dim=0) fg_masks = (masks == 1).float() depths = torch.stack([self.get_depth(i) for i in target_idcs], dim=0) inv_Ks = torch.linalg.inv(self.Ks[target_idcs]) c2ws = torch.linalg.inv(self.w2cs[target_idcs]) num_per_query_frame = int(np.ceil(num_samples / len(query_idcs))) cur_num = 0 tracks_all_queries = [] for q_idx in query_idcs: # (N, T, 4) tracks_2d = self.load_target_tracks(q_idx, target_idcs) num_sel = int( min(num_per_query_frame, num_samples - cur_num, len(tracks_2d)) ) if num_sel < len(tracks_2d): sel_idcs = np.random.choice(len(tracks_2d), num_sel, replace=False) tracks_2d = tracks_2d[sel_idcs] cur_num += tracks_2d.shape[0] img = self.get_image(q_idx) tidx = target_idcs.index(q_idx) tracks_tuple = get_tracks_3d_for_query_frame( tidx, img, tracks_2d, depths, fg_masks, inv_Ks, c2ws ) tracks_all_queries.append(tracks_tuple) tracks_3d, colors, visibles, invisibles, confidences = map( partial(torch.cat, dim=0), zip(*tracks_all_queries) ) return tracks_3d, visibles, invisibles, confidences, colors def get_bkgd_points( self, num_samples: int, use_kf_tstamps: bool = True, stride: int = 8, down_rate: int = 8, min_per_frame: int = 64, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: start = 0 end = self.num_frames H, W = self.get_image(0).shape[:2] grid = torch.stack( torch.meshgrid( torch.arange(0, W, dtype=torch.float32), torch.arange(0, H, dtype=torch.float32), indexing="xy", ), dim=-1, ) if use_kf_tstamps: query_idcs = self.keyframe_idcs.tolist() else: num_query_frames = self.num_frames // stride query_endpts = torch.linspace(start, end, num_query_frames + 1) query_idcs = ((query_endpts[:-1] + query_endpts[1:]) / 2).long().tolist() bg_geometry = [] print(f"{query_idcs=}") for query_idx in tqdm(query_idcs, desc="Loading bkgd points", leave=False): img = self.get_image(query_idx) depth = self.get_depth(query_idx) bg_mask = self.get_mask(query_idx) < 0 bool_mask = (bg_mask * (depth > 0)).to(torch.bool) w2c = self.w2cs[query_idx] K = self.Ks[query_idx] # get the bounding box of previous points that reproject into frame # inefficient but works for now bmax_x, bmax_y, bmin_x, bmin_y = 0, 0, W, H for p3d, _, _ in bg_geometry: if len(p3d) < 1: continue # reproject into current frame p2d = torch.einsum( "ij,jk,pk->pi", K, w2c[:3], F.pad(p3d, (0, 1), value=1.0) ) p2d = p2d[:, :2] / p2d[:, 2:].clamp(min=1e-6) xmin, xmax = p2d[:, 0].min().item(), p2d[:, 0].max().item() ymin, ymax = p2d[:, 1].min().item(), p2d[:, 1].max().item() bmin_x = min(bmin_x, int(xmin)) bmin_y = min(bmin_y, int(ymin)) bmax_x = max(bmax_x, int(xmax)) bmax_y = max(bmax_y, int(ymax)) # don't include points that are covered by previous points bmin_x = max(0, bmin_x) bmin_y = max(0, bmin_y) bmax_x = min(W, bmax_x) bmax_y = min(H, bmax_y) overlap_mask = torch.ones_like(bool_mask) overlap_mask[bmin_y:bmax_y, bmin_x:bmax_x] = 0 bool_mask &= overlap_mask if bool_mask.sum() < min_per_frame: guru.debug(f"skipping {query_idx=}") continue points = ( torch.einsum( "ij,pj->pi", torch.linalg.inv(K), F.pad(grid[bool_mask], (0, 1), value=1.0), ) * depth[bool_mask][:, None] ) points = torch.einsum( "ij,pj->pi", torch.linalg.inv(w2c)[:3], F.pad(points, (0, 1), value=1.0) ) point_normals = normal_from_depth_image(depth, K, w2c)[bool_mask] point_colors = img[bool_mask] num_sel = max(len(points) // down_rate, min_per_frame) sel_idcs = np.random.choice(len(points), num_sel, replace=False) points = points[sel_idcs] point_normals = point_normals[sel_idcs] point_colors = point_colors[sel_idcs] guru.debug(f"{query_idx=} {points.shape=}") bg_geometry.append((points, point_normals, point_colors)) bg_points, bg_normals, bg_colors = map( partial(torch.cat, dim=0), zip(*bg_geometry) ) if len(bg_points) > num_samples: sel_idcs = np.random.choice(len(bg_points), num_samples, replace=False) bg_points = bg_points[sel_idcs] bg_normals = bg_normals[sel_idcs] bg_colors = bg_colors[sel_idcs] return bg_points, bg_normals, bg_colors def __getitem__(self, index: int): index = np.random.randint(0, self.num_frames) data = { # (). "frame_names": self.frame_names[index], # (). "ts": torch.tensor(index), # (4, 4). "w2cs": self.w2cs[index], # (3, 3). "Ks": self.Ks[index], # (H, W, 3). "imgs": self.get_image(index), "depths": self.get_depth(index), } tri_mask = self.get_mask(index) valid_mask = tri_mask != 0 # not fg or bg mask = tri_mask == 1 # fg mask data["masks"] = mask.float() data["valid_masks"] = valid_mask.float() # (P, 2) query_tracks = self.load_target_tracks(index, [index])[:, 0, :2] target_inds = torch.from_numpy( np.random.choice( self.num_frames, (self.num_targets_per_frame,), replace=False ) ) # (N, P, 4) target_tracks = self.load_target_tracks(index, target_inds.tolist(), dim=0) data["query_tracks_2d"] = query_tracks data["target_ts"] = target_inds data["target_w2cs"] = self.w2cs[target_inds] data["target_Ks"] = self.Ks[target_inds] data["target_tracks_2d"] = target_tracks[..., :2] # (N, P). ( data["target_visibles"], data["target_invisibles"], data["target_confidences"], ) = parse_tapir_track_info(target_tracks[..., 2], target_tracks[..., 3]) # (N, H, W) target_depths = torch.stack([self.get_depth(i) for i in target_inds], dim=0) H, W = target_depths.shape[-2:] data["target_track_depths"] = F.grid_sample( target_depths[:, None], normalize_coords(target_tracks[..., None, :2], H, W), align_corners=True, padding_mode="border", )[:, 0, :, 0] return data def load_cameras( path: str, H: int, W: int ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert os.path.exists(path), f"Camera file {path} does not exist." recon = np.load(path, allow_pickle=True).item() guru.debug(f"{recon.keys()=}") traj_c2w = recon["traj_c2w"] # (N, 4, 4) h, w = recon["img_shape"] sy, sx = H / h, W / w traj_w2c = np.linalg.inv(traj_c2w) fx, fy, cx, cy = recon["intrinsics"] # (4,) K = np.array([[fx * sx, 0, cx * sx], [0, fy * sy, cy * sy], [0, 0, 1]]) # (3, 3) Ks = np.tile(K[None, ...], (len(traj_c2w), 1, 1)) # (N, 3, 3) kf_tstamps = recon["tstamps"].astype("int") return ( torch.from_numpy(traj_w2c).float(), torch.from_numpy(Ks).float(), torch.from_numpy(kf_tstamps), ) def compute_scene_norm( X: torch.Tensor, w2cs: torch.Tensor ) -> tuple[float, torch.Tensor]: """ :param X: [N*T, 3] :param w2cs: [N, 4, 4] """ X = X.reshape(-1, 3) scene_center = X.mean(dim=0) X = X - scene_center[None] min_scale = X.quantile(0.05, dim=0) max_scale = X.quantile(0.95, dim=0) scale = (max_scale - min_scale).max().item() / 2.0 original_up = -F.normalize(w2cs[:, 1, :3].mean(0), dim=-1) target_up = original_up.new_tensor([0.0, 0.0, 1.0]) R = roma.rotvec_to_rotmat( F.normalize(original_up.cross(target_up), dim=-1) * original_up.dot(target_up).acos_() ) transfm = rt_to_mat4(R, torch.einsum("ij,j->i", -R, scene_center)) return scale, transfm if __name__ == "__main__": d = CasualDataset("bear", "/shared/vye/datasets/DAVIS", camera_type="droid_recon") ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/data/colmap.py ================================================ import os import struct from dataclasses import dataclass from pathlib import Path from typing import Dict, Union import numpy as np def get_colmap_camera_params(colmap_dir, img_files): cameras = read_cameras_binary(colmap_dir + "/cameras.bin") images = read_images_binary(colmap_dir + "/images.bin") colmap_image_idcs = {v.name: k for k, v in images.items()} img_names = [os.path.basename(img_file) for img_file in img_files] num_imgs = len(img_names) K_all = np.zeros((num_imgs, 4, 4)) extrinsics_all = np.zeros((num_imgs, 4, 4)) for idx, name in enumerate(img_names): key = colmap_image_idcs[name] image = images[key] assert image.name == name K, extrinsics = get_intrinsics_extrinsics(image, cameras) K_all[idx] = K extrinsics_all[idx] = extrinsics return K_all, extrinsics_all @dataclass(frozen=True) class CameraModel: model_id: int model_name: str num_params: int @dataclass(frozen=True) class Camera: id: int model: str width: int height: int params: np.ndarray @dataclass(frozen=True) class BaseImage: id: int qvec: np.ndarray tvec: np.ndarray camera_id: int name: str xys: np.ndarray point3D_ids: np.ndarray @dataclass(frozen=True) class Point3D: id: int xyz: np.ndarray rgb: np.ndarray error: Union[float, np.ndarray] image_ids: np.ndarray point2D_idxs: np.ndarray class Image(BaseImage): def qvec2rotmat(self): return qvec2rotmat(self.qvec) CAMERA_MODELS = { CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), CameraModel(model_id=1, model_name="PINHOLE", num_params=4), CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), CameraModel(model_id=3, model_name="RADIAL", num_params=5), CameraModel(model_id=4, model_name="OPENCV", num_params=8), CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), CameraModel(model_id=7, model_name="FOV", num_params=5), CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), } CAMERA_MODEL_IDS = dict( [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] ) def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): """Read and unpack the next bytes from a binary file. :param fid: :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. :param endian_character: Any of {@, =, <, >, !} :return: Tuple of read and unpacked values. """ data = fid.read(num_bytes) return struct.unpack(endian_character + format_char_sequence, data) def read_cameras_text(path: Union[str, Path]) -> Dict[int, Camera]: """ see: src/base/reconstruction.cc void Reconstruction::WriteCamerasText(const std::string& path) void Reconstruction::ReadCamerasText(const std::string& path) """ cameras = {} with open(path, "r") as fid: while True: line = fid.readline() if not line: break line = line.strip() if len(line) > 0 and line[0] != "#": elems = line.split() camera_id = int(elems[0]) model = elems[1] width = int(elems[2]) height = int(elems[3]) params = np.array(tuple(map(float, elems[4:]))) cameras[camera_id] = Camera( id=camera_id, model=model, width=width, height=height, params=params ) return cameras def read_cameras_binary(path_to_model_file: Union[str, Path]) -> Dict[int, Camera]: """ see: src/base/reconstruction.cc void Reconstruction::WriteCamerasBinary(const std::string& path) void Reconstruction::ReadCamerasBinary(const std::string& path) """ cameras = {} with open(path_to_model_file, "rb") as fid: num_cameras = read_next_bytes(fid, 8, "Q")[0] for camera_line_index in range(num_cameras): camera_properties = read_next_bytes( fid, num_bytes=24, format_char_sequence="iiQQ" ) camera_id = camera_properties[0] model_id = camera_properties[1] model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name width = camera_properties[2] height = camera_properties[3] num_params = CAMERA_MODEL_IDS[model_id].num_params params = read_next_bytes( fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params ) cameras[camera_id] = Camera( id=camera_id, model=model_name, width=width, height=height, params=np.array(params), ) assert len(cameras) == num_cameras return cameras def read_images_text(path: Union[str, Path]) -> Dict[int, Image]: """ see: src/base/reconstruction.cc void Reconstruction::ReadImagesText(const std::string& path) void Reconstruction::WriteImagesText(const std::string& path) """ images = {} with open(path, "r") as fid: while True: line = fid.readline() if not line: break line = line.strip() if len(line) > 0 and line[0] != "#": elems = line.split() image_id = int(elems[0]) qvec = np.array(tuple(map(float, elems[1:5]))) tvec = np.array(tuple(map(float, elems[5:8]))) camera_id = int(elems[8]) image_name = elems[9] elems = fid.readline().split() xys = np.column_stack( [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))] ) point3D_ids = np.array(tuple(map(int, elems[2::3]))) images[image_id] = Image( id=image_id, qvec=qvec, tvec=tvec, camera_id=camera_id, name=image_name, xys=xys, point3D_ids=point3D_ids, ) return images def read_images_binary(path_to_model_file: Union[str, Path]) -> Dict[int, Image]: """ see: src/base/reconstruction.cc void Reconstruction::ReadImagesBinary(const std::string& path) void Reconstruction::WriteImagesBinary(const std::string& path) """ images = {} with open(path_to_model_file, "rb") as fid: num_reg_images = read_next_bytes(fid, 8, "Q")[0] for image_index in range(num_reg_images): binary_image_properties = read_next_bytes( fid, num_bytes=64, format_char_sequence="idddddddi" ) image_id = binary_image_properties[0] qvec = np.array(binary_image_properties[1:5]) tvec = np.array(binary_image_properties[5:8]) camera_id = binary_image_properties[8] image_name = "" current_char = read_next_bytes(fid, 1, "c")[0] while current_char != b"\x00": # look for the ASCII 0 entry image_name += current_char.decode("utf-8") current_char = read_next_bytes(fid, 1, "c")[0] num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ 0 ] x_y_id_s = read_next_bytes( fid, num_bytes=24 * num_points2D, format_char_sequence="ddq" * num_points2D, ) xys = np.column_stack( [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))] ) point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) images[image_id] = Image( id=image_id, qvec=qvec, tvec=tvec, camera_id=camera_id, name=image_name, xys=xys, point3D_ids=point3D_ids, ) return images def read_points3D_text(path: Union[str, Path]): """ see: src/base/reconstruction.cc void Reconstruction::ReadPoints3DText(const std::string& path) void Reconstruction::WritePoints3DText(const std::string& path) """ points3D = {} with open(path, "r") as fid: while True: line = fid.readline() if not line: break line = line.strip() if len(line) > 0 and line[0] != "#": elems = line.split() point3D_id = int(elems[0]) xyz = np.array(tuple(map(float, elems[1:4]))) rgb = np.array(tuple(map(int, elems[4:7]))) error = float(elems[7]) image_ids = np.array(tuple(map(int, elems[8::2]))) point2D_idxs = np.array(tuple(map(int, elems[9::2]))) points3D[point3D_id] = Point3D( id=point3D_id, xyz=xyz, rgb=rgb, error=error, image_ids=image_ids, point2D_idxs=point2D_idxs, ) return points3D def read_points3d_binary(path_to_model_file: Union[str, Path]) -> Dict[int, Point3D]: """ see: src/base/reconstruction.cc void Reconstruction::ReadPoints3DBinary(const std::string& path) void Reconstruction::WritePoints3DBinary(const std::string& path) """ points3D = {} with open(path_to_model_file, "rb") as fid: num_points = read_next_bytes(fid, 8, "Q")[0] for point_line_index in range(num_points): binary_point_line_properties = read_next_bytes( fid, num_bytes=43, format_char_sequence="QdddBBBd" ) point3D_id = binary_point_line_properties[0] xyz = np.array(binary_point_line_properties[1:4]) rgb = np.array(binary_point_line_properties[4:7]) error = np.array(binary_point_line_properties[7]) track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ 0 ] track_elems = read_next_bytes( fid, num_bytes=8 * track_length, format_char_sequence="ii" * track_length, ) image_ids = np.array(tuple(map(int, track_elems[0::2]))) point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) points3D[point3D_id] = Point3D( id=point3D_id, xyz=xyz, rgb=rgb, error=error, image_ids=image_ids, point2D_idxs=point2D_idxs, ) return points3D def qvec2rotmat(qvec): return np.array( [ [ 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], ], [ 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], ], [ 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, ], ] ) def get_intrinsics_extrinsics(img, cameras): # world to cam transformation R = qvec2rotmat(img.qvec) # translation t = img.tvec cam = cameras[img.camera_id] if cam.model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"): fx = fy = cam.params[0] cx = cam.params[1] cy = cam.params[2] elif cam.model in ( "PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV", ): fx = cam.params[0] fy = cam.params[1] cx = cam.params[2] cy = cam.params[3] else: raise Exception("Camera model not supported") # intrinsics K = np.identity(4) K[0, 0] = fx K[1, 1] = fy K[0, 2] = cx K[1, 2] = cy extrinsics = np.eye(4) extrinsics[:3, :3] = R extrinsics[:3, 3] = t return K, extrinsics ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/data/iphone_dataset.py ================================================ import json import os import os.path as osp from dataclasses import dataclass from glob import glob from itertools import product from typing import Literal import imageio.v3 as iio import numpy as np import roma import torch import torch.nn.functional as F import tyro from loguru import logger as guru from torch.utils.data import Dataset from tqdm import tqdm from flow3d.data.base_dataset import BaseDataset from flow3d.data.colmap import get_colmap_camera_params from flow3d.data.utils import ( SceneNormDict, masked_median_blur, normal_from_depth_image, normalize_coords, parse_tapir_track_info, ) from flow3d.transforms import rt_to_mat4 @dataclass class iPhoneDataConfig: data_dir: str start: int = 0 end: int = -1 split: Literal["train", "val"] = "train" depth_type: Literal[ "midas", "depth_anything", "lidar", "depth_anything_colmap", ] = "depth_anything_colmap" camera_type: Literal["original", "refined"] = "refined" use_median_filter: bool = False num_targets_per_frame: int = 4 scene_norm_dict: tyro.conf.Suppress[SceneNormDict | None] = None load_from_cache: bool = False skip_load_imgs: bool = False class iPhoneDataset(BaseDataset): def __init__( self, data_dir: str, start: int = 0, end: int = -1, factor: int = 1, split: Literal["train", "val"] = "train", depth_type: Literal[ "midas", "depth_anything", "lidar", "depth_anything_colmap", ] = "depth_anything_colmap", camera_type: Literal["original", "refined"] = "refined", use_median_filter: bool = False, num_targets_per_frame: int = 1, scene_norm_dict: SceneNormDict | None = None, load_from_cache: bool = False, skip_load_imgs: bool = False, **_, ): super().__init__() print(skip_load_imgs) self.data_dir = data_dir self.training = split == "train" self.split = split self.factor = factor self.start = start self.end = end self.depth_type = depth_type self.camera_type = camera_type self.use_median_filter = use_median_filter self.num_targets_per_frame = num_targets_per_frame self.scene_norm_dict = scene_norm_dict self.load_from_cache = load_from_cache self.cache_dir = osp.join(data_dir, "flow3d_preprocessed", "cache") os.makedirs(self.cache_dir, exist_ok=True) # Test if the current data has validation set. with open(osp.join(data_dir, "splits", "val.json")) as f: split_dict = json.load(f) self.has_validation = len(split_dict["frame_names"]) > 0 # Load metadata. with open(osp.join(data_dir, "splits", f"{split}.json")) as f: split_dict = json.load(f) full_len = len(split_dict["frame_names"]) end = min(end, full_len) if end > 0 else full_len self.end = end self.frame_names = split_dict["frame_names"][start:end] time_ids = [t for t in split_dict["time_ids"] if t >= start and t < end] self.time_ids = torch.tensor(time_ids) - start guru.info(f"{self.time_ids.min()=} {self.time_ids.max()=}") # with open(osp.join(data_dir, "dataset.json")) as f: # dataset_dict = json.load(f) # self.num_frames = dataset_dict["num_exemplars"] guru.info(f"{self.num_frames=}") with open(osp.join(data_dir, "extra.json")) as f: extra_dict = json.load(f) self.fps = float(extra_dict["fps"]) # Load cameras. if self.camera_type == "original": Ks, w2cs = [], [] for frame_name in self.frame_names: with open(osp.join(data_dir, "camera", f"{frame_name}.json")) as f: camera_dict = json.load(f) focal_length = camera_dict["focal_length"] principal_point = camera_dict["principal_point"] Ks.append( [ [focal_length, 0.0, principal_point[0]], [0.0, focal_length, principal_point[1]], [0.0, 0.0, 1.0], ] ) orientation = np.array(camera_dict["orientation"]) position = np.array(camera_dict["position"]) w2cs.append( np.block( [ [orientation, -orientation @ position[:, None]], [np.zeros((1, 3)), np.ones((1, 1))], ] ).astype(np.float32) ) self.Ks = torch.tensor(Ks) self.Ks[:, :2] /= factor self.w2cs = torch.from_numpy(np.array(w2cs)) elif self.camera_type == "refined": Ks, w2cs = get_colmap_camera_params( osp.join(data_dir, "flow3d_preprocessed/colmap/sparse/"), [frame_name + ".png" for frame_name in self.frame_names], ) self.Ks = torch.from_numpy(Ks[:, :3, :3].astype(np.float32)) self.Ks[:, :2] /= factor self.w2cs = torch.from_numpy(w2cs.astype(np.float32)) if not skip_load_imgs: # Load images. imgs = torch.from_numpy( np.array( [ iio.imread( osp.join(self.data_dir, f"rgb/{factor}x/{frame_name}.png") ) for frame_name in tqdm( self.frame_names, desc=f"Loading {self.split} images", leave=False, ) ], ) ) self.imgs = imgs[..., :3] / 255.0 self.valid_masks = imgs[..., 3] / 255.0 # Load masks. self.masks = ( torch.from_numpy( np.array( [ iio.imread( osp.join( self.data_dir, "flow3d_preprocessed/track_anything/", f"{factor}x/{frame_name}.png", ) ) for frame_name in tqdm( self.frame_names, desc=f"Loading {self.split} masks", leave=False, ) ], ) ) / 255.0 ) if self.training: # Load depths. def load_depth(frame_name): if self.depth_type == "lidar": depth = np.load( osp.join( self.data_dir, f"depth/{factor}x/{frame_name}.npy", ) )[..., 0] else: depth = np.load( osp.join( self.data_dir, f"flow3d_preprocessed/aligned_{self.depth_type}/", f"{factor}x/{frame_name}.npy", ) ) depth[depth < 1e-3] = 1e-3 depth = 1.0 / depth return depth self.depths = torch.from_numpy( np.array( [ load_depth(frame_name) for frame_name in tqdm( self.frame_names, desc=f"Loading {self.split} depths", leave=False, ) ], np.float32, ) ) max_depth_values_per_frame = self.depths.reshape( self.num_frames, -1 ).max(1)[0] max_depth_value = max_depth_values_per_frame.median() * 2.5 print("max_depth_value", max_depth_value) self.depths = torch.clamp(self.depths, 0, max_depth_value) # Median filter depths. # NOTE(hangg): This operator is very expensive. if self.use_median_filter: for i in tqdm( range(self.num_frames), desc="Processing depths", leave=False ): depth = masked_median_blur( self.depths[[i]].unsqueeze(1).to("cuda"), ( self.masks[[i]] * self.valid_masks[[i]] * (self.depths[[i]] > 0) ) .unsqueeze(1) .to("cuda"), )[0, 0].cpu() self.depths[i] = depth * self.masks[i] + self.depths[i] * ( 1 - self.masks[i] ) # Load the query pixels from 2D tracks. self.query_tracks_2d = [ torch.from_numpy( np.load( osp.join( self.data_dir, "flow3d_preprocessed/2d_tracks/", f"{factor}x/{frame_name}_{frame_name}.npy", ) ).astype(np.float32) ) for frame_name in self.frame_names ] guru.info( f"{len(self.query_tracks_2d)=} {self.query_tracks_2d[0].shape=}" ) # Load sam features. # sam_feat_dir = osp.join( # data_dir, f"flow3d_preprocessed/sam_features/{factor}x" # ) # assert osp.exists(sam_feat_dir), f"SAM features not exist!" # sam_features, original_size, input_size = load_sam_features( # sam_feat_dir, self.frame_names # ) # guru.info(f"{sam_features.shape=} {original_size=} {input_size=}") # self.sam_features = sam_features # self.sam_original_size = original_size # self.sam_input_size = input_size else: # Load covisible masks. self.covisible_masks = ( torch.from_numpy( np.array( [ iio.imread( osp.join( self.data_dir, "flow3d_preprocessed/covisible/", f"{factor}x/{split}/{frame_name}.png", ) ) for frame_name in tqdm( self.frame_names, desc=f"Loading {self.split} covisible masks", leave=False, ) ], ) ) / 255.0 ) if self.scene_norm_dict is None: cached_scene_norm_dict_path = osp.join( self.cache_dir, "scene_norm_dict.pth" ) if osp.exists(cached_scene_norm_dict_path) and self.load_from_cache: print("loading cached scene norm dict...") self.scene_norm_dict = torch.load( osp.join(self.cache_dir, "scene_norm_dict.pth") ) elif self.training: # Compute the scene scale and transform for normalization. # Normalize the scene based on the foreground 3D tracks. subsampled_tracks_3d = self.get_tracks_3d( num_samples=10000, step=self.num_frames // 10, show_pbar=False )[0] scene_center = subsampled_tracks_3d.mean((0, 1)) tracks_3d_centered = subsampled_tracks_3d - scene_center min_scale = tracks_3d_centered.quantile(0.05, dim=0) max_scale = tracks_3d_centered.quantile(0.95, dim=0) scale = torch.max(max_scale - min_scale).item() / 2.0 original_up = -F.normalize(self.w2cs[:, 1, :3].mean(0), dim=-1) target_up = original_up.new_tensor([0.0, 0.0, 1.0]) R = roma.rotvec_to_rotmat( F.normalize(original_up.cross(target_up, dim=-1), dim=-1) * original_up.dot(target_up).acos_() ) transfm = rt_to_mat4(R, torch.einsum("ij,j->i", -R, scene_center)) self.scene_norm_dict = SceneNormDict(scale=scale, transfm=transfm) torch.save(self.scene_norm_dict, cached_scene_norm_dict_path) else: raise ValueError("scene_norm_dict must be provided for validation.") # Normalize the scene. scale = self.scene_norm_dict["scale"] transfm = self.scene_norm_dict["transfm"] self.w2cs = self.w2cs @ torch.linalg.inv(transfm) self.w2cs[:, :3, 3] /= scale if self.training and not skip_load_imgs: self.depths /= scale if not skip_load_imgs: guru.info( f"{self.imgs.shape=} {self.valid_masks.shape=} {self.masks.shape=}" ) @property def num_frames(self) -> int: return len(self.frame_names) def __len__(self): return self.imgs.shape[0] def get_w2cs(self) -> torch.Tensor: return self.w2cs def get_Ks(self) -> torch.Tensor: return self.Ks def get_image(self, index: int) -> torch.Tensor: return self.imgs[index] def get_depth(self, index: int) -> torch.Tensor: return self.depths[index] def get_mask(self, index: int) -> torch.Tensor: return self.masks[index] def get_img_wh(self) -> tuple[int, int]: return iio.imread( osp.join(self.data_dir, f"rgb/{self.factor}x/{self.frame_names[0]}.png") ).shape[1::-1] # def get_sam_features(self) -> list[torch.Tensor, tuple[int, int], tuple[int, int]]: # return self.sam_features, self.sam_original_size, self.sam_input_size def get_tracks_3d( self, num_samples: int, step: int = 1, show_pbar: bool = True, **kwargs ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Get 3D tracks from the dataset. Args: num_samples (int | None): The number of samples to fetch. If None, fetch all samples. If not None, fetch roughly a same number of samples across each frame. Note that this might result in number of samples less than what is specified. step (int): The step to temporally subsample the track. """ assert ( self.split == "train" ), "fetch_tracks_3d is only available for the training split." cached_track_3d_path = osp.join(self.cache_dir, f"tracks_3d_{num_samples}.pth") if osp.exists(cached_track_3d_path) and step == 1 and self.load_from_cache: print("loading cached 3d tracks data...") start, end = self.start, self.end cached_track_3d_data = torch.load(cached_track_3d_path) tracks_3d, visibles, invisibles, confidences, track_colors = ( cached_track_3d_data["tracks_3d"][:, start:end], cached_track_3d_data["visibles"][:, start:end], cached_track_3d_data["invisibles"][:, start:end], cached_track_3d_data["confidences"][:, start:end], cached_track_3d_data["track_colors"], ) return tracks_3d, visibles, invisibles, confidences, track_colors # Load 2D tracks. raw_tracks_2d = [] candidate_frames = list(range(0, self.num_frames, step)) num_sampled_frames = len(candidate_frames) for i in ( tqdm(candidate_frames, desc="Loading 2D tracks", leave=False) if show_pbar else candidate_frames ): curr_num_samples = self.query_tracks_2d[i].shape[0] num_samples_per_frame = ( int(np.floor(num_samples / num_sampled_frames)) if i != candidate_frames[-1] else num_samples - (num_sampled_frames - 1) * int(np.floor(num_samples / num_sampled_frames)) ) if num_samples_per_frame < curr_num_samples: track_sels = np.random.choice( curr_num_samples, (num_samples_per_frame,), replace=False ) else: track_sels = np.arange(0, curr_num_samples) curr_tracks_2d = [] for j in range(0, self.num_frames, step): if i == j: target_tracks_2d = self.query_tracks_2d[i] else: target_tracks_2d = torch.from_numpy( np.load( osp.join( self.data_dir, "flow3d_preprocessed/2d_tracks/", f"{self.factor}x/" f"{self.frame_names[i]}_" f"{self.frame_names[j]}.npy", ) ).astype(np.float32) ) curr_tracks_2d.append(target_tracks_2d[track_sels]) raw_tracks_2d.append(torch.stack(curr_tracks_2d, dim=1)) guru.info(f"{step=} {len(raw_tracks_2d)=} {raw_tracks_2d[0].shape=}") # Process 3D tracks. inv_Ks = torch.linalg.inv(self.Ks)[::step] c2ws = torch.linalg.inv(self.w2cs)[::step] H, W = self.imgs.shape[1:3] filtered_tracks_3d, filtered_visibles, filtered_track_colors = [], [], [] filtered_invisibles, filtered_confidences = [], [] masks = self.masks * self.valid_masks * (self.depths > 0) masks = (masks > 0.5).float() for i, tracks_2d in enumerate(raw_tracks_2d): tracks_2d = tracks_2d.swapdims(0, 1) tracks_2d, occs, dists = ( tracks_2d[..., :2], tracks_2d[..., 2], tracks_2d[..., 3], ) # visibles = postprocess_occlusions(occs, dists) visibles, invisibles, confidences = parse_tapir_track_info(occs, dists) # Unproject 2D tracks to 3D. track_depths = F.grid_sample( self.depths[::step, None], normalize_coords(tracks_2d[..., None, :], H, W), align_corners=True, padding_mode="border", )[:, 0] tracks_3d = ( torch.einsum( "nij,npj->npi", inv_Ks, F.pad(tracks_2d, (0, 1), value=1.0), ) * track_depths ) tracks_3d = torch.einsum( "nij,npj->npi", c2ws, F.pad(tracks_3d, (0, 1), value=1.0) )[..., :3] # Filter out out-of-mask tracks. is_in_masks = ( F.grid_sample( masks[::step, None], normalize_coords(tracks_2d[..., None, :], H, W), align_corners=True, ).squeeze() == 1 ) visibles *= is_in_masks invisibles *= is_in_masks confidences *= is_in_masks.float() # Get track's color from the query frame. track_colors = ( F.grid_sample( self.imgs[i * step : i * step + 1].permute(0, 3, 1, 2), normalize_coords(tracks_2d[i : i + 1, None, :], H, W), align_corners=True, padding_mode="border", ) .squeeze() .T ) # at least visible 5% of the time, otherwise discard visible_counts = visibles.sum(0) valid = visible_counts >= min( int(0.05 * self.num_frames), visible_counts.float().quantile(0.1).item(), ) filtered_tracks_3d.append(tracks_3d[:, valid]) filtered_visibles.append(visibles[:, valid]) filtered_invisibles.append(invisibles[:, valid]) filtered_confidences.append(confidences[:, valid]) filtered_track_colors.append(track_colors[valid]) filtered_tracks_3d = torch.cat(filtered_tracks_3d, dim=1).swapdims(0, 1) filtered_visibles = torch.cat(filtered_visibles, dim=1).swapdims(0, 1) filtered_invisibles = torch.cat(filtered_invisibles, dim=1).swapdims(0, 1) filtered_confidences = torch.cat(filtered_confidences, dim=1).swapdims(0, 1) filtered_track_colors = torch.cat(filtered_track_colors, dim=0) if step == 1: torch.save( { "tracks_3d": filtered_tracks_3d, "visibles": filtered_visibles, "invisibles": filtered_invisibles, "confidences": filtered_confidences, "track_colors": filtered_track_colors, }, cached_track_3d_path, ) return ( filtered_tracks_3d, filtered_visibles, filtered_invisibles, filtered_confidences, filtered_track_colors, ) def get_bkgd_points( self, num_samples: int, **kwargs ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: H, W = self.imgs.shape[1:3] grid = torch.stack( torch.meshgrid( torch.arange(W, dtype=torch.float32), torch.arange(H, dtype=torch.float32), indexing="xy", ), dim=-1, ) candidate_frames = list(range(self.num_frames)) num_sampled_frames = len(candidate_frames) bkgd_points, bkgd_point_normals, bkgd_point_colors = [], [], [] for i in tqdm(candidate_frames, desc="Loading bkgd points", leave=False): img = self.imgs[i] depth = self.depths[i] bool_mask = ((1.0 - self.masks[i]) * self.valid_masks[i] * (depth > 0)).to( torch.bool ) w2c = self.w2cs[i] K = self.Ks[i] points = ( torch.einsum( "ij,pj->pi", torch.linalg.inv(K), F.pad(grid[bool_mask], (0, 1), value=1.0), ) * depth[bool_mask][:, None] ) points = torch.einsum( "ij,pj->pi", torch.linalg.inv(w2c)[:3], F.pad(points, (0, 1), value=1.0) ) point_normals = normal_from_depth_image(depth, K, w2c)[bool_mask] point_colors = img[bool_mask] curr_num_samples = points.shape[0] num_samples_per_frame = ( int(np.floor(num_samples / num_sampled_frames)) if i != candidate_frames[-1] else num_samples - (num_sampled_frames - 1) * int(np.floor(num_samples / num_sampled_frames)) ) if num_samples_per_frame < curr_num_samples: point_sels = np.random.choice( curr_num_samples, (num_samples_per_frame,), replace=False ) else: point_sels = np.arange(0, curr_num_samples) bkgd_points.append(points[point_sels]) bkgd_point_normals.append(point_normals[point_sels]) bkgd_point_colors.append(point_colors[point_sels]) bkgd_points = torch.cat(bkgd_points, dim=0) bkgd_point_normals = torch.cat(bkgd_point_normals, dim=0) bkgd_point_colors = torch.cat(bkgd_point_colors, dim=0) return bkgd_points, bkgd_point_normals, bkgd_point_colors def get_video_dataset(self) -> Dataset: return iPhoneDatasetVideoView(self) def __getitem__(self, index: int): if self.training: index = np.random.randint(0, self.num_frames) data = { # (). "frame_names": self.frame_names[index], # (). "ts": self.time_ids[index], # (4, 4). "w2cs": self.w2cs[index], # (3, 3). "Ks": self.Ks[index], # (H, W, 3). "imgs": self.imgs[index], # (H, W). "valid_masks": self.valid_masks[index], # (H, W). "masks": self.masks[index], } if self.training: # (H, W). data["depths"] = self.depths[index] # (P, 2). data["query_tracks_2d"] = self.query_tracks_2d[index][:, :2] target_inds = torch.from_numpy( np.random.choice( self.num_frames, (self.num_targets_per_frame,), replace=False ) ) # (N, P, 4). target_tracks_2d = torch.stack( [ torch.from_numpy( np.load( osp.join( self.data_dir, "flow3d_preprocessed/2d_tracks/", f"{self.factor}x/" f"{self.frame_names[index]}_" f"{self.frame_names[target_index.item()]}.npy", ) ).astype(np.float32) ) for target_index in target_inds ], dim=0, ) # (N,). target_ts = self.time_ids[target_inds] data["target_ts"] = target_ts # (N, 4, 4). data["target_w2cs"] = self.w2cs[target_ts] # (N, 3, 3). data["target_Ks"] = self.Ks[target_ts] # (N, P, 2). data["target_tracks_2d"] = target_tracks_2d[..., :2] # (N, P). ( data["target_visibles"], data["target_invisibles"], data["target_confidences"], ) = parse_tapir_track_info( target_tracks_2d[..., 2], target_tracks_2d[..., 3] ) # (N, P). data["target_track_depths"] = F.grid_sample( self.depths[target_inds, None], normalize_coords( target_tracks_2d[..., None, :2], self.imgs.shape[1], self.imgs.shape[2], ), align_corners=True, padding_mode="border", )[:, 0, :, 0] else: # (H, W). data["covisible_masks"] = self.covisible_masks[index] return data def preprocess(self, data): return data class iPhoneDatasetKeypointView(Dataset): """Return a dataset view of the annotated keypoints.""" def __init__(self, dataset: iPhoneDataset): super().__init__() self.dataset = dataset assert self.dataset.split == "train" # Load 2D keypoints. keypoint_paths = sorted( glob(osp.join(self.dataset.data_dir, "keypoint/2x/train/0_*.json")) ) keypoints = [] for keypoint_path in keypoint_paths: with open(keypoint_path) as f: keypoints.append(json.load(f)) time_ids = [ int(osp.basename(p).split("_")[1].split(".")[0]) for p in keypoint_paths ] # only use time ids that are in the dataset. start = self.dataset.start time_ids = [t - start for t in time_ids if t - start in self.dataset.time_ids] self.time_ids = torch.tensor(time_ids) self.time_pairs = torch.tensor(list(product(self.time_ids, repeat=2))) self.index_pairs = torch.tensor( list(product(range(len(self.time_ids)), repeat=2)) ) self.keypoints = torch.tensor(keypoints, dtype=torch.float32) self.keypoints[..., :2] *= 2.0 / self.dataset.factor def __len__(self): return len(self.time_pairs) def __getitem__(self, index: int): ts = self.time_pairs[index] return { "ts": ts, "w2cs": self.dataset.w2cs[ts], "Ks": self.dataset.Ks[ts], "imgs": self.dataset.imgs[ts], "keypoints": self.keypoints[self.index_pairs[index]], } class iPhoneDatasetVideoView(Dataset): """Return a dataset view of the video trajectory.""" def __init__(self, dataset: iPhoneDataset): super().__init__() self.dataset = dataset self.fps = self.dataset.fps assert self.dataset.split == "train" def __len__(self): return self.dataset.num_frames def __getitem__(self, index): return { "frame_names": self.dataset.frame_names[index], "ts": index, "w2cs": self.dataset.w2cs[index], "Ks": self.dataset.Ks[index], "imgs": self.dataset.imgs[index], "depths": self.dataset.depths[index], "masks": self.dataset.masks[index], } """ class iPhoneDataModule(BaseDataModule[iPhoneDataset]): def __init__( self, data_dir: str, factor: int = 1, start: int = 0, end: int = -1, depth_type: Literal[ "midas", "depth_anything", "lidar", "depth_anything_colmap", ] = "depth_anything_colmap", camera_type: Literal["original", "refined"] = "refined", use_median_filter: bool = False, num_targets_per_frame: int = 1, load_from_cache: bool = False, **kwargs, ): super().__init__(dataset_cls=iPhoneDataset, **kwargs) self.data_dir = data_dir self.start = start self.end = end self.factor = factor self.depth_type = depth_type self.camera_type = camera_type self.use_median_filter = use_median_filter self.num_targets_per_frame = num_targets_per_frame self.load_from_cache = load_from_cache self.val_loader_tasks = ["img", "keypoint"] def setup(self, *_, **__) -> None: guru.info("Loading train dataset...") self.train_dataset = self.dataset_cls( data_dir=self.data_dir, training=True, split="train", start=self.start, end=self.end, factor=self.factor, depth_type=self.depth_type, # type: ignore camera_type=self.camera_type, # type: ignore use_median_filter=self.use_median_filter, num_targets_per_frame=self.num_targets_per_frame, max_steps=self.max_steps * self.batch_size, load_from_cache=self.load_from_cache, ) if self.train_dataset.has_validation: guru.info("Loading val dataset...") self.val_dataset = self.dataset_cls( data_dir=self.data_dir, training=False, split="val", start=self.start, end=self.end, factor=self.factor, depth_type=self.depth_type, # type: ignore camera_type=self.camera_type, # type: ignore use_median_filter=self.use_median_filter, scene_norm_dict=self.train_dataset.scene_norm_dict, load_from_cache=self.load_from_cache, ) else: # Dummy validation set. self.val_dataset = TensorDataset(torch.zeros(0)) # type: ignore self.keypoint_dataset = iPhoneDatasetKeypointView(self.train_dataset) self.video_dataset = self.train_dataset.get_video_dataset() guru.success("Loading finished!") def train_dataloader(self) -> DataLoader: return DataLoader( self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=iPhoneDataset.train_collate_fn, ) def val_dataloader(self) -> list[DataLoader]: return [DataLoader(self.val_dataset), DataLoader(self.keypoint_dataset)] """ ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/data/panoptic_dataset.py ================================================ import os from dataclasses import dataclass from functools import partial from typing import Literal, cast import cv2 import imageio import numpy as np import torch import torch.nn.functional as F import tyro from loguru import logger as guru from roma import roma from tqdm import tqdm from flow3d.data.base_dataset import BaseDataset from flow3d.data.utils import ( UINT16_MAX, SceneNormDict, get_tracks_3d_for_query_frame, median_filter_2d, normal_from_depth_image, normalize_coords, parse_tapir_track_info, ) from flow3d.transforms import rt_to_mat4 import sys sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) import models.spatracker.datasets.utils as dataset_utils from models.spatracker.datasets.panoptic_studio_multiview_dataset import PanopticStudioMultiViewDataset from torch.utils.data import default_collate @dataclass class PanopticDataConfig: seq_name: str root_dir: str start: int = 0 end: int = -1 res: str = "" image_type: str = "images" mask_type: str = "masks" depth_type: Literal[ "aligned_depth_anything", "aligned_depth_anything_v2", "depth_anything", "depth_anything_v2", "unidepth_disp", ] = "aligned_depth_anything" camera_type: Literal["droid_recon"] = "droid_recon" track_2d_type: Literal["bootstapir", "tapir"] = "bootstapir" mask_erosion_radius: int = 7 scene_norm_dict: tyro.conf.Suppress[SceneNormDict | None] = None num_targets_per_frame: int = 4 load_from_cache: bool = False class PanopticStudioDatasetSoM(BaseDataset): def __init__( self, seq_name: str, root_dir: str, res: str = "480p", depth_type: Literal[ "aligned_depth_anything", "aligned_depth_anything_v2", "depth_anything", "depth_anything_v2", "unidepth_disp", ] = "aligned_depth_anything", mask_erosion_radius: int = 0, scene_norm_dict: SceneNormDict | None = None, num_targets_per_frame: int = 4, load_from_cache: bool = False, **_, ): super().__init__() self.seq_name = seq_name self.root_dir = root_dir self.res = res self.depth_type = depth_type self.num_targets_per_frame = num_targets_per_frame self.load_from_cache = load_from_cache self.has_validation = False self.mask_erosion_radius = mask_erosion_radius ####################################################################### self.views_to_return = [1, 7, 14, 20] datasets_root = "/cluster/scratch/egundogdu/datasets/" panoptic_kwargs = { "data_root": os.path.join(datasets_root, "panoptic_d3dgs"), "traj_per_sample": 384, "seed": 72, "max_videos": 1, "perform_sanity_checks": False, "views_to_return": [1, 7, 14, 20], "use_duster_depths": False, "clean_duster_depths": False, } self.panoptic_spatial_dataset = PanopticStudioMultiViewDataset(**panoptic_kwargs) datapoint = self.panoptic_spatial_dataset.__getitem__(0) if isinstance(datapoint, tuple): datapoint, gotit = datapoint assert gotit if torch.cuda.is_available(): dataset_utils.dataclass_to_cuda_(datapoint) device = torch.device("cuda") else: device = torch.device("cpu") self.img_dir_view_1 = os.path.join(datasets_root, "panoptic_d3dgs", "basketball", "ims", "1") self.frame_names = [os.path.splitext(p)[0] for p in sorted(os.listdir(self.img_dir_view_1))] # Per view data self.rgbs = datapoint.video self.depths = datapoint.videodepth self.image_features = datapoint.feats self.intrs = datapoint.intrs self.extrs = datapoint.extrs self.gt_trajectories_2d_pixelspace_w_z_cameraspace = datapoint.trajectory self.gt_visibilities_per_view = datapoint.visibility self.query_points_2d = (datapoint.query_points.clone().float().to(device) if datapoint.query_points is not None else None) self.query_points_3d = datapoint.query_points_3d.clone().float().to(device) # Non-per-view data self.gt_trajectories_3d_worldspace = datapoint.trajectory_3d self.valid_tracks_per_frame = datapoint.valid self.track_upscaling_factor = datapoint.track_upscaling_factor print(self.rgbs.shape) num_views, num_frames, _, height, width = self.rgbs.shape num_points = self.gt_trajectories_2d_pixelspace_w_z_cameraspace.shape[2] self.rgbs = self.rgbs.permute(0, 1, 3, 4, 2).cpu() self.depths = self.depths.permute(0, 1, 3, 4, 2).cpu() # Assert shapes of per-view data assert self.depths is not None, "Depth is required for evaluation." assert self.rgbs.shape == (num_views, num_frames, height, width, 3) assert self.depths.shape == (num_views, num_frames, height, width, 1) assert self.intrs.shape == (num_views, num_frames, 3, 3) assert self.extrs.shape == (num_views, num_frames, 3, 4) assert self.gt_trajectories_2d_pixelspace_w_z_cameraspace.shape == ( num_views, num_frames, num_points, 3) assert self.gt_visibilities_per_view.shape == (num_views, num_frames, num_points) # Assert shapes of non-per-view data assert self.query_points_3d.shape == (num_points, 4) assert self.gt_trajectories_3d_worldspace.shape == (num_frames, num_points, 3) assert self.valid_tracks_per_frame.shape == (num_frames, num_points) self.w2cs = torch.eye(4).expand(num_views, num_frames, 4, 4).clone() self.w2cs[:, :, :3, :] = self.extrs.squeeze(0).cpu() # (n_views, n_frames, 4, 4) self.Ks = self.intrs.squeeze(0).cpu() # (n_views, n_frames, 3, 3) ###### normalization... self.scale = 1 tracks_3d = self.get_tracks_3d(5000, step=num_frames // 10)[0] scale, transfm = compute_scene_norm(tracks_3d, self.w2cs) scene_norm_dict = SceneNormDict(scale=scale, transfm=transfm) # transform cameras self.scene_norm_dict = cast(SceneNormDict, scene_norm_dict) self.scale = self.scene_norm_dict["scale"] transform = self.scene_norm_dict["transfm"] guru.info(f"scene norm {self.scale=}, {transform=}") for v in range(num_views): self.w2cs[v] = torch.einsum("nij,jk->nik", self.w2cs[v], torch.linalg.inv(transform)) self.w2cs[v, :, :3, 3] /= self.scale @property def num_frames(self) -> int: return len(self.frame_names) @property def keyframe_idcs(self) -> torch.Tensor: # return self._keyframe_idcs return np.array(range(10,140,10)) def __len__(self): return len(self.frame_names) def get_w2cs(self, view_index=0) -> torch.Tensor: return self.w2cs[view_index].cpu().to(torch.float32) def get_Ks(self, view_index=0) -> torch.Tensor: return self.Ks[view_index].cpu().to(torch.float32) def get_img_wh(self) -> tuple[int, int]: return self.get_image(0).shape[1::-1] def get_image(self, index, view_index=0) -> torch.Tensor: return self.rgbs[view_index][index].cpu().to(torch.float32) / 255.0 def get_mask(self, index, view_index=0) -> torch.Tensor: view = self.views_to_return[view_index] mask = self.load_mask(index, view) mask = cast(torch.Tensor, mask) return mask.cpu().to(torch.float32) def get_depth(self, index, view=0) -> torch.Tensor: # return self.load_depth(index, view) / self.scales[view] return self.load_depth(index, view).cpu().to(torch.float32) / self.scale def load_mask(self, index, view=0) -> torch.Tensor: # self.mask_dir = "/cluster/scratch/egundogdu/datasets/panoptic_d3dgs/basketball/seg" self.mask_dir = "/cluster/home/egundogdu/projects/vlg-lab/spatialtracker/shape-of-motion/panoptic_masks" path = f"{self.mask_dir}/{view}/{self.frame_names[index]}.png" r = self.mask_erosion_radius mask = imageio.imread(path) fg_mask = mask.reshape((*mask.shape[:2], -1)).max(axis=-1) > 0 bg_mask = ~fg_mask fg_mask_erode = cv2.erode( fg_mask.astype(np.uint8), np.ones((r, r), np.uint8), iterations=1 ) bg_mask_erode = cv2.erode( bg_mask.astype(np.uint8), np.ones((r, r), np.uint8), iterations=1 ) out_mask = np.zeros_like(fg_mask, dtype=np.float32) out_mask[bg_mask_erode > 0] = -1 out_mask[fg_mask_erode > 0] = 1 return torch.from_numpy(out_mask).float() def load_depth(self, index, view=0) -> torch.Tensor: depth = self.depths[view][index] depth = depth.permute(2, 0, 1).unsqueeze(0) depth = median_filter_2d(depth, 11, 1)[0, 0] return depth.squeeze(0) ##################################### def get_foreground_points( self, num_samples: int, use_kf_tstamps: bool = False, stride: int = 4, down_rate: int = 8, min_per_frame: int = 64, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: start = 0 end = self.num_frames H, W = self.rgbs.shape[2:4] # Get height & width from rgbs shape # Create pixel grid grid = torch.stack( torch.meshgrid( torch.arange(0, W, dtype=torch.float32), torch.arange(0, H, dtype=torch.float32), indexing="xy", ), dim=-1, ) # Shape: (H, W, 2) if use_kf_tstamps: query_idcs = self.keyframe_idcs.tolist() else: num_query_frames = self.num_frames // stride query_endpts = torch.linspace(start, end, num_query_frames + 1) query_idcs = ((query_endpts[:-1] + query_endpts[1:]) / 2).long().tolist() bg_geometry = [] print(f"{query_idcs=}") # for v in range(self.rgbs.shape[0]): # Iterate over views for query_idx in tqdm(query_idcs, desc=f"Loading foreground points (view)", leave=False): for v in [0, 1, 2, 3]: img = self.get_image(query_idx, v).cpu().numpy() # Shape: (H, W, 3) height, width = img.shape[0], img.shape[1] depth = self.get_depth(query_idx, v).cpu().numpy() mask = self.get_mask(query_idx, v).cpu().numpy() < 0 # Shape: (H, W) valid_mask = (~mask * (depth > 0)).ravel() w2c = self.w2cs[v, query_idx].cpu().numpy() c2w = np.linalg.inv(w2c) k = self.Ks[v, query_idx].cpu().numpy() k_inv = np.linalg.inv(k) y, x = np.indices((height, width)) homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T cam_coords = (k_inv @ homo_pixel_coords) * depth.ravel() cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1])))) world_coords = (c2w @ cam_coords)[:3].T world_coords = world_coords[valid_mask] rgb_colors = img.reshape(-1, 3)[valid_mask].astype(np.uint8) bg_geometry.append((torch.from_numpy(world_coords), torch.from_numpy(world_coords), torch.from_numpy(rgb_colors))) rr.set_time_seconds("frame", query_idx / 30) rr.log(f"world/points/view_{v}_foreground", rr.Points3D(positions=world_coords, colors=rgb_colors * 255.0)) # tmp_img = img.clone() # tmp_img[~bool_mask] = 1 # img_8bit = (tmp_img.reshape(self.rgbs[v, query_idx].shape).cpu().numpy() * 255).astype(np.uint8) # datasets_root = f"/cluster/scratch/egundogdu/datasets/view{v}_frame{query_idx}.png" # cv2.imwrite(datasets_root, img_8bit[..., ::-1]) # print(f"Saved {datasets_root}") # img_8bit = (depth.cpu().numpy() * 255).astype(np.uint8) # datasets_root = f"/cluster/scratch/egundogdu/datasets/depth_view{v}_frame{query_idx}.png" # cv2.imwrite(datasets_root, img_8bit[..., ::-1]) # print(f"Saved {datasets_root}") # img_8bit = (bool_mask.cpu().numpy() * 255).astype(np.uint8) # datasets_root = f"/cluster/scratch/egundogdu/datasets/bool_mask_view{v}_frame{query_idx}.png" # cv2.imwrite(datasets_root, img_8bit[..., ::-1]) # print(f"Saved {datasets_root}") bg_points, bg_normals, bg_colors = map( partial(torch.cat, dim=0), zip(*bg_geometry) ) # Final downsampling # doesnt use texture-based prob sampling # TODO: add texture information to sample from a probability if len(bg_points) > num_samples: sel_idcs = np.random.choice(len(bg_points), num_samples, replace=False) bg_points = bg_points[sel_idcs] bg_normals = bg_normals[sel_idcs] bg_colors = bg_colors[sel_idcs] return bg_points, bg_normals, bg_colors def get_bkgd_points( self, num_samples: int, use_kf_tstamps: bool = False, stride: int = 8, down_rate: int = 8, min_per_frame: int = 64, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: start = 0 end = self.num_frames H, W = self.rgbs.shape[2:4] # Get height & width from rgbs shape # Create pixel grid grid = torch.stack( torch.meshgrid( torch.arange(0, W, dtype=torch.float32), torch.arange(0, H, dtype=torch.float32), indexing="xy", ), dim=-1, ) # Shape: (H, W, 2) if use_kf_tstamps: query_idcs = self.keyframe_idcs.tolist() else: num_query_frames = self.num_frames // stride query_endpts = torch.linspace(start, end, num_query_frames + 1) query_idcs = ((query_endpts[:-1] + query_endpts[1:]) / 2).long().tolist() bg_geometry = [] print(f"{query_idcs=}") view_index_list = [0, 1, 2, 3] # for v in range(self.rgbs.shape[0]): # Iterate over views for query_idx in tqdm(query_idcs, desc=f"Loading bkgd points (view)", leave=False): for v in view_index_list: img = self.get_image(query_idx, v).cpu().numpy() # Shape: (H, W, 3) height, width = img.shape[0], img.shape[1] depth = self.get_depth(query_idx, v).cpu().numpy() mask = self.get_mask(query_idx, v).cpu().numpy() < 0 # Shape: (H, W) valid_mask = (mask * (depth > 0)).ravel() # valid_mask = depth.ravel() > 0 w2c = self.w2cs[v, query_idx].cpu().numpy() c2w = np.linalg.inv(w2c) k = self.Ks[v, query_idx].cpu().numpy() k_inv = np.linalg.inv(k) y, x = np.indices((height, width)) homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T cam_coords = (k_inv @ homo_pixel_coords) * depth.ravel() cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1])))) world_coords = (c2w @ cam_coords)[:3].T world_coords = world_coords[valid_mask] rgb_colors = img.reshape(-1, 3)[valid_mask] bg_geometry.append((torch.from_numpy(world_coords).to(torch.float32), torch.from_numpy(world_coords).to(torch.float32), torch.from_numpy(rgb_colors).to(torch.float32))) bg_points, bg_normals, bg_colors = map( partial(torch.cat, dim=0), zip(*bg_geometry) ) # Final downsampling # doesnt use texture-based prob sampling # TODO: add texture information to sample from a probability if len(bg_points) > num_samples: sel_idcs = np.random.choice(len(bg_points), num_samples, replace=False) bg_points = bg_points[sel_idcs] bg_normals = bg_normals[sel_idcs] bg_colors = bg_colors[sel_idcs] return bg_points, bg_normals, bg_colors ##################################### def load_target_tracks( self, query_index: int, target_indices: list[int], view_index=0, dim: int = 1 ): """ tracks are 2d, occs and uncertainties :param dim (int), default 1: dimension to stack the time axis return (N, T, 4) if dim=1, (T, N, 4) if dim=0 """ view = self.views_to_return[view_index] q_name = self.frame_names[query_index] all_tracks = [] for ti in target_indices: t_name = self.frame_names[ti] # path = f"/cluster/scratch/egundogdu/datasets/panoptic_d3dgs/basketball/tracks_tapvid_som/{view}/{q_name}_{t_name}.npy" path = f"/cluster/home/egundogdu/projects/vlg-lab/spatialtracker/shape-of-motion/panoptic_tracks/{view}/{q_name}_{t_name}.npy" tracks = np.load(path).astype(np.float32) all_tracks.append(tracks) return torch.from_numpy(np.stack(all_tracks, axis=dim)) def get_tracks_3d( self, num_samples: int, start: int = 0, end: int = -1, step: int = 1, **kwargs ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: num_frames = self.num_frames if end < 0: end = num_frames + 1 + end query_idcs = list(range(start, end, step)) target_idcs = list(range(start, end, step)) num_per_query_frame = int(np.ceil(num_samples / len(query_idcs) / 8)) cur_num = 0 tracks_all_queries = [] view_index_list = [0, 1, 2, 3] precomputed_data = {} for v in view_index_list: masks = torch.stack([self.get_mask(i, v).cpu() for i in target_idcs], dim=0) fg_masks = (masks == 1).float() depths = torch.stack([self.get_depth(i, v).cpu() for i in target_idcs], dim=0) inv_Ks = torch.linalg.inv(self.Ks[v][target_idcs].cpu()) c2ws = torch.linalg.inv(self.w2cs[v][target_idcs].cpu()) precomputed_data[v] = (fg_masks, depths, inv_Ks, c2ws) for q_idx in tqdm(query_idcs, desc=f"Loading 3d tracks points", leave=False): for v in view_index_list: # # masks = torch.stack([self.get_mask(i, v) for i in target_idcs], dim=0) # # fg_masks = (masks == 1).float() # # depths = torch.stack([self.get_depth(i, v) for i in target_idcs], dim=0) # inv_Ks = torch.linalg.inv(self.Ks[v][target_idcs]) # c2ws = torch.linalg.inv(self.w2cs[v][target_idcs]) fg_masks, depths, inv_Ks, c2ws = precomputed_data[v] # (N, T, 4) # print(q_idx, len(query_idcs), "cur: ", cur_num) tracks_2d = self.load_target_tracks(q_idx, target_idcs, v).cpu() num_sel = int( min(num_per_query_frame, num_samples - cur_num, len(tracks_2d)) ) if num_sel < len(tracks_2d): sel_idcs = np.random.choice(len(tracks_2d), num_sel, replace=False) tracks_2d = tracks_2d[sel_idcs] cur_num += tracks_2d.shape[0] img = self.get_image(q_idx, v).cpu() tidx = target_idcs.index(q_idx) tracks_tuple = get_tracks_3d_for_query_frame( tidx, img, tracks_2d, depths, fg_masks, inv_Ks, c2ws ) tracks_all_queries.append(tracks_tuple) tracks_3d, colors, visibles, invisibles, confidences = map( partial(torch.cat, dim=0), zip(*tracks_all_queries) ) return tracks_3d, visibles, invisibles, confidences, colors def train_collate_fn(self, batch): """ Collate function that correctly batches data when each sample consists of multiple views. """ # Step 1: Transpose the batch to group by views # If batch contains 4 views per sample, `batch` is a list of lists: [ [view_1, view_2, view_3, view_4], [view_1, view_2, view_3, view_4], ... ] # We want to group all view_1's together, all view_2's together, etc. num_views = len(batch[0]) # Assumes each sample has the same number of views batch_per_view = list(zip(*batch)) # Transposes list-of-lists structure collated_views = [] # Step 2: Collate each view separately for view_batch in batch_per_view: collated = {} for k in view_batch[0]: # Iterate over keys in the dictionary if k not in [ "query_tracks_2d", "target_ts", "target_w2cs", "target_Ks", "target_tracks_2d", "target_visibles", "target_track_depths", "target_invisibles", "target_confidences", ]: collated[k] = default_collate([sample[k] for sample in view_batch]) else: collated[k] = [sample[k] for sample in view_batch] # Keep list format collated_views.append(collated) return collated_views # List of collated dictionaries, one per view # def __getitem__(self, index: int, view=0): # index = np.random.randint(0, self.num_frames) # data = { # # (). # "frame_names": self.frame_names[index], # # (). # "ts": torch.tensor(index), # # (4, 4). # "w2cs": self.w2cs[view][index], # # (3, 3). # "Ks": self.Ks[view][index], # # (H, W, 3). # "imgs": self.get_image(index, view), # "depths": self.get_depth(index, view), # } # tri_mask = self.get_mask(index, view) # valid_mask = tri_mask != 0 # not fg or bg # mask = tri_mask == 1 # fg mask # data["masks"] = mask.float() # data["valid_masks"] = valid_mask.float() # # (P, 2) # query_tracks = self.load_target_tracks(index, [index], view_index=view)[:, 0, :2] # target_inds = torch.from_numpy( # np.random.choice( # self.num_frames, (self.num_targets_per_frame,), replace=False # ) # ) # # (N, P, 4) # target_tracks = self.load_target_tracks(index, target_inds.tolist(), view_index=view, dim=0) # data["query_tracks_2d"] = query_tracks # data["target_ts"] = target_inds # data["target_w2cs"] = self.w2cs[view][target_inds] # data["target_Ks"] = self.Ks[view][target_inds] # data["target_tracks_2d"] = target_tracks[..., :2] # # (N, P). # ( # data["target_visibles"], # data["target_invisibles"], # data["target_confidences"], # ) = parse_tapir_track_info(target_tracks[..., 2], target_tracks[..., 3]) # # (N, H, W) # target_depths = torch.stack([self.get_depth(i, view) for i in target_inds], dim=0) # H, W = target_depths.shape[-2:] # data["target_track_depths"] = F.grid_sample( # target_depths[:, None], # normalize_coords(target_tracks[..., None, :2], H, W), # align_corners=True, # padding_mode="border", # )[:, 0, :, 0] # return data def get_batches(self, batch_size): num_batches = self.num_frames // batch_size # Determine number of batches train_collated_merged_data = [] for _ in range(num_batches): train_collated_merged_data.append(self.__getitem_as_batch__(batch_size)) return train_collated_merged_data def __getitem_as_batch__(self, batch_size): # index = np.random.randint(0, self.num_frames) if batch_size > self.num_frames: index = np.random.choice(self.num_frames, batch_size, replace=True) # Sample with replacement else: index = np.random.choice(self.num_frames, batch_size, replace=False) # Sample without replacement merged_data = [] for i in tqdm(index): view_data = [] for view in [0, 1, 2, 3]: view_data.append(self.__getitem_single_view__(i, view)) merged_data.append(view_data) return self.train_collate_fn(merged_data) def __getitem_single_view__(self, index: int, view: int): index = np.random.randint(0, self.num_frames) data = { # (). "frame_names": self.frame_names[index], # (). "ts": torch.tensor(index), # (4, 4). "w2cs": self.w2cs[view][index], # (3, 3). "Ks": self.Ks[view][index], # (H, W, 3). "imgs": self.get_image(index, view), "depths": self.get_depth(index, view), } tri_mask = self.get_mask(index, view) valid_mask = tri_mask != 0 # not fg or bg mask = tri_mask == 1 # fg mask data["masks"] = mask.float() data["valid_masks"] = valid_mask.float() # (P, 2) query_tracks = self.load_target_tracks(index, [index], view_index=view)[:, 0, :2] target_inds = torch.from_numpy( np.random.choice( self.num_frames, (self.num_targets_per_frame,), replace=False ) ) # (N, P, 4) target_tracks = self.load_target_tracks(index, target_inds.tolist(), view_index=view, dim=0) data["query_tracks_2d"] = query_tracks data["target_ts"] = target_inds data["target_w2cs"] = self.w2cs[view][target_inds] data["target_Ks"] = self.Ks[view][target_inds] data["target_tracks_2d"] = target_tracks[..., :2] # (N, P). ( data["target_visibles"], data["target_invisibles"], data["target_confidences"], ) = parse_tapir_track_info(target_tracks[..., 2], target_tracks[..., 3]) # (N, H, W) target_depths = torch.stack([self.get_depth(i, view) for i in target_inds], dim=0) H, W = target_depths.shape[-2:] data["target_track_depths"] = F.grid_sample( target_depths[:, None], normalize_coords(target_tracks[..., None, :2], H, W), align_corners=True, padding_mode="border", )[:, 0, :, 0] return data def __getitem__(self, index: int): index = np.random.randint(0, self.num_frames) merged_data = [] for view in [0, 1, 2, 3]: data = { # (). "frame_names": self.frame_names[index], # (). "ts": torch.tensor(index), # (4, 4). "w2cs": self.w2cs[view][index], # (3, 3). "Ks": self.Ks[view][index], # (H, W, 3). "imgs": self.get_image(index, view), "depths": self.get_depth(index, view), } tri_mask = self.get_mask(index, view) valid_mask = tri_mask != 0 # not fg or bg mask = tri_mask == 1 # fg mask data["masks"] = mask.float() data["valid_masks"] = valid_mask.float() # (P, 2) query_tracks = self.load_target_tracks(index, [index], view_index=view)[:, 0, :2] target_inds = torch.from_numpy( np.random.choice( self.num_frames, (self.num_targets_per_frame,), replace=False ) ) # (N, P, 4) target_tracks = self.load_target_tracks(index, target_inds.tolist(), view_index=view, dim=0) data["query_tracks_2d"] = query_tracks data["target_ts"] = target_inds data["target_w2cs"] = self.w2cs[view][target_inds] data["target_Ks"] = self.Ks[view][target_inds] data["target_tracks_2d"] = target_tracks[..., :2] # (N, P). ( data["target_visibles"], data["target_invisibles"], data["target_confidences"], ) = parse_tapir_track_info(target_tracks[..., 2], target_tracks[..., 3]) # (N, H, W) target_depths = torch.stack([self.get_depth(i, view) for i in target_inds], dim=0) H, W = target_depths.shape[-2:] data["target_track_depths"] = F.grid_sample( target_depths[:, None], normalize_coords(target_tracks[..., None, :2], H, W), align_corners=True, padding_mode="border", )[:, 0, :, 0] merged_data.append(data) return merged_data def compute_scene_norm( X: torch.Tensor, w2cs: torch.Tensor ) -> tuple[float, torch.Tensor]: """ :param X: [N*T, 3] # :param w2cs: [N, 4, 4] :param w2cs: [n_views, N, 4, 4] """ X = X.reshape(-1, 3) scene_center = X.mean(dim=0) X = X - scene_center[None] min_scale = X.quantile(0.05, dim=0) max_scale = X.quantile(0.95, dim=0) scale = (max_scale - min_scale).max().item() / 2.0 original_up = -F.normalize(w2cs[:, :, 1, :3].mean(dim=(0,1)), dim=-1) target_up = original_up.new_tensor([0.0, 0.0, 1.0]) R = roma.rotvec_to_rotmat( F.normalize(original_up.cross(target_up), dim=-1) * original_up.dot(target_up).acos_() ) transfm = rt_to_mat4(R, torch.einsum("ij,j->i", -R, scene_center)) return scale, transfm # import rerun as rr if __name__ == "__main__": # rr.init("3dpt", recording_id="v0.1") # rr.connect_tcp("0.0.0.0:9876") # rr.set_time_seconds("frame", 0) # rr.log("world/xyz", rr.Arrows3D(vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]], # colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]])) d = PanopticStudioDatasetSoM("", "", camera_type="") batch = d.__getitem_as_batch__(150) import ipdb ipdb.set_trace() # print(d["imgs"]) # # Get background points # points, normals, colors = d.get_bkgd_points(num_samples=100_000) # print(points.dtype) # rr.set_time_seconds("frame", 0) # rr.log(f"world/points/final_background", rr.Points3D(positions=points, colors=colors * 255.0)) # print("Done.") # # # Get foreground points # points, normals, colors = d.get_foreground_points(num_samples=40_000) # rr.set_time_seconds("frame", 0) # rr.log(f"world/points/final_foreground", rr.Points3D(positions=points, colors=colors * 255.0)) # print("Done.") # # tracks_2d = d.load_target_tracks(0, [0,1,2,3,4], 1) # # print(tracks_2d.dtype) # # # tracks_3d, visibles, invisibles, confidences, colors = d.get_tracks_3d(40000) # # # colors = (colors * 255.0) # # # print( # # # f"{tracks_3d.shape=} {visibles.shape=} " # # # f"{invisibles.shape=} {confidences.shape=} " # # # f"{colors.shape=}" # # # ) # # # # Loop through 150 frames and log the corresponding points # # # num_frames = tracks_3d.shape[1] # 150 frames # # # for frame_idx in range(num_frames): # # # rr.set_time_seconds("frame", frame_idx) # # # # Get the 3D positions for the current frame # # # frame_tracks = tracks_3d[:, frame_idx, :] # Shape: (35418, 3) # # # frame_visibles = visibles[:, frame_idx] # Visibility mask # # # # Filter only visible points # # # visible_tracks = frame_tracks[frame_visibles > 0] # # # visible_colors = colors[frame_visibles > 0] # # # rr.set_time_seconds("frame", frame_idx / 30) # # # rr.log(f"world/tracks_3d", rr.Points3D(positions=visible_tracks, colors=visible_colors)) ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/data/utils.py ================================================ from typing import List, Optional, Tuple, TypedDict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.utils import _pair, _quadruple UINT16_MAX = 65535 class SceneNormDict(TypedDict): scale: float transfm: torch.Tensor def to_device(batch, device): if isinstance(batch, dict): return {k: to_device(v, device) for k, v in batch.items()} if isinstance(batch, (list, tuple)): return [to_device(v, device) for v in batch] if isinstance(batch, torch.Tensor): return batch.to(device) return batch def normalize_coords(coords, h, w): assert coords.shape[-1] == 2 return coords / torch.tensor([w - 1.0, h - 1.0], device=coords.device) * 2 - 1.0 def postprocess_occlusions(occlusions, expected_dist): """Postprocess occlusions to boolean visible flag. Args: occlusions: [-inf, inf], np.float32 expected_dist:, [-inf, inf], np.float32 Returns: visibles: bool """ def sigmoid(x): if x.dtype == np.ndarray: return 1 / (1 + np.exp(-x)) else: return torch.sigmoid(x) visibles = (1 - sigmoid(occlusions)) * (1 - sigmoid(expected_dist)) > 0.5 return visibles def parse_tapir_track_info(occlusions, expected_dist): """ return: valid_visible: mask of visible & confident points valid_invisible: mask of invisible & confident points confidence: clamped confidence scores (all < 0.5 -> 0) """ visiblility = 1 - F.sigmoid(occlusions) confidence = 1 - F.sigmoid(expected_dist) valid_visible = visiblility * confidence > 0.5 valid_invisible = (1 - visiblility) * confidence > 0.5 # set all confidence < 0.5 to 0 confidence = confidence * (valid_visible | valid_invisible).float() return valid_visible, valid_invisible, confidence def get_tracks_3d_for_query_frame( query_index: int, query_img: torch.Tensor, tracks_2d: torch.Tensor, depths: torch.Tensor, masks: torch.Tensor, inv_Ks: torch.Tensor, c2ws: torch.Tensor, ): """ :param query_index (int) :param query_img [H, W, 3] :param tracks_2d [N, T, 4] :param depths [T, H, W] :param masks [T, H, W] :param inv_Ks [T, 3, 3] :param c2ws [T, 4, 4] returns ( tracks_3d [N, T, 3] track_colors [N, 3] visibles [N, T] invisibles [N, T] confidences [N, T] ) """ T, H, W = depths.shape query_img = query_img[None].permute(0, 3, 1, 2) # (1, 3, H, W) tracks_2d = tracks_2d.swapaxes(0, 1) # (T, N, 4) tracks_2d, occs, dists = ( tracks_2d[..., :2], tracks_2d[..., 2], tracks_2d[..., 3], ) # visibles = postprocess_occlusions(occs, dists) # (T, N), (T, N), (T, N) visibles, invisibles, confidences = parse_tapir_track_info(occs, dists) # Unproject 2D tracks to 3D. # (T, 1, H, W), (T, 1, N, 2) -> (T, 1, 1, N) track_depths = F.grid_sample( depths[:, None], normalize_coords(tracks_2d[:, None], H, W), align_corners=True, padding_mode="border", )[:, 0, 0] tracks_3d = ( torch.einsum( "nij,npj->npi", inv_Ks, F.pad(tracks_2d, (0, 1), value=1.0), ) * track_depths[..., None] ) tracks_3d = torch.einsum("nij,npj->npi", c2ws, F.pad(tracks_3d, (0, 1), value=1.0))[ ..., :3 ] # Filter out out-of-mask tracks. # (T, 1, H, W), (T, 1, N, 2) -> (T, 1, 1, N) is_in_masks = ( F.grid_sample( masks[:, None], normalize_coords(tracks_2d[:, None], H, W), align_corners=True, )[:, 0, 0] == 1 ) visibles *= is_in_masks invisibles *= is_in_masks confidences *= is_in_masks.float() # valid if in the fg mask at least 40% of the time # in_mask_counts = is_in_masks.sum(0) # t = 0.25 # thresh = min(t * T, in_mask_counts.float().quantile(t).item()) # valid = in_mask_counts > thresh valid = is_in_masks[query_index] # valid if visible 5% of the time visible_counts = visibles.sum(0) valid = valid & ( visible_counts >= min( int(0.05 * T), visible_counts.float().quantile(0.1).item(), ) ) # Get track's color from the query frame. # (1, 3, H, W), (1, 1, N, 2) -> (1, 3, 1, N) -> (N, 3) track_colors = F.grid_sample( query_img, normalize_coords(tracks_2d[query_index : query_index + 1, None], H, W), align_corners=True, padding_mode="border", )[0, :, 0].T return ( tracks_3d[:, valid].swapdims(0, 1), track_colors[valid], visibles[:, valid].swapdims(0, 1), invisibles[:, valid].swapdims(0, 1), confidences[:, valid].swapdims(0, 1), ) def _get_padding(x, k, stride, padding, same: bool): if same: ih, iw = x.size()[2:] if ih % stride[0] == 0: ph = max(k[0] - stride[0], 0) else: ph = max(k[0] - (ih % stride[0]), 0) if iw % stride[1] == 0: pw = max(k[1] - stride[1], 0) else: pw = max(k[1] - (iw % stride[1]), 0) pl = pw // 2 pr = pw - pl pt = ph // 2 pb = ph - pt padding = (pl, pr, pt, pb) else: padding = padding return padding def median_filter_2d(x, kernel_size=3, stride=1, padding=1, same: bool = True): """ :param x [B, C, H, W] """ k = _pair(kernel_size) stride = _pair(stride) # convert to tuple padding = _quadruple(padding) # convert to l, r, t, b # using existing pytorch functions and tensor ops so that we get autograd, # would likely be more efficient to implement from scratch at C/Cuda level x = F.pad(x, _get_padding(x, k, stride, padding, same), mode="reflect") x = x.unfold(2, k[0], stride[0]).unfold(3, k[1], stride[1]) x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] return x def masked_median_blur(image, mask, kernel_size=11): """ Args: image: [B, C, H, W] mask: [B, C, H, W] kernel_size: int """ assert image.shape == mask.shape if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if not len(image.shape) == 4: raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {image.shape}") padding: Tuple[int, int] = _compute_zero_padding((kernel_size, kernel_size)) # prepare kernel kernel: torch.Tensor = get_binary_kernel2d((kernel_size, kernel_size)).to(image) b, c, h, w = image.shape # map the local window to single vector features: torch.Tensor = F.conv2d( image.reshape(b * c, 1, h, w), kernel, padding=padding, stride=1 ) masks: torch.Tensor = F.conv2d( mask.reshape(b * c, 1, h, w), kernel, padding=padding, stride=1 ) features = features.view(b, c, -1, h, w).permute( 0, 1, 3, 4, 2 ) # BxCxxHxWx(K_h * K_w) min_value, max_value = features.min(), features.max() masks = masks.view(b, c, -1, h, w).permute(0, 1, 3, 4, 2) # BxCxHxWx(K_h * K_w) index_invalid = (1 - masks).nonzero(as_tuple=True) index_b, index_c, index_h, index_w, index_k = index_invalid features[(index_b[::2], index_c[::2], index_h[::2], index_w[::2], index_k[::2])] = ( min_value ) features[ (index_b[1::2], index_c[1::2], index_h[1::2], index_w[1::2], index_k[1::2]) ] = max_value # compute the median along the feature axis median: torch.Tensor = torch.median(features, dim=-1)[0] return median def _compute_zero_padding(kernel_size: Tuple[int, int]) -> Tuple[int, int]: r"""Utility function that computes zero padding tuple.""" computed: List[int] = [(k - 1) // 2 for k in kernel_size] return computed[0], computed[1] def get_binary_kernel2d( window_size: tuple[int, int] | int, *, device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ from kornia Create a binary kernel to extract the patches. If the window size is HxW will create a (H*W)x1xHxW kernel. """ ky, kx = _unpack_2d_ks(window_size) window_range = kx * ky kernel = torch.zeros((window_range, window_range), device=device, dtype=dtype) idx = torch.arange(window_range, device=device) kernel[idx, idx] += 1.0 return kernel.view(window_range, 1, ky, kx) def _unpack_2d_ks(kernel_size: tuple[int, int] | int) -> tuple[int, int]: if isinstance(kernel_size, int): ky = kx = kernel_size else: assert len(kernel_size) == 2, "2D Kernel size should have a length of 2." ky, kx = kernel_size ky = int(ky) kx = int(kx) return (ky, kx) ## Functions from GaussianShader. def ndc_2_cam(ndc_xyz, intrinsic, W, H): inv_scale = torch.tensor([[W - 1, H - 1]], device=ndc_xyz.device) cam_z = ndc_xyz[..., 2:3] cam_xy = ndc_xyz[..., :2] * inv_scale * cam_z cam_xyz = torch.cat([cam_xy, cam_z], dim=-1) cam_xyz = cam_xyz @ torch.inverse(intrinsic[0, ...].t()) return cam_xyz def depth2point_cam(sampled_depth, ref_intrinsic): B, N, C, H, W = sampled_depth.shape valid_z = sampled_depth valid_x = torch.arange(W, dtype=torch.float32, device=sampled_depth.device) / ( W - 1 ) valid_y = torch.arange(H, dtype=torch.float32, device=sampled_depth.device) / ( H - 1 ) valid_y, valid_x = torch.meshgrid(valid_y, valid_x, indexing="ij") # B,N,H,W valid_x = valid_x[None, None, None, ...].expand(B, N, C, -1, -1) valid_y = valid_y[None, None, None, ...].expand(B, N, C, -1, -1) ndc_xyz = torch.stack([valid_x, valid_y, valid_z], dim=-1).view( B, N, C, H, W, 3 ) # 1, 1, 5, 512, 640, 3 cam_xyz = ndc_2_cam(ndc_xyz, ref_intrinsic, W, H) # 1, 1, 5, 512, 640, 3 return ndc_xyz, cam_xyz def depth2point_world(depth_image, intrinsic_matrix, extrinsic_matrix): # depth_image: (H, W), intrinsic_matrix: (3, 3), extrinsic_matrix: (4, 4) _, xyz_cam = depth2point_cam( depth_image[None, None, None, ...], intrinsic_matrix[None, ...] ) xyz_cam = xyz_cam.reshape(-1, 3) xyz_world = torch.cat( [xyz_cam, torch.ones_like(xyz_cam[..., 0:1])], dim=-1 ) @ torch.inverse(extrinsic_matrix).transpose(0, 1) xyz_world = xyz_world[..., :3] return xyz_world def depth_pcd2normal(xyz): hd, wd, _ = xyz.shape bottom_point = xyz[..., 2:hd, 1 : wd - 1, :] top_point = xyz[..., 0 : hd - 2, 1 : wd - 1, :] right_point = xyz[..., 1 : hd - 1, 2:wd, :] left_point = xyz[..., 1 : hd - 1, 0 : wd - 2, :] left_to_right = right_point - left_point bottom_to_top = top_point - bottom_point xyz_normal = torch.cross(left_to_right, bottom_to_top, dim=-1) xyz_normal = torch.nn.functional.normalize(xyz_normal, p=2, dim=-1) xyz_normal = torch.nn.functional.pad( xyz_normal.permute(2, 0, 1), (1, 1, 1, 1), mode="constant" ).permute(1, 2, 0) return xyz_normal def normal_from_depth_image(depth, intrinsic_matrix, extrinsic_matrix): # depth: (H, W), intrinsic_matrix: (3, 3), extrinsic_matrix: (4, 4) # xyz_normal: (H, W, 3) xyz_world = depth2point_world(depth, intrinsic_matrix, extrinsic_matrix) # (HxW, 3) xyz_world = xyz_world.reshape(*depth.shape, 3) xyz_normal = depth_pcd2normal(xyz_world) return xyz_normal ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/init_utils.py ================================================ import time from typing import Literal import cupy as cp import imageio.v3 as iio import numpy as np # from pytorch3d.ops import sample_farthest_points import roma import torch import torch.nn.functional as F from cuml import HDBSCAN, KMeans from loguru import logger as guru from matplotlib.pyplot import get_cmap from tqdm import tqdm from viser import ViserServer from flow3d.loss_utils import ( compute_accel_loss, compute_se3_smoothness_loss, compute_z_acc_loss, get_weights_for_procrustes, knn, masked_l1_loss, ) from flow3d.params import GaussianParams, MotionBases from flow3d.tensor_dataclass import StaticObservations, TrackObservations from flow3d.transforms import cont_6d_to_rmat, rt_to_mat4, solve_procrustes from flow3d.vis.utils import draw_keypoints_video, get_server, project_2d_tracks def init_fg_from_tracks_3d( cano_t: int, tracks_3d: TrackObservations, motion_coefs: torch.Tensor ) -> GaussianParams: """ using dataclasses individual tensors so we know they're consistent and are always masked/filtered together """ num_fg = tracks_3d.xyz.shape[0] # Initialize gaussian colors. colors = torch.logit(tracks_3d.colors) # Initialize gaussian scales: find the average of the three nearest # neighbors in the first frame for each point and use that as the # scale. dists, _ = knn(tracks_3d.xyz[:, cano_t], 3) dists = torch.from_numpy(dists) scales = dists.mean(dim=-1, keepdim=True) scales = scales.clamp(torch.quantile(scales, 0.05), torch.quantile(scales, 0.95)) scales = torch.log(scales.repeat(1, 3)) # Initialize gaussian means. means = tracks_3d.xyz[:, cano_t] # Initialize gaussian orientations as random. quats = torch.rand(num_fg, 4) # Initialize gaussian opacities. opacities = torch.logit(torch.full((num_fg,), 0.7)) gaussians = GaussianParams(means, quats, scales, colors, opacities, motion_coefs) return gaussians def init_bg( points: StaticObservations, ) -> GaussianParams: """ using dataclasses instead of individual tensors so we know they're consistent and are always masked/filtered together """ num_init_bg_gaussians = points.xyz.shape[0] bg_scene_center = points.xyz.mean(0) bg_points_centered = points.xyz - bg_scene_center bg_min_scale = bg_points_centered.quantile(0.05, dim=0) bg_max_scale = bg_points_centered.quantile(0.95, dim=0) bg_scene_scale = torch.max(bg_max_scale - bg_min_scale).item() / 2.0 bkdg_colors = torch.logit(points.colors) # Initialize gaussian scales: find the average of the three nearest # neighbors in the first frame for each point and use that as the # scale. dists, _ = knn(points.xyz, 3) dists = torch.from_numpy(dists) bg_scales = dists.mean(dim=-1, keepdim=True) bkdg_scales = torch.log(bg_scales.repeat(1, 3)) bg_means = points.xyz # Initialize gaussian orientations by normals. local_normals = points.normals.new_tensor([[0.0, 0.0, 1.0]]).expand_as( points.normals ) angles = torch.clamp((local_normals * points.normals).sum(-1, keepdim=True), -1.0, 1.0).acos_() # bg_quats = roma.rotvec_to_unitquat( # F.normalize(local_normals.cross(points.normals), dim=-1) # * (local_normals * points.normals).sum(-1, keepdim=True).acos_() # ).roll(1, dims=-1) bg_quats = roma.rotvec_to_unitquat( F.normalize(local_normals.cross(points.normals), dim=-1) * angles ).roll(1, dims=-1) bg_opacities = torch.logit(torch.full((num_init_bg_gaussians,), 0.7)) gaussians = GaussianParams( bg_means, bg_quats, bkdg_scales, bkdg_colors, bg_opacities, scene_center=bg_scene_center, scene_scale=bg_scene_scale, ) return gaussians def init_motion_params_with_procrustes( tracks_3d: TrackObservations, num_bases: int, rot_type: Literal["quat", "6d"], cano_t: int, cluster_init_method: str = "kmeans", min_mean_weight: float = 0.1, vis: bool = False, port: int | None = None, ) -> tuple[MotionBases, torch.Tensor, TrackObservations]: device = tracks_3d.xyz.device num_frames = tracks_3d.xyz.shape[1] # sample centers and get initial se3 motion bases by solving procrustes means_cano = tracks_3d.xyz[:, cano_t].clone() # [num_gaussians, 3] # remove outliers scene_center = means_cano.median(dim=0).values print(f"{scene_center=}") dists = torch.norm(means_cano - scene_center, dim=-1) dists_th = torch.quantile(dists, 0.95) valid_mask = dists < dists_th # remove tracks that are not visible in any frame valid_mask = valid_mask & tracks_3d.visibles.any(dim=1) print(f"{valid_mask.sum()=}") tracks_3d = tracks_3d.filter_valid(valid_mask) if vis and port is not None: server = get_server(port) try: pts = tracks_3d.xyz.cpu().numpy() clrs = tracks_3d.colors.cpu().numpy() while True: for t in range(num_frames): server.scene.add_point_cloud("points", pts[:, t], clrs) time.sleep(0.3) except KeyboardInterrupt: pass means_cano = means_cano[valid_mask] sampled_centers, num_bases, labels = sample_initial_bases_centers( cluster_init_method, cano_t, tracks_3d, num_bases ) # assign each point to the label to compute the cluster weight ids, counts = labels.unique(return_counts=True) ids = ids[counts > 100] num_bases = len(ids) sampled_centers = sampled_centers[:, ids] print(f"{num_bases=} {sampled_centers.shape=}") # compute basis weights from the distance to the cluster centers dists2centers = torch.norm(means_cano[:, None] - sampled_centers, dim=-1) motion_coefs = 10 * torch.exp(-dists2centers) init_rots, init_ts = [], [] if rot_type == "quat": id_rot = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device) rot_dim = 4 else: id_rot = torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], device=device) rot_dim = 6 init_rots = id_rot.reshape(1, 1, rot_dim).repeat(num_bases, num_frames, 1) init_ts = torch.zeros(num_bases, num_frames, 3, device=device) errs_before = np.full((num_bases, num_frames), -1.0) errs_after = np.full((num_bases, num_frames), -1.0) tgt_ts = list(range(cano_t - 1, -1, -1)) + list(range(cano_t, num_frames)) print(f"{tgt_ts=}") skipped_ts = {} for n, cluster_id in enumerate(ids): mask_in_cluster = labels == cluster_id cluster = tracks_3d.xyz[mask_in_cluster].transpose( 0, 1 ) # [num_frames, n_pts, 3] visibilities = tracks_3d.visibles[mask_in_cluster].swapaxes( 0, 1 ) # [num_frames, n_pts] confidences = tracks_3d.confidences[mask_in_cluster].swapaxes( 0, 1 ) # [num_frames, n_pts] weights = get_weights_for_procrustes(cluster, visibilities) prev_t = cano_t cluster_skip_ts = [] for cur_t in tgt_ts: # compute pairwise transform from cano_t procrustes_weights = ( weights[cano_t] * weights[cur_t] * (confidences[cano_t] + confidences[cur_t]) / 2 ) if procrustes_weights.sum() < min_mean_weight * num_frames: init_rots[n, cur_t] = init_rots[n, prev_t] init_ts[n, cur_t] = init_ts[n, prev_t] cluster_skip_ts.append(cur_t) else: se3, (err, err_before) = solve_procrustes( cluster[cano_t], cluster[cur_t], weights=procrustes_weights, enforce_se3=True, rot_type=rot_type, ) init_rot, init_t, _ = se3 assert init_rot.shape[-1] == rot_dim # double cover if rot_type == "quat" and torch.linalg.norm( init_rot - init_rots[n][prev_t] ) > torch.linalg.norm(-init_rot - init_rots[n][prev_t]): init_rot = -init_rot init_rots[n, cur_t] = init_rot init_ts[n, cur_t] = init_t if err == np.nan: print(f"{cur_t=} {err=}") print(f"{procrustes_weights.isnan().sum()=}") if err_before == np.nan: print(f"{cur_t=} {err_before=}") print(f"{procrustes_weights.isnan().sum()=}") errs_after[n, cur_t] = err errs_before[n, cur_t] = err_before prev_t = cur_t skipped_ts[cluster_id.item()] = cluster_skip_ts guru.info(f"{skipped_ts=}") guru.info( "procrustes init median error: {:.5f} => {:.5f}".format( np.median(errs_before[errs_before > 0]), np.median(errs_after[errs_after > 0]), ) ) guru.info( "procrustes init mean error: {:.5f} => {:.5f}".format( np.mean(errs_before[errs_before > 0]), np.mean(errs_after[errs_after > 0]) ) ) guru.info(f"{init_rots.shape=}, {init_ts.shape=}, {motion_coefs.shape=}") if vis: server = get_server(port) center_idcs = torch.argmin(dists2centers, dim=0) print(f"{dists2centers.shape=} {center_idcs.shape=}") vis_se3_init_3d(server, init_rots, init_ts, means_cano[center_idcs]) vis_tracks_3d(server, tracks_3d.xyz[center_idcs].numpy(), name="center_tracks") import ipdb ipdb.set_trace() bases = MotionBases(init_rots, init_ts) return bases, motion_coefs, tracks_3d def run_initial_optim( fg: GaussianParams, bases: MotionBases, tracks_3d: TrackObservations, Ks: torch.Tensor, w2cs: torch.Tensor, num_iters: int = 1000, use_depth_range_loss: bool = False, ): """ :param motion_rots: [num_bases, num_frames, 4|6] :param motion_transls: [num_bases, num_frames, 3] :param motion_coefs: [num_bases, num_frames] :param means: [num_gaussians, 3] """ optimizer = torch.optim.Adam( [ {"params": bases.params["rots"], "lr": 1e-2}, {"params": bases.params["transls"], "lr": 3e-2}, {"params": fg.params["motion_coefs"], "lr": 1e-2}, {"params": fg.params["means"], "lr": 1e-3}, ], ) scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, gamma=0.1 ** (1 / num_iters) ) G = fg.params.means.shape[0] num_frames = bases.num_frames device = bases.params["rots"].device w_smooth_func = lambda i, min_v, max_v, th: ( min_v if i <= th else (max_v - min_v) * (i - th) / (num_iters - th) + min_v ) gt_2d, gt_depth = project_2d_tracks( tracks_3d.xyz.swapaxes(0, 1), Ks, w2cs, return_depth=True ) # (G, T, 2) gt_2d = gt_2d.swapaxes(0, 1) # (G, T) gt_depth = gt_depth.swapaxes(0, 1) ts = torch.arange(0, num_frames, device=device) ts_clamped = torch.clamp(ts, min=1, max=num_frames - 2) ts_neighbors = torch.cat((ts_clamped - 1, ts_clamped, ts_clamped + 1)) # i (3B,) pbar = tqdm(range(0, num_iters)) for i in pbar: coefs = fg.get_coefs() transfms = bases.compute_transforms(ts, coefs) positions = torch.einsum( "pnij,pj->pni", transfms, F.pad(fg.params["means"], (0, 1), value=1.0), ) loss = 0.0 track_3d_loss = masked_l1_loss( positions, tracks_3d.xyz, (tracks_3d.visibles.float() * tracks_3d.confidences)[..., None], ) loss += track_3d_loss * 1.0 pred_2d, pred_depth = project_2d_tracks( positions.swapaxes(0, 1), Ks, w2cs, return_depth=True ) pred_2d = pred_2d.swapaxes(0, 1) pred_depth = pred_depth.swapaxes(0, 1) loss_2d = ( masked_l1_loss( pred_2d, gt_2d, (tracks_3d.invisibles.float() * tracks_3d.confidences)[..., None], quantile=0.95, ) / Ks[0, 0, 0] ) loss += 0.5 * loss_2d if use_depth_range_loss: near_depths = torch.quantile(gt_depth, 0.0, dim=0, keepdim=True) far_depths = torch.quantile(gt_depth, 0.98, dim=0, keepdim=True) loss_depth_in_range = 0 if (pred_depth < near_depths).any(): loss_depth_in_range += (near_depths - pred_depth)[ pred_depth < near_depths ].mean() if (pred_depth > far_depths).any(): loss_depth_in_range += (pred_depth - far_depths)[ pred_depth > far_depths ].mean() loss += loss_depth_in_range * w_smooth_func(i, 0.05, 0.5, 400) motion_coef_sparse_loss = 1 - (coefs**2).sum(dim=-1).mean() loss += motion_coef_sparse_loss * 0.01 # motion basis should be smooth. w_smooth = w_smooth_func(i, 0.01, 0.1, 400) small_acc_loss = compute_se3_smoothness_loss( bases.params["rots"], bases.params["transls"] ) loss += small_acc_loss * w_smooth small_acc_loss_tracks = compute_accel_loss(positions) loss += small_acc_loss_tracks * w_smooth * 0.5 transfms_nbs = bases.compute_transforms(ts_neighbors, coefs) means_nbs = torch.einsum( "pnij,pj->pni", transfms_nbs, F.pad(fg.params["means"], (0, 1), value=1.0) ) # (G, 3n, 3) means_nbs = means_nbs.reshape(means_nbs.shape[0], 3, -1, 3) # [G, 3, n, 3] z_accel_loss = compute_z_acc_loss(means_nbs, w2cs) loss += z_accel_loss * 0.1 optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() pbar.set_description( f"{loss.item():.3f} " f"{track_3d_loss.item():.3f} " f"{motion_coef_sparse_loss.item():.3f} " f"{small_acc_loss.item():.3f} " f"{small_acc_loss_tracks.item():.3f} " f"{z_accel_loss.item():.3f} " ) def random_quats(N: int) -> torch.Tensor: u = torch.rand(N, 1) v = torch.rand(N, 1) w = torch.rand(N, 1) quats = torch.cat( [ torch.sqrt(1.0 - u) * torch.sin(2.0 * np.pi * v), torch.sqrt(1.0 - u) * torch.cos(2.0 * np.pi * v), torch.sqrt(u) * torch.sin(2.0 * np.pi * w), torch.sqrt(u) * torch.cos(2.0 * np.pi * w), ], -1, ) return quats def compute_means(ts, fg: GaussianParams, bases: MotionBases): transfms = bases.compute_transforms(ts, fg.get_coefs()) means = torch.einsum( "pnij,pj->pni", transfms, F.pad(fg.params["means"], (0, 1), value=1.0), ) return means def vis_init_params( server, fg: GaussianParams, bases: MotionBases, name="init_params", num_vis: int = 100, ): idcs = np.random.choice(fg.num_gaussians, num_vis) labels = np.linspace(0, 1, num_vis) ts = torch.arange(bases.num_frames, device=bases.params["rots"].device) with torch.no_grad(): pred_means = compute_means(ts, fg, bases) vis_means = pred_means[idcs].detach().cpu().numpy() vis_tracks_3d(server, vis_means, labels, name=name) @torch.no_grad() def vis_se3_init_3d(server, init_rots, init_ts, basis_centers): """ :param init_rots: [num_bases, num_frames, 4|6] :param init_ts: [num_bases, num_frames, 3] :param basis_centers: [num_bases, 3] """ # visualize the initial centers across time rot_dim = init_rots.shape[-1] assert rot_dim in [4, 6] num_bases = init_rots.shape[0] assert init_ts.shape[0] == num_bases assert basis_centers.shape[0] == num_bases labels = np.linspace(0, 1, num_bases) if rot_dim == 4: quats = F.normalize(init_rots, dim=-1, p=2) rmats = roma.unitquat_to_rotmat(quats.roll(-1, dims=-1)) else: rmats = cont_6d_to_rmat(init_rots) transls = init_ts transfms = rt_to_mat4(rmats, transls) center_tracks3d = torch.einsum( "bnij,bj->bni", transfms, F.pad(basis_centers, (0, 1), value=1.0) )[..., :3] vis_tracks_3d(server, center_tracks3d.cpu().numpy(), labels, name="se3_centers") @torch.no_grad() def vis_tracks_2d_video( path, imgs: np.ndarray, tracks_3d: np.ndarray, Ks: np.ndarray, w2cs: np.ndarray, occs=None, radius: int = 3, ): num_tracks = tracks_3d.shape[0] labels = np.linspace(0, 1, num_tracks) cmap = get_cmap("gist_rainbow") colors = cmap(labels)[:, :3] tracks_2d = ( project_2d_tracks(tracks_3d.swapaxes(0, 1), Ks, w2cs).cpu().numpy() # type: ignore ) frames = np.asarray( draw_keypoints_video(imgs, tracks_2d, colors, occs, radius=radius) ) iio.imwrite(path, frames, fps=15) def vis_tracks_3d( server: ViserServer, vis_tracks: np.ndarray, vis_label: np.ndarray | None = None, name: str = "tracks", ): """ :param vis_tracks (np.ndarray): (N, T, 3) :param vis_label (np.ndarray): (N) """ cmap = get_cmap("gist_rainbow") if vis_label is None: vis_label = np.linspace(0, 1, len(vis_tracks)) colors = cmap(np.asarray(vis_label))[:, :3] guru.info(f"{colors.shape=}, {vis_tracks.shape=}") N, T = vis_tracks.shape[:2] vis_tracks = np.asarray(vis_tracks) for i in range(N): server.scene.add_spline_catmull_rom( f"/{name}/{i}/spline", vis_tracks[i], color=colors[i], segments=T - 1 ) server.scene.add_point_cloud( f"/{name}/{i}/start", vis_tracks[i, [0]], colors=colors[i : i + 1], point_size=0.05, point_shape="circle", ) server.scene.add_point_cloud( f"/{name}/{i}/end", vis_tracks[i, [-1]], colors=colors[i : i + 1], point_size=0.05, point_shape="diamond", ) def sample_initial_bases_centers( mode: str, cano_t: int, tracks_3d: TrackObservations, num_bases: int ): """ :param mode: "farthest" | "hdbscan" | "kmeans" :param tracks_3d: [G, T, 3] :param cano_t: canonical index :param num_bases: number of SE3 bases """ assert mode in ["farthest", "hdbscan", "kmeans"] means_canonical = tracks_3d.xyz[:, cano_t].clone() # if mode == "farthest": # vis_mask = tracks_3d.visibles[:, cano_t] # sampled_centers, _ = sample_farthest_points( # means_canonical[vis_mask][None], # K=num_bases, # random_start_point=True, # ) # [1, num_bases, 3] # dists2centers = torch.norm(means_canonical[:, None] - sampled_centers, dim=-1).T # return sampled_centers, num_bases, dists2centers # linearly interpolate missing 3d points xyz = cp.asarray(tracks_3d.xyz) print(f"{xyz.shape=}") visibles = cp.asarray(tracks_3d.visibles) num_tracks = xyz.shape[0] xyz_interp = batched_interp_masked(xyz, visibles) # num_vis = 50 # server = get_server(port=8890) # idcs = np.random.choice(num_tracks, num_vis) # labels = np.linspace(0, 1, num_vis) # vis_tracks_3d(server, tracks_3d.xyz[idcs].get(), labels, name="raw_tracks") # vis_tracks_3d(server, xyz_interp[idcs].get(), labels, name="interp_tracks") # import ipdb; ipdb.set_trace() velocities = xyz_interp[:, 1:] - xyz_interp[:, :-1] vel_dirs = ( velocities / (cp.linalg.norm(velocities, axis=-1, keepdims=True) + 1e-5) ).reshape((num_tracks, -1)) # [num_bases, num_gaussians] if mode == "kmeans": model = KMeans(n_clusters=num_bases) else: model = HDBSCAN(min_cluster_size=20, max_cluster_size=num_tracks // 4) model.fit(vel_dirs) labels = model.labels_ num_bases = labels.max().item() + 1 sampled_centers = torch.stack( [ means_canonical[torch.tensor(labels == i)].median(dim=0).values for i in range(num_bases) ] )[None] print("number of {} clusters: ".format(mode), num_bases) return sampled_centers, num_bases, torch.tensor(labels) def interp_masked(vals: cp.ndarray, mask: cp.ndarray, pad: int = 1) -> cp.ndarray: """ hacky way to interpolate batched with cupy by concatenating the batches and pad with dummy values :param vals: [B, M, *] :param mask: [B, M] """ assert mask.ndim == 2 assert vals.shape[:2] == mask.shape B, M = mask.shape # get the first and last valid values for each track sh = vals.shape[2:] vals = vals.reshape((B, M, -1)) D = vals.shape[-1] first_val_idcs = cp.argmax(mask, axis=-1) last_val_idcs = M - 1 - cp.argmax(cp.flip(mask, axis=-1), axis=-1) bidcs = cp.arange(B) v0 = vals[bidcs, first_val_idcs][:, None] v1 = vals[bidcs, last_val_idcs][:, None] m0 = mask[bidcs, first_val_idcs][:, None] m1 = mask[bidcs, last_val_idcs][:, None] if pad > 1: v0 = cp.tile(v0, [1, pad, 1]) v1 = cp.tile(v1, [1, pad, 1]) m0 = cp.tile(m0, [1, pad]) m1 = cp.tile(m1, [1, pad]) vals_pad = cp.concatenate([v0, vals, v1], axis=1) mask_pad = cp.concatenate([m0, mask, m1], axis=1) M_pad = vals_pad.shape[1] vals_flat = vals_pad.reshape((B * M_pad, -1)) mask_flat = mask_pad.reshape((B * M_pad,)) idcs = cp.where(mask_flat)[0] cx = cp.arange(B * M_pad) out = cp.zeros((B * M_pad, D), dtype=vals_flat.dtype) for d in range(D): out[:, d] = cp.interp(cx, idcs, vals_flat[idcs, d]) out = out.reshape((B, M_pad, *sh))[:, pad:-pad] return out def batched_interp_masked( vals: cp.ndarray, mask: cp.ndarray, batch_num: int = 4096, batch_time: int = 64 ): assert mask.ndim == 2 B, M = mask.shape out = cp.zeros_like(vals) for b in tqdm(range(0, B, batch_num), leave=False): for m in tqdm(range(0, M, batch_time), leave=False): x = interp_masked( vals[b : b + batch_num, m : m + batch_time], mask[b : b + batch_num, m : m + batch_time], ) # (batch_num, batch_time, *) out[b : b + batch_num, m : m + batch_time] = x return out ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/loss_utils.py ================================================ import numpy as np import torch import torch.nn.functional as F from sklearn.neighbors import NearestNeighbors def masked_mse_loss(pred, gt, mask=None, normalize=True, quantile: float = 1.0): if mask is None: return trimmed_mse_loss(pred, gt, quantile) else: sum_loss = F.mse_loss(pred, gt, reduction="none").mean(dim=-1, keepdim=True) quantile_mask = ( (sum_loss < torch.quantile(sum_loss, quantile)).squeeze(-1) if quantile < 1 else torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1) ) ndim = sum_loss.shape[-1] if normalize: return torch.sum((sum_loss * mask)[quantile_mask]) / ( ndim * torch.sum(mask[quantile_mask]) + 1e-8 ) else: return torch.mean((sum_loss * mask)[quantile_mask]) def masked_l1_loss(pred, gt, mask=None, normalize=True, quantile: float = 1.0): if mask is None: return trimmed_l1_loss(pred, gt, quantile) else: sum_loss = F.l1_loss(pred, gt, reduction="none").mean(dim=-1, keepdim=True) quantile_mask = ( (sum_loss < torch.quantile(sum_loss, quantile)).squeeze(-1) if quantile < 1 else torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1) ) ndim = sum_loss.shape[-1] if normalize: return torch.sum((sum_loss * mask)[quantile_mask]) / ( ndim * torch.sum(mask[quantile_mask]) + 1e-8 ) else: return torch.mean((sum_loss * mask)[quantile_mask]) def masked_huber_loss(pred, gt, delta, mask=None, normalize=True): if mask is None: return F.huber_loss(pred, gt, delta=delta) else: sum_loss = F.huber_loss(pred, gt, delta=delta, reduction="none") ndim = sum_loss.shape[-1] if normalize: return torch.sum(sum_loss * mask) / (ndim * torch.sum(mask) + 1e-8) else: return torch.mean(sum_loss * mask) def trimmed_mse_loss(pred, gt, quantile=0.9): loss = F.mse_loss(pred, gt, reduction="none").mean(dim=-1) loss_at_quantile = torch.quantile(loss, quantile) trimmed_loss = loss[loss < loss_at_quantile].mean() return trimmed_loss def trimmed_l1_loss(pred, gt, quantile=0.9): loss = F.l1_loss(pred, gt, reduction="none").mean(dim=-1) loss_at_quantile = torch.quantile(loss, quantile) trimmed_loss = loss[loss < loss_at_quantile].mean() return trimmed_loss def compute_gradient_loss(pred, gt, mask, quantile=0.98): """ Compute gradient loss pred: (batch_size, H, W, D) or (batch_size, H, W) gt: (batch_size, H, W, D) or (batch_size, H, W) mask: (batch_size, H, W), bool or float """ # NOTE: messy need to be cleaned up mask_x = mask[:, :, 1:] * mask[:, :, :-1] mask_y = mask[:, 1:, :] * mask[:, :-1, :] pred_grad_x = pred[:, :, 1:] - pred[:, :, :-1] pred_grad_y = pred[:, 1:, :] - pred[:, :-1, :] gt_grad_x = gt[:, :, 1:] - gt[:, :, :-1] gt_grad_y = gt[:, 1:, :] - gt[:, :-1, :] loss = masked_l1_loss( pred_grad_x[mask_x][..., None], gt_grad_x[mask_x][..., None], quantile=quantile ) + masked_l1_loss( pred_grad_y[mask_y][..., None], gt_grad_y[mask_y][..., None], quantile=quantile ) return loss def knn(x: torch.Tensor, k: int) -> tuple[np.ndarray, np.ndarray]: x = x.cpu().numpy() knn_model = NearestNeighbors( n_neighbors=k + 1, algorithm="auto", metric="euclidean" ).fit(x) distances, indices = knn_model.kneighbors(x) return distances[:, 1:].astype(np.float32), indices[:, 1:].astype(np.float32) def get_weights_for_procrustes(clusters, visibilities=None): clusters_median = clusters.median(dim=-2, keepdim=True)[0] dists2clusters_center = torch.norm(clusters - clusters_median, dim=-1) dists2clusters_center /= dists2clusters_center.median(dim=-1, keepdim=True)[0] weights = torch.exp(-dists2clusters_center) weights /= weights.mean(dim=-1, keepdim=True) + 1e-6 if visibilities is not None: weights *= visibilities.float() + 1e-6 invalid = dists2clusters_center > np.quantile( dists2clusters_center.cpu().numpy(), 0.9 ) invalid |= torch.isnan(weights) weights[invalid] = 0 return weights def compute_z_acc_loss(means_ts_nb: torch.Tensor, w2cs: torch.Tensor): """ :param means_ts (G, 3, B, 3) :param w2cs (B, 4, 4) return (float) """ camera_center_t = torch.linalg.inv(w2cs)[:, :3, 3] # (B, 3) ray_dir = F.normalize( means_ts_nb[:, 1] - camera_center_t, p=2.0, dim=-1 ) # [G, B, 3] # acc = 2 * means[:, 1] - means[:, 0] - means[:, 2] # [G, B, 3] # acc_loss = (acc * ray_dir).sum(dim=-1).abs().mean() acc_loss = ( ((means_ts_nb[:, 1] - means_ts_nb[:, 0]) * ray_dir).sum(dim=-1) ** 2 ).mean() + ( ((means_ts_nb[:, 2] - means_ts_nb[:, 1]) * ray_dir).sum(dim=-1) ** 2 ).mean() return acc_loss def compute_se3_smoothness_loss( rots: torch.Tensor, transls: torch.Tensor, weight_rot: float = 1.0, weight_transl: float = 2.0, ): """ central differences :param motion_transls (K, T, 3) :param motion_rots (K, T, 6) """ r_accel_loss = compute_accel_loss(rots) t_accel_loss = compute_accel_loss(transls) return r_accel_loss * weight_rot + t_accel_loss * weight_transl def compute_accel_loss(transls): accel = 2 * transls[:, 1:-1] - transls[:, :-2] - transls[:, 2:] loss = accel.norm(dim=-1).mean() return loss ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/metrics.py ================================================ from typing import Literal import numpy as np import torch import torch.nn.functional as F from torchmetrics.functional.image.lpips import _NoTrainLpips from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure from torchmetrics.metric import Metric from torchmetrics.utilities import dim_zero_cat from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE def compute_psnr( preds: torch.Tensor, targets: torch.Tensor, masks: torch.Tensor | None = None, ) -> float: """ Args: preds (torch.Tensor): (..., 3) predicted images in [0, 1]. targets (torch.Tensor): (..., 3) target images in [0, 1]. masks (torch.Tensor | None): (...,) optional binary masks where the 1-regions will be taken into account. Returns: psnr (float): Peak signal-to-noise ratio. """ if masks is None: masks = torch.ones_like(preds[..., 0]) return ( -10.0 * torch.log( F.mse_loss( preds * masks[..., None], targets * masks[..., None], reduction="sum", ) / masks.sum().clamp(min=1.0) / 3.0 ) / np.log(10.0) ).item() def compute_pose_errors( preds: torch.Tensor, targets: torch.Tensor ) -> tuple[float, float, float]: """ Args: preds: (N, 4, 4) predicted camera poses. targets: (N, 4, 4) target camera poses. Returns: ate (float): Absolute trajectory error. rpe_t (float): Relative pose error in translation. rpe_r (float): Relative pose error in rotation (degree). """ # Compute ATE. ate = torch.linalg.norm(preds[:, :3, -1] - targets[:, :3, -1], dim=-1).mean().item() # Compute RPE_t and RPE_r. # NOTE(hangg): It's important to use numpy here for the accuracy of RPE_r. # torch has numerical issues for acos when the value is close to 1.0, i.e. # RPE_r is supposed to be very small, and will result in artificially large # error. preds = preds.detach().cpu().numpy() targets = targets.detach().cpu().numpy() pred_rels = np.linalg.inv(preds[:-1]) @ preds[1:] pred_rels = np.linalg.inv(preds[:-1]) @ preds[1:] target_rels = np.linalg.inv(targets[:-1]) @ targets[1:] error_rels = np.linalg.inv(target_rels) @ pred_rels traces = error_rels[:, :3, :3].trace(axis1=-2, axis2=-1) rpe_t = np.linalg.norm(error_rels[:, :3, -1], axis=-1).mean().item() rpe_r = ( np.arccos(np.clip((traces - 1.0) / 2.0, -1.0, 1.0)).mean().item() / np.pi * 180.0 ) return ate, rpe_t, rpe_r class mPSNR(PeakSignalNoiseRatio): sum_squared_error: list[torch.Tensor] total: list[torch.Tensor] def __init__(self, **kwargs) -> None: super().__init__( data_range=1.0, base=10.0, dim=None, reduction="elementwise_mean", **kwargs, ) self.add_state("sum_squared_error", default=[], dist_reduce_fx="cat") self.add_state("total", default=[], dist_reduce_fx="cat") def __len__(self) -> int: return len(self.total) def update( self, preds: torch.Tensor, targets: torch.Tensor, masks: torch.Tensor | None = None, ): """Update state with predictions and targets. Args: preds (torch.Tensor): (..., 3) float32 predicted images. targets (torch.Tensor): (..., 3) float32 target images. masks (torch.Tensor | None): (...,) optional binary masks where the 1-regions will be taken into account. """ if masks is None: masks = torch.ones_like(preds[..., 0]) self.sum_squared_error.append( torch.sum(torch.pow((preds - targets) * masks[..., None], 2)) ) self.total.append(masks.sum().to(torch.int64) * 3) def compute(self) -> torch.Tensor: """Compute peak signal-to-noise ratio over state.""" sum_squared_error = dim_zero_cat(self.sum_squared_error) total = dim_zero_cat(self.total) return -10.0 * torch.log(sum_squared_error / total).mean() / np.log(10.0) class mSSIM(StructuralSimilarityIndexMeasure): similarity: list def __init__(self, **kwargs) -> None: super().__init__( reduction=None, data_range=1.0, return_full_image=False, **kwargs, ) assert isinstance(self.sigma, float) def __len__(self) -> int: return sum([s.shape[0] for s in self.similarity]) def update( self, preds: torch.Tensor, targets: torch.Tensor, masks: torch.Tensor | None = None, ): """Update state with predictions and targets. Args: preds (torch.Tensor): (B, H, W, 3) float32 predicted images. targets (torch.Tensor): (B, H, W, 3) float32 target images. masks (torch.Tensor | None): (B, H, W) optional binary masks where the 1-regions will be taken into account. """ if masks is None: masks = torch.ones_like(preds[..., 0]) # Construct a 1D Gaussian blur filter. assert isinstance(self.kernel_size, int) hw = self.kernel_size // 2 shift = (2 * hw - self.kernel_size + 1) / 2 assert isinstance(self.sigma, float) f_i = ( (torch.arange(self.kernel_size, device=preds.device) - hw + shift) / self.sigma ) ** 2 filt = torch.exp(-0.5 * f_i) filt /= torch.sum(filt) # Blur in x and y (faster than the 2D convolution). def convolve2d(z, m, f): # z: (B, H, W, C), m: (B, H, W), f: (Hf, Wf). z = z.permute(0, 3, 1, 2) m = m[:, None] f = f[None, None].expand(z.shape[1], -1, -1, -1) z_ = torch.nn.functional.conv2d( z * m, f, padding="valid", groups=z.shape[1] ) m_ = torch.nn.functional.conv2d(m, torch.ones_like(f[:1]), padding="valid") return torch.where( m_ != 0, z_ * torch.ones_like(f).sum() / (m_ * z.shape[1]), 0 ).permute(0, 2, 3, 1), (m_ != 0)[:, 0].to(z.dtype) filt_fn1 = lambda z, m: convolve2d(z, m, filt[:, None]) filt_fn2 = lambda z, m: convolve2d(z, m, filt[None, :]) filt_fn = lambda z, m: filt_fn1(*filt_fn2(z, m)) mu0 = filt_fn(preds, masks)[0] mu1 = filt_fn(targets, masks)[0] mu00 = mu0 * mu0 mu11 = mu1 * mu1 mu01 = mu0 * mu1 sigma00 = filt_fn(preds**2, masks)[0] - mu00 sigma11 = filt_fn(targets**2, masks)[0] - mu11 sigma01 = filt_fn(preds * targets, masks)[0] - mu01 # Clip the variances and covariances to valid values. # Variance must be non-negative: sigma00 = sigma00.clamp(min=0.0) sigma11 = sigma11.clamp(min=0.0) sigma01 = torch.sign(sigma01) * torch.minimum( torch.sqrt(sigma00 * sigma11), torch.abs(sigma01) ) assert isinstance(self.data_range, float) c1 = (self.k1 * self.data_range) ** 2 c2 = (self.k2 * self.data_range) ** 2 numer = (2 * mu01 + c1) * (2 * sigma01 + c2) denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) ssim_map = numer / denom self.similarity.append(ssim_map.mean(dim=(1, 2, 3))) def compute(self) -> torch.Tensor: """Compute final SSIM metric.""" return torch.cat(self.similarity).mean() class mLPIPS(Metric): sum_scores: list[torch.Tensor] total: list[torch.Tensor] def __init__( self, net_type: Literal["vgg", "alex", "squeeze"] = "alex", **kwargs, ): super().__init__(**kwargs) if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( "LPIPS metric requires that torchvision is installed." " Either install as `pip install torchmetrics[image]` or `pip install torchvision`." ) valid_net_type = ("vgg", "alex", "squeeze") if net_type not in valid_net_type: raise ValueError( f"Argument `net_type` must be one of {valid_net_type}, but got {net_type}." ) self.net = _NoTrainLpips(net=net_type, spatial=True) self.add_state("sum_scores", [], dist_reduce_fx="cat") self.add_state("total", [], dist_reduce_fx="cat") def __len__(self) -> int: return len(self.total) def update( self, preds: torch.Tensor, targets: torch.Tensor, masks: torch.Tensor | None = None, ): """Update internal states with lpips scores. Args: preds (torch.Tensor): (B, H, W, 3) float32 predicted images. targets (torch.Tensor): (B, H, W, 3) float32 target images. masks (torch.Tensor | None): (B, H, W) optional float32 binary masks where the 1-regions will be taken into account. """ if masks is None: masks = torch.ones_like(preds[..., 0]) scores = self.net( (preds * masks[..., None]).permute(0, 3, 1, 2), (targets * masks[..., None]).permute(0, 3, 1, 2), normalize=True, ) self.sum_scores.append((scores * masks[:, None]).sum()) self.total.append(masks.sum().to(torch.int64)) def compute(self) -> torch.Tensor: """Compute final perceptual similarity metric.""" return ( torch.tensor(self.sum_scores, device=self.device) / torch.tensor(self.total, device=self.device) ).mean() class PCK(Metric): correct: list[torch.Tensor] total: list[int] def __init__(self, **kwargs): super().__init__(**kwargs) self.add_state("correct", default=[], dist_reduce_fx="cat") self.add_state("total", default=[], dist_reduce_fx="cat") def __len__(self) -> int: return len(self.total) def update(self, preds: torch.Tensor, targets: torch.Tensor, threshold: float): """Update internal states with PCK scores. Args: preds (torch.Tensor): (N, 2) predicted 2D keypoints. targets (torch.Tensor): (N, 2) targets 2D keypoints. threshold (float): PCK threshold. """ self.correct.append( (torch.linalg.norm(preds - targets, dim=-1) < threshold).sum() ) self.total.append(preds.shape[0]) def compute(self) -> torch.Tensor: """Compute PCK over state.""" return ( torch.tensor(self.correct, device=self.device) / torch.clamp(torch.tensor(self.total, device=self.device), min=1e-8) ).mean() ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/params.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from flow3d.transforms import cont_6d_to_rmat class GaussianParams(nn.Module): def __init__( self, means: torch.Tensor, quats: torch.Tensor, scales: torch.Tensor, colors: torch.Tensor, opacities: torch.Tensor, motion_coefs: torch.Tensor | None = None, scene_center: torch.Tensor | None = None, scene_scale: torch.Tensor | float = 1.0, ): super().__init__() if not check_gaussian_sizes( means, quats, scales, colors, opacities, motion_coefs ): import ipdb ipdb.set_trace() params_dict = { "means": nn.Parameter(means), "quats": nn.Parameter(quats), "scales": nn.Parameter(scales), "colors": nn.Parameter(colors), "opacities": nn.Parameter(opacities), } if motion_coefs is not None: params_dict["motion_coefs"] = nn.Parameter(motion_coefs) self.params = nn.ParameterDict(params_dict) self.quat_activation = lambda x: F.normalize(x, dim=-1, p=2) self.color_activation = torch.sigmoid self.scale_activation = torch.exp self.opacity_activation = torch.sigmoid self.motion_coef_activation = lambda x: F.softmax(x, dim=-1) if scene_center is None: scene_center = torch.zeros(3, device=means.device) self.register_buffer("scene_center", scene_center) self.register_buffer("scene_scale", torch.as_tensor(scene_scale)) @staticmethod def init_from_state_dict(state_dict, prefix="params."): req_keys = ["means", "quats", "scales", "colors", "opacities"] assert all(f"{prefix}{k}" in state_dict for k in req_keys) args = { "motion_coefs": None, "scene_center": torch.zeros(3), "scene_scale": torch.tensor(1.0), } for k in req_keys + list(args.keys()): if f"{prefix}{k}" in state_dict: args[k] = state_dict[f"{prefix}{k}"] return GaussianParams(**args) @property def num_gaussians(self) -> int: return self.params["means"].shape[0] def get_colors(self) -> torch.Tensor: return self.color_activation(self.params["colors"]) def get_scales(self) -> torch.Tensor: return self.scale_activation(self.params["scales"]) def get_opacities(self) -> torch.Tensor: return self.opacity_activation(self.params["opacities"]) def get_quats(self) -> torch.Tensor: return self.quat_activation(self.params["quats"]) def get_coefs(self) -> torch.Tensor: assert "motion_coefs" in self.params return self.motion_coef_activation(self.params["motion_coefs"]) def densify_params(self, should_split, should_dup): """ densify gaussians """ updated_params = {} for name, x in self.params.items(): x_dup = x[should_dup] x_split = x[should_split].repeat([2] + [1] * (x.ndim - 1)) if name == "scales": x_split -= math.log(1.6) x_new = nn.Parameter(torch.cat([x[~should_split], x_dup, x_split], dim=0)) updated_params[name] = x_new self.params[name] = x_new return updated_params def cull_params(self, should_cull): """ cull gaussians """ updated_params = {} for name, x in self.params.items(): x_new = nn.Parameter(x[~should_cull]) updated_params[name] = x_new self.params[name] = x_new return updated_params def reset_opacities(self, new_val): """ reset all opacities to new_val """ self.params["opacities"].data.fill_(new_val) updated_params = {"opacities": self.params["opacities"]} return updated_params class MotionBases(nn.Module): def __init__(self, rots, transls): super().__init__() self.num_frames = rots.shape[1] self.num_bases = rots.shape[0] assert check_bases_sizes(rots, transls) self.params = nn.ParameterDict( { "rots": nn.Parameter(rots), "transls": nn.Parameter(transls), } ) @staticmethod def init_from_state_dict(state_dict, prefix="params."): param_keys = ["rots", "transls"] assert all(f"{prefix}{k}" in state_dict for k in param_keys) args = {k: state_dict[f"{prefix}{k}"] for k in param_keys} return MotionBases(**args) def compute_transforms(self, ts: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor: """ :param ts (B) :param coefs (G, K) returns transforms (G, B, 3, 4) """ transls = self.params["transls"][:, ts] # (K, B, 3) rots = self.params["rots"][:, ts] # (K, B, 6) transls = torch.einsum("pk,kni->pni", coefs, transls) rots = torch.einsum("pk,kni->pni", coefs, rots) # (G, B, 6) rotmats = cont_6d_to_rmat(rots) # (K, B, 3, 3) return torch.cat([rotmats, transls[..., None]], dim=-1) def check_gaussian_sizes( means: torch.Tensor, quats: torch.Tensor, scales: torch.Tensor, colors: torch.Tensor, opacities: torch.Tensor, motion_coefs: torch.Tensor | None = None, ) -> bool: dims = means.shape[:-1] leading_dims_match = ( quats.shape[:-1] == dims and scales.shape[:-1] == dims and colors.shape[:-1] == dims and opacities.shape == dims ) if motion_coefs is not None and motion_coefs.numel() > 0: leading_dims_match &= motion_coefs.shape[:-1] == dims dims_correct = ( means.shape[-1] == 3 and (quats.shape[-1] == 4) and (scales.shape[-1] == 3) and (colors.shape[-1] == 3) ) return leading_dims_match and dims_correct def check_bases_sizes(motion_rots: torch.Tensor, motion_transls: torch.Tensor) -> bool: return ( motion_rots.shape[-1] == 6 and motion_transls.shape[-1] == 3 and motion_rots.shape[:-2] == motion_transls.shape[:-2] ) ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/renderer.py ================================================ import numpy as np import torch import torch.nn.functional as F from loguru import logger as guru from nerfview import CameraState from flow3d.scene_model import SceneModel from flow3d.vis.utils import draw_tracks_2d_th, get_server from flow3d.vis.viewer import DynamicViewer class Renderer: def __init__( self, model: SceneModel, device: torch.device, # Logging. work_dir: str, port: int | None = None, ): self.device = device self.model = model self.num_frames = model.num_frames self.work_dir = work_dir self.global_step = 0 self.epoch = 0 self.viewer = None if port is not None: server = get_server(port=port) self.viewer = DynamicViewer( server, self.render_fn, model.num_frames, work_dir, mode="rendering" ) self.tracks_3d = self.model.compute_poses_fg( # torch.arange(max(0, t - 20), max(1, t), device=self.device), torch.arange(self.num_frames, device=self.device), inds=torch.arange(10, device=self.device), )[0] @staticmethod def init_from_checkpoint( path: str, device: torch.device, *args, **kwargs ) -> "Renderer": guru.info(f"Loading checkpoint from {path}") ckpt = torch.load(path) state_dict = ckpt["model"] model = SceneModel.init_from_state_dict(state_dict) model = model.to(device) renderer = Renderer(model, device, *args, **kwargs) renderer.global_step = ckpt.get("global_step", 0) renderer.epoch = ckpt.get("epoch", 0) return renderer @torch.inference_mode() def render_fn(self, camera_state: CameraState, img_wh: tuple[int, int]): if self.viewer is None: return np.full((img_wh[1], img_wh[0], 3), 255, dtype=np.uint8) W, H = img_wh focal = 0.5 * H / np.tan(0.5 * camera_state.fov).item() K = torch.tensor( [[focal, 0.0, W / 2.0], [0.0, focal, H / 2.0], [0.0, 0.0, 1.0]], device=self.device, ) w2c = torch.linalg.inv( torch.from_numpy(camera_state.c2w.astype(np.float32)).to(self.device) ) t = ( int(self.viewer._playback_guis[0].value) if not self.viewer._canonical_checkbox.value else None ) self.model.training = False img = self.model.render(t, w2c[None], K[None], img_wh)["img"][0] if not self.viewer._render_track_checkbox.value: img = (img.cpu().numpy() * 255.0).astype(np.uint8) else: assert t is not None tracks_3d = self.tracks_3d[:, max(0, t - 20) : max(1, t)] tracks_2d = torch.einsum( "ij,jk,nbk->nbi", K, w2c[:3], F.pad(tracks_3d, (0, 1), value=1.0) ) tracks_2d = tracks_2d[..., :2] / tracks_2d[..., 2:] img = draw_tracks_2d_th(img, tracks_2d) return img ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/scene_model.py ================================================ import roma import torch import torch.nn as nn import torch.nn.functional as F from gsplat.rendering import rasterization from torch import Tensor from flow3d.params import GaussianParams, MotionBases class SceneModel(nn.Module): def __init__( self, Ks: Tensor, w2cs: Tensor, fg_params: GaussianParams, motion_bases: MotionBases, bg_params: GaussianParams | None = None, ): super().__init__() self.num_frames = motion_bases.num_frames self.fg = fg_params self.motion_bases = motion_bases self.bg = bg_params scene_scale = 1.0 if bg_params is None else bg_params.scene_scale self.register_buffer("bg_scene_scale", torch.as_tensor(scene_scale)) self.register_buffer("Ks", Ks) self.register_buffer("w2cs", w2cs) self._current_xys = None self._current_radii = None self._current_img_wh = None @property def num_gaussians(self) -> int: return self.num_bg_gaussians + self.num_fg_gaussians @property def num_bg_gaussians(self) -> int: return self.bg.num_gaussians if self.bg is not None else 0 @property def num_fg_gaussians(self) -> int: return self.fg.num_gaussians @property def num_motion_bases(self) -> int: return self.motion_bases.num_bases @property def has_bg(self) -> bool: return self.bg is not None def compute_poses_bg(self) -> tuple[torch.Tensor, torch.Tensor]: """ Returns: means: (G, B, 3) quats: (G, B, 4) """ assert self.bg is not None return self.bg.params["means"], self.bg.get_quats() def compute_transforms( self, ts: torch.Tensor, inds: torch.Tensor | None = None ) -> torch.Tensor: coefs = self.fg.get_coefs() # (G, K) if inds is not None: coefs = coefs[inds] transfms = self.motion_bases.compute_transforms(ts, coefs) # (G, B, 3, 4) return transfms def compute_poses_fg( self, ts: torch.Tensor | None, inds: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor]: """ :returns means: (G, B, 3), quats: (G, B, 4) """ means = self.fg.params["means"] # (G, 3) quats = self.fg.get_quats() # (G, 4) if inds is not None: means = means[inds] quats = quats[inds] if ts is not None: transfms = self.compute_transforms(ts, inds) # (G, B, 3, 4) means = torch.einsum( "pnij,pj->pni", transfms, F.pad(means, (0, 1), value=1.0), ) quats = roma.quat_xyzw_to_wxyz( ( roma.quat_product( roma.rotmat_to_unitquat(transfms[..., :3, :3]), roma.quat_wxyz_to_xyzw(quats[:, None]), ) ) ) quats = F.normalize(quats, p=2, dim=-1) else: means = means[:, None] quats = quats[:, None] return means, quats def compute_poses_all( self, ts: torch.Tensor | None ) -> tuple[torch.Tensor, torch.Tensor]: means, quats = self.compute_poses_fg(ts) if self.has_bg: bg_means, bg_quats = self.compute_poses_bg() means = torch.cat( [means, bg_means[:, None].expand(-1, means.shape[1], -1)], dim=0 ).contiguous() quats = torch.cat( [quats, bg_quats[:, None].expand(-1, means.shape[1], -1)], dim=0 ).contiguous() return means, quats def get_colors_all(self) -> torch.Tensor: colors = self.fg.get_colors() if self.bg is not None: colors = torch.cat([colors, self.bg.get_colors()], dim=0).contiguous() return colors def get_scales_all(self) -> torch.Tensor: scales = self.fg.get_scales() if self.bg is not None: scales = torch.cat([scales, self.bg.get_scales()], dim=0).contiguous() return scales def get_opacities_all(self) -> torch.Tensor: """ :returns colors: (G, 3), scales: (G, 3), opacities: (G, 1) """ opacities = self.fg.get_opacities() if self.bg is not None: opacities = torch.cat( [opacities, self.bg.get_opacities()], dim=0 ).contiguous() return opacities @staticmethod def init_from_state_dict(state_dict, prefix=""): fg = GaussianParams.init_from_state_dict( state_dict, prefix=f"{prefix}fg.params." ) bg = None if any("bg." in k for k in state_dict): bg = GaussianParams.init_from_state_dict( state_dict, prefix=f"{prefix}bg.params." ) motion_bases = MotionBases.init_from_state_dict( state_dict, prefix=f"{prefix}motion_bases.params." ) Ks = state_dict[f"{prefix}Ks"] w2cs = state_dict[f"{prefix}w2cs"] return SceneModel(Ks, w2cs, fg, motion_bases, bg) def render( self, # A single time instance for view rendering. t: int | None, w2cs: torch.Tensor, # (C, 4, 4) Ks: torch.Tensor, # (C, 3, 3) img_wh: tuple[int, int], # Multiple time instances for track rendering: (B,). target_ts: torch.Tensor | None = None, # (B) target_w2cs: torch.Tensor | None = None, # (B, 4, 4) bg_color: torch.Tensor | float = 1.0, colors_override: torch.Tensor | None = None, means: torch.Tensor | None = None, quats: torch.Tensor | None = None, target_means: torch.Tensor | None = None, return_color: bool = True, return_depth: bool = False, return_mask: bool = False, fg_only: bool = False, filter_mask: torch.Tensor | None = None, ) -> dict: device = w2cs.device C = w2cs.shape[0] W, H = img_wh pose_fnc = self.compute_poses_fg if fg_only else self.compute_poses_all N = self.num_fg_gaussians if fg_only else self.num_gaussians if means is None or quats is None: means, quats = pose_fnc( torch.tensor([t], device=device) if t is not None else None ) means = means[:, 0] quats = quats[:, 0] if colors_override is None: if return_color: colors_override = ( self.fg.get_colors() if fg_only else self.get_colors_all() ) else: colors_override = torch.zeros(N, 0, device=device) D = colors_override.shape[-1] scales = self.fg.get_scales() if fg_only else self.get_scales_all() opacities = self.fg.get_opacities() if fg_only else self.get_opacities_all() if isinstance(bg_color, float): bg_color = torch.full((C, D), bg_color, device=device) assert isinstance(bg_color, torch.Tensor) mode = "RGB" ds_expected = {"img": D} if return_mask: if self.has_bg and not fg_only: mask_values = torch.zeros((self.num_gaussians, 1), device=device) mask_values[: self.num_fg_gaussians] = 1.0 else: mask_values = torch.ones((self.num_fg_gaussians, 1), device=device) colors_override = torch.cat([colors_override, mask_values], dim=-1) bg_color = torch.cat([bg_color, torch.zeros(C, 1, device=device)], dim=-1) ds_expected["mask"] = 1 B = 0 if target_ts is not None: B = target_ts.shape[0] if target_means is None: target_means, _ = pose_fnc(target_ts) # [G, B, 3] if target_w2cs is not None: target_means = torch.einsum( "bij,pbj->pbi", target_w2cs[:, :3], F.pad(target_means, (0, 1), value=1.0), ) track_3d_vals = target_means.flatten(-2) # (G, B * 3) d_track = track_3d_vals.shape[-1] colors_override = torch.cat([colors_override, track_3d_vals], dim=-1) bg_color = torch.cat( [bg_color, torch.zeros(C, track_3d_vals.shape[-1], device=device)], dim=-1, ) ds_expected["tracks_3d"] = d_track assert colors_override.shape[-1] == sum(ds_expected.values()) assert bg_color.shape[-1] == sum(ds_expected.values()) if return_depth: mode = "RGB+ED" ds_expected["depth"] = 1 if filter_mask is not None: assert filter_mask.shape == (N,) means = means[filter_mask] quats = quats[filter_mask] scales = scales[filter_mask] opacities = opacities[filter_mask] colors_override = colors_override[filter_mask] render_colors, alphas, info = rasterization( means=means, quats=quats, scales=scales, opacities=opacities, colors=colors_override, backgrounds=bg_color, viewmats=w2cs, # [C, 4, 4] Ks=Ks, # [C, 3, 3] width=W, height=H, packed=False, render_mode=mode, ) # Populate the current data for adaptive gaussian control. if self.training and info["means2d"].requires_grad: self._current_xys = info["means2d"] self._current_radii = info["radii"] self._current_img_wh = img_wh # We want to be able to access to xys' gradients later in a # torch.no_grad context. self._current_xys.retain_grad() assert render_colors.shape[-1] == sum(ds_expected.values()) outputs = torch.split(render_colors, list(ds_expected.values()), dim=-1) out_dict = {} for i, (name, dim) in enumerate(ds_expected.items()): x = outputs[i] assert x.shape[-1] == dim, f"{x.shape[-1]=} != {dim=}" if name == "tracks_3d": x = x.reshape(C, H, W, B, 3) out_dict[name] = x out_dict["acc"] = alphas return out_dict ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/tensor_dataclass.py ================================================ from dataclasses import dataclass from typing import Callable, TypeVar import torch from typing_extensions import Self TensorDataclassT = TypeVar("T", bound="TensorDataclass") class TensorDataclass: """A lighter version of nerfstudio's TensorDataclass: https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/utils/tensor_dataclass.py """ def __getitem__(self, key) -> Self: return self.map(lambda x: x[key]) def to(self, device: torch.device | str) -> Self: """Move the tensors in the dataclass to the given device. Args: device: The device to move to. Returns: A new dataclass. """ return self.map(lambda x: x.to(device)) def map(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Self: """Apply a function to all tensors in the dataclass. Also recurses into lists, tuples, and dictionaries. Args: fn: The function to apply to each tensor. Returns: A new dataclass. """ MapT = TypeVar("MapT") def _map_impl( fn: Callable[[torch.Tensor], torch.Tensor], val: MapT, ) -> MapT: if isinstance(val, torch.Tensor): return fn(val) elif isinstance(val, TensorDataclass): return type(val)(**_map_impl(fn, vars(val))) elif isinstance(val, (list, tuple)): return type(val)(_map_impl(fn, v) for v in val) elif isinstance(val, dict): assert type(val) is dict # No subclass support. return {k: _map_impl(fn, v) for k, v in val.items()} # type: ignore else: return val return _map_impl(fn, self) @dataclass class TrackObservations(TensorDataclass): xyz: torch.Tensor visibles: torch.Tensor invisibles: torch.Tensor confidences: torch.Tensor colors: torch.Tensor def check_sizes(self) -> bool: dims = self.xyz.shape[:-1] return ( self.visibles.shape == dims and self.invisibles.shape == dims and self.confidences.shape == dims and self.colors.shape[:-1] == dims[:-1] and self.xyz.shape[-1] == 3 and self.colors.shape[-1] == 3 ) def filter_valid(self, valid_mask: torch.Tensor) -> Self: return self.map(lambda x: x[valid_mask]) @dataclass class StaticObservations(TensorDataclass): xyz: torch.Tensor normals: torch.Tensor colors: torch.Tensor def check_sizes(self) -> bool: dims = self.xyz.shape return self.normals.shape == dims and self.colors.shape == dims def filter_valid(self, valid_mask: torch.Tensor) -> Self: return self.map(lambda x: x[valid_mask]) ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/trainer.py ================================================ import functools import time from dataclasses import asdict from typing import cast import numpy as np import torch import torch.nn.functional as F from loguru import logger as guru from nerfview import CameraState from pytorch_msssim import SSIM from torch.utils.tensorboard import SummaryWriter # type: ignore from flow3d.configs import LossesConfig, OptimizerConfig, SceneLRConfig from flow3d.loss_utils import ( compute_gradient_loss, compute_se3_smoothness_loss, compute_z_acc_loss, masked_l1_loss, ) from flow3d.metrics import PCK, mLPIPS, mPSNR, mSSIM from flow3d.scene_model import SceneModel from flow3d.vis.utils import get_server from flow3d.vis.viewer import DynamicViewer class Trainer: def __init__( self, model: SceneModel, device: torch.device, lr_cfg: SceneLRConfig, losses_cfg: LossesConfig, optim_cfg: OptimizerConfig, # Logging. work_dir: str, port: int | None = None, log_every: int = 10, checkpoint_every: int = 200, validate_every: int = 500, validate_video_every: int = 1000, validate_viewer_assets_every: int = 100, ): self.device = device self.log_every = log_every self.checkpoint_every = checkpoint_every self.validate_every = validate_every self.validate_video_every = validate_video_every self.validate_viewer_assets_every = validate_viewer_assets_every self.model = model self.num_frames = model.num_frames self.lr_cfg = lr_cfg self.losses_cfg = losses_cfg self.optim_cfg = optim_cfg self.reset_opacity_every = ( self.optim_cfg.reset_opacity_every_n_controls * self.optim_cfg.control_every ) self.optimizers, self.scheduler = self.configure_optimizers() # running stats for adaptive density control self.running_stats = { "xys_grad_norm_acc": torch.zeros(self.model.num_gaussians, device=device), "vis_count": torch.zeros( self.model.num_gaussians, device=device, dtype=torch.int64 ), "max_radii": torch.zeros(self.model.num_gaussians, device=device), } self.work_dir = work_dir self.writer = SummaryWriter(log_dir=work_dir) self.global_step = 0 self.epoch = 0 self.viewer = None if port is not None: server = get_server(port=port) self.viewer = DynamicViewer( server, self.render_fn, model.num_frames, work_dir, mode="training" ) # metrics self.ssim = SSIM(data_range=1.0, size_average=True, channel=3) self.psnr_metric = mPSNR() self.ssim_metric = mSSIM() self.lpips_metric = mLPIPS() self.pck_metric = PCK() self.bg_psnr_metric = mPSNR() self.fg_psnr_metric = mPSNR() self.bg_ssim_metric = mSSIM() self.fg_ssim_metric = mSSIM() self.bg_lpips_metric = mLPIPS() self.fg_lpips_metric = mLPIPS() def set_epoch(self, epoch: int): self.epoch = epoch def save_checkpoint(self, path: str): model_dict = self.model.state_dict() optimizer_dict = {k: v.state_dict() for k, v in self.optimizers.items()} scheduler_dict = {k: v.state_dict() for k, v in self.scheduler.items()} ckpt = { "model": model_dict, "optimizers": optimizer_dict, "schedulers": scheduler_dict, "global_step": self.global_step, "epoch": self.epoch, } torch.save(ckpt, path) guru.info(f"Saved checkpoint at {self.global_step=} to {path}") @staticmethod def init_from_checkpoint( path: str, device: torch.device, *args, **kwargs ) -> tuple["Trainer", int]: guru.info(f"Loading checkpoint from {path}") ckpt = torch.load(path) state_dict = ckpt["model"] model = SceneModel.init_from_state_dict(state_dict) model = model.to(device) trainer = Trainer(model, device, *args, **kwargs) if "optimizers" in ckpt: trainer.load_checkpoint_optimizers(ckpt["optimizers"]) if "schedulers" in ckpt: trainer.load_checkpoint_schedulers(ckpt["schedulers"]) trainer.global_step = ckpt.get("global_step", 0) start_epoch = ckpt.get("epoch", 0) trainer.set_epoch(start_epoch) return trainer, start_epoch def load_checkpoint_optimizers(self, opt_ckpt): for k, v in self.optimizers.items(): v.load_state_dict(opt_ckpt[k]) def load_checkpoint_schedulers(self, sched_ckpt): for k, v in self.scheduler.items(): v.load_state_dict(sched_ckpt[k]) @torch.inference_mode() def render_fn(self, camera_state: CameraState, img_wh: tuple[int, int]): W, H = img_wh focal = 0.5 * H / np.tan(0.5 * camera_state.fov).item() K = torch.tensor( [[focal, 0.0, W / 2.0], [0.0, focal, H / 2.0], [0.0, 0.0, 1.0]], device=self.device, ) w2c = torch.linalg.inv( torch.from_numpy(camera_state.c2w.astype(np.float32)).to(self.device) ) t = 0 if self.viewer is not None: t = ( int(self.viewer._playback_guis[0].value) if not self.viewer._canonical_checkbox.value else None ) self.model.training = False img = self.model.render(t, w2c[None], K[None], img_wh)["img"][0] return (img.cpu().numpy() * 255.0).astype(np.uint8) def train_step(self, batch): if self.viewer is not None: while self.viewer.state.status == "paused": time.sleep(0.1) self.viewer.lock.acquire() multi_loss = 0.0 # import ipdb # ipdb.set_trace() for view_index in [0, 1, 2, 3]: view_data = batch[view_index] loss, stats, num_rays_per_step, num_rays_per_sec = self.compute_losses(view_data) if loss.isnan(): guru.info(f"Loss is NaN at step {self.global_step}!!") import ipdb ipdb.set_trace() multi_loss += loss / 4 multi_loss.backward() # loss.backward() for opt in self.optimizers.values(): opt.step() opt.zero_grad(set_to_none=True) for sched in self.scheduler.values(): sched.step() self.log_dict(stats) self.global_step += 1 self.run_control_steps() if self.viewer is not None: self.viewer.lock.release() self.viewer.state.num_train_rays_per_sec = num_rays_per_sec if self.viewer.mode == "training": self.viewer.update(self.global_step, num_rays_per_step) if self.global_step % self.checkpoint_every == 0: self.save_checkpoint(f"{self.work_dir}/checkpoints/last.ckpt") # return loss.item() return multi_loss.item() def compute_losses(self, batch): self.model.training = True B = batch["imgs"].shape[0] W, H = img_wh = batch["imgs"].shape[2:0:-1] N = batch["target_ts"][0].shape[0] # (B,). ts = batch["ts"] # (B, 4, 4). w2cs = batch["w2cs"] # (B, 3, 3). Ks = batch["Ks"] # (B, H, W, 3). imgs = batch["imgs"] # (B, H, W). valid_masks = batch.get("valid_masks", torch.ones_like(batch["imgs"][..., 0])) # (B, H, W). masks = batch["masks"] masks *= valid_masks # (B, H, W). depths = batch["depths"] # [(P, 2), ...]. query_tracks_2d = batch["query_tracks_2d"] # [(N,), ...]. target_ts = batch["target_ts"] # [(N, 4, 4), ...]. target_w2cs = batch["target_w2cs"] # [(N, 3, 3), ...]. target_Ks = batch["target_Ks"] # [(N, P, 2), ...]. target_tracks_2d = batch["target_tracks_2d"] # [(N, P), ...]. target_visibles = batch["target_visibles"] # [(N, P), ...]. target_invisibles = batch["target_invisibles"] # [(N, P), ...]. target_confidences = batch["target_confidences"] # [(N, P), ...]. target_track_depths = batch["target_track_depths"] _tic = time.time() # (B, G, 3). means, quats = self.model.compute_poses_all(ts) # (G, B, 3), (G, B, 4) device = means.device means = means.transpose(0, 1) quats = quats.transpose(0, 1) # [(N, G, 3), ...]. target_ts_vec = torch.cat(target_ts) # (B * N, G, 3). target_means, _ = self.model.compute_poses_all(target_ts_vec) target_means = target_means.transpose(0, 1) target_mean_list = target_means.split(N) num_frames = self.model.num_frames loss = 0.0 bg_colors = [] rendered_all = [] self._batched_xys = [] self._batched_radii = [] self._batched_img_wh = [] for i in range(B): bg_color = torch.ones(1, 3, device=device) # import ipdb # ipdb.set_trace() rendered = self.model.render( ts[i].item(), w2cs[None, i], Ks[None, i], img_wh, target_ts=target_ts[i], target_w2cs=target_w2cs[i], bg_color=bg_color, means=means[i], quats=quats[i], target_means=target_mean_list[i].transpose(0, 1), return_depth=True, return_mask=self.model.has_bg, ) rendered_all.append(rendered) bg_colors.append(bg_color) if ( self.model._current_xys is not None and self.model._current_radii is not None and self.model._current_img_wh is not None ): self._batched_xys.append(self.model._current_xys) self._batched_radii.append(self.model._current_radii) self._batched_img_wh.append(self.model._current_img_wh) # Necessary to make viewer work. num_rays_per_step = H * W * B num_rays_per_sec = num_rays_per_step / (time.time() - _tic) # (B, H, W, N, *). rendered_all = { key: ( torch.cat([out_dict[key] for out_dict in rendered_all], dim=0) if rendered_all[0][key] is not None else None ) for key in rendered_all[0] } bg_colors = torch.cat(bg_colors, dim=0) # Compute losses. # (B * N). frame_intervals = (ts.repeat_interleave(N) - target_ts_vec).abs() if not self.model.has_bg: imgs = ( imgs * masks[..., None] + (1.0 - masks[..., None]) * bg_colors[:, None, None] ) else: imgs = ( imgs * valid_masks[..., None] + (1.0 - valid_masks[..., None]) * bg_colors[:, None, None] ) # (P_all, 2). tracks_2d = torch.cat([x.reshape(-1, 2) for x in target_tracks_2d], dim=0) # (P_all,) visibles = torch.cat([x.reshape(-1) for x in target_visibles], dim=0) # (P_all,) confidences = torch.cat([x.reshape(-1) for x in target_confidences], dim=0) # RGB loss. rendered_imgs = cast(torch.Tensor, rendered_all["img"]) if self.model.has_bg: rendered_imgs = ( rendered_imgs * valid_masks[..., None] + (1.0 - valid_masks[..., None]) * bg_colors[:, None, None] ) # import cv2 # print(imgs[0].shape) # print(imgs[0].max()) # cv2.imwrite("/cluster/scratch/egundogdu/rendered_image.jpg", ((rendered_imgs[0]*255.0).cpu().detach().numpy()).astype(np.uint8)) # if True: # import ipdb # ipdb.set_trace() rgb_loss = 0.8 * F.l1_loss(rendered_imgs, imgs) + 0.2 * ( 1 - self.ssim(rendered_imgs.permute(0, 3, 1, 2), imgs.permute(0, 3, 1, 2)) ) loss += rgb_loss * self.losses_cfg.w_rgb # Mask loss. if not self.model.has_bg: mask_loss = F.mse_loss(rendered_all["acc"], masks[..., None]) # type: ignore else: mask_loss = F.mse_loss( rendered_all["acc"], torch.ones_like(rendered_all["acc"]) # type: ignore ) + masked_l1_loss( rendered_all["mask"], masks[..., None], quantile=0.98, # type: ignore ) loss += mask_loss * self.losses_cfg.w_mask # (B * N, H * W, 3). pred_tracks_3d = ( rendered_all["tracks_3d"].permute(0, 3, 1, 2, 4).reshape(-1, H * W, 3) # type: ignore ) pred_tracks_2d = torch.einsum( "bij,bpj->bpi", torch.cat(target_Ks), pred_tracks_3d ) # (B * N, H * W, 1). mapped_depth = torch.clamp(pred_tracks_2d[..., 2:], min=1e-6) # (B * N, H * W, 2). pred_tracks_2d = pred_tracks_2d[..., :2] / mapped_depth # (B * N). w_interval = torch.exp(-2 * frame_intervals / num_frames) # w_track_loss = min(1, (self.max_steps - self.global_step) / 6000) track_weights = confidences[..., None] * w_interval # (B, H, W). masks_flatten = torch.zeros_like(masks) for i in range(B): # This takes advantage of the fact that the query 2D tracks are # always on the grid. query_pixels = query_tracks_2d[i].to(torch.int64) masks_flatten[i, query_pixels[:, 1], query_pixels[:, 0]] = 1.0 # (B * N, H * W). masks_flatten = ( masks_flatten.reshape(-1, H * W).tile(1, N).reshape(-1, H * W) > 0.5 ) track_2d_loss = masked_l1_loss( pred_tracks_2d[masks_flatten][visibles], tracks_2d[visibles], mask=track_weights[visibles], quantile=0.98, ) / max(H, W) loss += track_2d_loss * self.losses_cfg.w_track depth_masks = ( masks[..., None] if not self.model.has_bg else valid_masks[..., None] ) pred_depth = cast(torch.Tensor, rendered_all["depth"]) pred_disp = 1.0 / (pred_depth + 1e-5) tgt_disp = 1.0 / (depths[..., None] + 1e-5) depth_loss = masked_l1_loss( pred_disp, tgt_disp, mask=depth_masks, quantile=0.98, ) # depth_loss = cauchy_loss_with_uncertainty( # pred_disp.squeeze(-1), # tgt_disp.squeeze(-1), # depth_masks.squeeze(-1), # self.depth_uncertainty_activation(self.depth_uncertainties)[ts], # bias=1e-3, # ) loss += depth_loss * self.losses_cfg.w_depth_reg # mapped depth loss (using cached depth with EMA) # mapped_depth_loss = 0.0 mapped_depth_gt = torch.cat([x.reshape(-1) for x in target_track_depths], dim=0) mapped_depth_loss = masked_l1_loss( 1 / (mapped_depth[masks_flatten][visibles] + 1e-5), 1 / (mapped_depth_gt[visibles, None] + 1e-5), track_weights[visibles], ) loss += mapped_depth_loss * self.losses_cfg.w_depth_const # depth_gradient_loss = 0.0 depth_gradient_loss = compute_gradient_loss( pred_disp, tgt_disp, mask=depth_masks > 0.5, quantile=0.95, ) # depth_gradient_loss = compute_gradient_loss( # pred_disps, # ref_disps, # mask=depth_masks.squeeze(-1) > 0.5, # c=depth_uncertainty.detach(), # mode="l1", # bias=1e-3, # ) loss += depth_gradient_loss * self.losses_cfg.w_depth_grad # bases should be smooth. small_accel_loss = compute_se3_smoothness_loss( self.model.motion_bases.params["rots"], self.model.motion_bases.params["transls"], ) loss += small_accel_loss * self.losses_cfg.w_smooth_bases # tracks should be smooth ts = torch.clamp(ts, min=1, max=num_frames - 2) ts_neighbors = torch.cat((ts - 1, ts, ts + 1)) transfms_nbs = self.model.compute_transforms(ts_neighbors) # (G, 3n, 3, 4) means_fg_nbs = torch.einsum( "pnij,pj->pni", transfms_nbs, F.pad(self.model.fg.params["means"], (0, 1), value=1.0), ) means_fg_nbs = means_fg_nbs.reshape( means_fg_nbs.shape[0], 3, -1, 3 ) # [G, 3, n, 3] if self.losses_cfg.w_smooth_tracks > 0: small_accel_loss_tracks = 0.5 * ( (2 * means_fg_nbs[:, 1:-1] - means_fg_nbs[:, :-2] - means_fg_nbs[:, 2:]) .norm(dim=-1) .mean() ) loss += small_accel_loss_tracks * self.losses_cfg.w_smooth_tracks # Constrain the std of scales. # TODO: do we want to penalize before or after exp? loss += ( self.losses_cfg.w_scale_var * torch.var(self.model.fg.params["scales"], dim=-1).mean() ) if self.model.bg is not None: loss += ( self.losses_cfg.w_scale_var * torch.var(self.model.bg.params["scales"], dim=-1).mean() ) # # sparsity loss # loss += 0.01 * self.opacity_activation(self.opacities).abs().mean() # Acceleration along ray direction should be small. z_accel_loss = compute_z_acc_loss(means_fg_nbs, w2cs) loss += self.losses_cfg.w_z_accel * z_accel_loss # Prepare stats for logging. stats = { "train/loss": loss.item(), "train/rgb_loss": rgb_loss.item(), "train/mask_loss": mask_loss.item(), "train/depth_loss": depth_loss.item(), "train/depth_gradient_loss": depth_gradient_loss.item(), "train/mapped_depth_loss": mapped_depth_loss.item(), "train/track_2d_loss": track_2d_loss.item(), "train/small_accel_loss": small_accel_loss.item(), "train/z_acc_loss": z_accel_loss.item(), "train/num_gaussians": self.model.num_gaussians, "train/num_fg_gaussians": self.model.num_fg_gaussians, "train/num_bg_gaussians": self.model.num_bg_gaussians, } # Compute metrics. with torch.no_grad(): psnr = self.psnr_metric( rendered_imgs, imgs, masks if not self.model.has_bg else valid_masks ) self.psnr_metric.reset() stats["train/psnr"] = psnr if self.model.has_bg: bg_psnr = self.bg_psnr_metric(rendered_imgs, imgs, 1.0 - masks) fg_psnr = self.fg_psnr_metric(rendered_imgs, imgs, masks) self.bg_psnr_metric.reset() self.fg_psnr_metric.reset() stats["train/bg_psnr"] = bg_psnr stats["train/fg_psnr"] = fg_psnr stats.update( **{ "train/num_rays_per_sec": num_rays_per_sec, "train/num_rays_per_step": float(num_rays_per_step), } ) # print(stats) return loss, stats, num_rays_per_step, num_rays_per_sec def log_dict(self, stats: dict): for k, v in stats.items(): self.writer.add_scalar(k, v, self.global_step) def run_control_steps(self): global_step = self.global_step # Adaptive gaussian control. cfg = self.optim_cfg num_frames = self.model.num_frames ready = self._prepare_control_step() if ( ready and global_step > cfg.warmup_steps and global_step % cfg.control_every == 0 and global_step < cfg.stop_control_steps ): if ( global_step < cfg.stop_densify_steps and global_step % self.reset_opacity_every > num_frames ): self._densify_control_step(global_step) if global_step % self.reset_opacity_every > min(3 * num_frames, 1000): self._cull_control_step(global_step) if global_step % self.reset_opacity_every == 0: self._reset_opacity_control_step() # Reset stats after every control. for k in self.running_stats: self.running_stats[k].zero_() @torch.no_grad() def _prepare_control_step(self) -> bool: # Prepare for adaptive gaussian control based on the current stats. if not ( self.model._current_radii is not None and self.model._current_xys is not None ): guru.warning("Model not training, skipping control step preparation") return False batch_size = len(self._batched_xys) # these quantities are for each rendered view and have shapes (C, G, *) # must be aggregated over all views for _current_xys, _current_radii, _current_img_wh in zip( self._batched_xys, self._batched_radii, self._batched_img_wh ): sel = _current_radii > 0 gidcs = torch.where(sel)[1] # normalize grads to [-1, 1] screen space xys_grad = _current_xys.grad.clone() xys_grad[..., 0] *= _current_img_wh[0] / 2.0 * batch_size xys_grad[..., 1] *= _current_img_wh[1] / 2.0 * batch_size self.running_stats["xys_grad_norm_acc"].index_add_( 0, gidcs, xys_grad[sel].norm(dim=-1) ) self.running_stats["vis_count"].index_add_( 0, gidcs, torch.ones_like(gidcs, dtype=torch.int64) ) max_radii = torch.maximum( self.running_stats["max_radii"].index_select(0, gidcs), _current_radii[sel] / max(_current_img_wh), ) self.running_stats["max_radii"].index_put((gidcs,), max_radii) return True @torch.no_grad() def _densify_control_step(self, global_step): assert (self.running_stats["vis_count"] > 0).any() cfg = self.optim_cfg xys_grad_avg = self.running_stats["xys_grad_norm_acc"] / self.running_stats[ "vis_count" ].clamp_min(1) is_grad_too_high = xys_grad_avg > cfg.densify_xys_grad_threshold # Split gaussians. scales = self.model.get_scales_all() is_scale_too_big = scales.amax(dim=-1) > cfg.densify_scale_threshold if global_step < cfg.stop_control_by_screen_steps: is_radius_too_big = ( self.running_stats["max_radii"] > cfg.densify_screen_threshold ) else: is_radius_too_big = torch.zeros_like(is_grad_too_high, dtype=torch.bool) should_split = is_grad_too_high & (is_scale_too_big | is_radius_too_big) should_dup = is_grad_too_high & ~is_scale_too_big num_fg = self.model.num_fg_gaussians should_fg_split = should_split[:num_fg] num_fg_splits = int(should_fg_split.sum().item()) should_fg_dup = should_dup[:num_fg] num_fg_dups = int(should_fg_dup.sum().item()) should_bg_split = should_split[num_fg:] num_bg_splits = int(should_bg_split.sum().item()) should_bg_dup = should_dup[num_fg:] num_bg_dups = int(should_bg_dup.sum().item()) fg_param_map = self.model.fg.densify_params(should_fg_split, should_fg_dup) for param_name, new_params in fg_param_map.items(): full_param_name = f"fg.params.{param_name}" optimizer = self.optimizers[full_param_name] dup_in_optim( optimizer, [new_params], should_fg_split, num_fg_splits * 2 + num_fg_dups, ) if self.model.bg is not None: bg_param_map = self.model.bg.densify_params(should_bg_split, should_bg_dup) for param_name, new_params in bg_param_map.items(): full_param_name = f"bg.params.{param_name}" optimizer = self.optimizers[full_param_name] dup_in_optim( optimizer, [new_params], should_bg_split, num_bg_splits * 2 + num_bg_dups, ) # update running stats for k, v in self.running_stats.items(): v_fg, v_bg = v[:num_fg], v[num_fg:] new_v = torch.cat( [ v_fg[~should_fg_split], v_fg[should_fg_dup], v_fg[should_fg_split].repeat(2), v_bg[~should_bg_split], v_bg[should_bg_dup], v_bg[should_bg_split].repeat(2), ], dim=0, ) self.running_stats[k] = new_v guru.info( f"Split {should_split.sum().item()} gaussians, " f"Duplicated {should_dup.sum().item()} gaussians, " f"{self.model.num_gaussians} gaussians left" ) @torch.no_grad() def _cull_control_step(self, global_step): # Cull gaussians. cfg = self.optim_cfg opacities = self.model.get_opacities_all() device = opacities.device is_opacity_too_small = opacities < cfg.cull_opacity_threshold is_radius_too_big = torch.zeros_like(is_opacity_too_small, dtype=torch.bool) is_scale_too_big = torch.zeros_like(is_opacity_too_small, dtype=torch.bool) cull_scale_threshold = ( torch.ones(len(is_scale_too_big), device=device) * cfg.cull_scale_threshold ) num_fg = self.model.num_fg_gaussians cull_scale_threshold[num_fg:] *= self.model.bg_scene_scale if global_step > self.reset_opacity_every: scales = self.model.get_scales_all() is_scale_too_big = scales.amax(dim=-1) > cull_scale_threshold if global_step < cfg.stop_control_by_screen_steps: is_radius_too_big = ( self.running_stats["max_radii"] > cfg.cull_screen_threshold ) should_cull = is_opacity_too_small | is_radius_too_big | is_scale_too_big should_fg_cull = should_cull[:num_fg] should_bg_cull = should_cull[num_fg:] fg_param_map = self.model.fg.cull_params(should_fg_cull) for param_name, new_params in fg_param_map.items(): full_param_name = f"fg.params.{param_name}" optimizer = self.optimizers[full_param_name] remove_from_optim(optimizer, [new_params], should_fg_cull) if self.model.bg is not None: bg_param_map = self.model.bg.cull_params(should_bg_cull) for param_name, new_params in bg_param_map.items(): full_param_name = f"bg.params.{param_name}" optimizer = self.optimizers[full_param_name] remove_from_optim(optimizer, [new_params], should_bg_cull) # update running stats for k, v in self.running_stats.items(): self.running_stats[k] = v[~should_cull] guru.info( f"Culled {should_cull.sum().item()} gaussians, " f"{self.model.num_gaussians} gaussians left" ) @torch.no_grad() def _reset_opacity_control_step(self): # Reset gaussian opacities. new_val = torch.logit(torch.tensor(0.8 * self.optim_cfg.cull_opacity_threshold)) for part in ["fg", "bg"]: part_params = getattr(self.model, part).reset_opacities(new_val) # Modify optimizer states by new assignment. for param_name, new_params in part_params.items(): full_param_name = f"{part}.params.{param_name}" optimizer = self.optimizers[full_param_name] reset_in_optim(optimizer, [new_params]) guru.info("Reset opacities") def configure_optimizers(self): def _exponential_decay(step, *, lr_init, lr_final): t = np.clip(step / self.optim_cfg.max_steps, 0.0, 1.0) lr = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) return lr / lr_init lr_dict = asdict(self.lr_cfg) optimizers = {} schedulers = {} # named parameters will be [part].params.[field] # e.g. fg.params.means # lr config is a nested dict for each fg/bg part for name, params in self.model.named_parameters(): part, _, field = name.split(".") lr = lr_dict[part][field] optim = torch.optim.Adam([{"params": params, "lr": lr, "name": name}]) if "scales" in name: fnc = functools.partial(_exponential_decay, lr_final=0.1 * lr) else: fnc = lambda _, **__: 1.0 optimizers[name] = optim schedulers[name] = torch.optim.lr_scheduler.LambdaLR( optim, functools.partial(fnc, lr_init=lr) ) return optimizers, schedulers def dup_in_optim(optimizer, new_params: list, should_dup: torch.Tensor, num_dups: int): assert len(optimizer.param_groups) == len(new_params) for i, p_new in enumerate(new_params): old_params = optimizer.param_groups[i]["params"][0] param_state = optimizer.state[old_params] if len(param_state) == 0: return for key in param_state: if key == "step": continue p = param_state[key] param_state[key] = torch.cat( [p[~should_dup], p.new_zeros(num_dups, *p.shape[1:])], dim=0, ) del optimizer.state[old_params] optimizer.state[p_new] = param_state optimizer.param_groups[i]["params"] = [p_new] del old_params torch.cuda.empty_cache() def remove_from_optim(optimizer, new_params: list, _should_cull: torch.Tensor): assert len(optimizer.param_groups) == len(new_params) for i, p_new in enumerate(new_params): old_params = optimizer.param_groups[i]["params"][0] param_state = optimizer.state[old_params] if len(param_state) == 0: return for key in param_state: if key == "step": continue param_state[key] = param_state[key][~_should_cull] del optimizer.state[old_params] optimizer.state[p_new] = param_state optimizer.param_groups[i]["params"] = [p_new] del old_params torch.cuda.empty_cache() def reset_in_optim(optimizer, new_params: list): assert len(optimizer.param_groups) == len(new_params) for i, p_new in enumerate(new_params): old_params = optimizer.param_groups[i]["params"][0] param_state = optimizer.state[old_params] if len(param_state) == 0: return for key in param_state: param_state[key] = torch.zeros_like(param_state[key]) del optimizer.state[old_params] optimizer.state[p_new] = param_state optimizer.param_groups[i]["params"] = [p_new] del old_params torch.cuda.empty_cache() ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/trajectories.py ================================================ import numpy as np import roma import torch import torch.nn.functional as F from .transforms import rt_to_mat4 def get_avg_w2c(w2cs: torch.Tensor): c2ws = torch.linalg.inv(w2cs) # 1. Compute the center center = c2ws[:, :3, -1].mean(0) # 2. Compute the z axis z = F.normalize(c2ws[:, :3, 2].mean(0), dim=-1) # 3. Compute axis y' (no need to normalize as it's not the final output) y_ = c2ws[:, :3, 1].mean(0) # (3) # 4. Compute the x axis x = F.normalize(torch.cross(y_, z, dim=-1), dim=-1) # (3) # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) y = torch.cross(z, x, dim=-1) # (3) avg_c2w = rt_to_mat4(torch.stack([x, y, z], 1), center) avg_w2c = torch.linalg.inv(avg_c2w) return avg_w2c def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor: """Triangulate a set of rays to find a single lookat point. Args: origins (torch.Tensor): A (N, 3) array of ray origins. viewdirs (torch.Tensor): A (N, 3) array of ray view directions. Returns: torch.Tensor: A (3,) lookat point. """ viewdirs = torch.nn.functional.normalize(viewdirs, dim=-1) eye = torch.eye(3, device=origins.device, dtype=origins.dtype)[None] # Calculate projection matrix I - rr^T I_min_cov = eye - (viewdirs[..., None] * viewdirs[..., None, :]) # Compute sum of projections sum_proj = I_min_cov.matmul(origins[..., None]).sum(dim=-3) # Solve for the intersection point using least squares lookat = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] # Check NaNs. assert not torch.any(torch.isnan(lookat)) return lookat def get_lookat_w2cs(positions: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor): """ Args: positions: (N, 3) tensor of camera positions lookat: (3,) tensor of lookat point up: (3,) tensor of up vector Returns: w2cs: (N, 3, 3) tensor of world to camera rotation matrices """ forward_vectors = F.normalize(lookat - positions, dim=-1) right_vectors = F.normalize(torch.cross(forward_vectors, up[None], dim=-1), dim=-1) down_vectors = F.normalize( torch.cross(forward_vectors, right_vectors, dim=-1), dim=-1 ) Rs = torch.stack([right_vectors, down_vectors, forward_vectors], dim=-1) w2cs = torch.linalg.inv(rt_to_mat4(Rs, positions)) return w2cs def get_arc_w2cs( ref_w2c: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor, num_frames: int, degree: float, **_, ) -> torch.Tensor: ref_position = torch.linalg.inv(ref_w2c)[:3, 3] thetas = ( torch.sin( torch.linspace(0.0, torch.pi * 2.0, num_frames + 1, device=ref_w2c.device)[ :-1 ] ) * (degree / 2.0) / 180.0 * torch.pi ) positions = torch.einsum( "nij,j->ni", roma.rotvec_to_rotmat(thetas[:, None] * up[None]), ref_position - lookat, ) return get_lookat_w2cs(positions, lookat, up) def get_lemniscate_w2cs( ref_w2c: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor, num_frames: int, degree: float, **_, ) -> torch.Tensor: ref_c2w = torch.linalg.inv(ref_w2c) a = torch.linalg.norm(ref_c2w[:3, 3] - lookat) * np.tan(degree / 360 * np.pi) # Lemniscate curve in camera space. Starting at the origin. thetas = ( torch.linspace(0, 2 * torch.pi, num_frames + 1, device=ref_w2c.device)[:-1] + torch.pi / 2 ) positions = torch.stack( [ a * torch.cos(thetas) / (1 + torch.sin(thetas) ** 2), a * torch.cos(thetas) * torch.sin(thetas) / (1 + torch.sin(thetas) ** 2), torch.zeros(num_frames, device=ref_w2c.device), ], dim=-1, ) # Transform to world space. positions = torch.einsum( "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0) ) return get_lookat_w2cs(positions, lookat, up) def get_spiral_w2cs( ref_w2c: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor, num_frames: int, rads: float | torch.Tensor, zrate: float, rots: int, **_, ) -> torch.Tensor: ref_c2w = torch.linalg.inv(ref_w2c) thetas = torch.linspace( 0, 2 * torch.pi * rots, num_frames + 1, device=ref_w2c.device )[:-1] # Spiral curve in camera space. Starting at the origin. if isinstance(rads, torch.Tensor): rads = rads.reshape(-1, 3).to(ref_w2c.device) positions = ( torch.stack( [ torch.cos(thetas), -torch.sin(thetas), -torch.sin(thetas * zrate), ], dim=-1, ) * rads ) # Transform to world space. positions = torch.einsum( "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0) ) return get_lookat_w2cs(positions, lookat, up) def get_wander_w2cs(ref_w2c, focal_length, num_frames, **_): device = ref_w2c.device c2w = np.linalg.inv(ref_w2c.detach().cpu().numpy()) max_disp = 48.0 max_trans = max_disp / focal_length output_poses = [] for i in range(num_frames): x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames)) y_trans = 0.0 z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 2.0 i_pose = np.concatenate( [ np.concatenate( [ np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis], ], axis=1, ), np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :], ], axis=0, ) i_pose = np.linalg.inv(i_pose) ref_pose = np.concatenate( [c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0 ) render_pose = np.dot(ref_pose, i_pose) output_poses.append(render_pose) output_poses = torch.from_numpy(np.array(output_poses, dtype=np.float32)).to(device) w2cs = torch.linalg.inv(output_poses) return w2cs ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/transforms.py ================================================ from typing import Literal import roma import torch import torch.nn.functional as F def rt_to_mat4( R: torch.Tensor, t: torch.Tensor, s: torch.Tensor | None = None ) -> torch.Tensor: """ Args: R (torch.Tensor): (..., 3, 3). t (torch.Tensor): (..., 3). s (torch.Tensor): (...,). Returns: torch.Tensor: (..., 4, 4) """ mat34 = torch.cat([R, t[..., None]], dim=-1) if s is None: bottom = ( mat34.new_tensor([[0.0, 0.0, 0.0, 1.0]]) .reshape((1,) * (mat34.dim() - 2) + (1, 4)) .expand(mat34.shape[:-2] + (1, 4)) ) else: bottom = F.pad(1.0 / s[..., None, None], (3, 0), value=0.0) mat4 = torch.cat([mat34, bottom], dim=-2) return mat4 def rmat_to_cont_6d(matrix): """ :param matrix (*, 3, 3) :returns 6d vector (*, 6) """ return torch.cat([matrix[..., 0], matrix[..., 1]], dim=-1) def cont_6d_to_rmat(cont_6d): """ :param 6d vector (*, 6) :returns matrix (*, 3, 3) """ x1 = cont_6d[..., 0:3] y1 = cont_6d[..., 3:6] x = F.normalize(x1, dim=-1) y = F.normalize(y1 - (y1 * x).sum(dim=-1, keepdim=True) * x, dim=-1) z = torch.linalg.cross(x, y, dim=-1) return torch.stack([x, y, z], dim=-1) def solve_procrustes( src: torch.Tensor, dst: torch.Tensor, weights: torch.Tensor | None = None, enforce_se3: bool = False, rot_type: Literal["quat", "mat", "6d"] = "quat", ): """ Solve the Procrustes problem to align two point clouds, by solving the following problem: min_{s, R, t} || s * (src @ R.T + t) - dst ||_2, s.t. R.T @ R = I and det(R) = 1. Args: src (torch.Tensor): (N, 3). dst (torch.Tensor): (N, 3). weights (torch.Tensor | None): (N,), optional weights for alignment. enforce_se3 (bool): Whether to enforce the transfm to be SE3. Returns: sim3 (tuple[torch.Tensor, torch.Tensor, torch.Tensor]): q (torch.Tensor): (4,), rotation component in quaternion of WXYZ format. t (torch.Tensor): (3,), translation component. s (torch.Tensor): (), scale component. error (torch.Tensor): (), average L2 distance after alignment. """ # Compute weights. if weights is None: weights = src.new_ones(src.shape[0]) weights = weights[:, None] / weights.sum() # Normalize point positions. src_mean = (src * weights).sum(dim=0) dst_mean = (dst * weights).sum(dim=0) src_cent = src - src_mean dst_cent = dst - dst_mean # Normalize point scales. if not enforce_se3: src_scale = (src_cent**2 * weights).sum(dim=-1).mean().sqrt() dst_scale = (dst_cent**2 * weights).sum(dim=-1).mean().sqrt() else: src_scale = dst_scale = src.new_tensor(1.0) src_scaled = src_cent / src_scale dst_scaled = dst_cent / dst_scale # Compute the matrix for the singular value decomposition (SVD). matrix = (weights * dst_scaled).T @ src_scaled U, _, Vh = torch.linalg.svd(matrix) # Special reflection case. S = torch.eye(3, device=src.device) if torch.det(U) * torch.det(Vh) < 0: S[2, 2] = -1 R = U @ S @ Vh # Compute the transformation. if rot_type == "quat": rot = roma.rotmat_to_unitquat(R).roll(1, dims=-1) elif rot_type == "6d": rot = rmat_to_cont_6d(R) else: rot = R s = dst_scale / src_scale t = dst_mean / s - src_mean @ R.T sim3 = rot, t, s # Debug: error. procrustes_dst = torch.einsum( "ij,nj->ni", rt_to_mat4(R, t, s), F.pad(src, (0, 1), value=1.0) ) procrustes_dst = procrustes_dst[:, :3] / procrustes_dst[:, 3:] error_before = (torch.linalg.norm(dst - src, dim=-1) * weights[:, 0]).sum() error = (torch.linalg.norm(dst - procrustes_dst, dim=-1) * weights[:, 0]).sum() # print(f"Procrustes error: {error_before} -> {error}") # if error_before < error: # print("Something is wrong.") # __import__("ipdb").set_trace() return sim3, (error.item(), error_before.item()) ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/validator.py ================================================ import functools import os import os.path as osp import time from dataclasses import asdict from typing import cast import imageio as iio import numpy as np import torch import torch.nn.functional as F from loguru import logger as guru from nerfview import CameraState, Viewer from pytorch_msssim import SSIM from torch.utils.data import DataLoader, Dataset from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from flow3d.configs import LossesConfig, OptimizerConfig, SceneLRConfig from flow3d.data.utils import normalize_coords, to_device from flow3d.metrics import PCK, mLPIPS, mPSNR, mSSIM from flow3d.scene_model import SceneModel from flow3d.vis.utils import ( apply_depth_colormap, make_video_divisble, plot_correspondences, ) class Validator: def __init__( self, model: SceneModel, device: torch.device, train_loader: DataLoader | None, val_img_loader: DataLoader | None, val_kpt_loader: DataLoader | None, save_dir: str, ): self.model = model self.device = device self.train_loader = train_loader self.val_img_loader = val_img_loader self.val_kpt_loader = val_kpt_loader self.save_dir = save_dir self.has_bg = self.model.has_bg # metrics self.ssim = SSIM(data_range=1.0, size_average=True, channel=3) self.psnr_metric = mPSNR() self.ssim_metric = mSSIM() self.lpips_metric = mLPIPS().to(device) self.fg_psnr_metric = mPSNR() self.fg_ssim_metric = mSSIM() self.fg_lpips_metric = mLPIPS().to(device) self.bg_psnr_metric = mPSNR() self.bg_ssim_metric = mSSIM() self.bg_lpips_metric = mLPIPS().to(device) self.pck_metric = PCK() def reset_metrics(self): self.psnr_metric.reset() self.ssim_metric.reset() self.lpips_metric.reset() self.fg_psnr_metric.reset() self.fg_ssim_metric.reset() self.fg_lpips_metric.reset() self.bg_psnr_metric.reset() self.bg_ssim_metric.reset() self.bg_lpips_metric.reset() self.pck_metric.reset() @torch.no_grad() def validate(self): self.reset_metrics() metric_imgs = self.validate_imgs() or {} metric_kpts = self.validate_keypoints() or {} return {**metric_imgs, **metric_kpts} @torch.no_grad() def validate_imgs(self): guru.info("rendering validation images...") if self.val_img_loader is None: return for batch in tqdm(self.val_img_loader, desc="render val images"): batch = to_device(batch, self.device) frame_name = batch["frame_names"][0] t = batch["ts"][0] # (1, 4, 4). w2c = batch["w2cs"] # (1, 3, 3). K = batch["Ks"] # (1, H, W, 3). img = batch["imgs"] # (1, H, W). valid_mask = batch.get( "valid_masks", torch.ones_like(batch["imgs"][..., 0]) ) # (1, H, W). fg_mask = batch["masks"] # (H, W). covisible_mask = batch.get( "covisible_masks", torch.ones_like(fg_mask)[None], ) W, H = img_wh = img[0].shape[-2::-1] rendered = self.model.render(t, w2c, K, img_wh, return_depth=True) # Compute metrics. valid_mask *= covisible_mask fg_valid_mask = fg_mask * valid_mask bg_valid_mask = (1 - fg_mask) * valid_mask main_valid_mask = valid_mask if self.has_bg else fg_valid_mask self.psnr_metric.update(rendered["img"], img, main_valid_mask) self.ssim_metric.update(rendered["img"], img, main_valid_mask) self.lpips_metric.update(rendered["img"], img, main_valid_mask) if self.has_bg: self.fg_psnr_metric.update(rendered["img"], img, fg_valid_mask) self.fg_ssim_metric.update(rendered["img"], img, fg_valid_mask) self.fg_lpips_metric.update(rendered["img"], img, fg_valid_mask) self.bg_psnr_metric.update(rendered["img"], img, bg_valid_mask) self.bg_ssim_metric.update(rendered["img"], img, bg_valid_mask) self.bg_lpips_metric.update(rendered["img"], img, bg_valid_mask) # Dump results. results_dir = osp.join(self.save_dir, "results", "rgb") os.makedirs(results_dir, exist_ok=True) iio.imwrite( osp.join(results_dir, f"{frame_name}.png"), (rendered["img"][0].cpu().numpy() * 255).astype(np.uint8), ) return { "val/psnr": self.psnr_metric.compute(), "val/ssim": self.ssim_metric.compute(), "val/lpips": self.lpips_metric.compute(), "val/fg_psnr": self.fg_psnr_metric.compute(), "val/fg_ssim": self.fg_ssim_metric.compute(), "val/fg_lpips": self.fg_lpips_metric.compute(), "val/bg_psnr": self.bg_psnr_metric.compute(), "val/bg_ssim": self.bg_ssim_metric.compute(), "val/bg_lpips": self.bg_lpips_metric.compute(), } @torch.no_grad() def validate_keypoints(self): if self.val_kpt_loader is None: return pred_keypoints_3d_all = [] time_ids = self.val_kpt_loader.dataset.time_ids.tolist() h, w = self.val_kpt_loader.dataset.dataset.imgs.shape[1:3] pred_train_depths = np.zeros((len(time_ids), h, w)) for batch in tqdm(self.val_kpt_loader, desc="render val keypoints"): batch = to_device(batch, self.device) # (2,). ts = batch["ts"][0] # (2, 4, 4). w2cs = batch["w2cs"][0] # (2, 3, 3). Ks = batch["Ks"][0] # (2, H, W, 3). imgs = batch["imgs"][0] # (2, P, 3). keypoints = batch["keypoints"][0] # (P,) keypoint_masks = (keypoints[..., -1] > 0.5).all(dim=0) src_keypoints, target_keypoints = keypoints[:, keypoint_masks, :2] W, H = img_wh = imgs.shape[-2:0:-1] rendered = self.model.render( ts[0].item(), w2cs[:1], Ks[:1], img_wh, target_ts=ts[1:], target_w2cs=w2cs[1:], return_depth=True, ) pred_tracks_3d = rendered["tracks_3d"][0, ..., 0, :] pred_tracks_2d = torch.einsum("ij,hwj->hwi", Ks[1], pred_tracks_3d) pred_tracks_2d = pred_tracks_2d[..., :2] / torch.clamp( pred_tracks_2d[..., -1:], min=1e-6 ) pred_keypoints = F.grid_sample( pred_tracks_2d[None].permute(0, 3, 1, 2), normalize_coords(src_keypoints, H, W)[None, None], align_corners=True, ).permute(0, 2, 3, 1)[0, 0] # Compute metrics. self.pck_metric.update(pred_keypoints, target_keypoints, max(img_wh) * 0.05) padded_keypoints_3d = torch.zeros_like(keypoints[0]) pred_keypoints_3d = F.grid_sample( pred_tracks_3d[None].permute(0, 3, 1, 2), normalize_coords(src_keypoints, H, W)[None, None], align_corners=True, ).permute(0, 2, 3, 1)[0, 0] # Transform 3D keypoints back to world space. pred_keypoints_3d = torch.einsum( "ij,pj->pi", torch.linalg.inv(w2cs[1])[:3], F.pad(pred_keypoints_3d, (0, 1), value=1.0), ) padded_keypoints_3d[keypoint_masks] = pred_keypoints_3d # Cache predicted keypoints. pred_keypoints_3d_all.append(padded_keypoints_3d.cpu().numpy()) pred_train_depths[time_ids.index(ts[0].item())] = ( rendered["depth"][0, ..., 0].cpu().numpy() ) # Dump unified results. all_Ks = self.val_kpt_loader.dataset.dataset.Ks all_w2cs = self.val_kpt_loader.dataset.dataset.w2cs keypoint_result_dict = { "Ks": all_Ks[time_ids].cpu().numpy(), "w2cs": all_w2cs[time_ids].cpu().numpy(), "pred_keypoints_3d": np.stack(pred_keypoints_3d_all, 0), "pred_train_depths": pred_train_depths, } results_dir = osp.join(self.save_dir, "results") os.makedirs(results_dir, exist_ok=True) np.savez( osp.join(results_dir, "keypoints.npz"), **keypoint_result_dict, ) guru.info( f"Dumped keypoint results to {results_dir=} {keypoint_result_dict['pred_keypoints_3d'].shape=}" ) return {"val/pck": self.pck_metric.compute()} @torch.no_grad() def save_train_videos(self, epoch: int): if self.train_loader is None: return video_dir = osp.join(self.save_dir, "videos", f"epoch_{epoch:04d}") os.makedirs(video_dir, exist_ok=True) fps = getattr(self.train_loader.dataset.dataset, "fps", 15.0) # Render video. video = [] ref_pred_depths = [] masks = [] depth_min, depth_max = 1e6, 0 for batch_idx, batch in enumerate( tqdm(self.train_loader, desc="Rendering video", leave=False) ): batch = { k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() } # (). t = batch["ts"][0] # (4, 4). w2c = batch["w2cs"][0] # (3, 3). K = batch["Ks"][0] # (H, W, 3). img = batch["imgs"][0] # (H, W). depth = batch["depths"][0] img_wh = img.shape[-2::-1] rendered = self.model.render( t, w2c[None], K[None], img_wh, return_depth=True, return_mask=True ) # Putting results onto CPU since it will consume unnecessarily # large GPU memory for long sequence OW. video.append(torch.cat([img, rendered["img"][0]], dim=1).cpu()) ref_pred_depth = torch.cat( (depth[..., None], rendered["depth"][0]), dim=1 ).cpu() ref_pred_depths.append(ref_pred_depth) depth_min = min(depth_min, ref_pred_depth.min().item()) depth_max = max(depth_max, ref_pred_depth.quantile(0.99).item()) if rendered["mask"] is not None: masks.append(rendered["mask"][0].cpu().squeeze(-1)) # rgb video video = torch.stack(video, dim=0) iio.mimwrite( osp.join(video_dir, "rgbs.mp4"), make_video_divisble((video.numpy() * 255).astype(np.uint8)), fps=fps, ) # depth video depth_video = torch.stack( [ apply_depth_colormap( ref_pred_depth, near_plane=depth_min, far_plane=depth_max ) for ref_pred_depth in ref_pred_depths ], dim=0, ) iio.mimwrite( osp.join(video_dir, "depths.mp4"), make_video_divisble((depth_video.numpy() * 255).astype(np.uint8)), fps=fps, ) if len(masks) > 0: # mask video mask_video = torch.stack(masks, dim=0) iio.mimwrite( osp.join(video_dir, "masks.mp4"), make_video_divisble((mask_video.numpy() * 255).astype(np.uint8)), fps=fps, ) # Render 2D track video. tracks_2d, target_imgs = [], [] sample_interval = 10 batch0 = { k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in self.train_loader.dataset[0].items() } # (). t = batch0["ts"] # (4, 4). w2c = batch0["w2cs"] # (3, 3). K = batch0["Ks"] # (H, W, 3). img = batch0["imgs"] # (H, W). bool_mask = batch0["masks"] > 0.5 img_wh = img.shape[-2::-1] for batch in tqdm( self.train_loader, desc="Rendering 2D track video", leave=False ): batch = { k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() } # Putting results onto CPU since it will consume unnecessarily # large GPU memory for long sequence OW. # (1, H, W, 3). target_imgs.append(batch["imgs"].cpu()) # (1,). target_ts = batch["ts"] # (1, 4, 4). target_w2cs = batch["w2cs"] # (1, 3, 3). target_Ks = batch["Ks"] rendered = self.model.render( t, w2c[None], K[None], img_wh, target_ts=target_ts, target_w2cs=target_w2cs, ) pred_tracks_3d = rendered["tracks_3d"][0][ ::sample_interval, ::sample_interval ][bool_mask[::sample_interval, ::sample_interval]].swapaxes(0, 1) pred_tracks_2d = torch.einsum("bij,bpj->bpi", target_Ks, pred_tracks_3d) pred_tracks_2d = pred_tracks_2d[..., :2] / torch.clamp( pred_tracks_2d[..., 2:], min=1e-6 ) tracks_2d.append(pred_tracks_2d.cpu()) tracks_2d = torch.cat(tracks_2d, dim=0) target_imgs = torch.cat(target_imgs, dim=0) track_2d_video = plot_correspondences( target_imgs.numpy(), tracks_2d.numpy(), query_id=cast(int, t), ) iio.mimwrite( osp.join(video_dir, "tracks_2d.mp4"), make_video_divisble(np.stack(track_2d_video, 0)), fps=fps, ) # Render motion coefficient video. with torch.random.fork_rng(): torch.random.manual_seed(0) motion_coef_colors = torch.pca_lowrank( self.model.fg.get_coefs()[None], q=3, )[0][0] motion_coef_colors = (motion_coef_colors - motion_coef_colors.min(0)[0]) / ( motion_coef_colors.max(0)[0] - motion_coef_colors.min(0)[0] ) motion_coef_colors = F.pad( motion_coef_colors, (0, 0, 0, self.model.bg.num_gaussians), value=0.5 ) video = [] for batch in tqdm( self.train_loader, desc="Rendering motion coefficient video", leave=False ): batch = { k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() } # (). t = batch["ts"][0] # (4, 4). w2c = batch["w2cs"][0] # (3, 3). K = batch["Ks"][0] # (3, 3). img = batch["imgs"][0] img_wh = img.shape[-2::-1] rendered = self.model.render( t, w2c[None], K[None], img_wh, colors_override=motion_coef_colors ) # Putting results onto CPU since it will consume unnecessarily # large GPU memory for long sequence OW. video.append(torch.cat([img, rendered["img"][0]], dim=1).cpu()) video = torch.stack(video, dim=0) iio.mimwrite( osp.join(video_dir, "motion_coefs.mp4"), make_video_divisble((video.numpy() * 255).astype(np.uint8)), fps=fps, ) ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/vis/__init__.py ================================================ ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/vis/playback_panel.py ================================================ import threading import time import viser def add_gui_playback_group( server: viser.ViserServer, num_frames: int, min_fps: float = 1.0, max_fps: float = 60.0, fps_step: float = 0.1, initial_fps: float = 10.0, ): gui_timestep = server.gui.add_slider( "Timestep", min=0, max=num_frames - 1, step=1, initial_value=0, disabled=True, ) gui_next_frame = server.gui.add_button("Next Frame") gui_prev_frame = server.gui.add_button("Prev Frame") gui_playing_pause = server.gui.add_button("Pause") gui_playing_pause.visible = False gui_playing_resume = server.gui.add_button("Resume") gui_framerate = server.gui.add_slider( "FPS", min=min_fps, max=max_fps, step=fps_step, initial_value=initial_fps ) # Frame step buttons. @gui_next_frame.on_click def _(_) -> None: gui_timestep.value = (gui_timestep.value + 1) % num_frames @gui_prev_frame.on_click def _(_) -> None: gui_timestep.value = (gui_timestep.value - 1) % num_frames # Disable frame controls when we're playing. def _toggle_gui_playing(_): gui_playing_pause.visible = not gui_playing_pause.visible gui_playing_resume.visible = not gui_playing_resume.visible gui_timestep.disabled = gui_playing_pause.visible gui_next_frame.disabled = gui_playing_pause.visible gui_prev_frame.disabled = gui_playing_pause.visible gui_playing_pause.on_click(_toggle_gui_playing) gui_playing_resume.on_click(_toggle_gui_playing) # Create a thread to update the timestep indefinitely. def _update_timestep(): while True: if gui_playing_pause.visible: gui_timestep.value = (gui_timestep.value + 1) % num_frames time.sleep(1 / gui_framerate.value) threading.Thread(target=_update_timestep, daemon=True).start() return ( gui_timestep, gui_next_frame, gui_prev_frame, gui_playing_pause, gui_playing_resume, gui_framerate, ) ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/vis/render_panel.py ================================================ # Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import colorsys import dataclasses import datetime import json import threading import time from pathlib import Path from typing import Dict, List, Literal, Optional, Tuple import numpy as np import scipy import splines import splines.quaternion import viser import viser.transforms as tf VISER_SCALE_RATIO = 10.0 @dataclasses.dataclass class Keyframe: time: float position: np.ndarray wxyz: np.ndarray override_fov_enabled: bool override_fov_rad: float aspect: float override_transition_enabled: bool override_transition_sec: Optional[float] @staticmethod def from_camera(time: float, camera: viser.CameraHandle, aspect: float) -> Keyframe: return Keyframe( time, camera.position, camera.wxyz, override_fov_enabled=False, override_fov_rad=camera.fov, aspect=aspect, override_transition_enabled=False, override_transition_sec=None, ) class CameraPath: def __init__( self, server: viser.ViserServer, duration_element: viser.GuiInputHandle[float] ): self._server = server self._keyframes: Dict[int, Tuple[Keyframe, viser.CameraFrustumHandle]] = {} self._keyframe_counter: int = 0 self._spline_nodes: List[viser.SceneNodeHandle] = [] self._camera_edit_panel: Optional[viser.Gui3dContainerHandle] = None self._orientation_spline: Optional[splines.quaternion.KochanekBartels] = None self._position_spline: Optional[splines.KochanekBartels] = None self._fov_spline: Optional[splines.KochanekBartels] = None self._time_spline: Optional[splines.KochanekBartels] = None self._keyframes_visible: bool = True self._duration_element = duration_element # These parameters should be overridden externally. self.loop: bool = False self.framerate: float = 30.0 self.tension: float = 0.5 # Tension / alpha term. self.default_fov: float = 0.0 self.default_transition_sec: float = 0.0 self.show_spline: bool = True def set_keyframes_visible(self, visible: bool) -> None: self._keyframes_visible = visible for keyframe in self._keyframes.values(): keyframe[1].visible = visible def add_camera( self, keyframe: Keyframe, keyframe_index: Optional[int] = None ) -> None: """Add a new camera, or replace an old one if `keyframe_index` is passed in.""" server = self._server # Add a keyframe if we aren't replacing an existing one. if keyframe_index is None: keyframe_index = self._keyframe_counter self._keyframe_counter += 1 print( f"{keyframe.wxyz=} {keyframe.position=} {keyframe_index=} {keyframe.aspect=}" ) frustum_handle = server.scene.add_camera_frustum( f"/render_cameras/{keyframe_index}", fov=( keyframe.override_fov_rad if keyframe.override_fov_enabled else self.default_fov ), aspect=keyframe.aspect, scale=0.1, color=(200, 10, 30), wxyz=keyframe.wxyz, position=keyframe.position, visible=self._keyframes_visible, ) self._server.scene.add_icosphere( f"/render_cameras/{keyframe_index}/sphere", radius=0.03, color=(200, 10, 30), ) @frustum_handle.on_click def _(_) -> None: if self._camera_edit_panel is not None: self._camera_edit_panel.remove() self._camera_edit_panel = None with server.scene.add_3d_gui_container( "/camera_edit_panel", position=keyframe.position, ) as camera_edit_panel: self._camera_edit_panel = camera_edit_panel override_fov = server.gui.add_checkbox( "Override FOV", initial_value=keyframe.override_fov_enabled ) override_fov_degrees = server.gui.add_slider( "Override FOV (degrees)", 5.0, 175.0, step=0.1, initial_value=keyframe.override_fov_rad * 180.0 / np.pi, disabled=not keyframe.override_fov_enabled, ) delete_button = server.gui.add_button( "Delete", color="red", icon=viser.Icon.TRASH ) go_to_button = server.gui.add_button("Go to") close_button = server.gui.add_button("Close") @override_fov.on_update def _(_) -> None: keyframe.override_fov_enabled = override_fov.value override_fov_degrees.disabled = not override_fov.value self.add_camera(keyframe, keyframe_index) @override_fov_degrees.on_update def _(_) -> None: keyframe.override_fov_rad = override_fov_degrees.value / 180.0 * np.pi self.add_camera(keyframe, keyframe_index) @delete_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None with event.client.gui.add_modal("Confirm") as modal: event.client.gui.add_markdown("Delete keyframe?") confirm_button = event.client.gui.add_button( "Yes", color="red", icon=viser.Icon.TRASH ) exit_button = event.client.gui.add_button("Cancel") @confirm_button.on_click def _(_) -> None: assert camera_edit_panel is not None keyframe_id = None for i, keyframe_tuple in self._keyframes.items(): if keyframe_tuple[1] is frustum_handle: keyframe_id = i break assert keyframe_id is not None self._keyframes.pop(keyframe_id) frustum_handle.remove() camera_edit_panel.remove() self._camera_edit_panel = None modal.close() self.update_spline() @exit_button.on_click def _(_) -> None: modal.close() @go_to_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None client = event.client T_world_current = tf.SE3.from_rotation_and_translation( tf.SO3(client.camera.wxyz), client.camera.position ) T_world_target = tf.SE3.from_rotation_and_translation( tf.SO3(keyframe.wxyz), keyframe.position ) @ tf.SE3.from_translation(np.array([0.0, 0.0, -0.5])) T_current_target = T_world_current.inverse() @ T_world_target for j in range(10): T_world_set = T_world_current @ tf.SE3.exp( T_current_target.log() * j / 9.0 ) # Important bit: we atomically set both the orientation and the position # of the camera. with client.atomic(): client.camera.wxyz = T_world_set.rotation().wxyz client.camera.position = T_world_set.translation() time.sleep(1.0 / 30.0) @close_button.on_click def _(_) -> None: assert camera_edit_panel is not None camera_edit_panel.remove() self._camera_edit_panel = None self._keyframes[keyframe_index] = (keyframe, frustum_handle) def update_aspect(self, aspect: float) -> None: for keyframe_index, frame in self._keyframes.items(): frame = dataclasses.replace(frame[0], aspect=aspect) self.add_camera(frame, keyframe_index=keyframe_index) def get_aspect(self) -> float: """Get W/H aspect ratio, which is shared across all keyframes.""" assert len(self._keyframes) > 0 return next(iter(self._keyframes.values()))[0].aspect def reset(self) -> None: for frame in self._keyframes.values(): print(f"removing {frame[1]}") frame[1].remove() self._keyframes.clear() self.update_spline() print("camera path reset") def spline_t_from_t_sec(self, time: np.ndarray) -> np.ndarray: """From a time value in seconds, compute a t value for our geometric spline interpolation. An increment of 1 for the latter will move the camera forward by one keyframe. We use a PCHIP spline here to guarantee monotonicity. """ transition_times_cumsum = self.compute_transition_times_cumsum() spline_indices = np.arange(transition_times_cumsum.shape[0]) if self.loop: # In the case of a loop, we pad the spline to match the start/end # slopes. interpolator = scipy.interpolate.PchipInterpolator( x=np.concatenate( [ [-(transition_times_cumsum[-1] - transition_times_cumsum[-2])], transition_times_cumsum, transition_times_cumsum[-1:] + transition_times_cumsum[1:2], ], axis=0, ), y=np.concatenate( [[-1], spline_indices, [spline_indices[-1] + 1]], axis=0 ), ) else: interpolator = scipy.interpolate.PchipInterpolator( x=transition_times_cumsum, y=spline_indices ) # Clip to account for floating point error. return np.clip(interpolator(time), 0, spline_indices[-1]) def interpolate_pose_and_fov_rad( self, normalized_t: float ) -> Optional[Tuple[tf.SE3, float, float]]: if len(self._keyframes) < 2: return None self._time_spline = splines.KochanekBartels( [keyframe[0].time for keyframe in self._keyframes.values()], tcb=(self.tension, 0.0, 0.0), endconditions="closed" if self.loop else "natural", ) self._fov_spline = splines.KochanekBartels( [ ( keyframe[0].override_fov_rad if keyframe[0].override_fov_enabled else self.default_fov ) for keyframe in self._keyframes.values() ], tcb=(self.tension, 0.0, 0.0), endconditions="closed" if self.loop else "natural", ) assert self._orientation_spline is not None assert self._position_spline is not None assert self._fov_spline is not None assert self._time_spline is not None max_t = self.compute_duration() t = max_t * normalized_t spline_t = float(self.spline_t_from_t_sec(np.array(t))) quat = self._orientation_spline.evaluate(spline_t) assert isinstance(quat, splines.quaternion.UnitQuaternion) return ( tf.SE3.from_rotation_and_translation( tf.SO3(np.array([quat.scalar, *quat.vector])), self._position_spline.evaluate(spline_t), ), float(self._fov_spline.evaluate(spline_t)), float(self._time_spline.evaluate(spline_t)), ) def update_spline(self) -> None: num_frames = int(self.compute_duration() * self.framerate) keyframes = list(self._keyframes.values()) if num_frames <= 0 or not self.show_spline or len(keyframes) < 2: for node in self._spline_nodes: node.remove() self._spline_nodes.clear() return transition_times_cumsum = self.compute_transition_times_cumsum() self._orientation_spline = splines.quaternion.KochanekBartels( [ splines.quaternion.UnitQuaternion.from_unit_xyzw( np.roll(keyframe[0].wxyz, shift=-1) ) for keyframe in keyframes ], tcb=(self.tension, 0.0, 0.0), endconditions="closed" if self.loop else "natural", ) self._position_spline = splines.KochanekBartels( [keyframe[0].position for keyframe in keyframes], tcb=(self.tension, 0.0, 0.0), endconditions="closed" if self.loop else "natural", ) # Update visualized spline. points_array = self._position_spline.evaluate( self.spline_t_from_t_sec( np.linspace(0, transition_times_cumsum[-1], num_frames) ) ) colors_array = np.array( [ colorsys.hls_to_rgb(h, 0.5, 1.0) for h in np.linspace(0.0, 1.0, len(points_array)) ] ) # Clear prior spline nodes. for node in self._spline_nodes: node.remove() self._spline_nodes.clear() self._spline_nodes.append( self._server.scene.add_spline_catmull_rom( "/render_camera_spline", positions=points_array, color=(220, 220, 220), closed=self.loop, line_width=1.0, segments=points_array.shape[0] + 1, ) ) self._spline_nodes.append( self._server.scene.add_point_cloud( "/render_camera_spline/points", points=points_array, colors=colors_array, point_size=0.04, ) ) def make_transition_handle(i: int) -> None: assert self._position_spline is not None transition_pos = self._position_spline.evaluate( float( self.spline_t_from_t_sec( (transition_times_cumsum[i] + transition_times_cumsum[i + 1]) / 2.0, ) ) ) transition_sphere = self._server.scene.add_icosphere( f"/render_camera_spline/transition_{i}", radius=0.04, color=(255, 0, 0), position=transition_pos, ) self._spline_nodes.append(transition_sphere) @transition_sphere.on_click def _(_) -> None: server = self._server if self._camera_edit_panel is not None: self._camera_edit_panel.remove() self._camera_edit_panel = None keyframe_index = (i + 1) % len(self._keyframes) keyframe = keyframes[keyframe_index][0] with server.scene.add_3d_gui_container( "/camera_edit_panel", position=transition_pos, ) as camera_edit_panel: self._camera_edit_panel = camera_edit_panel override_transition_enabled = server.gui.add_checkbox( "Override transition", initial_value=keyframe.override_transition_enabled, ) override_transition_sec = server.gui.add_number( "Override transition (sec)", initial_value=( keyframe.override_transition_sec if keyframe.override_transition_sec is not None else self.default_transition_sec ), min=0.001, max=30.0, step=0.001, disabled=not override_transition_enabled.value, ) close_button = server.gui.add_button("Close") @override_transition_enabled.on_update def _(_) -> None: keyframe.override_transition_enabled = ( override_transition_enabled.value ) override_transition_sec.disabled = ( not override_transition_enabled.value ) self._duration_element.value = self.compute_duration() @override_transition_sec.on_update def _(_) -> None: keyframe.override_transition_sec = override_transition_sec.value self._duration_element.value = self.compute_duration() @close_button.on_click def _(_) -> None: assert camera_edit_panel is not None camera_edit_panel.remove() self._camera_edit_panel = None (num_transitions_plus_1,) = transition_times_cumsum.shape for i in range(num_transitions_plus_1 - 1): make_transition_handle(i) # for i in range(transition_times.shape[0]) def compute_duration(self) -> float: """Compute the total duration of the trajectory.""" total = 0.0 for i, (keyframe, frustum) in enumerate(self._keyframes.values()): if i == 0 and not self.loop: continue del frustum total += ( keyframe.override_transition_sec if keyframe.override_transition_enabled and keyframe.override_transition_sec is not None else self.default_transition_sec ) return total def compute_transition_times_cumsum(self) -> np.ndarray: """Compute the total duration of the trajectory.""" total = 0.0 out = [0.0] for i, (keyframe, frustum) in enumerate(self._keyframes.values()): if i == 0: continue del frustum total += ( keyframe.override_transition_sec if keyframe.override_transition_enabled and keyframe.override_transition_sec is not None else self.default_transition_sec ) out.append(total) if self.loop: keyframe = next(iter(self._keyframes.values()))[0] total += ( keyframe.override_transition_sec if keyframe.override_transition_enabled and keyframe.override_transition_sec is not None else self.default_transition_sec ) out.append(total) return np.array(out) @dataclasses.dataclass class RenderTabState: """Useful GUI handles exposed by the render tab.""" preview_render: bool preview_fov: float preview_aspect: float preview_camera_type: Literal["Perspective", "Fisheye", "Equirectangular"] def populate_render_tab( server: viser.ViserServer, datapath: Path, gui_timestep_handle: viser.GuiInputHandle[int] | None, ) -> RenderTabState: render_tab_state = RenderTabState( preview_render=False, preview_fov=0.0, preview_aspect=1.0, preview_camera_type="Perspective", ) fov_degrees = server.gui.add_slider( "Default FOV", initial_value=75.0, min=0.1, max=175.0, step=0.01, hint="Field-of-view for rendering, which can also be overridden on a per-keyframe basis.", ) @fov_degrees.on_update def _(_) -> None: fov_radians = fov_degrees.value / 180.0 * np.pi for client in server.get_clients().values(): client.camera.fov = fov_radians camera_path.default_fov = fov_radians # Updating the aspect ratio will also re-render the camera frustums. # Could rethink this. camera_path.update_aspect(resolution.value[0] / resolution.value[1]) compute_and_update_preview_camera_state() resolution = server.gui.add_vector2( "Resolution", initial_value=(1920, 1080), min=(50, 50), max=(10_000, 10_000), step=1, hint="Render output resolution in pixels.", ) @resolution.on_update def _(_) -> None: camera_path.update_aspect(resolution.value[0] / resolution.value[1]) compute_and_update_preview_camera_state() camera_type = server.gui.add_dropdown( "Camera type", ("Perspective", "Fisheye", "Equirectangular"), initial_value="Perspective", hint="Camera model to render with. This is applied to all keyframes.", ) add_button = server.gui.add_button( "Add Keyframe", icon=viser.Icon.PLUS, hint="Add a new keyframe at the current pose.", ) @add_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client_id is not None camera = server.get_clients()[event.client_id].camera pose = tf.SE3.from_rotation_and_translation( tf.SO3(camera.wxyz), camera.position ) print(f"client {event.client_id} at {camera.position} {camera.wxyz}") print(f"camera pose {pose.as_matrix()}") if gui_timestep_handle is not None: print(f"timestep {gui_timestep_handle.value}") # Add this camera to the path. time = 0 if gui_timestep_handle is not None: time = gui_timestep_handle.value camera_path.add_camera( Keyframe.from_camera( time, camera, aspect=resolution.value[0] / resolution.value[1], ), ) duration_number.value = camera_path.compute_duration() camera_path.update_spline() clear_keyframes_button = server.gui.add_button( "Clear Keyframes", icon=viser.Icon.TRASH, hint="Remove all keyframes from the render path.", ) @clear_keyframes_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client_id is not None client = server.get_clients()[event.client_id] with client.atomic(), client.gui.add_modal("Confirm") as modal: client.gui.add_markdown("Clear all keyframes?") confirm_button = client.gui.add_button( "Yes", color="red", icon=viser.Icon.TRASH ) exit_button = client.gui.add_button("Cancel") @confirm_button.on_click def _(_) -> None: camera_path.reset() modal.close() duration_number.value = camera_path.compute_duration() # Clear move handles. if len(transform_controls) > 0: for t in transform_controls: t.remove() transform_controls.clear() return @exit_button.on_click def _(_) -> None: modal.close() loop = server.gui.add_checkbox( "Loop", False, hint="Add a segment between the first and last keyframes." ) @loop.on_update def _(_) -> None: camera_path.loop = loop.value duration_number.value = camera_path.compute_duration() tension_slider = server.gui.add_slider( "Spline tension", min=0.0, max=1.0, initial_value=0.0, step=0.01, hint="Tension parameter for adjusting smoothness of spline interpolation.", ) @tension_slider.on_update def _(_) -> None: camera_path.tension = tension_slider.value camera_path.update_spline() move_checkbox = server.gui.add_checkbox( "Move keyframes", initial_value=False, hint="Toggle move handles for keyframes in the scene.", ) transform_controls: List[viser.SceneNodeHandle] = [] @move_checkbox.on_update def _(event: viser.GuiEvent) -> None: # Clear move handles when toggled off. if move_checkbox.value is False: for t in transform_controls: t.remove() transform_controls.clear() return def _make_transform_controls_callback( keyframe: Tuple[Keyframe, viser.SceneNodeHandle], controls: viser.TransformControlsHandle, ) -> None: @controls.on_update def _(_) -> None: keyframe[0].wxyz = controls.wxyz keyframe[0].position = controls.position keyframe[1].wxyz = controls.wxyz keyframe[1].position = controls.position camera_path.update_spline() # Show move handles. assert event.client is not None for keyframe_index, keyframe in camera_path._keyframes.items(): controls = event.client.scene.add_transform_controls( f"/keyframe_move/{keyframe_index}", scale=0.4, wxyz=keyframe[0].wxyz, position=keyframe[0].position, ) transform_controls.append(controls) _make_transform_controls_callback(keyframe, controls) show_keyframe_checkbox = server.gui.add_checkbox( "Show keyframes", initial_value=True, hint="Show keyframes in the scene.", ) @show_keyframe_checkbox.on_update def _(_: viser.GuiEvent) -> None: camera_path.set_keyframes_visible(show_keyframe_checkbox.value) show_spline_checkbox = server.gui.add_checkbox( "Show spline", initial_value=True, hint="Show camera path spline in the scene.", ) @show_spline_checkbox.on_update def _(_) -> None: camera_path.show_spline = show_spline_checkbox.value camera_path.update_spline() playback_folder = server.gui.add_folder("Playback") with playback_folder: play_button = server.gui.add_button("Play", icon=viser.Icon.PLAYER_PLAY) pause_button = server.gui.add_button( "Pause", icon=viser.Icon.PLAYER_PAUSE, visible=False ) preview_render_button = server.gui.add_button( "Preview Render", hint="Show a preview of the render in the viewport." ) preview_render_stop_button = server.gui.add_button( "Exit Render Preview", color="red", visible=False ) transition_sec_number = server.gui.add_number( "Transition (sec)", min=0.001, max=30.0, step=0.001, initial_value=2.0, hint="Time in seconds between each keyframe, which can also be overridden on a per-transition basis.", ) framerate_number = server.gui.add_number( "FPS", min=0.1, max=240.0, step=1e-2, initial_value=30.0 ) framerate_buttons = server.gui.add_button_group("", ("24", "30", "60")) duration_number = server.gui.add_number( "Duration (sec)", min=0.0, max=1e8, step=0.001, initial_value=0.0, disabled=True, ) @framerate_buttons.on_click def _(_) -> None: framerate_number.value = float(framerate_buttons.value) @transition_sec_number.on_update def _(_) -> None: camera_path.default_transition_sec = transition_sec_number.value duration_number.value = camera_path.compute_duration() def get_max_frame_index() -> int: return max(1, int(framerate_number.value * duration_number.value) - 1) preview_camera_handle: Optional[viser.SceneNodeHandle] = None def remove_preview_camera() -> None: nonlocal preview_camera_handle if preview_camera_handle is not None: preview_camera_handle.remove() preview_camera_handle = None def compute_and_update_preview_camera_state() -> ( Optional[Tuple[tf.SE3, float, float]] ): """Update the render tab state with the current preview camera pose. Returns current camera pose + FOV if available.""" if preview_frame_slider is None: return maybe_pose_and_fov_rad_and_time = camera_path.interpolate_pose_and_fov_rad( preview_frame_slider.value / get_max_frame_index() ) if maybe_pose_and_fov_rad_and_time is None: remove_preview_camera() return pose, fov_rad, time = maybe_pose_and_fov_rad_and_time render_tab_state.preview_fov = fov_rad render_tab_state.preview_aspect = camera_path.get_aspect() render_tab_state.preview_camera_type = camera_type.value if gui_timestep_handle is not None: gui_timestep_handle.value = int(time) return pose, fov_rad, time def add_preview_frame_slider() -> Optional[viser.GuiInputHandle[int]]: """Helper for creating the current frame # slider. This is removed and re-added anytime the `max` value changes.""" with playback_folder: preview_frame_slider = server.gui.add_slider( "Preview frame", min=0, max=get_max_frame_index(), step=1, initial_value=0, # Place right after the pause button. order=preview_render_stop_button.order + 0.01, disabled=get_max_frame_index() == 1, ) play_button.disabled = preview_frame_slider.disabled preview_render_button.disabled = preview_frame_slider.disabled @preview_frame_slider.on_update def _(_) -> None: nonlocal preview_camera_handle maybe_pose_and_fov_rad_and_time = compute_and_update_preview_camera_state() if maybe_pose_and_fov_rad_and_time is None: return pose, fov_rad, time = maybe_pose_and_fov_rad_and_time preview_camera_handle = server.scene.add_camera_frustum( "/preview_camera", fov=fov_rad, aspect=resolution.value[0] / resolution.value[1], scale=0.35, wxyz=pose.rotation().wxyz, position=pose.translation(), color=(10, 200, 30), ) if render_tab_state.preview_render: for client in server.get_clients().values(): client.camera.wxyz = pose.rotation().wxyz client.camera.position = pose.translation() if gui_timestep_handle is not None: gui_timestep_handle.value = int(time) return preview_frame_slider # We back up the camera poses before and after we start previewing renders. camera_pose_backup_from_id: Dict[int, tuple] = {} @preview_render_button.on_click def _(_) -> None: render_tab_state.preview_render = True preview_render_button.visible = False preview_render_stop_button.visible = True maybe_pose_and_fov_rad_and_time = compute_and_update_preview_camera_state() if maybe_pose_and_fov_rad_and_time is None: remove_preview_camera() return pose, fov, time = maybe_pose_and_fov_rad_and_time del fov # Hide all scene nodes when we're previewing the render. server.scene.set_global_visibility(True) # Back up and then set camera poses. for client in server.get_clients().values(): camera_pose_backup_from_id[client.client_id] = ( client.camera.position, client.camera.look_at, client.camera.up_direction, ) client.camera.wxyz = pose.rotation().wxyz client.camera.position = pose.translation() if gui_timestep_handle is not None: gui_timestep_handle.value = int(time) @preview_render_stop_button.on_click def _(_) -> None: render_tab_state.preview_render = False preview_render_button.visible = True preview_render_stop_button.visible = False # Revert camera poses. for client in server.get_clients().values(): if client.client_id not in camera_pose_backup_from_id: continue cam_position, cam_look_at, cam_up = camera_pose_backup_from_id.pop( client.client_id ) client.camera.position = cam_position client.camera.look_at = cam_look_at client.camera.up_direction = cam_up client.flush() # Un-hide scene nodes. server.scene.set_global_visibility(True) preview_frame_slider = add_preview_frame_slider() # Update the # of frames. @duration_number.on_update @framerate_number.on_update def _(_) -> None: remove_preview_camera() # Will be re-added when slider is updated. nonlocal preview_frame_slider old = preview_frame_slider assert old is not None preview_frame_slider = add_preview_frame_slider() if preview_frame_slider is not None: old.remove() else: preview_frame_slider = old camera_path.framerate = framerate_number.value camera_path.update_spline() # Play the camera trajectory when the play button is pressed. @play_button.on_click def _(_) -> None: play_button.visible = False pause_button.visible = True def play() -> None: while not play_button.visible: max_frame = int(framerate_number.value * duration_number.value) if max_frame > 0: assert preview_frame_slider is not None preview_frame_slider.value = ( preview_frame_slider.value + 1 ) % max_frame time.sleep(1.0 / framerate_number.value) threading.Thread(target=play).start() # Play the camera trajectory when the play button is pressed. @pause_button.on_click def _(_) -> None: play_button.visible = True pause_button.visible = False # add button for loading existing path load_camera_path_button = server.gui.add_button( "Load Path", icon=viser.Icon.FOLDER_OPEN, hint="Load an existing camera path." ) @load_camera_path_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None camera_path_dir = datapath.parent camera_path_dir.mkdir(parents=True, exist_ok=True) preexisting_camera_paths = list(camera_path_dir.glob("*.json")) preexisting_camera_filenames = [p.name for p in preexisting_camera_paths] with event.client.gui.add_modal("Load Path") as modal: if len(preexisting_camera_filenames) == 0: event.client.gui.add_markdown("No existing paths found") else: event.client.gui.add_markdown("Select existing camera path:") camera_path_dropdown = event.client.gui.add_dropdown( label="Camera Path", options=[str(p) for p in preexisting_camera_filenames], initial_value=str(preexisting_camera_filenames[0]), ) load_button = event.client.gui.add_button("Load") @load_button.on_click def _(_) -> None: # load the json file json_path = datapath / camera_path_dropdown.value with open(json_path, "r") as f: json_data = json.load(f) keyframes = json_data["keyframes"] camera_path.reset() for i in range(len(keyframes)): frame = keyframes[i] pose = tf.SE3.from_matrix( np.array(frame["matrix"]).reshape(4, 4) ) # apply the x rotation by 180 deg pose = tf.SE3.from_rotation_and_translation( pose.rotation() @ tf.SO3.from_x_radians(np.pi), pose.translation(), ) camera_path.add_camera( Keyframe( frame["time"], position=pose.translation(), wxyz=pose.rotation().wxyz, # There are some floating point conversions between degrees and radians, so the fov and # default_Fov values will not be exactly matched. override_fov_enabled=abs( frame["fov"] - json_data.get("default_fov", 0.0) ) > 1e-3, override_fov_rad=frame["fov"] / 180.0 * np.pi, aspect=frame["aspect"], override_transition_enabled=frame.get( "override_transition_enabled", None ), override_transition_sec=frame.get( "override_transition_sec", None ), ) ) transition_sec_number.value = json_data.get( "default_transition_sec", 0.5 ) # update the render name camera_path_name.value = json_path.stem camera_path.update_spline() modal.close() cancel_button = event.client.gui.add_button("Cancel") @cancel_button.on_click def _(_) -> None: modal.close() # set the initial value to the current date-time string now = datetime.datetime.now() camera_path_name = server.gui.add_text( "Camera path name", initial_value=now.strftime("%Y-%m-%d %H:%M:%S"), hint="Name of the render", ) save_path_button = server.gui.add_button( "Save Camera Path", color="green", icon=viser.Icon.FILE_EXPORT, hint="Save the camera path to json.", ) reset_up_button = server.gui.add_button( "Reset Up Direction", icon=viser.Icon.ARROW_BIG_UP_LINES, color="gray", hint="Set the up direction of the camera orbit controls to the camera's current up direction.", ) @reset_up_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None event.client.camera.up_direction = tf.SO3(event.client.camera.wxyz) @ np.array( [0.0, -1.0, 0.0] ) @save_path_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None num_frames = int(framerate_number.value * duration_number.value) json_data = {} # json data has the properties: # keyframes: list of keyframes with # matrix : flattened 4x4 matrix # fov: float in degrees # aspect: float # camera_type: string of camera type # render_height: int # render_width: int # fps: int # seconds: float # is_cycle: bool # smoothness_value: float # camera_path: list of frames with properties # camera_to_world: flattened 4x4 matrix # fov: float in degrees # aspect: float # first populate the keyframes: keyframes = [] for keyframe, dummy in camera_path._keyframes.values(): pose = tf.SE3.from_rotation_and_translation( tf.SO3(keyframe.wxyz), keyframe.position ) keyframes.append( { "matrix": pose.as_matrix().flatten().tolist(), "fov": ( np.rad2deg(keyframe.override_fov_rad) if keyframe.override_fov_enabled else fov_degrees.value ), "aspect": keyframe.aspect, "override_transition_enabled": keyframe.override_transition_enabled, "override_transition_sec": keyframe.override_transition_sec, } ) json_data["default_fov"] = fov_degrees.value json_data["default_transition_sec"] = transition_sec_number.value json_data["keyframes"] = keyframes json_data["camera_type"] = camera_type.value.lower() json_data["render_height"] = resolution.value[1] json_data["render_width"] = resolution.value[0] json_data["fps"] = framerate_number.value json_data["seconds"] = duration_number.value json_data["is_cycle"] = loop.value json_data["smoothness_value"] = tension_slider.value def get_intrinsics(W, H, fov): focal = 0.5 * H / np.tan(0.5 * fov) return np.array( [[focal, 0.0, 0.5 * W], [0.0, focal, 0.5 * H], [0.0, 0.0, 1.0]] ) # now populate the camera path: camera_path_list = [] for i in range(num_frames): maybe_pose_and_fov_and_time = camera_path.interpolate_pose_and_fov_rad( i / num_frames ) if maybe_pose_and_fov_and_time is None: return pose, fov, time = maybe_pose_and_fov_and_time H = resolution.value[1] W = resolution.value[0] K = get_intrinsics(W, H, fov) # rotate the axis of the camera 180 about x axis w2c = pose.inverse().as_matrix() camera_path_list.append( { "time": time, "w2c": w2c.flatten().tolist(), "K": K.flatten().tolist(), "img_wh": (W, H), } ) json_data["camera_path"] = camera_path_list # now write the json file out_name = camera_path_name.value json_outfile = datapath / f"{out_name}.json" datapath.mkdir(parents=True, exist_ok=True) print(f"writing to {json_outfile}") with open(json_outfile.absolute(), "w") as outfile: json.dump(json_data, outfile) camera_path = CameraPath(server, duration_number) camera_path.default_fov = fov_degrees.value / 180.0 * np.pi camera_path.default_transition_sec = transition_sec_number.value return render_tab_state if __name__ == "__main__": populate_render_tab( server=viser.ViserServer(), datapath=Path("."), gui_timestep_handle=None, ) while True: time.sleep(10.0) ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/vis/utils.py ================================================ import colorsys from typing import cast import cv2 import numpy as np import nvdiffrast.torch as dr import torch import torch.nn.functional as F from matplotlib import colormaps from viser import ViserServer class Singleton(type): _instances = {} def __call__(cls, *args, **kwargs): if cls not in cls._instances: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) return cls._instances[cls] class VisManager(metaclass=Singleton): _servers = {} def get_server(port: int | None = None) -> ViserServer: manager = VisManager() if port is None: avail_ports = list(manager._servers.keys()) port = avail_ports[0] if len(avail_ports) > 0 else 8890 if port not in manager._servers: manager._servers[port] = ViserServer(port=port, verbose=False) return manager._servers[port] def project_2d_tracks(tracks_3d_w, Ks, T_cw, return_depth=False): """ :param tracks_3d_w (torch.Tensor): (T, N, 3) :param Ks (torch.Tensor): (T, 3, 3) :param T_cw (torch.Tensor): (T, 4, 4) :returns tracks_2d (torch.Tensor): (T, N, 2) """ tracks_3d_c = torch.einsum( "tij,tnj->tni", T_cw, F.pad(tracks_3d_w, (0, 1), value=1) )[..., :3] tracks_3d_v = torch.einsum("tij,tnj->tni", Ks, tracks_3d_c) if return_depth: return ( tracks_3d_v[..., :2] / torch.clamp(tracks_3d_v[..., 2:], min=1e-5), tracks_3d_v[..., 2], ) return tracks_3d_v[..., :2] / torch.clamp(tracks_3d_v[..., 2:], min=1e-5) def draw_keypoints_video( imgs, kps, colors=None, occs=None, cmap: str = "gist_rainbow", radius: int = 3 ): """ :param imgs (np.ndarray): (T, H, W, 3) uint8 [0, 255] :param kps (np.ndarray): (N, T, 2) :param colors (np.ndarray): (N, 3) float [0, 1] :param occ (np.ndarray): (N, T) bool return out_frames (T, H, W, 3) """ if colors is None: label = np.linspace(0, 1, kps.shape[0]) colors = np.asarray(colormaps.get_cmap(cmap)(label))[..., :3] out_frames = [] for t in range(len(imgs)): occ = occs[:, t] if occs is not None else None vis = draw_keypoints_cv2(imgs[t], kps[:, t], colors, occ, radius=radius) out_frames.append(vis) return out_frames def draw_keypoints_cv2(img, kps, colors=None, occs=None, radius=3): """ :param img (H, W, 3) :param kps (N, 2) :param occs (N) :param colors (N, 3) from 0 to 1 """ out_img = img.copy() kps = kps.round().astype("int").tolist() if colors is not None: colors = (255 * colors).astype("int").tolist() for n in range(len(kps)): kp = kps[n] color = colors[n] if colors is not None else (255, 0, 0) thickness = -1 if occs is None or occs[n] == 0 else 1 out_img = cv2.circle(out_img, kp, radius, color, thickness, cv2.LINE_AA) return out_img def draw_tracks_2d( img: torch.Tensor, tracks_2d: torch.Tensor, track_point_size: int = 2, track_line_width: int = 1, cmap_name: str = "gist_rainbow", ): cmap = colormaps.get_cmap(cmap_name) # (H, W, 3). img_np = (img.cpu().numpy() * 255.0).astype(np.uint8) # (P, N, 2). tracks_2d_np = tracks_2d.cpu().numpy() num_tracks, num_frames = tracks_2d_np.shape[:2] canvas = img_np.copy() for i in range(num_frames - 1): alpha = max(1 - 0.9 * ((num_frames - 1 - i) / (num_frames * 0.99)), 0.1) img_curr = canvas.copy() for j in range(num_tracks): color = tuple(np.array(cmap(j / max(1, float(num_tracks - 1)))[:3]) * 255) color_alpha = 1 hsv = colorsys.rgb_to_hsv(color[0], color[1], color[2]) color = colorsys.hsv_to_rgb(hsv[0], hsv[1] * color_alpha, hsv[2]) pt1 = tracks_2d_np[j, i] pt2 = tracks_2d_np[j, i + 1] p1 = (int(round(pt1[0])), int(round(pt1[1]))) p2 = (int(round(pt2[0])), int(round(pt2[1]))) img_curr = cv2.line( img_curr, p1, p2, color, thickness=track_line_width, lineType=cv2.LINE_AA, ) canvas = cv2.addWeighted(img_curr, alpha, canvas, 1 - alpha, 0) for j in range(num_tracks): color = tuple(np.array(cmap(j / max(1, float(num_tracks - 1)))[:3]) * 255) pt = tracks_2d_np[j, -1] pt = (int(round(pt[0])), int(round(pt[1]))) canvas = cv2.circle( canvas, pt, track_point_size, color, thickness=-1, lineType=cv2.LINE_AA, ) return canvas def generate_line_verts_faces(starts, ends, line_width): """ Args: starts: (P, N, 2). ends: (P, N, 2). line_width: int. Returns: verts: (P * N * 4, 2). faces: (P * N * 2, 3). """ P, N, _ = starts.shape directions = F.normalize(ends - starts, dim=-1) deltas = ( torch.cat([-directions[..., 1:], directions[..., :1]], dim=-1) * line_width / 2.0 ) v0 = starts + deltas v1 = starts - deltas v2 = ends + deltas v3 = ends - deltas verts = torch.stack([v0, v1, v2, v3], dim=-2) verts = verts.reshape(-1, 2) faces = [] for p in range(P): for n in range(N): base_index = p * N * 4 + n * 4 # Two triangles per rectangle: (0, 1, 2) and (2, 1, 3) faces.append([base_index, base_index + 1, base_index + 2]) faces.append([base_index + 2, base_index + 1, base_index + 3]) faces = torch.as_tensor(faces, device=starts.device) return verts, faces def generate_point_verts_faces(points, point_size, num_segments=10): """ Args: points: (P, 2). point_size: int. num_segments: int. Returns: verts: (P * (num_segments + 1), 2). faces: (P * num_segments, 3). """ P, _ = points.shape angles = torch.linspace(0, 2 * torch.pi, num_segments + 1, device=points.device)[ ..., :-1 ] unit_circle = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1) scaled_circles = (point_size / 2.0) * unit_circle scaled_circles = scaled_circles[None].repeat(P, 1, 1) verts = points[:, None] + scaled_circles verts = torch.cat([verts, points[:, None]], dim=1) verts = verts.reshape(-1, 2) faces = F.pad( torch.as_tensor( [[i, (i + 1) % num_segments] for i in range(num_segments)], device=points.device, ), (0, 1), value=num_segments, ) faces = faces[None, :] + torch.arange(P, device=points.device)[:, None, None] * ( num_segments + 1 ) faces = faces.reshape(-1, 3) return verts, faces def pixel_to_verts_clip(pixels, img_wh, z: float | torch.Tensor = 0.0, w=1.0): verts_clip = pixels / pixels.new_tensor(img_wh) * 2.0 - 1.0 w = torch.full_like(verts_clip[..., :1], w) verts_clip = torch.cat([verts_clip, z * w, w], dim=-1) return verts_clip def draw_tracks_2d_th( img: torch.Tensor, tracks_2d: torch.Tensor, track_point_size: int = 5, track_point_segments: int = 16, track_line_width: int = 2, cmap_name: str = "gist_rainbow", ): cmap = colormaps.get_cmap(cmap_name) CTX = dr.RasterizeCudaContext() W, H = img.shape[1], img.shape[0] if W % 8 != 0 or H % 8 != 0: # Make sure img is divisible by 8. img = F.pad( img, ( 0, 0, 0, 8 - W % 8 if W % 8 != 0 else 0, 0, 8 - H % 8 if H % 8 != 0 else 0, ), value=0.0, ) num_tracks, num_frames = tracks_2d.shape[:2] track_colors = torch.tensor( [cmap(j / max(1, float(num_tracks - 1)))[:3] for j in range(num_tracks)], device=img.device, ).float() # Generate line verts. verts_l, faces_l = generate_line_verts_faces( tracks_2d[:, :-1], tracks_2d[:, 1:], track_line_width ) # Generate point verts. verts_p, faces_p = generate_point_verts_faces( tracks_2d[:, -1], track_point_size, track_point_segments ) verts = torch.cat([verts_l, verts_p], dim=0) faces = torch.cat([faces_l, faces_p + len(verts_l)], dim=0) vert_colors = torch.cat( [ ( track_colors[:, None] .repeat_interleave(4 * (num_frames - 1), dim=1) .reshape(-1, 3) ), ( track_colors[:, None] .repeat_interleave(track_point_segments + 1, dim=1) .reshape(-1, 3) ), ], dim=0, ) track_zs = torch.linspace(0.0, 1.0, num_tracks, device=img.device)[:, None] vert_zs = torch.cat( [ ( track_zs[:, None] .repeat_interleave(4 * (num_frames - 1), dim=1) .reshape(-1, 1) ), ( track_zs[:, None] .repeat_interleave(track_point_segments + 1, dim=1) .reshape(-1, 1) ), ], dim=0, ) track_alphas = torch.linspace( max(0.1, 1.0 - (num_frames - 1) * 0.1), 1.0, num_frames, device=img.device ) vert_alphas = torch.cat( [ ( track_alphas[None, :-1, None] .repeat_interleave(num_tracks, dim=0) .repeat_interleave(4, dim=-2) .reshape(-1, 1) ), ( track_alphas[None, -1:, None] .repeat_interleave(num_tracks, dim=0) .repeat_interleave(track_point_segments + 1, dim=-2) .reshape(-1, 1) ), ], dim=0, ) # Small trick to always render one track in front of the other. verts_clip = pixel_to_verts_clip(verts, (img.shape[1], img.shape[0]), vert_zs) faces_int32 = faces.to(torch.int32) rast, _ = cast( tuple, dr.rasterize(CTX, verts_clip[None], faces_int32, (img.shape[0], img.shape[1])), ) rgba = cast( torch.Tensor, dr.interpolate( torch.cat([vert_colors, vert_alphas], dim=-1).contiguous(), rast, faces_int32, ), )[0] rgba = cast(torch.Tensor, dr.antialias(rgba, rast, verts_clip, faces_int32))[ 0 ].clamp(0, 1) # Compose. color = rgba[..., :-1] * rgba[..., -1:] + (1.0 - rgba[..., -1:]) * img # Unpad. color = color[:H, :W] return (color.cpu().numpy() * 255.0).astype(np.uint8) def make_video_divisble( video: torch.Tensor | np.ndarray, block_size=16 ) -> torch.Tensor | np.ndarray: H, W = video.shape[1:3] H_new = H - H % block_size W_new = W - W % block_size return video[:, :H_new, :W_new] def apply_float_colormap(img: torch.Tensor, colormap: str = "turbo") -> torch.Tensor: """Convert single channel to a color img. Args: img (torch.Tensor): (..., 1) float32 single channel image. colormap (str): Colormap for img. Returns: (..., 3) colored img with colors in [0, 1]. """ img = torch.nan_to_num(img, 0) if colormap == "gray": return img.repeat(1, 1, 3) img_long = (img * 255).long() img_long_min = torch.min(img_long) img_long_max = torch.max(img_long) assert img_long_min >= 0, f"the min value is {img_long_min}" assert img_long_max <= 255, f"the max value is {img_long_max}" return torch.tensor( colormaps[colormap].colors, # type: ignore device=img.device, )[img_long[..., 0]] def apply_depth_colormap( depth: torch.Tensor, acc: torch.Tensor | None = None, near_plane: float | None = None, far_plane: float | None = None, ) -> torch.Tensor: """Converts a depth image to color for easier analysis. Args: depth (torch.Tensor): (..., 1) float32 depth. acc (torch.Tensor | None): (..., 1) optional accumulation mask. near_plane: Closest depth to consider. If None, use min image value. far_plane: Furthest depth to consider. If None, use max image value. Returns: (..., 3) colored depth image with colors in [0, 1]. """ near_plane = near_plane or float(torch.min(depth)) far_plane = far_plane or float(torch.max(depth)) depth = (depth - near_plane) / (far_plane - near_plane + 1e-10) depth = torch.clip(depth, 0.0, 1.0) img = apply_float_colormap(depth, colormap="turbo") if acc is not None: img = img * acc + (1.0 - acc) return img def float2uint8(x): return (255.0 * x).astype(np.uint8) def uint82float(img): return np.ascontiguousarray(img) / 255.0 def drawMatches( img1, img2, kp1, kp2, num_vis=200, center=None, idx_vis=None, radius=2, seed=1234, mask=None, ): num_pts = len(kp1) if idx_vis is None: if num_vis < num_pts: rng = np.random.RandomState(seed) idx_vis = rng.choice(num_pts, num_vis, replace=False) else: idx_vis = np.arange(num_pts) kp1_vis = kp1[idx_vis] kp2_vis = kp2[idx_vis] h1, w1 = img1.shape[:2] h2, w2 = img2.shape[:2] kp1_vis[:, 0] = np.clip(kp1_vis[:, 0], a_min=0, a_max=w1 - 1) kp1_vis[:, 1] = np.clip(kp1_vis[:, 1], a_min=0, a_max=h1 - 1) kp2_vis[:, 0] = np.clip(kp2_vis[:, 0], a_min=0, a_max=w2 - 1) kp2_vis[:, 1] = np.clip(kp2_vis[:, 1], a_min=0, a_max=h2 - 1) img1 = float2uint8(img1) img2 = float2uint8(img2) if center is None: center = np.median(kp1, axis=0) set_max = range(128) colors = {m: i for i, m in enumerate(set_max)} hsv = colormaps.get_cmap("hsv") colors = { m: (255 * np.array(hsv(i / float(len(colors))))[:3][::-1]).astype(np.int32) for m, i in colors.items() } if mask is not None: ind = np.argsort(mask)[::-1] kp1_vis = kp1_vis[ind] kp2_vis = kp2_vis[ind] mask = mask[ind] for i, (pt1, pt2) in enumerate(zip(kp1_vis, kp2_vis)): # random_color = tuple(np.random.randint(low=0, high=255, size=(3,)).tolist()) coord_angle = np.arctan2(pt1[1] - center[1], pt1[0] - center[0]) corr_color = np.int32(64 * coord_angle / np.pi) % 128 color = tuple(colors[corr_color].tolist()) if ( (pt1[0] <= w1 - 1) and (pt1[0] >= 0) and (pt1[1] <= h1 - 1) and (pt1[1] >= 0) ): img1 = cv2.circle( img1, (int(pt1[0]), int(pt1[1])), radius, color, -1, cv2.LINE_AA ) if ( (pt2[0] <= w2 - 1) and (pt2[0] >= 0) and (pt2[1] <= h2 - 1) and (pt2[1] >= 0) ): if mask is not None and mask[i]: continue # img2 = cv2.drawMarker(img2, (int(pt2[0]), int(pt2[1])), color, markerType=cv2.MARKER_CROSS, # markerSize=int(5*radius), thickness=int(radius/2), line_type=cv2.LINE_AA) else: img2 = cv2.circle( img2, (int(pt2[0]), int(pt2[1])), radius, color, -1, cv2.LINE_AA ) out = np.concatenate([img1, img2], axis=1) return out def plot_correspondences( rgbs, kpts, query_id=0, masks=None, num_vis=1000000, radius=3, seed=1234 ): num_rgbs = len(rgbs) rng = np.random.RandomState(seed) permutation = rng.permutation(kpts.shape[1]) kpts = kpts[:, permutation, :][:, :num_vis] if masks is not None: masks = masks[:, permutation][:, :num_vis] rgbq = rgbs[query_id] # [h, w, 3] kptsq = kpts[query_id] # [n, 2] frames = [] for i in range(num_rgbs): rgbi = rgbs[i] kptsi = kpts[i] if masks is not None: maski = masks[i] else: maski = None frame = drawMatches( rgbq, rgbi, kptsq, kptsi, mask=maski, num_vis=num_vis, radius=radius, seed=seed, ) frames.append(frame) return frames ================================================ FILE: mvtracker/models/core/shape-of-motion/flow3d/vis/viewer.py ================================================ from pathlib import Path from typing import Callable, Literal, Optional, Tuple, Union import numpy as np from jaxtyping import Float32, UInt8 from nerfview import CameraState, Viewer from viser import Icon, ViserServer from flow3d.vis.playback_panel import add_gui_playback_group from flow3d.vis.render_panel import populate_render_tab class DynamicViewer(Viewer): def __init__( self, server: ViserServer, render_fn: Callable[ [CameraState, Tuple[int, int]], Union[ UInt8[np.ndarray, "H W 3"], Tuple[UInt8[np.ndarray, "H W 3"], Optional[Float32[np.ndarray, "H W"]]], ], ], num_frames: int, work_dir: str, mode: Literal["rendering", "training"] = "rendering", ): self.num_frames = num_frames self.work_dir = Path(work_dir) super().__init__(server, render_fn, mode) def _define_guis(self): super()._define_guis() server = self.server self._time_folder = server.gui.add_folder("Time") with self._time_folder: self._playback_guis = add_gui_playback_group( server, num_frames=self.num_frames, initial_fps=15.0, ) self._playback_guis[0].on_update(self.rerender) self._canonical_checkbox = server.gui.add_checkbox("Canonical", False) self._canonical_checkbox.on_update(self.rerender) _cached_playback_disabled = [] def _toggle_gui_playing(event): if event.target.value: nonlocal _cached_playback_disabled _cached_playback_disabled = [ gui.disabled for gui in self._playback_guis ] target_disabled = [True] * len(self._playback_guis) else: target_disabled = _cached_playback_disabled for gui, disabled in zip(self._playback_guis, target_disabled): gui.disabled = disabled self._canonical_checkbox.on_update(_toggle_gui_playing) self._render_track_checkbox = server.gui.add_checkbox("Render tracks", False) self._render_track_checkbox.on_update(self.rerender) tabs = server.gui.add_tab_group() with tabs.add_tab("Render", Icon.CAMERA): self.render_tab_state = populate_render_tab( server, Path(self.work_dir) / "camera_paths", self._playback_guis[0] ) ================================================ FILE: mvtracker/models/core/shape-of-motion/launch_davis.py ================================================ import os import subprocess from concurrent.futures import ProcessPoolExecutor import tyro def main( devices: list[int], seqs: list[str] | None, work_root: str, davis_root: str = "/shared/vye/datasets/DAVIS", image_name: str = "JPEGImages", res: str = "480p", depth_type: str = "aligned_depth_anything", ): img_dir = f"{davis_root}/{image_name}/{res}" if seqs is None: seqs = sorted(os.listdir(img_dir)) with ProcessPoolExecutor() as exc: for i, seq_name in enumerate(seqs): device = devices[i % len(devices)] cmd = ( f"CUDA_VISIBLE_DEVICES={device} python run_training.py " f"--work-dir {work_root}/{seq_name} data:davis " f"--data.seq_name {seq_name} --data.root_dir {davis_root} " f"--data.res {res} --data.depth_type {depth_type}" ) print(cmd) exc.submit(subprocess.call, cmd, shell=True) if __name__ == "__main__": tyro.cli(main) ================================================ FILE: mvtracker/models/core/spatracker/__init__.py ================================================ ================================================ FILE: mvtracker/models/core/spatracker/blocks.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import collections from itertools import repeat import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) return parse def exists(val): return val is not None def default(val, d): return val if exists(val) else d to_2tuple = _ntuple(2) class Mlp(nn.Module): """MLP as used in Vision Transformer, MLP-Mixer and related networks""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class ResidualBlock(nn.Module): def __init__(self, in_planes, planes, norm_fn="group", stride=1): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=3, padding=1, stride=stride, padding_mode="zeros", ) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros") self.relu = nn.ReLU(inplace=True) num_groups = planes // 8 if norm_fn == "group": self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) if not stride == 1: self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) elif norm_fn == "batch": self.norm1 = nn.BatchNorm2d(planes) self.norm2 = nn.BatchNorm2d(planes) if not stride == 1: self.norm3 = nn.BatchNorm2d(planes) elif norm_fn == "instance": self.norm1 = nn.InstanceNorm2d(planes) self.norm2 = nn.InstanceNorm2d(planes) if not stride == 1: self.norm3 = nn.InstanceNorm2d(planes) elif norm_fn == "none": self.norm1 = nn.Sequential() self.norm2 = nn.Sequential() if not stride == 1: self.norm3 = nn.Sequential() if stride == 1: self.downsample = None else: self.downsample = nn.Sequential( nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 ) def forward(self, x): y = x y = self.relu(self.norm1(self.conv1(y))) y = self.relu(self.norm2(self.conv2(y))) if self.downsample is not None: x = self.downsample(x) return self.relu(x + y) class BasicEncoder(nn.Module): def __init__( self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0, Embed3D=False ): super(BasicEncoder, self).__init__() self.stride = stride self.norm_fn = norm_fn self.in_planes = 64 if self.norm_fn == "group": self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) elif self.norm_fn == "batch": self.norm1 = nn.BatchNorm2d(self.in_planes) self.norm2 = nn.BatchNorm2d(output_dim * 2) elif self.norm_fn == "instance": self.norm1 = nn.InstanceNorm2d(self.in_planes) self.norm2 = nn.InstanceNorm2d(output_dim * 2) elif self.norm_fn == "none": self.norm1 = nn.Sequential() self.conv1 = nn.Conv2d( input_dim, self.in_planes, kernel_size=7, stride=2, padding=3, padding_mode="zeros", ) self.relu1 = nn.ReLU(inplace=True) self.shallow = False if self.shallow: self.layer1 = self._make_layer(64, stride=1) self.layer2 = self._make_layer(96, stride=2) self.layer3 = self._make_layer(128, stride=2) self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1) else: if Embed3D: self.conv_fuse = nn.Conv2d(64 + 63, self.in_planes, kernel_size=3, padding=1) self.layer1 = self._make_layer(64, stride=1) self.layer2 = self._make_layer(96, stride=2) self.layer3 = self._make_layer(128, stride=2) self.layer4 = self._make_layer(128, stride=2) # TODO: Add 2 layers. # self.layer5 = self._make_layer(128, stride=1) # self.layer6 = self._make_layer(128, stride=1) self.conv2 = nn.Conv2d( 128 + 128 + 96 + 64, output_dim * 2, kernel_size=3, padding=1, padding_mode="zeros", ) self.relu2 = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) self.dropout = None if dropout > 0: self.dropout = nn.Dropout2d(p=dropout) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def _make_layer(self, dim, stride=1): layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) layers = (layer1, layer2) self.in_planes = dim return nn.Sequential(*layers) def forward(self, x, feat_PE=None): _, _, H, W = x.shape x = self.conv1(x) x = self.norm1(x) x = self.relu1(x) if self.shallow: a = self.layer1(x) b = self.layer2(a) c = self.layer3(b) a = F.interpolate( a, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True, ) b = F.interpolate( b, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True, ) c = F.interpolate( c, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True, ) x = self.conv2(torch.cat([a, b, c], dim=1)) else: if feat_PE is not None: x = self.conv_fuse(torch.cat([x, feat_PE], dim=1)) a = self.layer1(x) else: a = self.layer1(x) b = self.layer2(a) c = self.layer3(b) d = self.layer4(c) a = F.interpolate( a, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True, ) b = F.interpolate( b, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True, ) c = F.interpolate( c, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True, ) d = F.interpolate( d, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True, ) x = self.conv2(torch.cat([a, b, c, d], dim=1)) x = self.norm2(x) x = self.relu2(x) x = self.conv3(x) if self.training and self.dropout is not None: x = self.dropout(x) return x class DeeperBasicEncoder(nn.Module): def __init__( self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0 ): super(DeeperBasicEncoder, self).__init__() self.stride = stride self.norm_fn = norm_fn self.in_planes = 64 if self.norm_fn == "group": self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) elif self.norm_fn == "batch": self.norm1 = nn.BatchNorm2d(self.in_planes) self.norm2 = nn.BatchNorm2d(output_dim * 2) elif self.norm_fn == "instance": self.norm1 = nn.InstanceNorm2d(self.in_planes) self.norm2 = nn.InstanceNorm2d(output_dim * 2) elif self.norm_fn == "none": self.norm1 = nn.Sequential() self.conv1 = nn.Conv2d( input_dim, self.in_planes, kernel_size=7, stride=2, padding=3, padding_mode="zeros", ) self.relu1 = nn.ReLU(inplace=True) self.layer1 = self._make_layer(64, stride=1) self.layer2 = self._make_layer(96, stride=2) self.layer3 = self._make_layer(128, stride=2) self.layer4 = self._make_layer(128, stride=2) self.layer5 = self._make_layer(128, stride=1) self.layer6 = self._make_layer(64, stride=2) self.conv2 = nn.Conv2d( 64 + 128 + 128 + 128 + 96 + 64, output_dim * 2, kernel_size=3, padding=1, padding_mode="zeros", ) self.relu2 = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) self.dropout = None if dropout > 0: self.dropout = nn.Dropout2d(p=dropout) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def _make_layer(self, dim, stride=1): layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) layers = (layer1, layer2) self.in_planes = dim return nn.Sequential(*layers) def forward(self, x, feat_PE=None): _, _, H, W = x.shape x = self.conv1(x) x = self.norm1(x) x = self.relu1(x) if feat_PE is not None: x = self.conv_fuse(torch.cat([x, feat_PE], dim=1)) a = self.layer1(x) else: a = self.layer1(x) b = self.layer2(a) c = self.layer3(b) d = self.layer4(c) e = self.layer5(d) f = self.layer6(e) a = F.interpolate( a, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True, ) b = F.interpolate( b, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True, ) c = F.interpolate( c, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True, ) d = F.interpolate( d, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True, ) e = F.interpolate( e, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True, ) f = F.interpolate( f, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True, ) x = self.conv2(torch.cat([a, b, c, d, e, f], dim=1)) x = self.norm2(x) x = self.relu2(x) x = self.conv3(x) if self.training and self.dropout is not None: x = self.dropout(x) return x class CorrBlock: def __init__(self, fmaps, num_levels=4, radius=4, depths_dnG=None): B, S, C, H_prev, W_prev = fmaps.shape self.S, self.C, self.H, self.W = S, C, H_prev, W_prev self.num_levels = num_levels self.radius = radius self.fmaps_pyramid = [] self.depth_pyramid = [] self.fmaps_pyramid.append(fmaps) if depths_dnG is not None: self.depth_pyramid.append(depths_dnG) for i in range(self.num_levels - 1): if depths_dnG is not None: depths_dnG_ = depths_dnG.reshape(B * S, 1, H_prev, W_prev) depths_dnG_ = F.avg_pool2d(depths_dnG_, 2, stride=2) _, _, H, W = depths_dnG_.shape depths_dnG = depths_dnG_.reshape(B, S, 1, H, W) self.depth_pyramid.append(depths_dnG) fmaps_ = fmaps.reshape(B * S, C, H_prev, W_prev) fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) _, _, H, W = fmaps_.shape fmaps = fmaps_.reshape(B, S, C, H, W) H_prev = H W_prev = W self.fmaps_pyramid.append(fmaps) def sample(self, coords): r = self.radius B, S, N, D = coords.shape assert D == 2 H, W = self.H, self.W out_pyramid = [] for i in range(self.num_levels): corrs = self.corrs_pyramid[i] # B, S, N, H, W _, _, _, H, W = corrs.shape dx = torch.linspace(-r, r, 2 * r + 1) dy = torch.linspace(-r, r, 2 * r + 1) delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to( coords.device ) centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) coords_lvl = centroid_lvl + delta_lvl corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl) corrs = corrs.view(B, S, N, -1) out_pyramid.append(corrs) out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2 return out.contiguous().float() def corr(self, targets): B, S, N, C = targets.shape assert C == self.C assert S == self.S fmap1 = targets self.corrs_pyramid = [] for fmaps in self.fmaps_pyramid: _, _, _, H, W = fmaps.shape fmap2s = fmaps.view(B, S, C, H * W) corrs = torch.matmul(fmap1, fmap2s) corrs = corrs.view(B, S, N, H, W) corrs = corrs / torch.sqrt(torch.tensor(C).float()) self.corrs_pyramid.append(corrs) def corr_sample(self, targets, coords, coords_dp=None): B, S, N, C = targets.shape r = self.radius Dim_c = (2 * r + 1) ** 2 assert C == self.C assert S == self.S out_pyramid = [] out_pyramid_dp = [] for i in range(self.num_levels): dx = torch.linspace(-r, r, 2 * r + 1) dy = torch.linspace(-r, r, 2 * r + 1) delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to( coords.device ) centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) coords_lvl = centroid_lvl + delta_lvl fmaps = self.fmaps_pyramid[i] _, _, _, H, W = fmaps.shape fmap2s = fmaps.view(B * S, C, H, W) if len(self.depth_pyramid) > 0: depths_dnG_i = self.depth_pyramid[i] depths_dnG_i = depths_dnG_i.view(B * S, 1, H, W) dnG_sample = bilinear_sampler(depths_dnG_i, coords_lvl.view(B * S, 1, N * Dim_c, 2)) dp_corrs = (dnG_sample.view(B * S, N, -1) - coords_dp[0]).abs() / coords_dp[0] out_pyramid_dp.append(dp_corrs) fmap2s_sample = bilinear_sampler(fmap2s, coords_lvl.view(B * S, 1, N * Dim_c, 2)) fmap2s_sample = fmap2s_sample.permute(0, 3, 1, 2) # B*S, N*Dim_c, C, -1 corrs = torch.matmul(targets.reshape(B * S * N, 1, -1), fmap2s_sample.reshape(B * S * N, Dim_c, -1).permute(0, 2, 1)) corrs = corrs / torch.sqrt(torch.tensor(C).float()) corrs = corrs.view(B, S, N, -1) out_pyramid.append(corrs) out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2 if len(self.depth_pyramid) > 0: out_dp = torch.cat(out_pyramid_dp, dim=-1) self.fcorrD = out_dp.contiguous().float() else: self.fcorrD = torch.zeros_like(out).contiguous().float() return out.contiguous().float() class Attention(nn.Module): def __init__(self, query_dim, num_heads=8, dim_head=48, qkv_bias=False, flash=False): super().__init__() inner_dim = self.inner_dim = dim_head * num_heads self.scale = dim_head ** -0.5 self.heads = num_heads self.flash = flash self.qkv = nn.Linear(query_dim, inner_dim * 3, bias=qkv_bias) self.proj = nn.Linear(inner_dim, query_dim) def forward(self, x, attn_bias=None): B, N1, _ = x.shape C = self.inner_dim h = self.heads qkv = self.qkv(x).reshape(B, N1, 3, h, C // h) q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] N2 = x.shape[1] k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3) v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3) q = q.reshape(B, N1, h, C // h).permute(0, 2, 1, 3) if self.flash == False: sim = (q @ k.transpose(-2, -1)) * self.scale if attn_bias is not None: sim = sim + attn_bias attn = sim.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N1, C) else: input_args = [x.half().contiguous() for x in [q, k, v]] x = F.scaled_dot_product_attention(*input_args).permute(0, 2, 1, 3).reshape(B, N1, -1) # type: ignore return self.proj(x.float()) class AttnBlock(nn.Module): """ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. """ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, flash=False, **block_kwargs): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.flash = flash self.attn = Attention( hidden_size, num_heads=num_heads, qkv_bias=True, flash=flash, **block_kwargs ) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp( in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0, ) def forward(self, x): x = x + self.attn(self.norm1(x)) x = x + self.mlp(self.norm2(x)) return x def bilinear_sampler(img, coords, mode="bilinear", mask=False): """Wrapper for grid_sample, uses pixel coordinates""" H, W = img.shape[-2:] xgrid, ygrid = coords.split([1, 1], dim=-1) # go to 0,1 then 0,2 then -1,1 xgrid = 2 * xgrid / (W - 1) - 1 ygrid = 2 * ygrid / (H - 1) - 1 grid = torch.cat([xgrid, ygrid], dim=-1) img = F.grid_sample(img, grid, align_corners=True) if mask: mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) return img, mask.float() return img class EUpdateFormer(nn.Module): """ Transformer model that updates track estimates. """ def __init__( self, space_depth=12, time_depth=12, input_dim=320, hidden_size=384, num_heads=8, output_dim=130, mlp_ratio=4.0, vq_depth=3, add_space_attn=True, add_time_attn=True, flash=True ): super().__init__() self.out_channels = 2 self.num_heads = num_heads self.hidden_size = hidden_size self.add_space_attn = add_space_attn self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) self.flash = flash self.flow_head = nn.Sequential( nn.Linear(hidden_size, output_dim, bias=True), nn.ReLU(inplace=True), nn.Linear(output_dim, output_dim, bias=True), nn.ReLU(inplace=True), nn.Linear(output_dim, output_dim, bias=True) ) cross_attn_kwargs = { "d_model": self.hidden_size, "nhead": 4, "layer_names": ['self', 'cross'] * 3, } from mvtracker.models.core.loftr import LocalFeatureTransformer self.gnn = LocalFeatureTransformer(cross_attn_kwargs) # Attention Modules in the temporal dimension self.time_blocks = nn.ModuleList( [ AttnBlock( hidden_size, num_heads, mlp_ratio=mlp_ratio, flash=flash, ) if add_time_attn else nn.Identity() for _ in range(time_depth) ] ) if add_space_attn: self.space_blocks = nn.ModuleList( [ AttnBlock( hidden_size, num_heads, mlp_ratio=mlp_ratio, flash=flash, ) for _ in range(space_depth) ] ) assert len(self.time_blocks) >= len(self.space_blocks) self.initialize_weights() def initialize_weights(self): def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) def forward(self, input_tensor, se3_feature): """ Updating with Transformer Args: input_tensor: B, N, T, C arap_embed: B, N, T, C """ B, N, T, C = input_tensor.shape x = self.input_transform(input_tensor) tokens = x K = 0 j = 0 for i in range(len(self.time_blocks)): tokens_time = rearrange(tokens, "b n t c -> (b n) t c", b=B, t=T, n=N + K) tokens_time = self.time_blocks[i](tokens_time) tokens = rearrange(tokens_time, "(b n) t c -> b n t c ", b=B, t=T, n=N + K) if self.add_space_attn and ( i % (len(self.time_blocks) // len(self.space_blocks)) == 0 ): tokens_space = rearrange(tokens, "b n t c -> (b t) n c ", b=B, t=T, n=N) tokens_space = self.space_blocks[j](tokens_space) tokens = rearrange(tokens_space, "(b t) n c -> b n t c ", b=B, t=T, n=N) j += 1 B, N, S, _ = tokens.shape feat0, feat1 = self.gnn(tokens.view(B * N * S, -1)[None, ...], se3_feature[None, ...]) flow = self.flow_head(feat0.view(B, N, S, -1)) return flow, feat1 def pix2cam(coords, intr): """ Args: coords: [B, T, N, 3] intr: [B, T, 3, 3] """ B, S, N, _, = coords.shape assert coords.shape == (B, S, N, 3) assert intr.shape == (B, S, 3, 3) coords = coords.detach() xy_src = coords.reshape(B * S * N, 3) intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B * S * N, 3, 3) xy_src = torch.cat([xy_src[..., :2], torch.ones_like(xy_src[..., :1])], dim=-1) xyz_src = (torch.inverse(intr) @ xy_src[..., None])[..., 0] dp_pred = coords[..., 2] xyz_src_ = (xyz_src * (dp_pred.reshape(B * S * N, 1))) xyz_src_ = xyz_src_.reshape(B, S, N, 3) return xyz_src_ def cam2pix(coords, intr): """ Args: coords: [B, T, N, 3] intr: [B, T, 3, 3] """ coords = coords.detach() B, S, N, _, = coords.shape xy_src = coords.reshape(B * S * N, 3).clone() intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B * S * N, 3, 3) xy_src = xy_src / (xy_src[..., 2:] + 1e-5) xyz_src = (intr @ xy_src[..., None])[..., 0] dp_pred = coords[..., 2] xyz_src[..., 2] *= dp_pred.reshape(S * N) xyz_src = xyz_src.reshape(B, S, N, 3) return xyz_src ================================================ FILE: mvtracker/models/core/spatracker/softsplat.py ================================================ #!/usr/bin/env python """The code of softsplat function is modified from: https://github.com/sniklaus/softmax-splatting/blob/master/softsplat.py """ import collections import os import re import typing import cupy import torch objCudacache = {} def cuda_int32(intIn: int): return cupy.int32(intIn) def cuda_float32(fltIn: float): return cupy.float32(fltIn) def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict): if 'device' not in objCudacache: objCudacache['device'] = torch.cuda.get_device_name() strKey = strFunction for strVariable in objVariables: objValue = objVariables[strVariable] strKey += strVariable if objValue is None: continue elif type(objValue) == int: strKey += str(objValue) elif type(objValue) == float: strKey += str(objValue) elif type(objValue) == bool: strKey += str(objValue) elif type(objValue) == str: strKey += objValue elif type(objValue) == torch.Tensor: strKey += str(objValue.dtype) strKey += str(objValue.shape) strKey += str(objValue.stride()) elif True: print(strVariable, type(objValue)) assert (False) strKey += objCudacache['device'] if strKey not in objCudacache: for strVariable in objVariables: objValue = objVariables[strVariable] if objValue is None: continue elif type(objValue) == int: strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) elif type(objValue) == float: strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) elif type(objValue) == bool: strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) elif type(objValue) == str: strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: strKernel = strKernel.replace('{{type}}', 'unsigned char') elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: strKernel = strKernel.replace('{{type}}', 'half') elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: strKernel = strKernel.replace('{{type}}', 'float') elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: strKernel = strKernel.replace('{{type}}', 'double') elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: strKernel = strKernel.replace('{{type}}', 'int') elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: strKernel = strKernel.replace('{{type}}', 'long') elif type(objValue) == torch.Tensor: print(strVariable, objValue.dtype) assert (False) elif True: print(strVariable, type(objValue)) assert (False) while True: objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) if objMatch is None: break intArg = int(objMatch.group(2)) strTensor = objMatch.group(4) intSizes = objVariables[strTensor].size() strKernel = strKernel.replace(objMatch.group(), str( intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) while True: objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) if objMatch is None: break intStart = objMatch.span()[1] intStop = objMatch.span()[1] intParentheses = 1 while True: intParentheses += 1 if strKernel[intStop] == '(' else 0 intParentheses -= 1 if strKernel[intStop] == ')' else 0 if intParentheses == 0: break intStop += 1 intArgs = int(objMatch.group(2)) strArgs = strKernel[intStart:intStop].split(',') assert (intArgs == len(strArgs) - 1) strTensor = strArgs[0] intStrides = objVariables[strTensor].stride() strIndex = [] for intArg in range(intArgs): strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str( intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[ intArg].item()) + ')') strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') while True: objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) if objMatch is None: break intStart = objMatch.span()[1] intStop = objMatch.span()[1] intParentheses = 1 while True: intParentheses += 1 if strKernel[intStop] == '(' else 0 intParentheses -= 1 if strKernel[intStop] == ')' else 0 if intParentheses == 0: break intStop += 1 intArgs = int(objMatch.group(2)) strArgs = strKernel[intStart:intStop].split(',') assert (intArgs == len(strArgs) - 1) strTensor = strArgs[0] intStrides = objVariables[strTensor].stride() strIndex = [] for intArg in range(intArgs): strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str( intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[ intArg].item()) + ')') strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') objCudacache[strKey] = { 'strFunction': strFunction, 'strKernel': strKernel } return strKey @cupy.memoize(for_each_device=True) def cuda_launch(strKey: str): if 'CUDA_HOME' not in os.environ: os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path() return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple( ['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function( objCudacache[strKey]['strFunction']) ########################################################## def softsplat( tenIn: torch.Tensor, tenFlow: torch.Tensor, tenMetric: typing.Optional[torch.Tensor], strMode: str, tenoutH=None, tenoutW=None, use_pointcloud_splatting=False, return_normalization_tensor=False, ): assert (strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft']) if strMode == 'sum': assert (tenMetric is None) if strMode == 'avg': assert (tenMetric is None) if strMode.split('-')[0] == 'linear': assert (tenMetric is not None) if strMode.split('-')[0] == 'soft': assert (tenMetric is not None) if strMode == 'avg': tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1) elif strMode.split('-')[0] == 'linear': tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) elif strMode.split('-')[0] == 'soft': tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) # If tenIn only contains a HW grid where each position in the grid will be # taken into account for splatting as (grid_x + flow_x, grid_y + flow_y), # then we use the original softsplat function which was designed for this. # Otherwise, we assume the positions of the points in the grid do not matter # and only the flow should be taken into account as (flow_x, flow_y) # to determine the splatted position. if use_pointcloud_splatting: tenOut = softsplat_pointcloud_func.apply(tenIn, tenFlow, tenoutH, tenoutW) else: tenOut = softsplat_func.apply(tenIn, tenFlow, tenoutH, tenoutW) if strMode.split('-')[0] in ['avg', 'linear', 'soft']: tenNormalize = tenOut[:, -1:, :, :] if len(strMode.split('-')) == 1: tenNormalize = tenNormalize + 0.0001 elif strMode.split('-')[1] == 'addeps': tenNormalize = tenNormalize + 0.0001 elif strMode.split('-')[1] == 'zeroeps': tenNormalize[tenNormalize == 0.0] = 1.0 elif strMode.split('-')[1] == 'clipeps': tenNormalize = tenNormalize.clip(0.0001, None) tenOut = tenOut[:, :-1, :, :] / tenNormalize if return_normalization_tensor: return tenOut, tenNormalize else: return tenOut class softsplat_func(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) def forward(self, tenIn, tenFlow, H=None, W=None): if H is None: tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) else: tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], H, W]) if tenIn.is_cuda == True: cuda_launch(cuda_kernel('softsplat_out', ''' extern "C" __global__ void __launch_bounds__(512) softsplat_out( const long long int n, const {{type}}* __restrict__ tenIn, const {{type}}* __restrict__ tenFlow, {{type}}* __restrict__ tenOut ) { for (long long int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { const int intN = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn) / SIZE_1(tenIn) ) % SIZE_0(tenIn); const int intC = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn) ) % SIZE_1(tenIn); const int intY = ( intIndex / SIZE_3(tenIn) ) % SIZE_2(tenIn); const int intX = ( intIndex ) % SIZE_3(tenIn); assert(SIZE_1(tenFlow) == 2); {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); if (isfinite(fltX) == false) { return; } if (isfinite(fltY) == false) { return; } {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); int intNorthwestX = (int) (floor(fltX)); int intNorthwestY = (int) (floor(fltY)); int intNortheastX = intNorthwestX + 1; int intNortheastY = intNorthwestY; int intSouthwestX = intNorthwestX; int intSouthwestY = intNorthwestY + 1; int intSoutheastX = intNorthwestX + 1; int intSoutheastY = intNorthwestY + 1; {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); } if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); } if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); } if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); } } } ''', { 'tenIn': tenIn, 'tenFlow': tenFlow, 'tenOut': tenOut }))( grid=tuple([int((tenIn.nelement() + 512 - 1) / 512), 1, 1]), block=tuple([512, 1, 1]), args=[cuda_int32(tenIn.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()], stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) ) elif tenIn.is_cuda != True: assert (False) self.save_for_backward(tenIn, tenFlow) return tenOut @staticmethod @torch.cuda.amp.custom_bwd def backward(self, tenOutgrad): tenIn, tenFlow = self.saved_tensors tenOutgrad = tenOutgrad.contiguous(); assert (tenOutgrad.is_cuda == True) tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if \ self.needs_input_grad[0] == True else None tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if \ self.needs_input_grad[1] == True else None Hgrad = None Wgrad = None if tenIngrad is not None: cuda_launch(cuda_kernel('softsplat_ingrad', ''' extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad( const long long int n, const {{type}}* __restrict__ tenIn, const {{type}}* __restrict__ tenFlow, const {{type}}* __restrict__ tenOutgrad, {{type}}* __restrict__ tenIngrad, {{type}}* __restrict__ tenFlowgrad ) { for (long long int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); const int intX = ( intIndex ) % SIZE_3(tenIngrad); assert(SIZE_1(tenFlow) == 2); {{type}} fltIngrad = 0.0f; {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); if (isfinite(fltX) == false) { return; } if (isfinite(fltY) == false) { return; } int intNorthwestX = (int) (floor(fltX)); int intNorthwestY = (int) (floor(fltY)); int intNortheastX = intNorthwestX + 1; int intNortheastY = intNorthwestY; int intSouthwestX = intNorthwestX; int intSouthwestY = intNorthwestY + 1; int intSoutheastX = intNorthwestX + 1; int intSoutheastY = intNorthwestY + 1; {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; } if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; } if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; } if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; } tenIngrad[intIndex] = fltIngrad; } } ''', { 'tenIn': tenIn, 'tenFlow': tenFlow, 'tenOutgrad': tenOutgrad, 'tenIngrad': tenIngrad, 'tenFlowgrad': tenFlowgrad }))( grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), block=tuple([512, 1, 1]), args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None], stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) ) if tenFlowgrad is not None: cuda_launch(cuda_kernel('softsplat_flowgrad', ''' extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( const long long int n, const {{type}}* __restrict__ tenIn, const {{type}}* __restrict__ tenFlow, const {{type}}* __restrict__ tenOutgrad, {{type}}* __restrict__ tenIngrad, {{type}}* __restrict__ tenFlowgrad ) { for (long long int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); assert(SIZE_1(tenFlow) == 2); {{type}} fltFlowgrad = 0.0f; {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); if (isfinite(fltX) == false) { return; } if (isfinite(fltY) == false) { return; } int intNorthwestX = (int) (floor(fltX)); int intNorthwestY = (int) (floor(fltY)); int intNortheastX = intNorthwestX + 1; int intNortheastY = intNorthwestY; int intSouthwestX = intNorthwestX; int intSouthwestY = intNorthwestY + 1; int intSoutheastX = intNorthwestX + 1; int intSoutheastY = intNorthwestY + 1; {{type}} fltNorthwest = 0.0f; {{type}} fltNortheast = 0.0f; {{type}} fltSouthwest = 0.0f; {{type}} fltSoutheast = 0.0f; if (intC == 0) { fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); } else if (intC == 1) { fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); } for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; } if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; } if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; } if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; } } tenFlowgrad[intIndex] = fltFlowgrad; } } ''', { 'tenIn': tenIn, 'tenFlow': tenFlow, 'tenOutgrad': tenOutgrad, 'tenIngrad': tenIngrad, 'tenFlowgrad': tenFlowgrad }))( grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), block=tuple([512, 1, 1]), args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()], stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) ) return tenIngrad, tenFlowgrad, Hgrad, Wgrad def cuda_int64(intIn: int): return cupy.int64(intIn) def cuda_kernel_longlong(strFunction: str, strKernel: str, objVariables: typing.Dict): if 'device' not in objCudacache: objCudacache['device'] = torch.cuda.get_device_name() strKey = strFunction for strVariable in objVariables: objValue = objVariables[strVariable] strKey += strVariable if objValue is None: continue elif type(objValue) == int: strKey += str(objValue) elif type(objValue) == float: strKey += str(objValue) elif type(objValue) == bool: strKey += str(objValue) elif type(objValue) == str: strKey += objValue elif type(objValue) == torch.Tensor: strKey += str(objValue.dtype) strKey += str(objValue.shape) strKey += str(objValue.stride()) elif True: print(strVariable, type(objValue)) assert (False) strKey += objCudacache['device'] if strKey not in objCudacache: for strVariable in objVariables: objValue = objVariables[strVariable] if objValue is None: continue elif type(objValue) == int: strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) elif type(objValue) == float: strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) elif type(objValue) == bool: strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) elif type(objValue) == str: strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: strKernel = strKernel.replace('{{type}}', 'unsigned char') elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: strKernel = strKernel.replace('{{type}}', 'half') elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: strKernel = strKernel.replace('{{type}}', 'float') elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: strKernel = strKernel.replace('{{type}}', 'double') elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: strKernel = strKernel.replace('{{type}}', 'int') elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: strKernel = strKernel.replace('{{type}}', 'long') elif type(objValue) == torch.Tensor: print(strVariable, objValue.dtype) assert (False) elif True: print(strVariable, type(objValue)) assert (False) while True: objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) if objMatch is None: break intArg = int(objMatch.group(2)) strTensor = objMatch.group(4) intSizes = objVariables[strTensor].size() strKernel = strKernel.replace(objMatch.group(), str( intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) while True: objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) if objMatch is None: break intStart = objMatch.span()[1] intStop = objMatch.span()[1] intParentheses = 1 while True: intParentheses += 1 if strKernel[intStop] == '(' else 0 intParentheses -= 1 if strKernel[intStop] == ')' else 0 if intParentheses == 0: break intStop += 1 intArgs = int(objMatch.group(2)) strArgs = strKernel[intStart:intStop].split(',') assert (intArgs == len(strArgs) - 1) strTensor = strArgs[0] intStrides = objVariables[strTensor].stride() strIndex = [] for intArg in range(intArgs): idx_expr = strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() stride_val = ( intStrides[intArg] if not torch.is_tensor(intStrides[intArg]) else intStrides[intArg].item() ) strIndex.append( '(static_cast(' + idx_expr + ') * ' + str(stride_val) + ')' ) strKernel = strKernel.replace( 'OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + ' + '.join(strIndex) + ')' ) while True: objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) if objMatch is None: break intStart = objMatch.span()[1] intStop = objMatch.span()[1] intParentheses = 1 while True: intParentheses += 1 if strKernel[intStop] == '(' else 0 intParentheses -= 1 if strKernel[intStop] == ')' else 0 if intParentheses == 0: break intStop += 1 intArgs = int(objMatch.group(2)) strArgs = strKernel[intStart:intStop].split(',') assert (intArgs == len(strArgs) - 1) strTensor = strArgs[0] intStrides = objVariables[strTensor].stride() strIndex = [] for intArg in range(intArgs): idx_expr = strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() stride_val = ( intStrides[intArg] if not torch.is_tensor(intStrides[intArg]) else intStrides[intArg].item() ) strIndex.append( '(static_cast(' + idx_expr + ') * ' + str(stride_val) + ')' ) strKernel = strKernel.replace( 'VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + ' + '.join(strIndex) + ']' ) objCudacache[strKey] = { 'strFunction': strFunction, 'strKernel': strKernel } return strKey class softsplat_pointcloud_func(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) def forward(self, tenIn, tenFlow, H=None, W=None): if H is None: tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) else: tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], H, W]) if tenIn.is_cuda == True: cuda_launch(cuda_kernel_longlong('softsplat_pointcloud_out', ''' extern "C" __global__ void __launch_bounds__(512) softsplat_pointcloud_out( const long long int n, const {{type}}* __restrict__ tenIn, const {{type}}* __restrict__ tenFlow, {{type}}* __restrict__ tenOut ) { for (long long int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { const int intN = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn) / SIZE_1(tenIn) ) % SIZE_0(tenIn); const int intC = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn) ) % SIZE_1(tenIn); const int intY = ( intIndex / SIZE_3(tenIn) ) % SIZE_2(tenIn); const int intX = ( intIndex ) % SIZE_3(tenIn); assert(SIZE_1(tenFlow) == 2); {{type}} fltX = ({{type}}) VALUE_4(tenFlow, intN, 0, intY, intX); {{type}} fltY = ({{type}}) VALUE_4(tenFlow, intN, 1, intY, intX); if (isfinite(fltX) == false) { return; } if (isfinite(fltY) == false) { return; } {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); int intNorthwestX = (int) (floor(fltX)); int intNorthwestY = (int) (floor(fltY)); int intNortheastX = intNorthwestX + 1; int intNortheastY = intNorthwestY; int intSouthwestX = intNorthwestX; int intSouthwestY = intNorthwestY + 1; int intSoutheastX = intNorthwestX + 1; int intSoutheastY = intNorthwestY + 1; {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); } if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); } if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); } if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); } } } ''', { 'tenIn': tenIn, 'tenFlow': tenFlow, 'tenOut': tenOut }))( grid=tuple([int((tenIn.nelement() + 512 - 1) / 512), 1, 1]), block=tuple([512, 1, 1]), args=[cuda_int64(tenIn.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()], stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) ) elif tenIn.is_cuda != True: assert (False) self.save_for_backward(tenIn, tenFlow) return tenOut @staticmethod @torch.cuda.amp.custom_bwd def backward(self, tenOutgrad): tenIn, tenFlow = self.saved_tensors tenOutgrad = tenOutgrad.contiguous(); assert (tenOutgrad.is_cuda == True) tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if \ self.needs_input_grad[0] == True else None tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if \ self.needs_input_grad[1] == True else None Hgrad = None Wgrad = None if tenIngrad is not None: cuda_launch(cuda_kernel_longlong('softsplat_pointcloud_ingrad', ''' extern "C" __global__ void __launch_bounds__(512) softsplat_pointcloud_ingrad( const long long int n, const {{type}}* __restrict__ tenIn, const {{type}}* __restrict__ tenFlow, const {{type}}* __restrict__ tenOutgrad, {{type}}* __restrict__ tenIngrad, {{type}}* __restrict__ tenFlowgrad ) { for (long long int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); const int intX = ( intIndex ) % SIZE_3(tenIngrad); assert(SIZE_1(tenFlow) == 2); {{type}} fltIngrad = 0.0f; {{type}} fltX = ({{type}}) VALUE_4(tenFlow, intN, 0, intY, intX); {{type}} fltY = ({{type}}) VALUE_4(tenFlow, intN, 1, intY, intX); if (isfinite(fltX) == false) { return; } if (isfinite(fltY) == false) { return; } int intNorthwestX = (int) (floor(fltX)); int intNorthwestY = (int) (floor(fltY)); int intNortheastX = intNorthwestX + 1; int intNortheastY = intNorthwestY; int intSouthwestX = intNorthwestX; int intSouthwestY = intNorthwestY + 1; int intSoutheastX = intNorthwestX + 1; int intSoutheastY = intNorthwestY + 1; {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; } if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; } if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; } if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; } tenIngrad[intIndex] = fltIngrad; } } ''', { 'tenIn': tenIn, 'tenFlow': tenFlow, 'tenOutgrad': tenOutgrad, 'tenIngrad': tenIngrad, 'tenFlowgrad': tenFlowgrad }))( grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), block=tuple([512, 1, 1]), args=[cuda_int64(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None], stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) ) if tenFlowgrad is not None: cuda_launch(cuda_kernel_longlong('softsplat_pointcloud_flowgrad', ''' extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( const long long int n, const {{type}}* __restrict__ tenIn, const {{type}}* __restrict__ tenFlow, const {{type}}* __restrict__ tenOutgrad, {{type}}* __restrict__ tenIngrad, {{type}}* __restrict__ tenFlowgrad ) { for (long long int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); assert(SIZE_1(tenFlow) == 2); {{type}} fltFlowgrad = 0.0f; {{type}} fltX = ({{type}}) VALUE_4(tenFlow, intN, 0, intY, intX); {{type}} fltY = ({{type}}) VALUE_4(tenFlow, intN, 1, intY, intX); if (isfinite(fltX) == false) { return; } if (isfinite(fltY) == false) { return; } int intNorthwestX = (int) (floor(fltX)); int intNorthwestY = (int) (floor(fltY)); int intNortheastX = intNorthwestX + 1; int intNortheastY = intNorthwestY; int intSouthwestX = intNorthwestX; int intSouthwestY = intNorthwestY + 1; int intSoutheastX = intNorthwestX + 1; int intSoutheastY = intNorthwestY + 1; {{type}} fltNorthwest = 0.0f; {{type}} fltNortheast = 0.0f; {{type}} fltSouthwest = 0.0f; {{type}} fltSoutheast = 0.0f; if (intC == 0) { fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); } else if (intC == 1) { fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); } for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; } if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; } if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; } if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; } } tenFlowgrad[intIndex] = fltFlowgrad; } } ''', { 'tenIn': tenIn, 'tenFlow': tenFlow, 'tenOutgrad': tenOutgrad, 'tenIngrad': tenIngrad, 'tenFlowgrad': tenFlowgrad }))( grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), block=tuple([512, 1, 1]), args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()], stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) ) return tenIngrad, tenFlowgrad, Hgrad, Wgrad ================================================ FILE: mvtracker/models/core/spatracker/spatracker_monocular.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. import logging import warnings import numpy as np import torch import torch.nn.functional as F from einops import rearrange from torch import nn as nn from mvtracker.models.core.embeddings import ( get_3d_embedding, get_1d_sincos_pos_embed_from_grid, get_2d_sincos_pos_embed, get_3d_sincos_pos_embed_from_grid, Embedder_Fourier, ) from mvtracker.models.core.model_utils import ( bilinear_sample2d, smart_cat, sample_features5d, pixel_xy_and_camera_z_to_world_space ) from mvtracker.models.core.spatracker.blocks import ( BasicEncoder, CorrBlock, EUpdateFormer, pix2cam, cam2pix ) from mvtracker.models.core.spatracker.softsplat import softsplat # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. def sample_pos_embed(grid_size, embed_dim, coords): if coords.shape[-1] == 2: pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim, grid_size=grid_size) pos_embed = ( torch.from_numpy(pos_embed) .reshape(grid_size[0], grid_size[1], embed_dim) .float() .unsqueeze(0) .to(coords.device) ) sampled_pos_embed = bilinear_sample2d( pos_embed.permute(0, 3, 1, 2), coords[:, 0, :, 0], coords[:, 0, :, 1] ) elif coords.shape[-1] == 3: sampled_pos_embed = get_3d_sincos_pos_embed_from_grid( embed_dim, coords[:, :1, ...] ).float()[:, 0, ...].permute(0, 2, 1) return sampled_pos_embed class SpaTracker(nn.Module): def __init__( self, sliding_window_len=8, stride=8, add_space_attn=True, num_heads=8, hidden_size=384, space_depth=12, time_depth=12, triplane_zres=128, ): super(SpaTracker, self).__init__() self.S = sliding_window_len self.stride = stride self.hidden_dim = 256 self.latent_dim = latent_dim = 128 self.b_latent_dim = self.latent_dim // 3 self.corr_levels = 4 self.corr_radius = 3 self.add_space_attn = add_space_attn self.triplane_zres = triplane_zres # @Encoder self.fnet = BasicEncoder(input_dim=3, output_dim=self.latent_dim, norm_fn="instance", dropout=0, stride=stride, Embed3D=False ) # conv head for the tri-plane features self.headyz = nn.Sequential( nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1)) self.headxz = nn.Sequential( nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1)) # @UpdateFormer self.updateformer = EUpdateFormer( space_depth=space_depth, time_depth=time_depth, input_dim=456, hidden_size=hidden_size, num_heads=num_heads, output_dim=latent_dim + 3, mlp_ratio=4.0, add_space_attn=add_space_attn, flash=True ) self.support_features = torch.zeros(100, 384).to("cuda") + 0.1 self.norm = nn.GroupNorm(1, self.latent_dim) self.ffeat_updater = nn.Sequential( nn.Linear(self.latent_dim, self.latent_dim), nn.GELU(), ) self.ffeatyz_updater = nn.Sequential( nn.Linear(self.latent_dim, self.latent_dim), nn.GELU(), ) self.ffeatxz_updater = nn.Sequential( nn.Linear(self.latent_dim, self.latent_dim), nn.GELU(), ) # TODO @NeuralArap: optimize the arap self.embed_traj = Embedder_Fourier( input_dim=5, max_freq_log2=5.0, N_freqs=3, include_input=True ) self.embed3d = Embedder_Fourier( input_dim=3, max_freq_log2=10.0, N_freqs=10, include_input=True ) self.embedConv = nn.Conv2d(self.latent_dim + 63, self.latent_dim, 3, padding=1) # @Vis_predictor self.vis_predictor = nn.Sequential( nn.Linear(128, 1), ) self.embedProj = nn.Linear(63, 456) self.zeroMLPflow = nn.Linear(195, 130) def prepare_track(self, rgbds, queries): """ NOTE: Normalized the rgbs and sorted the queries via their first appeared time Args: rgbds: the input rgbd images (B T 4 H W) queries: the input queries (B N 4) Return: rgbds: the normalized rgbds (B T 4 H W) queries: the sorted queries (B N 4) track_mask: """ assert (rgbds.shape[2] == 4) and (queries.shape[2] == 4) # Step1: normalize the rgbs input device = rgbds.device rgbds[:, :, :3, ...] = 2 * (rgbds[:, :, :3, ...] / 255.0) - 1.0 B, T, C, H, W = rgbds.shape B, N, __ = queries.shape self.traj_e = torch.zeros((B, T, N, 3), device=device) self.vis_e = torch.zeros((B, T, N), device=device) # Step2: sort the points via their first appeared time first_positive_inds = queries[0, :, 0].long() __, sort_inds = torch.sort(first_positive_inds, dim=0, descending=False) inv_sort_inds = torch.argsort(sort_inds, dim=0) first_positive_sorted_inds = first_positive_inds[sort_inds] # check if can be inverse assert torch.allclose( first_positive_inds, first_positive_inds[sort_inds][inv_sort_inds] ) # filter those points never appear points during 1 - T ind_array = torch.arange(T, device=device) ind_array = ind_array[None, :, None].repeat(B, 1, N) track_mask = (ind_array >= first_positive_inds[None, None, :]).unsqueeze(-1) # scale the coords_init coords_init = queries[:, :, 1:].reshape(B, 1, N, 3).repeat( 1, self.S, 1, 1 ) coords_init[..., :2] /= float(self.stride) # Step3: initial the regular grid gridx = torch.linspace(0, W // self.stride - 1, W // self.stride) gridy = torch.linspace(0, H // self.stride - 1, H // self.stride) gridx, gridy = torch.meshgrid(gridx, gridy, indexing="ij") gridxy = torch.stack([gridx, gridy], dim=-1).to(rgbds.device).permute( 2, 1, 0 ) vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10 # Step4: initial traj for neural arap T_series = torch.linspace(0, 5, T).reshape(1, T, 1, 1).cuda() # 1 T 1 1 T_series = T_series.repeat(B, 1, N, 1) # get the 3d traj in the camera coordinates intr_init = self.intrs[:, queries[0, :, 0].long()] Traj_series = pix2cam(queries[:, :, None, 1:].double(), intr_init.double()) # torch.inverse(intr_init.double())@queries[:,:,1:,None].double() # B N 3 1 Traj_series = Traj_series.repeat(1, 1, T, 1).permute(0, 2, 1, 3).float() Traj_series = torch.cat([T_series, Traj_series], dim=-1) # get the indicator for the neural arap Traj_mask = -1e2 * torch.ones_like(T_series) Traj_series = torch.cat([Traj_series, Traj_mask], dim=-1) return ( rgbds, first_positive_inds, first_positive_sorted_inds, sort_inds, inv_sort_inds, track_mask, gridxy, coords_init[..., sort_inds, :].clone(), vis_init, Traj_series[..., sort_inds, :].clone() ) def sample_trifeat(self, t, coords, featMapxy, featMapyz, featMapxz): """ Sample the features from the 5D triplane feature map 3*(B S C H W) Args: t: the time index coords: the coordinates of the points B S N 3 featMapxy: the feature map B S C Hx Wy featMapyz: the feature map B S C Hy Wz featMapxz: the feature map B S C Hx Wz """ # get xy_t yz_t xz_t queried_t = t.reshape(1, 1, -1, 1) xy_t = torch.cat( [queried_t, coords[..., [0, 1]]], dim=-1 ) yz_t = torch.cat( [queried_t, coords[..., [1, 2]]], dim=-1 ) xz_t = torch.cat( [queried_t, coords[..., [0, 2]]], dim=-1 ) featxy_init = sample_features5d(featMapxy, xy_t) featyz_init = sample_features5d(featMapyz, yz_t) featxz_init = sample_features5d(featMapxz, xz_t) featxy_init = featxy_init.repeat(1, self.S, 1, 1) featyz_init = featyz_init.repeat(1, self.S, 1, 1) featxz_init = featxz_init.repeat(1, self.S, 1, 1) return featxy_init, featyz_init, featxz_init def neural_arap(self, coords, Traj_arap, intrs_S, T_mark): """ calculate the ARAP embedding and offset Args: coords: the coordinates of the current points 1 S N' 3 Traj_arap: the trajectory of the points 1 T N' 5 intrs_S: the camera intrinsics B S 3 3 """ coords_out = coords.clone() coords_out[..., :2] *= float(self.stride) coords_out[..., 2] = coords_out[..., 2] / self.Dz coords_out[..., 2] = coords_out[..., 2] * (self.d_far - self.d_near) + self.d_near intrs_S = intrs_S[:, :, None, ...].repeat(1, 1, coords_out.shape[2], 1, 1) B, S, N, D = coords_out.shape if S != intrs_S.shape[1]: intrs_S = torch.cat( [intrs_S, intrs_S[:, -1:].repeat(1, S - intrs_S.shape[1], 1, 1, 1)], dim=1 ) T_mark = torch.cat( [T_mark, T_mark[:, -1:].repeat(1, S - T_mark.shape[1], 1)], dim=1 ) xyz_ = pix2cam(coords_out.double(), intrs_S.double()[:, :, 0]) xyz_ = xyz_.float() xyz_embed = torch.cat([T_mark[..., None], xyz_, torch.zeros_like(T_mark[..., None])], dim=-1) xyz_embed = self.embed_traj(xyz_embed) Traj_arap_embed = self.embed_traj(Traj_arap) d_xyz, traj_feat = self.arapFormer(xyz_embed, Traj_arap_embed) # update in camera coordinate xyz_ = xyz_ + d_xyz.clamp(-5, 5) # project back to the image plane coords_out = cam2pix(xyz_.double(), intrs_S[:, :, 0].double()).float() # resize back coords_out[..., :2] /= float(self.stride) coords_out[..., 2] = (coords_out[..., 2] - self.d_near) / (self.d_far - self.d_near) coords_out[..., 2] *= self.Dz return xyz_, coords_out, traj_feat def gradient_arap(self, coords, aff_avg=None, aff_std=None, aff_f_sg=None, iter=0, iter_num=4, neigh_idx=None, intr=None, msk_track=None): with torch.enable_grad(): coords.requires_grad_(True) y = self.ARAP_ln(coords, aff_f_sg=aff_f_sg, neigh_idx=neigh_idx, iter=iter, iter_num=iter_num, intr=intr, msk_track=msk_track) d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=coords, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True)[0] return gradients.detach() def forward_iteration( self, fmapXY, fmapYZ, fmapXZ, coords_init, feat_init=None, vis_init=None, track_mask=None, iters=4, intrs_S=None, ): B, S_init, N, D = coords_init.shape assert D == 3 assert B == 1 B, S, __, H8, W8 = fmapXY.shape device = fmapXY.device if S_init < S: coords = torch.cat( [coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1 ) vis_init = torch.cat( [vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1 ) intrs_S = torch.cat( [intrs_S, intrs_S[:, -1].repeat(1, S - S_init, 1, 1)], dim=1 ) else: coords = coords_init.clone() fcorr_fnXY = CorrBlock( fmapXY, num_levels=self.corr_levels, radius=self.corr_radius ) fcorr_fnYZ = CorrBlock( fmapYZ, num_levels=self.corr_levels, radius=self.corr_radius ) fcorr_fnXZ = CorrBlock( fmapXZ, num_levels=self.corr_levels, radius=self.corr_radius ) ffeats = torch.split(feat_init.clone(), dim=-1, split_size_or_sections=1) ffeats = [f.squeeze(-1) for f in ffeats] times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1) pos_embed = sample_pos_embed( grid_size=(H8, W8), embed_dim=456, coords=coords[..., :2], ) pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1) times_embed = ( torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None] .repeat(B, 1, 1) .float() .to(device) ) coord_predictions = [] attn_predictions = [] Rot_ln = 0 support_feat = self.support_features for __ in range(iters): coords = coords.detach() # if self.args.if_ARAP == True: # # refine the track with arap # xyz_pred, coords, flows_cat0 = self.neural_arap(coords.detach(), # Traj_arap.detach(), # intrs_S, T_mark) fcorrsXY = fcorr_fnXY.corr_sample(ffeats[0], coords[..., :2]) fcorrsYZ = fcorr_fnYZ.corr_sample(ffeats[1], coords[..., [1, 2]]) fcorrsXZ = fcorr_fnXZ.corr_sample(ffeats[2], coords[..., [0, 2]]) # fcorrs = fcorrsXY fcorrs = fcorrsXY + fcorrsYZ + fcorrsXZ LRR = fcorrs.shape[3] fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR) flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 3) flows_cat = get_3d_embedding(flows_, 64, cat_coords=True) flows_cat = self.zeroMLPflow(flows_cat) ffeats_xy = ffeats[0].permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) ffeats_yz = ffeats[1].permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) ffeats_xz = ffeats[2].permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) ffeats_ = ffeats_xy + ffeats_yz + ffeats_xz if track_mask.shape[1] < vis_init.shape[1]: track_mask = torch.cat( [ track_mask, torch.zeros_like(track_mask[:, 0]).repeat( 1, vis_init.shape[1] - track_mask.shape[1], 1, 1 ), ], dim=1, ) concat = ( torch.cat([track_mask, vis_init], dim=2) .permute(0, 2, 1, 3) .reshape(B * N, S, 2) ) transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2) if transformer_input.shape[-1] < pos_embed.shape[-1]: # padding the transformer_input to the same dimension as pos_embed transformer_input = F.pad( transformer_input, (0, pos_embed.shape[-1] - transformer_input.shape[-1]), "constant", 0 ) x = transformer_input + pos_embed + times_embed x = rearrange(x, "(b n) t d -> b n t d", b=B) delta, delta_se3F = self.updateformer(x, support_feat) support_feat = support_feat + delta_se3F[0] / 100 delta = rearrange(delta, " b n t d -> (b n) t d") d_coord = delta[:, :, :3] d_feats = delta[:, :, 3:] ffeats_xy = self.ffeat_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_xy.reshape(-1, self.latent_dim) ffeats_yz = self.ffeatyz_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_yz.reshape(-1, self.latent_dim) ffeats_xz = self.ffeatxz_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_xz.reshape(-1, self.latent_dim) ffeats[0] = ffeats_xy.reshape(B, N, S, self.latent_dim).permute( 0, 2, 1, 3 ) # B,S,N,C ffeats[1] = ffeats_yz.reshape(B, N, S, self.latent_dim).permute( 0, 2, 1, 3 ) # B,S,N,C ffeats[2] = ffeats_xz.reshape(B, N, S, self.latent_dim).permute( 0, 2, 1, 3 ) # B,S,N,C coords = coords + d_coord.reshape(B, N, S, 3).permute(0, 2, 1, 3) if torch.isnan(coords).any(): # import ipdb; # ipdb.set_trace() logging.error("nan in coords") coords_out = coords.clone() coords_out[..., :2] *= float(self.stride) coords_out[..., 2] = coords_out[..., 2] / self.Dz coords_out[..., 2] = coords_out[..., 2] * (self.d_far - self.d_near) + self.d_near coord_predictions.append(coords_out) ffeats_f = ffeats[0] + ffeats[1] + ffeats[2] vis_e = self.vis_predictor(ffeats_f.reshape(B * S * N, self.latent_dim)).reshape( B, S, N ) self.support_features = support_feat.detach() return coord_predictions, attn_predictions, vis_e, feat_init, Rot_ln def forward(self, rgbds, queries, iters=4, feat_init=None, is_train=False, intrs=None): self.support_features = torch.zeros(100, 384).to("cuda") + 0.1 self.is_train = is_train B, T, C, H, W = rgbds.shape # set the intrinsic or simply initialized if intrs is None: intrs = torch.from_numpy(np.array([[W, 0.0, W // 2], [0.0, W, H // 2], [0.0, 0.0, 1.0]])) intrs = intrs[None, None, ...].repeat(B, T, 1, 1).float().to(rgbds.device) self.intrs = intrs # prepare the input for tracking ( rgbds, first_positive_inds, first_positive_sorted_inds, sort_inds, inv_sort_inds, timestep_should_be_estimated_mask, gridxy, coords_init, vis_init, Traj_arap ) = self.prepare_track(rgbds.clone(), queries) coords_init_ = coords_init.clone() vis_init_ = vis_init[:, :, sort_inds].clone() depth_all = rgbds[:, :, 3, ...] d_near = self.d_near = depth_all[depth_all > 0.01].min().item() d_far = self.d_far = depth_all[depth_all > 0.01].max().item() B, N, __ = queries.shape self.Dz = Dz = self.triplane_zres w_idx_start = 0 p_idx_end = 0 p_idx_start = 0 fmaps_ = None vis_predictions = [] coord_predictions = [] attn_predictions = [] p_idx_end_list = [] Rigid_ln_total = 0 while w_idx_start < T - self.S // 2: curr_wind_points = torch.nonzero( first_positive_sorted_inds < w_idx_start + self.S) if curr_wind_points.shape[0] == 0: w_idx_start = w_idx_start + self.S // 2 logging.info(f"No points in window {w_idx_start}-{w_idx_start + self.S}; adding empty results to list") p_idx_end_list.append(torch.zeros((1,), dtype=torch.int64, device=first_positive_sorted_inds.device)) if is_train: vis_predictions.append(torch.zeros((B, self.S, 0), device=rgbds.device)) coord_predictions.append( [torch.zeros((B, self.S, 0, 3), device=rgbds.device) for _ in range(iters)]) attn_predictions.append([-1 for _ in range(iters)]) continue p_idx_end = curr_wind_points[-1] + 1 p_idx_end_list.append(p_idx_end) # the T may not be divided by self.S rgbds_seq = rgbds[:, w_idx_start:w_idx_start + self.S].clone() S = S_local = rgbds_seq.shape[1] if S < self.S: rgbds_seq = torch.cat( [rgbds_seq, rgbds_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)], dim=1, ) S = rgbds_seq.shape[1] rgbs_ = rgbds_seq.reshape(B * S, C, H, W)[:, :3] depths = rgbds_seq.reshape(B * S, C, H, W)[:, 3:].clone() # open the mask # Traj_arap[:, w_idx_start:w_idx_start + self.S, :p_idx_end, -1] = 0 # step1: normalize the depth map depths = (depths - d_near) / (d_far - d_near) depths_dn = nn.functional.interpolate( depths, scale_factor=1.0 / self.stride, mode="nearest") depths_dnG = depths_dn * Dz # step2: normalize the coordinate coords_init_[:, :, p_idx_start:p_idx_end, 2] = ( coords_init[:, :, p_idx_start:p_idx_end, 2] - d_near ) / (d_far - d_near) coords_init_[:, :, p_idx_start:p_idx_end, 2] *= Dz # efficient triplane splatting gridxyz = torch.cat([gridxy[None, ...].repeat( depths_dn.shape[0], 1, 1, 1), depths_dnG], dim=1) Fxy2yz = gridxyz[:, [1, 2], ...] - gridxyz[:, :2] Fxy2xz = gridxyz[:, [0, 2], ...] - gridxyz[:, :2] gridxyz_nm = gridxyz.clone() gridxyz_nm[:, 0, ...] = (gridxyz_nm[:, 0, ...] - gridxyz_nm[:, 0, ...].min()) / ( gridxyz_nm[:, 0, ...].max() - gridxyz_nm[:, 0, ...].min()) gridxyz_nm[:, 1, ...] = (gridxyz_nm[:, 1, ...] - gridxyz_nm[:, 1, ...].min()) / ( gridxyz_nm[:, 1, ...].max() - gridxyz_nm[:, 1, ...].min()) gridxyz_nm[:, 2, ...] = (gridxyz_nm[:, 2, ...] - gridxyz_nm[:, 2, ...].min()) / ( gridxyz_nm[:, 2, ...].max() - gridxyz_nm[:, 2, ...].min()) gridxyz_nm = 2 * (gridxyz_nm - 0.5) _, _, h4, w4 = gridxyz_nm.shape gridxyz_nm = gridxyz_nm.permute(0, 2, 3, 1).reshape(S * h4 * w4, 3) featPE = self.embed3d(gridxyz_nm).view(S, h4, w4, -1).permute(0, 3, 1, 2) if fmaps_ is None: fmaps_ = torch.cat([self.fnet(rgbs_), featPE], dim=1) fmaps_ = self.embedConv(fmaps_) else: fmaps_new = torch.cat([self.fnet(rgbs_[self.S // 2:]), featPE[self.S // 2:]], dim=1) fmaps_new = self.embedConv(fmaps_new) fmaps_ = torch.cat( [fmaps_[self.S // 2:], fmaps_new], dim=0 ) fmapXY = fmaps_[:, :self.latent_dim].reshape( B, S, self.latent_dim, H // self.stride, W // self.stride ) fmapYZ = softsplat(fmapXY[0], Fxy2yz, None, strMode="avg", tenoutH=self.Dz, tenoutW=H // self.stride) fmapXZ = softsplat(fmapXY[0], Fxy2xz, None, strMode="avg", tenoutH=self.Dz, tenoutW=W // self.stride) fmapYZ = self.headyz(fmapYZ)[None, ...] fmapXZ = self.headxz(fmapXZ)[None, ...] if p_idx_end - p_idx_start > 0: queried_t = (first_positive_sorted_inds[p_idx_start:p_idx_end] - w_idx_start) (featxy_init, featyz_init, featxz_init) = self.sample_trifeat( t=queried_t, featMapxy=fmapXY, featMapyz=fmapYZ, featMapxz=fmapXZ, coords=coords_init_[:, :1, p_idx_start:p_idx_end] ) # T, S, N, C, 3 feat_init_curr = torch.stack([featxy_init, featyz_init, featxz_init], dim=-1) feat_init = smart_cat(feat_init, feat_init_curr, dim=2) if p_idx_start > 0: # preprocess the coordinates of last windows last_coords = coords[-1][:, self.S // 2:].clone() last_coords[..., :2] /= float(self.stride) last_coords[..., 2:] = (last_coords[..., 2:] - d_near) / (d_far - d_near) last_coords[..., 2:] = last_coords[..., 2:] * Dz coords_init_[:, : self.S // 2, :p_idx_start] = last_coords coords_init_[:, self.S // 2:, :p_idx_start] = last_coords[ :, -1 ].repeat(1, self.S // 2, 1, 1) last_vis = vis[:, self.S // 2:].unsqueeze(-1) vis_init_[:, : self.S // 2, :p_idx_start] = last_vis vis_init_[:, self.S // 2:, :p_idx_start] = last_vis[:, -1].repeat( 1, self.S // 2, 1, 1 ) coords, attns, vis, __, Rigid_ln = self.forward_iteration( fmapXY=fmapXY, fmapYZ=fmapYZ, fmapXZ=fmapXZ, coords_init=coords_init_[:, :, :p_idx_end], feat_init=feat_init[:, :, :p_idx_end], vis_init=vis_init_[:, :, :p_idx_end], track_mask=timestep_should_be_estimated_mask[:, w_idx_start: w_idx_start + self.S, :p_idx_end], iters=iters, intrs_S=self.intrs[:, w_idx_start: w_idx_start + self.S], ) Rigid_ln_total += Rigid_ln if is_train: vis_predictions.append(vis[:, :S_local]) coord_predictions.append([coord[:, :S_local] for coord in coords]) attn_predictions.append(attns) self.traj_e[:, w_idx_start:w_idx_start + self.S, :p_idx_end] = coords[-1][:, :S_local] self.vis_e[:, w_idx_start:w_idx_start + self.S, :p_idx_end] = vis[:, :S_local] timestep_should_be_estimated_mask[:, : w_idx_start + self.S, :p_idx_end] = 0.0 w_idx_start = w_idx_start + self.S // 2 p_idx_start = p_idx_end self.traj_e = self.traj_e[:, :, inv_sort_inds] self.vis_e = self.vis_e[:, :, inv_sort_inds] self.vis_e = torch.sigmoid(self.vis_e) train_data = ( (vis_predictions, coord_predictions, attn_predictions, p_idx_end_list, sort_inds, Rigid_ln_total) ) if self.is_train: return self.traj_e, feat_init, self.vis_e, train_data else: return self.traj_e, feat_init, self.vis_e class SpaTrackerMultiViewAdapter(nn.Module): def __init__(self, **kwargs): super(SpaTrackerMultiViewAdapter, self).__init__() self.spatracker = SpaTracker(**kwargs) def forward( self, rgbs, depths, query_points, intrs, extrs, iters=4, feat_init=None, is_train=False, save_debug_logs=False, debug_logs_path="", query_points_view=None, **kwargs, ): batch_size, num_views, num_frames, _, height, width = rgbs.shape _, num_points, _ = query_points.shape depths = depths.clamp(max=36.0) assert rgbs.shape == (batch_size, num_views, num_frames, 3, height, width) assert depths.shape == (batch_size, num_views, num_frames, 1, height, width) assert query_points.shape == (batch_size, num_points, 4) assert intrs.shape == (batch_size, num_views, num_frames, 3, 3) assert extrs.shape == (batch_size, num_views, num_frames, 3, 4) if feat_init is not None: raise NotImplementedError("feat_init is not supported yet") # Project the queries to each view query_points_t = query_points[:, :, :1].long() query_points_xyz_worldspace = query_points[:, :, 1:] query_points_xy_pixelspace_per_view = query_points.new_zeros((batch_size, num_views, num_points, 2)) query_points_z_cameraspace_per_view = query_points.new_zeros((batch_size, num_views, num_points, 1)) for batch_idx in range(batch_size): for t in query_points_t[batch_idx].unique(): query_points_t_mask = query_points_t[batch_idx].squeeze(-1) == t point_3d_world = query_points_xyz_worldspace[batch_idx][query_points_t_mask] # World to camera space point_4d_world_homo = torch.cat( [point_3d_world, point_3d_world.new_ones(point_3d_world[..., :1].shape)], -1) point_3d_camera = torch.einsum('Aij,Bj->ABi', extrs[batch_idx, :, t, :, :], point_4d_world_homo[:, :]) # Camera to pixel space point_2d_pixel_homo = torch.einsum('Aij,ABj->ABi', intrs[batch_idx, :, t, :, :], point_3d_camera[:, :]) point_2d_pixel = point_2d_pixel_homo[..., :2] / point_2d_pixel_homo[..., 2:] query_points_xy_pixelspace_per_view[batch_idx, :, query_points_t_mask] = point_2d_pixel query_points_z_cameraspace_per_view[batch_idx, :, query_points_t_mask] = point_3d_camera[..., -1:] # Estimate occlusion mask in each view based on depth maps query_points_depth_in_view = query_points.new_zeros((batch_size, num_views, num_points, 1)) for batch_idx in range(batch_size): for view_idx in range(num_views): for t in query_points_t[batch_idx].unique(): query_points_t_mask = query_points_t[batch_idx].squeeze(-1) == t interpolated_depth = bilinear_sample2d( im=depths[batch_idx, view_idx, t][None], x=query_points_xy_pixelspace_per_view[batch_idx, view_idx, query_points_t_mask, 0][None], y=query_points_xy_pixelspace_per_view[batch_idx, view_idx, query_points_t_mask, 1][None], )[0].permute(1, 0).type(query_points.dtype) query_points_depth_in_view[batch_idx, view_idx, query_points_t_mask] = interpolated_depth query_points_depth_in_view_masked = query_points_depth_in_view.clone() query_points_outside_of_view_box = ( (query_points_xy_pixelspace_per_view[..., 0] < 0) | (query_points_xy_pixelspace_per_view[..., 0] >= width) | (query_points_xy_pixelspace_per_view[..., 1] < 0) | (query_points_xy_pixelspace_per_view[..., 1] >= height) | (query_points_z_cameraspace_per_view[..., 0] < 0) ) if query_points_outside_of_view_box.all(1).any(): warnings.warn(f"There are some query points that are outside of the frame of every view: " f"{query_points_xy_pixelspace_per_view[query_points_outside_of_view_box.all(1)[:, None, :].repeat(1, num_views, 1)].reshape(num_views, -1, 2).permute(1, 0, 2)}") query_points_depth_in_view_masked[query_points_outside_of_view_box] = -1e4 # query_points_occluded_by_depthmap = (query_points_depth_in_view * 1.1 < query_points_z_cameraspace_per_view) # query_points_depth_in_view_masked[query_points_occluded_by_depthmap] = -1e3 query_points_best_visibility_view = ( query_points_depth_in_view_masked - query_points_z_cameraspace_per_view).argmax(1) query_points_best_visibility_view = query_points_best_visibility_view.squeeze(-1) if query_points_view is not None: query_points_best_visibility_view = query_points_view logging.info(f"Using the provided query_points_view instead of the estimated one") assert batch_size == 1, "Batch size > 1 is not supported yet" batch_idx = 0 results = {} # Call the SpaTracker for each view traj_e_per_view = {} feat_init_per_view = {} vis_e_per_view = {} train_data_per_view = {} for view_idx in range(num_views): track_mask = query_points_best_visibility_view[batch_idx] == view_idx if track_mask.sum() == 0: continue view_query_points = torch.concat([ query_points_t[batch_idx, :, :][track_mask], query_points_xy_pixelspace_per_view[batch_idx, view_idx, :, :][track_mask], query_points_z_cameraspace_per_view[batch_idx, view_idx, :, :][track_mask], ], dim=-1) view_rgbds = torch.concat([rgbs[batch_idx, view_idx], depths[batch_idx, view_idx]], dim=1) view_intrs = intrs[batch_idx, view_idx] view_extrs = extrs[batch_idx, view_idx] output_tuple = self.spatracker( rgbds=view_rgbds[None], queries=view_query_points[None], intrs=view_intrs[None], iters=iters, feat_init=None, is_train=is_train, ) if is_train: view_traj_e, view_feat_init, view_vis_e, view_train_data = output_tuple else: view_traj_e, view_feat_init, view_vis_e = output_tuple # Project points to the world space intrs_inv = torch.inverse(view_intrs.float()) view_extrs_square = torch.eye(4).to(view_extrs.device)[None].repeat(num_frames, 1, 1) view_extrs_square[:, :3, :] = view_extrs extrs_inv = torch.inverse(view_extrs_square.float()) view_traj_e = pixel_xy_and_camera_z_to_world_space( pixel_xy=view_traj_e[0, ..., :-1].float(), camera_z=view_traj_e[0, ..., -1:].float(), intrs_inv=intrs_inv, extrs_inv=extrs_inv, )[None] if is_train: num_windows = len(view_train_data[1]) num_iterations = len(view_train_data[1][0]) coord_predictions = view_train_data[1] window_start_t = 0 while window_start_t < num_frames - self.spatracker.S // 2: window_idx = window_start_t // (self.spatracker.S // 2) for iteration_idx in range(num_iterations): coord_predictions[window_idx][iteration_idx] = pixel_xy_and_camera_z_to_world_space( pixel_xy=coord_predictions[window_idx][iteration_idx][0, ..., :-1].float(), camera_z=coord_predictions[window_idx][iteration_idx][0, ..., -1:].float(), intrs_inv=intrs_inv[window_start_t:window_start_t + self.spatracker.S], extrs_inv=extrs_inv[window_start_t:window_start_t + self.spatracker.S], )[None] window_start_t = window_start_t + (self.spatracker.S // 2) assert window_idx == num_windows - 1, "The last window should be the last one" assert view_train_data[1] == coord_predictions, "The view_train_data[1] should be updated in-place" # Set the trajectory to (0,0,0) for the timesteps before the query timestep for point_idx, t in enumerate(query_points_t[batch_idx, :, :].squeeze(-1)[track_mask]): view_traj_e[0, :t, point_idx, :] = 0.0 traj_e_per_view[view_idx] = view_traj_e feat_init_per_view[view_idx] = view_feat_init vis_e_per_view[view_idx] = view_vis_e if is_train: train_data_per_view[view_idx] = view_train_data # Merging the results from all views views_to_keep = list(traj_e_per_view.keys()) traj_e = torch.cat([traj_e_per_view[view_idx] for view_idx in views_to_keep], dim=2) vis_e = torch.cat([vis_e_per_view[view_idx] for view_idx in views_to_keep], dim=2) feat_init = torch.cat([feat_init_per_view[view_idx] for view_idx in views_to_keep], dim=2) # Sort the traj_e and vis_e based on the original indices, since concatenating the results from all views # will first put the results from the first view, then the results from the second view, and so on. # But we want to keep the trajectories order to match the original query points order. sort_inds = [] for view_idx in views_to_keep: track_mask = query_points_best_visibility_view[batch_idx] == view_idx if track_mask.sum() == 0: continue global_indices = torch.nonzero(track_mask).squeeze(-1) sort_inds += [global_indices] sort_inds = torch.cat(sort_inds, dim=0) inv_sort_inds = torch.argsort(sort_inds, dim=0) # Use the inv_sort_inds to sort the traj_e and vis_e traj_e = traj_e[:, :, inv_sort_inds] vis_e = vis_e[:, :, inv_sort_inds] feat_init = None # Not supported yet, correct sorting needs to be implemented # Delete the intermediate variables to avoid confusion with the later variables del sort_inds, inv_sort_inds # # Sanity check that the sorted traj_e have about similar values for the query points # # The forward pass is expected to tweak the values a bit, but they would probably stay close # pred_xyz_for_query = traj_e[0][query_points_t[batch_idx].squeeze(-1), torch.arange(num_points)] # pred_xyz_for_query = pred_xyz_for_query.type(query_points_xyz_worldspace.dtype) # assert torch.allclose(pred_xyz_for_query, query_points_xyz_worldspace[batch_idx], atol=1) # # But, an untrained model might not be able to predict the query points exactly # # Also check that the query points are visible # pred_visibility_for_query = vis_e[0][query_points_t[batch_idx].squeeze(-1), torch.arange(num_points)] # assert torch.all(pred_visibility_for_query > 0.5) # # But, for some points the model might predict the query points to be occluded if not is_train: if torch.isnan(traj_e).any(): warnings.warn( f"Found {torch.isnan(traj_e).sum()}/{traj_e.numel()} NaN values in traj_e. Setting them to 0.") traj_e[traj_e.isnan()] = 0 if torch.isnan(vis_e).any(): warnings.warn( f"Found {torch.isnan(vis_e).sum()}/{vis_e.numel()} NaN values in visibilities. Setting them to 1.") vis_e[vis_e.isnan()] = 1 # Save to results results["traj_e"] = traj_e results["feat_init"] = feat_init results["vis_e"] = vis_e # If training mode, we need to merge the results from all views. # Those merged results are used in the backward pass to compute the loss. # train_data is a tuple of (vis_pred, coord_pred, attn_pred, p_idx_end_list, sort_inds, Rigid_ln_total) if is_train: # SpaTracker is using sliding windows, and for each window, it is using multiple iterations. num_windows = len(train_data_per_view[views_to_keep[0]][0]) num_iterations = len(train_data_per_view[views_to_keep[0]][1][0]) sort_inds = [] vis_predictions = [[] for _ in range(num_windows)] coord_predictions = [[[] for _ in range(num_iterations)] for _ in range(num_windows)] for window_idx in range(num_windows): for view_idx in views_to_keep: # What points will be tracked in this view track_mask = query_points_best_visibility_view[batch_idx] == view_idx if track_mask.sum() == 0: # This view does not track any points at all continue # Get the indices of points that appeared in this window (from the points tracked in this view) try: start_idx = 0 if window_idx == 0 else train_data_per_view[view_idx][3][window_idx - 1].item() end_idx = train_data_per_view[view_idx][3][window_idx].item() if end_idx == 0: # No points from this view were tracked in this window continue except Exception as e: logging.error(f"Error: {e}") logging.error(f"view_idx: {view_idx}, window_idx: {window_idx}") logging.error(f"train_data_per_view[view_idx][3]: {train_data_per_view[view_idx][3]}") raise e # Convert the view-specific sorted indices to "global" indices # that say which trajectory/query the point originally belonged to indices_in_view = train_data_per_view[view_idx][4][start_idx:end_idx] global_indices = torch.nonzero(track_mask).squeeze(-1)[indices_in_view] # Sorted indices are saying how the original trajectories were reordered/sorted # in the return results. This is because in the forward passes, we want to group # the points that will appear in the same window together. The points that haven't # appeared in a window will not be used in the forward pass for that window. # For each new window, points can only be added, not removed, and they will be added # if they have just appeared in that window. Since we are merging the results from # all views, we will first take all the points that appeared in the first window from # all views, then all the points that appeared in the second window from all views, # and so on. This is why we do a for loop over the windows first, then over the views # and merge the indices in the next line: sort_inds.append(global_indices) # The indices are now sorted in the order that they will appear in the merged results. # This can be illustrated as follows: # Final sorted indices for the merged results: [ # view 1 new points from window 1 # view 2 new points from window 1 # view ... new points from window 1 # view 1 new points from window 2 # view 2 new points from window 2 # view ... new points from window 2 # ... # ] # This also means that the results from each view need to be carefully merged to match # the expected ordering/sorting. To illustrate this, the merged results for the vis_predictions # and coord_predictions will look like this: # Window 1 results: [ # view 1 new points from window 1 # view 2 new points from window 1 # view ... new points from window 1 # ] # Window 2 results: [ # view 1 new points from window 1 # view 2 new points from window 1 # view ... new points from window 1 # view 1 new points from window 2 # view 2 new points from window 2 # view ... new points from window 2 # ] # Window ... # Below we will merge the results from all views for each window as illustrated above for window_idx_inner in range(num_windows): vis_predictions[window_idx_inner].append( train_data_per_view[view_idx][0][window_idx_inner][:, :, start_idx:end_idx] ) for iteration_idx in range(num_iterations): coord_predictions[window_idx_inner][iteration_idx].append( train_data_per_view[view_idx][1][window_idx_inner][iteration_idx][ :, :, start_idx:end_idx, :] ) # Concatenate the merged results correctly sort_inds = torch.cat(sort_inds, dim=0) vis_predictions = [ torch.cat(vis_predictions[window_idx], dim=2) for window_idx in range(num_windows) ] coord_predictions = [ [ torch.cat(coord_predictions[window_idx][iteration_idx], dim=2) for iteration_idx in range(num_iterations) ] for window_idx in range(num_windows) ] # Compute the p_idx_end_list for each window, it is the sum of the number of points # that appeared in each view for that window as this is the way we have merged the results. p_idx_end_list = [ torch.stack([ train_data_per_view[view_idx][3][window_idx] for view_idx in views_to_keep ], dim=1).sum(dim=1) for window_idx in range(num_windows) ] # Compute the attn_predictions and Rigid_ln_total attn_predictions = None # Not supported yet Rigid_ln_total = None # Not supported yet # Sanity check that using the computed sort_inds gives the same results as the merged traj_e and vis_e traj_e_reproduced = traj_e.new_zeros(traj_e.shape) vis_e_reproduced = vis_e.new_zeros(vis_e.shape) window_start_t = 0 while window_start_t < num_frames - self.spatracker.S // 2: window_idx = window_start_t // (self.spatracker.S // 2) p_idx_end = p_idx_end_list[window_idx] if p_idx_end == 0: continue wind_coords = coord_predictions[window_idx][-1] wind_vis = vis_predictions[window_idx] traj_e_reproduced[:, window_start_t:window_start_t + self.spatracker.S, :p_idx_end] = wind_coords vis_e_reproduced[:, window_start_t:window_start_t + self.spatracker.S, :p_idx_end] = wind_vis window_start_t = window_start_t + (self.spatracker.S // 2) inv_sort_inds = torch.argsort(sort_inds, dim=0) traj_e_reproduced = traj_e_reproduced[:, :, inv_sort_inds] vis_e_reproduced = torch.sigmoid(vis_e_reproduced[:, :, inv_sort_inds]) # Set the trajectory to (0,0,0) for the timesteps before the query timestep for point_idx, t in enumerate(query_points_t[batch_idx, :, :].squeeze(-1)): traj_e_reproduced[0, :t, point_idx, :] = 0.0 assert torch.allclose(traj_e, traj_e_reproduced, atol=1e-3) assert torch.allclose(vis_e, vis_e_reproduced, atol=1e-3) # Save to results results["train_data"] = { "vis_predictions": vis_predictions, "coord_predictions": coord_predictions, "attn_predictions": attn_predictions, "p_idx_end_list": p_idx_end_list, "sort_inds": sort_inds, "Rigid_ln_total": Rigid_ln_total, } return results ================================================ FILE: mvtracker/models/core/spatracker/spatracker_multiview.py ================================================ import logging import os import warnings import cv2 import numpy as np import torch from einops import rearrange from matplotlib import pyplot as plt from torch import nn as nn from mvtracker.models.core.embeddings import Embedder_Fourier, get_3d_sincos_pos_embed_from_grid, \ get_1d_sincos_pos_embed_from_grid, get_3d_embedding from mvtracker.models.core.model_utils import sample_features5d, smart_cat from mvtracker.models.core.spatracker.blocks import BasicEncoder, EUpdateFormer, CorrBlock from mvtracker.models.core.spatracker.softsplat import softsplat from mvtracker.models.core.spatracker.spatracker_monocular import sample_pos_embed from mvtracker.utils.basic import to_homogeneous, from_homogeneous, time_now class MultiViewSpaTracker(nn.Module): """ Multi-view Spatial Tracker: A 3D Multi-View Tracker with Transformer-based Iterative Flow Updates. This version computes local correlation in a global triplane space that is aligned with the world coordinate planes. However, this leaves most of the triplane space empty since it is difficult to create one plane that covers all the relevant areas of interest. """ def __init__( self, sliding_window_len=8, stride=8, add_space_attn=True, use_3d_pos_embed=True, remove_zeromlpflow=True, concat_triplane_features=True, num_heads=8, hidden_size=384, space_depth=12, time_depth=12, fmaps_dim=128, triplane_xres=128, triplane_yres=128, triplane_zres=128, ): super(MultiViewSpaTracker, self).__init__() self.S = sliding_window_len self.stride = stride self.hidden_dim = 256 self.latent_dim = fmaps_dim self.flow_embed_dim = 64 self.b_latent_dim = self.latent_dim // 3 self.corr_levels = 4 self.corr_radius = 3 self.add_space_attn = add_space_attn self.use_3d_pos_embed = use_3d_pos_embed self.remove_zeromlpflow = remove_zeromlpflow self.concat_triplane_features = concat_triplane_features self.updateformer_input_dim = ( # The positional encoding of the 3D flow from t=i to t=0 + (self.flow_embed_dim + 1) * (3 if self.remove_zeromlpflow else 2) # The correlation features (LRR) for the three planes (xy, yz, xz), concatenated + 196 * (3 if self.concat_triplane_features else 1) # The features of the tracked points, one for each of the three planes + self.latent_dim * (3 if self.concat_triplane_features else 1) # The visibility mask + 1 # The whether-the-point-is-tracked mask + 1 ) self.triplane_xres = triplane_xres self.triplane_yres = triplane_yres self.triplane_zres = triplane_zres # Feature encoder self.fnet = BasicEncoder( input_dim=3, output_dim=self.latent_dim, norm_fn="instance", dropout=0, stride=stride, Embed3D=False, ) # Convolutional heads for the tri-plane features self.headxy = nn.Sequential( nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1), ) self.headyz = nn.Sequential( nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1), ) self.headxz = nn.Sequential( nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1), ) # Transformer for the iterative flow updates self.support_features = torch.zeros(100, 384).to("cuda") + 0.1 self.updateformer = EUpdateFormer( space_depth=space_depth, time_depth=time_depth, input_dim=self.updateformer_input_dim, hidden_size=hidden_size, num_heads=num_heads, output_dim=3 + self.latent_dim * 3, mlp_ratio=4.0, add_space_attn=add_space_attn, flash=True, ) # Updater of the features of the tracked points self.norm_xy = nn.GroupNorm(1, self.latent_dim) self.norm_yz = nn.GroupNorm(1, self.latent_dim) self.norm_xz = nn.GroupNorm(1, self.latent_dim) self.ffeatxy_updater = nn.Sequential( nn.Linear(self.latent_dim, self.latent_dim), nn.GELU(), ) self.ffeatyz_updater = nn.Sequential( nn.Linear(self.latent_dim, self.latent_dim), nn.GELU(), ) self.ffeatxz_updater = nn.Sequential( nn.Linear(self.latent_dim, self.latent_dim), nn.GELU(), ) # Embedders self.embed_traj = Embedder_Fourier(input_dim=5, max_freq_log2=5.0, N_freqs=3, include_input=True) self.embed3d = Embedder_Fourier(input_dim=3, max_freq_log2=10.0, N_freqs=10, include_input=True) self.embedConv = nn.Conv2d(self.latent_dim + 63, self.latent_dim, 3, padding=1) # Predictor of the visibility of the tracked points self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim * (3 if self.concat_triplane_features else 1), 1)) self.zeroMLPflow = nn.Linear(195, 130) def sample_trifeat(self, t, coords, featMapxy, featMapyz, featMapxz): """ Sample the features from the 5D triplane feature map 3*(B S C H W) Args: t: the time index coords: the coordinates of the points B S N 3 featMapxy: the feature map B S C Hx Wy featMapyz: the feature map B S C Hy Wz featMapxz: the feature map B S C Hx Wz """ # get xy_t yz_t xz_t queried_t = t.reshape(1, 1, -1, 1) xy_t = torch.cat( [queried_t, coords[..., [0, 1]]], dim=-1 ) yz_t = torch.cat( [queried_t, coords[..., [1, 2]]], dim=-1 ) xz_t = torch.cat( [queried_t, coords[..., [0, 2]]], dim=-1 ) featxy_init = sample_features5d(featMapxy, xy_t) featyz_init = sample_features5d(featMapyz, yz_t) featxz_init = sample_features5d(featMapxz, xz_t) featxy_init = featxy_init.repeat(1, self.S, 1, 1) featyz_init = featyz_init.repeat(1, self.S, 1, 1) featxz_init = featxz_init.repeat(1, self.S, 1, 1) return featxy_init, featyz_init, featxz_init def forward_iteration( self, fmapXY, fmapYZ, fmapXZ, coords_init, vis_init, track_mask, iters=4, feat_init=None, ): N = coords_init.shape[2] B, S, fmap_dim, triplane_H, triplane_W = fmapXY.shape triplane_D = fmapXZ.shape[-2] device = fmapXY.device if coords_init.shape[1] < S: coords = torch.cat([coords_init, coords_init[:, -1].repeat(1, S - coords_init.shape[1], 1, 1)], dim=1) vis_init = torch.cat([vis_init, vis_init[:, -1].repeat(1, S - coords_init.shape[1], 1, 1)], dim=1) else: coords = coords_init.clone() assert B == 1 assert fmapXY.shape == (B, S, fmap_dim, triplane_H, triplane_W) assert fmapYZ.shape == (B, S, fmap_dim, triplane_D, triplane_H) assert fmapXZ.shape == (B, S, fmap_dim, triplane_D, triplane_W) assert coords.shape == (B, S, N, 3) assert vis_init.shape == (B, S, N, 1) assert track_mask.shape == (B, S, N, 1) assert feat_init is None or feat_init.shape == (B, S, N, self.latent_dim, 3) fcorr_fnXY = CorrBlock(fmapXY, num_levels=self.corr_levels, radius=self.corr_radius) fcorr_fnYZ = CorrBlock(fmapYZ, num_levels=self.corr_levels, radius=self.corr_radius) fcorr_fnXZ = CorrBlock(fmapXZ, num_levels=self.corr_levels, radius=self.corr_radius) ffeats = torch.split(feat_init.clone(), dim=-1, split_size_or_sections=1) ffeats = [f.squeeze(-1) for f in ffeats] grid_size = coords.new_tensor([triplane_H, triplane_W, triplane_D]) # @Single-view-difference: # Instead of computing 2D positional embeddings in the XY plane of the single-view triplane # (which is aligned with the monocular view used in the single-view SpatialTracker), I will # compute 3D positional embeddings in the 3D grid of the triplane. This could allow the model # to more easily learn the 3D spatial relationships between the points in the triplane. # pos_embed = sample_pos_embed( # grid_size=(H8, W8), # embed_dim=456, # coords=coords[..., :2], # ) embed_dim = self.updateformer_input_dim if self.use_3d_pos_embed: # Ours if embed_dim % 3 != 0: # Make sure that the embed_dim is divisible by 3 embed_dim += 3 - (embed_dim % 3) pos_embed = get_3d_sincos_pos_embed_from_grid( embed_dim=embed_dim, # Normalize the coordinates so that the grid ranges over [-128,128] grid=((coords[:, :1, ...] / grid_size) * 2 - 1) * 128, ).float()[:, 0, ...].permute(0, 2, 1) else: # Original if embed_dim % 4 != 0: # Make sure that the embed_dim is divisible by 4 embed_dim += 4 - (embed_dim % 4) pos_embed = sample_pos_embed( grid_size=(triplane_H, triplane_W), embed_dim=embed_dim, coords=coords[..., :2], ) if embed_dim > self.updateformer_input_dim: # If the embed_dim was increased for divisibility, then remove the extra dimensions pos_embed = pos_embed[:, :self.updateformer_input_dim, :] pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1) times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1) embed_dim = self.updateformer_input_dim if embed_dim % 2 != 0: # Make sure that the embed_dim is divisible by 2 embed_dim += 2 - (embed_dim % 2) times_embed = ( torch.from_numpy(get_1d_sincos_pos_embed_from_grid(embed_dim, times_[0]))[None] .repeat(B, 1, 1) .float() .to(device) ) if embed_dim > self.updateformer_input_dim: # If the embed_dim was increased to be divisible by 2, then remove the extra dimensions times_embed = times_embed[:, :, :self.updateformer_input_dim] coord_predictions = [] support_feat = self.support_features for _ in range(iters): coords = coords.detach() fcorrsXY = fcorr_fnXY.corr_sample(ffeats[0], coords[..., :2]) fcorrsYZ = fcorr_fnYZ.corr_sample(ffeats[1], coords[..., [1, 2]]) fcorrsXZ = fcorr_fnXZ.corr_sample(ffeats[2], coords[..., [0, 2]]) # @Single-view-difference: # Instead of summing the correlations for different planes, I will concatenate them so that the model # can learn to differentiate between the correlations of different planes. Summing the correlations up # can make it very difficult for the model to differentiate between the correlations of different # planes unless, e.g., it learns to create the feature maps in a way that they are orthogonal # to each other. But rather than relying on the model to learn this, I believe that it is better # to provide the model with the information that the correlations are from different planes explicitly. # Note that this change will increase the dimension of the correlation features that are given to the # transformer: 196 * 3 = 588, instead of 196. # fcorrs = fcorrsXY + fcorrsYZ + fcorrsXZ if self.concat_triplane_features: # Ours fcorrs = torch.cat([fcorrsXY, fcorrsYZ, fcorrsXZ], dim=-1) else: # Original fcorrs = fcorrsXY + fcorrsYZ + fcorrsXZ LRR = fcorrs.shape[3] fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR) flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 3) flows_cat = get_3d_embedding(flows_, self.flow_embed_dim, cat_coords=True) # @Single-view-difference: # I have removed the zeroMLPflow linear layer which was added to project the flow embedding # from a 195-dimensional vector to a 130-dimensional to have a cleaner architecture. # I believe that the authors have added this layer just to match the 130 that the original # CoTracker implementation had used, but this can introduce confusion in the architecture's design. # flows_cat = self.zeroMLPflow(flows_cat) if self.remove_zeromlpflow: # Ours pass else: # Original flows_cat = self.zeroMLPflow(flows_cat) ffeats_xy = ffeats[0].permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) ffeats_yz = ffeats[1].permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) ffeats_xz = ffeats[2].permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) # @Single-view-difference: # Instead of summing the features for different planes, I will concatenate them so that the model # can learn to differentiate between the features of different planes. Summing the features up # can make it very difficult for the model to differentiate between the features of different # planes. I believe that it is better to provide the model with the information that the features # are from different planes explicitly. Note that this change will increase the dimension of the # feature embeddings that are given to the transformer: 128 * 3 = 384, instead of 128. # ffeats_ = ffeats_xy + ffeats_yz + ffeats_xz if self.concat_triplane_features: # Ours ffeats_ = torch.cat([ffeats_xy, ffeats_yz, ffeats_xz], dim=-1) else: # Original ffeats_ = ffeats_xy + ffeats_yz + ffeats_xz if track_mask.shape[1] < vis_init.shape[1]: track_mask = torch.cat([ track_mask, torch.zeros_like(track_mask[:, 0]).repeat(1, vis_init.shape[1] - track_mask.shape[1], 1, 1), ], dim=1) track_mask_and_vis = torch.cat([track_mask, vis_init], dim=2).permute(0, 2, 1, 3).reshape(B * N, S, 2) transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, track_mask_and_vis], dim=2) assert transformer_input.shape[-1] == pos_embed.shape[-1] x = transformer_input + pos_embed + times_embed x = rearrange(x, "(b n) t d -> b n t d", b=B) delta, delta_se3F = self.updateformer(x, support_feat) support_feat = support_feat + delta_se3F[0] / 100 delta = rearrange(delta, " b n t d -> (b n) t d") d_coord = delta[:, :, :3] d_feats_xy = delta[:, :, 3:self.latent_dim + 3] d_feats_yz = delta[:, :, self.latent_dim + 3:self.latent_dim * 2 + 3] d_feats_xz = delta[:, :, self.latent_dim * 2 + 3:] d_feats_xy_norm = self.norm_xy(d_feats_xy.view(-1, self.latent_dim)) d_feats_yz_norm = self.norm_yz(d_feats_yz.view(-1, self.latent_dim)) d_feats_xz_norm = self.norm_xz(d_feats_xz.view(-1, self.latent_dim)) ffeats_xy = ffeats_xy.reshape(-1, self.latent_dim) + self.ffeatxy_updater(d_feats_xy_norm) ffeats_yz = ffeats_yz.reshape(-1, self.latent_dim) + self.ffeatyz_updater(d_feats_yz_norm) ffeats_xz = ffeats_xz.reshape(-1, self.latent_dim) + self.ffeatxz_updater(d_feats_xz_norm) ffeats[0] = ffeats_xy.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) ffeats[1] = ffeats_yz.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) ffeats[2] = ffeats_xz.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) coords = coords + d_coord.reshape(B, N, S, 3).permute(0, 2, 1, 3) if torch.isnan(coords).any(): logging.error("Got NaN values in coords, perhaps the training exploded") import ipdb; ipdb.set_trace() coord_predictions.append(coords.clone()) # @Single-view-difference: # Instead of summing the features for different planes, # I will concatenate before inputting them to the shallow visibility predictor. # ffeats_f = ffeats[0] + ffeats[1] + ffeats[2] if self.concat_triplane_features: ffeats_f = torch.cat(ffeats, dim=-1) vis_e = self.vis_predictor(ffeats_f.reshape(B * S * N, self.latent_dim * 3)).reshape(B, S, N) else: ffeats_f = ffeats[0] + ffeats[1] + ffeats[2] vis_e = self.vis_predictor(ffeats_f.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) self.support_features = support_feat.detach() return coord_predictions, vis_e, feat_init def forward( self, rgbs, depths, query_points, intrs, extrs, iters=4, feat_init=None, is_train=False, save_debug_logs=False, debug_logs_path="", **kwargs, ): batch_size, num_views, num_frames, _, height, width = rgbs.shape _, num_points, _ = query_points.shape assert rgbs.shape == (batch_size, num_views, num_frames, 3, height, width) assert depths.shape == (batch_size, num_views, num_frames, 1, height, width) assert query_points.shape == (batch_size, num_points, 4) assert intrs.shape == (batch_size, num_views, num_frames, 3, 3) assert extrs.shape == (batch_size, num_views, num_frames, 3, 4) if feat_init is not None: raise NotImplementedError("feat_init is not supported yet") if save_debug_logs: os.makedirs(debug_logs_path, exist_ok=True) if kwargs: warnings.warn(f"Received unexpected kwargs: {kwargs.keys()}") self.support_features = torch.zeros(100, 384).to("cuda") + 0.1 self.is_train = is_train # Unpack the query points query_points_t = query_points[:, :, :1].long() query_points_xyz_worldspace = query_points[:, :, 1:] # Invert intrinsics and extrinsics intrs_inv = torch.inverse(intrs.float()) extrs_square = torch.eye(4).to(extrs.device)[None].repeat(batch_size, num_views, num_frames, 1, 1) extrs_square[:, :, :, :3, :] = extrs extrs_inv = torch.inverse(extrs_square.float()) # Interpolate the rgbs and depthmaps to the stride of the SpaTracker strided_height = height // self.stride strided_width = width // self.stride strided_depths = nn.functional.interpolate( input=depths.reshape(-1, 1, height, width), scale_factor=1.0 / self.stride, mode="nearest", ).reshape(batch_size, num_views, num_frames, 1, strided_height, strided_width) strided_rgbs = nn.functional.interpolate( input=rgbs.reshape(-1, 3, height, width), scale_factor=1.0 / self.stride, mode="bilinear", ).reshape(batch_size, num_views, num_frames, 3, strided_height, strided_width) # Un-project strided depthmaps back to world coordinates pixel_xy = torch.stack(torch.meshgrid( (torch.arange(0, height / self.stride) + 0.5) * self.stride - 0.5, (torch.arange(0, width / self.stride) + 0.5) * self.stride - 0.5, indexing="ij", )[::-1], dim=-1) pixel_xy = pixel_xy.to(device=rgbs.device, dtype=rgbs.dtype) pixel_xy_homo = to_homogeneous(pixel_xy) depthmap_camera_xyz = torch.einsum('BVTij,HWj->BVTHWi', intrs_inv, pixel_xy_homo) depthmap_camera_xyz = depthmap_camera_xyz * strided_depths[..., 0, :, :, None] depthmap_camera_xyz_homo = to_homogeneous(depthmap_camera_xyz) depthmap_world_xyz_homo = torch.einsum('BVTij,BVTHWj->BVTHWi', extrs_inv, depthmap_camera_xyz_homo) depthmap_world_xyz = from_homogeneous(depthmap_world_xyz_homo) if save_debug_logs: t = 0 n_skip = 4 xyz = depthmap_world_xyz[0, :, t, ::n_skip, ::n_skip, :].reshape(-1, 3).cpu().numpy() c = strided_rgbs.permute(0, 1, 2, 4, 5, 3)[0, :, t, ::n_skip, ::n_skip].reshape(-1, 3).cpu().numpy() / 255 filename = time_now() + "__rgbd_with_queries" qp = query_points_xyz_worldspace[0].cpu().numpy() qc = np.array([[1, 0, 0]] * query_points_xyz_worldspace.shape[1]) self._plot_pointcloud(debug_logs_path, filename, xyz, c, qp, qc, show=False) # Put the three planes along the YX, ZX, and ZY axes # TODO: Hardcode the xyz ranges for the triplanes, # as taking the whole range would make the # central object of interest very tiny and # the grid would be wasted in representing # wast background. x_range = [-14, 14] y_range = [-14, 14] z_range = [-1, 10] query_points_outside_of_triplane_range = ( (query_points_xyz_worldspace[..., 0].flatten() < x_range[0]) | (query_points_xyz_worldspace[..., 0].flatten() > x_range[1]) | (query_points_xyz_worldspace[..., 1].flatten() < y_range[0]) | (query_points_xyz_worldspace[..., 1].flatten() > y_range[1]) | (query_points_xyz_worldspace[..., 2].flatten() < z_range[0]) | (query_points_xyz_worldspace[..., 2].flatten() > z_range[1]) ) if query_points_outside_of_triplane_range.any(): warnings.warn(f"Some Query points are outside of the triplane range. " f"x_range={x_range}, y_range={y_range}, z_range={z_range}. " f"query_points_xyz_worldspace={query_points_xyz_worldspace[:, query_points_outside_of_triplane_range]}") kwargs = {"device": depthmap_world_xyz.device, "dtype": depthmap_world_xyz.dtype} triplane_xyz_min = torch.tensor([x_range[0], y_range[0], z_range[0]], **kwargs) triplane_xyz_max = torch.tensor([x_range[1], y_range[1], z_range[1]], **kwargs) triplane_grid_dims = torch.tensor([self.triplane_xres, self.triplane_yres, self.triplane_zres], **kwargs) if save_debug_logs: t = 0 n_skip = 1 xyz = depthmap_world_xyz[0, :, t, ::n_skip, ::n_skip, :].reshape(-1, 3).cpu().numpy() c = strided_rgbs.permute(0, 1, 2, 4, 5, 3)[0, :, t, ::n_skip, ::n_skip, :].reshape(-1, 3).cpu().numpy() / 255 mask = ( (xyz[:, 0] >= x_range[0]) & (xyz[:, 0] <= x_range[1]) & (xyz[:, 1] >= y_range[0]) & (xyz[:, 1] <= y_range[1]) & (xyz[:, 2] >= z_range[0]) & (xyz[:, 2] <= z_range[1]) ) xyz_in_range = xyz[mask] c_in_range = c[mask] qp = query_points_xyz_worldspace[0].cpu().numpy() qc = np.array([[1, 0, 0]] * query_points_xyz_worldspace.shape[1]) mask = ( (qp[:, 0] >= x_range[0]) & (qp[:, 0] <= x_range[1]) & (qp[:, 1] >= y_range[0]) & (qp[:, 1] <= y_range[1]) & (qp[:, 2] >= z_range[0]) & (qp[:, 2] <= z_range[1]) ) qp_in_range = qp[mask] qc_in_range = qc[mask] filename = time_now() + "__rgbd_with_queries_within_triplane_range" self._plot_pointcloud(debug_logs_path, filename, xyz_in_range, c_in_range, qp_in_range, qc_in_range, show=False) # Pre-compute the per-view feature maps rgbs_normalized = 2 * (rgbs / 255.0) - 1.0 fnet_fmaps = self.fnet(rgbs_normalized.reshape(-1, 3, height, width)) fnet_fmaps = fnet_fmaps.reshape( batch_size, num_views, num_frames, self.latent_dim, strided_height, strided_width, ) # Add Positional 3D Embeddings/Encodings def world_to_triplane(points, inverse=False): assert points.shape[-1] == 3 if inverse: return points * (triplane_xyz_max - triplane_xyz_min) / (triplane_grid_dims - 1) + triplane_xyz_min else: return (points - triplane_xyz_min) / (triplane_xyz_max - triplane_xyz_min) * (triplane_grid_dims - 1) depthmap_world_xyz_normalized = (depthmap_world_xyz - triplane_xyz_min) / (triplane_xyz_max - triplane_xyz_min) positional_encoding_3d = self.embed3d(2 * depthmap_world_xyz_normalized.reshape(-1, 3) - 1) positional_encoding_3d = ( positional_encoding_3d .reshape(batch_size, num_views, num_frames, strided_height, strided_width, -1) .permute(0, 1, 2, 5, 3, 4) # HWC --> CHW ) fmaps = torch.cat([fnet_fmaps, positional_encoding_3d], dim=-3) fmaps = fmaps.reshape(-1, self.latent_dim + self.embed3d.out_dim, strided_height, strided_width) fmaps = self.embedConv(fmaps) fmaps = fmaps.reshape(batch_size, num_views, num_frames, self.latent_dim, strided_height, strided_width) # Compute the flows from each depthmap to the triplane # The flows are needed to splat the features from the depthmap to the triplane # The flow defines how one 2D plane is transformed to another 2D plane # In our case, the first plane will be of ... TODO describe the planes more depthmap_world_xyz_normalized_to_triplane_grid = depthmap_world_xyz_normalized * (triplane_grid_dims - 1) depthmap_world_xyz_reproduced = world_to_triplane( points=depthmap_world_xyz_normalized_to_triplane_grid, inverse=True, ) if not depthmap_world_xyz_reproduced.allclose(depthmap_world_xyz, atol=0.72): logging.info("depthmap_world_xyz_reproduced", depthmap_world_xyz_reproduced) logging.info("depthmap_world_xyz", depthmap_world_xyz) warnings.warn(f"Applying the inverse of world_to_triplane did not reproduce depthmap_world_xyz... " f"The maximum difference is {torch.max(torch.abs(depthmap_world_xyz_reproduced - depthmap_world_xyz))}") flow_pointcloud_to_xy = depthmap_world_xyz_normalized_to_triplane_grid[..., [0, 1]] flow_pointcloud_to_yz = depthmap_world_xyz_normalized_to_triplane_grid[..., [1, 2]] flow_pointcloud_to_xz = depthmap_world_xyz_normalized_to_triplane_grid[..., [0, 2]] flow_pointcloud_to_xy = ( flow_pointcloud_to_xy .permute(0, 2, 5, 3, 1, 4) .reshape(batch_size * num_frames, 2, strided_height, num_views * strided_width) ) flow_pointcloud_to_yz = ( flow_pointcloud_to_yz .permute(0, 2, 5, 3, 1, 4) .reshape(batch_size * num_frames, 2, strided_height, num_views * strided_width) ) flow_pointcloud_to_xz = ( flow_pointcloud_to_xz .permute(0, 2, 5, 3, 1, 4) .reshape(batch_size * num_frames, 2, strided_height, num_views * strided_width) ) # Compute the triplane features by splatting the per-view features following the flows def splat_fmaps(fmaps, flow_xy, flow_yz, flow_xz, out_shape): dtype = fmaps.dtype fmaps = fmaps.float() flow_xy = flow_xy.float() flow_yz = flow_yz.float() flow_xz = flow_xz.float() fmap_xy, fmap_xy_norm = softsplat( tenIn=fmaps, tenFlow=flow_xy, tenMetric=None, strMode="avg", tenoutH=out_shape[1], tenoutW=out_shape[0], use_pointcloud_splatting=True, return_normalization_tensor=True, ) fmap_yz, fmap_yz_norm = softsplat( tenIn=fmaps, tenFlow=flow_yz, tenMetric=None, strMode="avg", tenoutH=out_shape[2], tenoutW=out_shape[1], use_pointcloud_splatting=True, return_normalization_tensor=True, ) fmap_xz, fmap_xz_norm = softsplat( tenIn=fmaps, tenFlow=flow_xz, tenMetric=None, strMode="avg", tenoutH=out_shape[2], tenoutW=out_shape[0], use_pointcloud_splatting=True, return_normalization_tensor=True, ) if dtype != fmaps.dtype: fmap_xy = fmap_xy.to(dtype) fmap_yz = fmap_yz.to(dtype) fmap_xz = fmap_xz.to(dtype) fmap_xy_norm = fmap_xy_norm.to(dtype) fmap_yz_norm = fmap_yz_norm.to(dtype) fmap_xz_norm = fmap_xz_norm.to(dtype) return fmap_xy, fmap_yz, fmap_xz, fmap_xy_norm, fmap_yz_norm, fmap_xz_norm fmaps = ( fmaps .permute(0, 2, 3, 4, 1, 5) .reshape(batch_size * num_frames, self.latent_dim, strided_height, num_views * strided_width) ) fmap_xy, fmap_yz, fmap_xz, fmap_xy_norm, fmap_yz_norm, fmap_xz_norm = splat_fmaps( fmaps=fmaps, flow_xy=flow_pointcloud_to_xy, flow_yz=flow_pointcloud_to_yz, flow_xz=flow_pointcloud_to_xz, out_shape=(self.triplane_xres, self.triplane_yres, self.triplane_zres), ) if save_debug_logs and (self.triplane_xres == self.triplane_yres == self.triplane_zres): # Visualize how the splatting would look like if the strided_rgbs would be directly splatted instead of feature maps rgbs_fmaps = ( strided_rgbs .permute(0, 2, 3, 4, 1, 5) .reshape(batch_size * num_frames, 3, strided_height, num_views * strided_width) ) rgbs_fmap_xy, rgbs_fmap_yz, rgbs_fmap_xz, rgbs_fmap_xy_norm, rgbs_fmap_yz_norm, rgbs_fmap_xz_norm = splat_fmaps( fmaps=rgbs_fmaps, flow_xy=flow_pointcloud_to_xy, flow_yz=flow_pointcloud_to_yz, flow_xz=flow_pointcloud_to_xz, out_shape=(self.triplane_xres, self.triplane_yres, self.triplane_zres), ) rgbs_fmap_xy_yz_xz_concat = torch.concat([rgbs_fmap_xy, rgbs_fmap_yz, rgbs_fmap_xz], -1) rgbs_fmap_norm_xy_yz_xz_concat = torch.concat([rgbs_fmap_xy_norm, rgbs_fmap_yz_norm, rgbs_fmap_xz_norm], -1) self._plot_featuremaps( logs_path=debug_logs_path, filename=time_now() + "__splatted_rgbs", fmaps_before_splatting=rgbs_fmaps, splatted_fmaps=rgbs_fmap_xy_yz_xz_concat, splat_normalization=rgbs_fmap_norm_xy_yz_xz_concat, chosen_channels=(0, 1, 2), ) if save_debug_logs and (self.triplane_xres == self.triplane_yres == self.triplane_zres): # Also splat only the first view RGBs to see how the splatting would look like rgbs_fmaps = strided_rgbs[0, 0] rgbs_fmap_xy, rgbs_fmap_yz, rgbs_fmap_xz, rgbs_fmap_xy_norm, rgbs_fmap_yz_norm, rgbs_fmap_xz_norm = splat_fmaps( fmaps=rgbs_fmaps, flow_xy=flow_pointcloud_to_xy[:, :, :, :strided_width], flow_yz=flow_pointcloud_to_yz[:, :, :, :strided_width], flow_xz=flow_pointcloud_to_xz[:, :, :, :strided_width], out_shape=(self.triplane_xres, self.triplane_yres, self.triplane_zres), ) rgbs_fmap_xy_yz_xz_concat = torch.concat([rgbs_fmap_xy, rgbs_fmap_yz, rgbs_fmap_xz], -1) rgbs_fmap_norm_xy_yz_xz_concat = torch.concat([rgbs_fmap_xy_norm, rgbs_fmap_yz_norm, rgbs_fmap_xz_norm], -1) self._plot_featuremaps( logs_path=debug_logs_path, filename=time_now() + "__splatted_rgbs_first_view_only", fmaps_before_splatting=rgbs_fmaps, splatted_fmaps=rgbs_fmap_xy_yz_xz_concat, splat_normalization=rgbs_fmap_norm_xy_yz_xz_concat, chosen_channels=(0, 1, 2), ) xyz = to_homogeneous( flow_pointcloud_to_xy[0, :, :, :strided_width].permute(1, 2, 0).reshape(-1, 2)).cpu().numpy() c = strided_rgbs[0, 0, 0, :, :].permute(1, 2, 0).reshape(-1, 3).cpu().numpy() / 255 self._plot_pointcloud(debug_logs_path, time_now() + "__flow_xy_debug", xyz, c, show=False) if save_debug_logs and (self.triplane_xres == self.triplane_yres == self.triplane_zres): if not (self.triplane_xres == self.triplane_yres == self.triplane_zres): raise NotImplementedError("Current implementation assumed these, otherwise needs some padding/interp.") fmap_xy_yz_xz_concat = torch.concat([fmap_xy, fmap_yz, fmap_xz], dim=-1) fmap_norm_xy_yz_xz_concat = torch.concat([fmap_xy_norm, fmap_yz_norm, fmap_xz_norm], dim=-1) self._plot_featuremaps( logs_path=debug_logs_path, filename=time_now() + "__fmaps", fmaps_before_splatting=fmaps, splatted_fmaps=fmap_xy_yz_xz_concat, splat_normalization=fmap_norm_xy_yz_xz_concat, chosen_channels=(0, 1, 2), ) fmap_xy = self.headxy(fmap_xy) fmap_yz = self.headyz(fmap_yz) fmap_xz = self.headxz(fmap_xz) fmap_xy = fmap_xy.reshape(batch_size, num_frames, self.latent_dim, self.triplane_yres, self.triplane_xres) fmap_yz = fmap_yz.reshape(batch_size, num_frames, self.latent_dim, self.triplane_zres, self.triplane_yres) fmap_xz = fmap_xz.reshape(batch_size, num_frames, self.latent_dim, self.triplane_zres, self.triplane_xres) if save_debug_logs and (self.triplane_xres == self.triplane_yres == self.triplane_zres): if not (self.triplane_xres == self.triplane_yres == self.triplane_zres): raise NotImplementedError("Current implementation assumed these, otherwise needs some padding/interp.") fmap_xy_yz_xz_concat = torch.concat([fmap_xy[0], fmap_yz[0], fmap_xz[0]], dim=-1) fmap_norm_xy_yz_xz_concat = torch.concat([fmap_xy_norm, fmap_yz_norm, fmap_xz_norm], dim=-1) self._plot_featuremaps( logs_path=debug_logs_path, filename=time_now() + "__fmaps_after_head", fmaps_before_splatting=fmaps, splatted_fmaps=fmap_xy_yz_xz_concat, splat_normalization=fmap_norm_xy_yz_xz_concat, chosen_channels=(-3, -2, -1), ) # Filter the points that never appear during 1 - T assert batch_size == 1, "Batch size > 1 is not supported yet" query_points_t = query_points_t.squeeze(0).squeeze(-1) # BN1 --> N ind_array = torch.arange(num_frames, device=query_points.device) ind_array = ind_array[None, :, None].repeat(batch_size, 1, num_points) track_mask = (ind_array >= query_points_t[None, None, :]).unsqueeze(-1) # Prepare the initial coordinates and visibility coords_init = query_points_xyz_worldspace.unsqueeze(1).repeat(1, self.S, 1, 1) coords_init = world_to_triplane(coords_init) vis_init = query_points.new_ones((batch_size, self.S, num_points, 1)) * 10 # Sort the queries via their first appeared time _, sort_inds = torch.sort(query_points_t, dim=0, descending=False) inv_sort_inds = torch.argsort(sort_inds, dim=0) assert torch.allclose(query_points_t, query_points_t[sort_inds][inv_sort_inds]) query_points_t_ = query_points_t[sort_inds] coords_init_ = coords_init[..., sort_inds, :].clone() vis_init_ = vis_init[:, :, sort_inds].clone() track_mask_ = track_mask[:, :, sort_inds].clone() # Placeholders for the results (for the sorted points) traj_e_ = query_points.new_zeros((batch_size, num_frames, num_points, 3)) vis_e_ = query_points.new_zeros((batch_size, num_frames, num_points)) # Perform the iterative forward pass of the SpaTracker as usual, # but make sure to use the pre-computed triplane features w_idx_start = 0 p_idx_start = 0 vis_predictions = [] coord_predictions = [] p_idx_end_list = [] while w_idx_start < num_frames - self.S // 2: curr_wind_points = torch.nonzero(query_points_t_ < w_idx_start + self.S) if curr_wind_points.shape[0] == 0: w_idx_start = w_idx_start + self.S // 2 continue p_idx_end = curr_wind_points[-1] + 1 p_idx_end_list.append(p_idx_end) # TODO: Is cloning necessary here – I don't think so? fmap_xy_seq = fmap_xy[:, w_idx_start:w_idx_start + self.S].clone() fmap_yz_seq = fmap_yz[:, w_idx_start:w_idx_start + self.S].clone() fmap_xz_seq = fmap_xz[:, w_idx_start:w_idx_start + self.S].clone() # the number of frames may not be divisible by self.S S_local = fmap_xy_seq.shape[1] if S_local < self.S: fmap_xy_seq = torch.cat([fmap_xy_seq, fmap_xy_seq[:, -1, None].repeat(1, self.S - S_local, 1, 1, 1)], 1) fmap_yz_seq = torch.cat([fmap_yz_seq, fmap_yz_seq[:, -1, None].repeat(1, self.S - S_local, 1, 1, 1)], 1) fmap_xz_seq = torch.cat([fmap_xz_seq, fmap_xz_seq[:, -1, None].repeat(1, self.S - S_local, 1, 1, 1)], 1) if p_idx_end - p_idx_start > 0: queried_t = (query_points_t_[p_idx_start:p_idx_end] - w_idx_start) featxy_init, featyz_init, featxz_init = self.sample_trifeat( t=queried_t, featMapxy=fmap_xy_seq, featMapyz=fmap_yz_seq, featMapxz=fmap_xz_seq, coords=coords_init_[:, :1, p_idx_start:p_idx_end], ) feat_init_curr = torch.stack([featxy_init, featyz_init, featxz_init], dim=-1) feat_init = smart_cat(feat_init, feat_init_curr, dim=2) # Update the initial coordinates and visibility for non-first windows if p_idx_start > 0: last_coords = coords[-1][:, self.S // 2:].clone() # Take the predicted coords from the last window coords_init_[:, : self.S // 2, :p_idx_start] = last_coords coords_init_[:, self.S // 2:, :p_idx_start] = last_coords[:, -1].repeat(1, self.S // 2, 1, 1) last_vis = vis[:, self.S // 2:][..., None] vis_init_[:, : self.S // 2, :p_idx_start] = last_vis vis_init_[:, self.S // 2:, :p_idx_start] = last_vis[:, -1].repeat(1, self.S // 2, 1, 1) track_mask_current = track_mask_[:, w_idx_start: w_idx_start + self.S, :p_idx_end] if S_local < self.S: track_mask_current = torch.cat([ track_mask_current, track_mask_current[:, -1:].repeat(1, self.S - S_local, 1, 1), ], 1) coords, vis, _ = self.forward_iteration( fmapXY=fmap_xy_seq, fmapYZ=fmap_yz_seq, fmapXZ=fmap_xz_seq, coords_init=coords_init_[:, :, :p_idx_end], feat_init=feat_init[:, :, :p_idx_end], vis_init=vis_init_[:, :, :p_idx_end], track_mask=track_mask_current, iters=iters, ) coords_in_worldspace = [world_to_triplane(coord, inverse=True) for coord in coords] if is_train: coord_predictions.append([coord[:, :S_local] for coord in coords_in_worldspace]) vis_predictions.append(vis[:, :S_local]) traj_e_[:, w_idx_start:w_idx_start + self.S, :p_idx_end] = coords_in_worldspace[-1][:, :S_local] vis_e_[:, w_idx_start:w_idx_start + self.S, :p_idx_end] = torch.sigmoid(vis[:, :S_local]) track_mask_[:, : w_idx_start + self.S, :p_idx_end] = 0.0 w_idx_start = w_idx_start + self.S // 2 p_idx_start = p_idx_end traj_e = traj_e_[:, :, inv_sort_inds] vis_e = vis_e_[:, :, inv_sort_inds] results = { "traj_e": traj_e, "feat_init": feat_init, "vis_e": vis_e, } if self.is_train: results["train_data"] = { "vis_predictions": vis_predictions, "coord_predictions": coord_predictions, "attn_predictions": None, "p_idx_end_list": p_idx_end_list, "sort_inds": sort_inds, "Rigid_ln_total": None, } return results @staticmethod def _plot_pointcloud(logs_path, filename, xyz, c, q_xyz=None, q_c=None, elevations=(0, 30, 90), azimuths=(0, 45, 90), show=False): fig = plt.figure(figsize=(len(azimuths) * 4.8, len(elevations) * 4.8)) fig.suptitle(filename) for i, elev_ in enumerate(elevations): for j, azim in enumerate(azimuths): ax = fig.add_subplot(len(elevations), len(azimuths), i * len(azimuths) + j + 1, projection='3d') ax.scatter(xyz[:, 0], xyz[:, 1], xyz[:, 2], c=c, s=1, marker=".", label="RGBD pointcloud") if q_xyz is not None: ax.scatter(q_xyz[:, 0], q_xyz[:, 1], q_xyz[:, 2], c=q_c, s=3, marker="^", label="Query Points") ax.set_xlabel('x') ax.set_ylabel('y') ax.set_zlabel('z') ax.legend() ax.view_init(elev=elev_, azim=azim) plt.tight_layout(pad=0) plt.savefig(os.path.join(logs_path, f"{filename}.png")) if show: plt.show() plt.close() @staticmethod def _plot_featuremaps( logs_path, filename, fmaps_before_splatting, splatted_fmaps, splat_normalization, chosen_channels=(-3, -2, -1), ): num_frames, n_channels, height_before, width_before = fmaps_before_splatting.shape _, _, height_after, width_after = splatted_fmaps.shape assert fmaps_before_splatting.shape == (num_frames, n_channels, height_before, width_before) assert splatted_fmaps.shape == (num_frames, n_channels, height_after, width_after) assert splat_normalization.shape == (num_frames, 1, height_after, width_after) fmaps_before_splatting = fmaps_before_splatting.detach().cpu().float().numpy() splatted_fmaps = splatted_fmaps.detach().cpu().float().numpy() splat_normalization = splat_normalization.detach().cpu().float().numpy() # Extract the chosen channels and normalize them fmaps_before_splatting = fmaps_before_splatting[:, chosen_channels, :, :] splatted_fmaps = splatted_fmaps[:, chosen_channels, :, :] ch_min = fmaps_before_splatting.min(axis=(0, 2, 3), keepdims=True) ch_max = fmaps_before_splatting.max(axis=(0, 2, 3), keepdims=True) fmaps_before_splatting = (fmaps_before_splatting - ch_min) / (ch_max - ch_min) splatted_fmaps = (splatted_fmaps - ch_min) / (ch_max - ch_min) # Normalize the normalization as well ( ͡° ͜ʖ ͡°) splat_normalization = splat_normalization / splat_normalization.max() # Pad the shorter side to match the longer side if width_before != width_after: if width_after > width_before: fmaps_before_splatting = np.pad( fmaps_before_splatting, ((0, 0), (0, 0), (0, 0), (0, width_after - width_before)), mode='constant', constant_values=0 ) else: splatted_fmaps = np.pad( splatted_fmaps, ((0, 0), (0, 0), (0, 0), (0, width_before - width_after)), mode='constant', constant_values=0 ) splat_normalization = np.pad( splat_normalization, ((0, 0), (0, 0), (0, 0), (0, width_before - width_after)), mode='constant', constant_values=0 ) # Concatenate images along the height dimension splat_normalization = np.repeat(splat_normalization, 3, axis=1) imgs = [ np.concatenate([ fmaps_before_splatting[t], splatted_fmaps[t], splat_normalization[t] ], axis=1).transpose(1, 2, 0)[..., ::-1] for t in range(num_frames) ] video = cv2.VideoWriter( os.path.join(logs_path, f"{filename}.mp4"), cv2.VideoWriter_fourcc(*"mp4v"), 12, (imgs[0].shape[1], imgs[0].shape[0]), ) for img in imgs: video.write((img * 255).astype(np.uint8)) video.release() logging.info(f"Saved the featuremap video to {os.path.abspath(os.path.join(logs_path, f'{filename}.mp4'))}") ================================================ FILE: mvtracker/models/core/vggt/__init__.py ================================================ ================================================ FILE: mvtracker/models/core/vggt/heads/camera_head.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from ..layers import Mlp from ..layers.block import Block from ..heads.head_act import activate_pose class CameraHead(nn.Module): """ CameraHead predicts camera parameters from token representations using iterative refinement. It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. """ def __init__( self, dim_in: int = 2048, trunk_depth: int = 4, pose_encoding_type: str = "absT_quaR_FoV", num_heads: int = 16, mlp_ratio: int = 4, init_values: float = 0.01, trans_act: str = "linear", quat_act: str = "linear", fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. ): super().__init__() if pose_encoding_type == "absT_quaR_FoV": self.target_dim = 9 else: raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") self.trans_act = trans_act self.quat_act = quat_act self.fl_act = fl_act self.trunk_depth = trunk_depth # Build the trunk using a sequence of transformer blocks. self.trunk = nn.Sequential( *[ Block( dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values, ) for _ in range(trunk_depth) ] ) # Normalizations for camera token and trunk output. self.token_norm = nn.LayerNorm(dim_in) self.trunk_norm = nn.LayerNorm(dim_in) # Learnable empty camera pose token. self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) self.embed_pose = nn.Linear(self.target_dim, dim_in) # Module for producing modulation parameters: shift, scale, and a gate. self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) # Adaptive layer normalization without affine parameters. self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) self.pose_branch = Mlp( in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0, ) def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list: """ Forward pass to predict camera parameters. Args: aggregated_tokens_list (list): List of token tensors from the network; the last tensor is used for prediction. num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. Returns: list: A list of predicted camera encodings (post-activation) from each iteration. """ # Use tokens from the last block for camera prediction. tokens = aggregated_tokens_list[-1] # Extract the camera tokens pose_tokens = tokens[:, :, 0] pose_tokens = self.token_norm(pose_tokens) pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations) return pred_pose_enc_list def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: """ Iteratively refine camera pose predictions. Args: pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C]. num_iterations (int): Number of refinement iterations. Returns: list: List of activated camera encodings from each iteration. """ B, S, C = pose_tokens.shape # S is expected to be 1. pred_pose_enc = None pred_pose_enc_list = [] for _ in range(num_iterations): # Use a learned empty pose for the first iteration. if pred_pose_enc is None: module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) else: # Detach the previous prediction to avoid backprop through time. pred_pose_enc = pred_pose_enc.detach() module_input = self.embed_pose(pred_pose_enc) # Generate modulation parameters and split them into shift, scale, and gate components. shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) # Adaptive layer normalization and modulation. pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) pose_tokens_modulated = pose_tokens_modulated + pose_tokens pose_tokens_modulated = self.trunk(pose_tokens_modulated) # Compute the delta update for the pose encoding. pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) if pred_pose_enc is None: pred_pose_enc = pred_pose_enc_delta else: pred_pose_enc = pred_pose_enc + pred_pose_enc_delta # Apply final activation functions for translation, quaternion, and field-of-view. activated_pose = activate_pose( pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act, ) pred_pose_enc_list.append(activated_pose) return pred_pose_enc_list def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """ Modulate the input tensor using scaling and shifting parameters. """ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 return x * (1 + scale) + shift ================================================ FILE: mvtracker/models/core/vggt/heads/dpt_head.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Inspired by https://github.com/DepthAnything/Depth-Anything-V2 import os from typing import List, Dict, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from .head_act import activate_head from .utils import create_uv_grid, position_grid_to_embed class DPTHead(nn.Module): """ DPT Head for dense prediction tasks. This implementation follows the architecture described in "Vision Transformers for Dense Prediction" (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer backbone and produces dense predictions by fusing multi-scale features. Args: dim_in (int): Input dimension (channels). patch_size (int, optional): Patch size. Default is 14. output_dim (int, optional): Number of output channels. Default is 4. activation (str, optional): Activation type. Default is "inv_log". conf_activation (str, optional): Confidence activation type. Default is "expp1". features (int, optional): Feature channels for intermediate representations. Default is 256. out_channels (List[int], optional): Output channels for each intermediate layer. intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. pos_embed (bool, optional): Whether to use positional embedding. Default is True. feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. """ def __init__( self, dim_in: int, patch_size: int = 14, output_dim: int = 4, activation: str = "inv_log", conf_activation: str = "expp1", features: int = 256, out_channels: List[int] = [256, 512, 1024, 1024], intermediate_layer_idx: List[int] = [4, 11, 17, 23], pos_embed: bool = True, feature_only: bool = False, down_ratio: int = 1, ) -> None: super(DPTHead, self).__init__() self.patch_size = patch_size self.activation = activation self.conf_activation = conf_activation self.pos_embed = pos_embed self.feature_only = feature_only self.down_ratio = down_ratio self.intermediate_layer_idx = intermediate_layer_idx self.norm = nn.LayerNorm(dim_in) # Projection layers for each output channel from tokens. self.projects = nn.ModuleList( [ nn.Conv2d( in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0, ) for oc in out_channels ] ) # Resize layers for upsampling feature maps. self.resize_layers = nn.ModuleList( [ nn.ConvTranspose2d( in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 ), nn.ConvTranspose2d( in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 ), nn.Identity(), nn.Conv2d( in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 ), ] ) self.scratch = _make_scratch( out_channels, features, expand=False, ) # Attach additional modules to scratch. self.scratch.stem_transpose = None self.scratch.refinenet1 = _make_fusion_block(features) self.scratch.refinenet2 = _make_fusion_block(features) self.scratch.refinenet3 = _make_fusion_block(features) self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) head_features_1 = features head_features_2 = 32 if feature_only: self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1) else: self.scratch.output_conv1 = nn.Conv2d( head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 ) conv2_in_channels = head_features_1 // 2 self.scratch.output_conv2 = nn.Sequential( nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), ) def forward( self, aggregated_tokens_list: List[torch.Tensor], images: torch.Tensor, patch_start_idx: int, frames_chunk_size: int = 8, inference_feature_only: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Forward pass through the DPT head, supports processing by chunking frames. Args: aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. patch_start_idx (int): Starting index for patch tokens in the token sequence. Used to separate patch tokens from other tokens (e.g., camera or register tokens). frames_chunk_size (int, optional): Number of frames to process in each chunk. If None or larger than S, all frames are processed at once. Default: 8. Returns: Tensor or Tuple[Tensor, Tensor]: - If feature_only=True: Feature maps with shape [B, S, C, H, W] - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] """ B, S, _, H, W = images.shape # If frames_chunk_size is not specified or greater than S, process all frames at once if frames_chunk_size is None or frames_chunk_size >= S: return self._forward_impl(aggregated_tokens_list, images, patch_start_idx, inference_feature_only = inference_feature_only) # Otherwise, process frames in chunks to manage memory usage assert frames_chunk_size > 0 # Process frames in batches all_preds = [] all_conf = [] for frames_start_idx in range(0, S, frames_chunk_size): frames_end_idx = min(frames_start_idx + frames_chunk_size, S) # Process batch of frames # if self.feature_only or inference_feature_only: # chunk_output = self._forward_impl( # aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx, inference_feature_only = inference_feature_only # ) # all_preds.append(chunk_output) # else: # chunk_preds, chunk_conf = self._forward_impl( # aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx, inference_feature_only = inference_feature_only # ) # all_preds.append(chunk_preds) # all_conf.append(chunk_conf) chunk_preds, chunk_conf = self._forward_impl( aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx, inference_feature_only = inference_feature_only ) all_preds.append(chunk_preds) all_conf.append(chunk_conf) # Concatenate results along the sequence dimension # if self.feature_only or inference_feature_only: # return torch.cat(all_preds, dim=1) # else: # return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) def _forward_impl( self, aggregated_tokens_list: List[torch.Tensor], images: torch.Tensor, patch_start_idx: int, frames_start_idx: int = None, frames_end_idx: int = None, inference_feature_only: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Implementation of the forward pass through the DPT head. This method processes a specific chunk of frames from the sequence. Args: aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. images (Tensor): Input images with shape [B, S, 3, H, W]. patch_start_idx (int): Starting index for patch tokens. frames_start_idx (int, optional): Starting index for frames to process. frames_end_idx (int, optional): Ending index for frames to process. Returns: Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). """ if frames_start_idx is not None and frames_end_idx is not None: images = images[:, frames_start_idx:frames_end_idx].contiguous() B, S, _, H, W = images.shape patch_h, patch_w = H // self.patch_size, W // self.patch_size out = [] dpt_idx = 0 for layer_idx in self.intermediate_layer_idx: x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] # Select frames if processing a chunk if frames_start_idx is not None and frames_end_idx is not None: x = x[:, frames_start_idx:frames_end_idx] x = x.view(B * S, -1, x.shape[-1]) x = self.norm(x) x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) x = self.projects[dpt_idx](x) if self.pos_embed: x = self._apply_pos_embed(x, W, H) x = self.resize_layers[dpt_idx](x) out.append(x) dpt_idx += 1 # Fuse features from multiple layers. out = self.scratch_forward(out) # Interpolate fused output to match target image resolution. out = custom_interpolate( out, (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), mode="bilinear", align_corners=True, ) if self.pos_embed: out = self._apply_pos_embed(out, W, H) if self.feature_only or inference_feature_only: feature_output = out.view(B, S, *out.shape[1:]) # return out.view(B, S, *out.shape[1:]) out = self.scratch.output_conv2(out) preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation) preds = preds.view(B, S, *preds.shape[1:]) conf = conf.view(B, S, *conf.shape[1:]) if self.feature_only or inference_feature_only: return feature_output, conf else: return preds, conf def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: """ Apply positional embedding to tensor x. """ patch_w = x.shape[-1] patch_h = x.shape[-2] pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) pos_embed = pos_embed * ratio pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) return x + pos_embed def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: """ Forward pass through the fusion blocks. Args: features (List[Tensor]): List of feature maps from different layers. Returns: Tensor: Fused feature map. """ layer_1, layer_2, layer_3, layer_4 = features layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) del layer_4_rn, layer_4 out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) del layer_3_rn, layer_3 out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) del layer_2_rn, layer_2 out = self.scratch.refinenet1(out, layer_1_rn) del layer_1_rn, layer_1 out = self.scratch.output_conv1(out) return out ################################################################################ # Modules ################################################################################ def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module: return FeatureFusionBlock( features, nn.ReLU(inplace=True), deconv=False, bn=False, expand=False, align_corners=True, size=size, has_residual=has_residual, groups=groups, ) def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module: scratch = nn.Module() out_shape1 = out_shape out_shape2 = out_shape out_shape3 = out_shape if len(in_shape) >= 4: out_shape4 = out_shape if expand: out_shape1 = out_shape out_shape2 = out_shape * 2 out_shape3 = out_shape * 4 if len(in_shape) >= 4: out_shape4 = out_shape * 8 scratch.layer1_rn = nn.Conv2d( in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) scratch.layer2_rn = nn.Conv2d( in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) scratch.layer3_rn = nn.Conv2d( in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) if len(in_shape) >= 4: scratch.layer4_rn = nn.Conv2d( in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) return scratch class ResidualConvUnit(nn.Module): """Residual convolution module.""" def __init__(self, features, activation, bn, groups=1): """Init. Args: features (int): number of features """ super().__init__() self.bn = bn self.groups = groups self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) self.norm1 = None self.norm2 = None self.activation = activation self.skip_add = nn.quantized.FloatFunctional() def forward(self, x): """Forward pass. Args: x (tensor): input Returns: tensor: output """ out = self.activation(x) out = self.conv1(out) if self.norm1 is not None: out = self.norm1(out) out = self.activation(out) out = self.conv2(out) if self.norm2 is not None: out = self.norm2(out) return self.skip_add.add(out, x) class FeatureFusionBlock(nn.Module): """Feature fusion block.""" def __init__( self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None, has_residual=True, groups=1, ): """Init. Args: features (int): number of features """ super(FeatureFusionBlock, self).__init__() self.deconv = deconv self.align_corners = align_corners self.groups = groups self.expand = expand out_features = features if self.expand == True: out_features = features // 2 self.out_conv = nn.Conv2d( features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups ) if has_residual: self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups) self.has_residual = has_residual self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups) self.skip_add = nn.quantized.FloatFunctional() self.size = size def forward(self, *xs, size=None): """Forward pass. Returns: tensor: output """ output = xs[0] if self.has_residual: res = self.resConfUnit1(xs[1]) output = self.skip_add.add(output, res) output = self.resConfUnit2(output) if (size is None) and (self.size is None): modifier = {"scale_factor": 2} elif size is None: modifier = {"size": self.size} else: modifier = {"size": size} output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) output = self.out_conv(output) return output def custom_interpolate( x: torch.Tensor, size: Tuple[int, int] = None, scale_factor: float = None, mode: str = "bilinear", align_corners: bool = True, ) -> torch.Tensor: """ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. """ if size is None: size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) INT_MAX = 1610612736 input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] if input_elements > INT_MAX: chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) interpolated_chunks = [ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks ] x = torch.cat(interpolated_chunks, dim=0) return x.contiguous() else: return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) ================================================ FILE: mvtracker/models/core/vggt/heads/head_act.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn.functional as F def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"): """ Activate pose parameters with specified activation functions. Args: pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] trans_act: Activation type for translation component quat_act: Activation type for quaternion component fl_act: Activation type for focal length component Returns: Activated pose parameters tensor """ T = pred_pose_enc[..., :3] quat = pred_pose_enc[..., 3:7] fl = pred_pose_enc[..., 7:] # or fov T = base_pose_act(T, trans_act) quat = base_pose_act(quat, quat_act) fl = base_pose_act(fl, fl_act) # or fov pred_pose_enc = torch.cat([T, quat, fl], dim=-1) return pred_pose_enc def base_pose_act(pose_enc, act_type="linear"): """ Apply basic activation function to pose parameters. Args: pose_enc: Tensor containing encoded pose parameters act_type: Activation type ("linear", "inv_log", "exp", "relu") Returns: Activated pose parameters """ if act_type == "linear": return pose_enc elif act_type == "inv_log": return inverse_log_transform(pose_enc) elif act_type == "exp": return torch.exp(pose_enc) elif act_type == "relu": return F.relu(pose_enc) else: raise ValueError(f"Unknown act_type: {act_type}") def activate_head(out, activation="norm_exp", conf_activation="expp1"): """ Process network output to extract 3D points and confidence values. Args: out: Network output tensor (B, C, H, W) activation: Activation type for 3D points conf_activation: Activation type for confidence values Returns: Tuple of (3D points tensor, confidence tensor) """ # Move channels from last dim to the 4th dimension => (B, H, W, C) fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected # Split into xyz (first C-1 channels) and confidence (last channel) xyz = fmap[:, :, :, :-1] conf = fmap[:, :, :, -1] if activation == "norm_exp": d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) xyz_normed = xyz / d pts3d = xyz_normed * torch.expm1(d) elif activation == "norm": pts3d = xyz / xyz.norm(dim=-1, keepdim=True) elif activation == "exp": pts3d = torch.exp(xyz) elif activation == "relu": pts3d = F.relu(xyz) elif activation == "inv_log": pts3d = inverse_log_transform(xyz) elif activation == "xy_inv_log": xy, z = xyz.split([2, 1], dim=-1) z = inverse_log_transform(z) pts3d = torch.cat([xy * z, z], dim=-1) elif activation == "sigmoid": pts3d = torch.sigmoid(xyz) elif activation == "linear": pts3d = xyz else: raise ValueError(f"Unknown activation: {activation}") if conf_activation == "expp1": conf_out = 1 + conf.exp() elif conf_activation == "expp0": conf_out = conf.exp() elif conf_activation == "sigmoid": conf_out = torch.sigmoid(conf) else: raise ValueError(f"Unknown conf_activation: {conf_activation}") return pts3d, conf_out def inverse_log_transform(y): """ Apply inverse log transform: sign(y) * (exp(|y|) - 1) Args: y: Input tensor Returns: Transformed tensor """ return torch.sign(y) * (torch.expm1(torch.abs(y))) ================================================ FILE: mvtracker/models/core/vggt/heads/track_head.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch.nn as nn from .dpt_head import DPTHead from .track_modules.base_track_predictor import BaseTrackerPredictor class TrackHead(nn.Module): """ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking. The tracking is performed iteratively, refining predictions over multiple iterations. """ def __init__( self, dim_in, patch_size=14, features=128, iters=4, predict_conf=True, stride=2, corr_levels=7, corr_radius=4, hidden_size=384, ): """ Initialize the TrackHead module. Args: dim_in (int): Input dimension of tokens from the backbone. patch_size (int): Size of image patches used in the vision transformer. features (int): Number of feature channels in the feature extractor output. iters (int): Number of refinement iterations for tracking predictions. predict_conf (bool): Whether to predict confidence scores for tracked points. stride (int): Stride value for the tracker predictor. corr_levels (int): Number of correlation pyramid levels corr_radius (int): Radius for correlation computation, controlling the search area. hidden_size (int): Size of hidden layers in the tracker network. """ super().__init__() self.patch_size = patch_size # Feature extractor based on DPT architecture # Processes tokens into feature maps for tracking self.feature_extractor = DPTHead( dim_in=dim_in, patch_size=patch_size, features=features, feature_only=True, # Only output features, no activation down_ratio=2, # Reduces spatial dimensions by factor of 2 pos_embed=False, ) # Tracker module that predicts point trajectories # Takes feature maps and predicts coordinates and visibility self.tracker = BaseTrackerPredictor( latent_dim=features, # Match the output_dim of feature extractor predict_conf=predict_conf, stride=stride, corr_levels=corr_levels, corr_radius=corr_radius, hidden_size=hidden_size, ) self.iters = iters def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None): """ Forward pass of the TrackHead. Args: aggregated_tokens_list (list): List of aggregated tokens from the backbone. images (torch.Tensor): Input images of shape (B, S, C, H, W) where: B = batch size, S = sequence length. patch_start_idx (int): Starting index for patch tokens. query_points (torch.Tensor, optional): Initial query points to track. If None, points are initialized by the tracker. iters (int, optional): Number of refinement iterations. If None, uses self.iters. Returns: tuple: - coord_preds (torch.Tensor): Predicted coordinates for tracked points. - vis_scores (torch.Tensor): Visibility scores for tracked points. - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True). """ B, S, _, H, W = images.shape # Extract features from tokens # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2 feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx) # Use default iterations if not specified if iters is None: iters = self.iters # Perform tracking using the extracted features coord_preds, vis_scores, conf_scores = self.tracker( query_points=query_points, fmaps=feature_maps, iters=iters, ) return coord_preds, vis_scores, conf_scores ================================================ FILE: mvtracker/models/core/vggt/heads/track_modules/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: mvtracker/models/core/vggt/heads/track_modules/base_track_predictor.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from einops import rearrange, repeat from .blocks import EfficientUpdateFormer, CorrBlock from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed from .modules import Mlp class BaseTrackerPredictor(nn.Module): def __init__( self, stride=1, corr_levels=5, corr_radius=4, latent_dim=128, hidden_size=384, use_spaceatt=True, depth=6, max_scale=518, predict_conf=True, ): super(BaseTrackerPredictor, self).__init__() """ The base template to create a track predictor Modified from https://github.com/facebookresearch/co-tracker/ and https://github.com/facebookresearch/vggsfm """ self.stride = stride self.latent_dim = latent_dim self.corr_levels = corr_levels self.corr_radius = corr_radius self.hidden_size = hidden_size self.max_scale = max_scale self.predict_conf = predict_conf self.flows_emb_dim = latent_dim // 2 self.corr_mlp = Mlp( in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2, hidden_features=self.hidden_size, out_features=self.latent_dim, ) self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4 self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim)) space_depth = depth if use_spaceatt else 0 time_depth = depth self.updateformer = EfficientUpdateFormer( space_depth=space_depth, time_depth=time_depth, input_dim=self.transformer_dim, hidden_size=self.hidden_size, output_dim=self.latent_dim + 2, mlp_ratio=4.0, add_space_attn=use_spaceatt, ) self.fmap_norm = nn.LayerNorm(self.latent_dim) self.ffeat_norm = nn.GroupNorm(1, self.latent_dim) # A linear layer to update track feats at each iteration self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) if predict_conf: self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True): """ query_points: B x N x 2, the number of batches, tracks, and xy fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. note HH and WW is the size of feature maps instead of original images """ B, N, D = query_points.shape B, S, C, HH, WW = fmaps.shape assert D == 2, "Input points must be 2D coordinates" # apply a layernorm to fmaps here fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2)) fmaps = fmaps.permute(0, 1, 4, 2, 3) # Scale the input query_points because we may downsample the images # by down_ratio or self.stride # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map # its query_points should be query_points/4 if down_ratio > 1: query_points = query_points / float(down_ratio) query_points = query_points / float(self.stride) # Init with coords as the query points # It means the search will start from the position of query points at the reference frames coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) # Sample/extract the features of the query points in the query frame query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) # init track feats by query feats track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C # back up the init coords coords_backup = coords.clone() fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) coord_preds = [] # Iterative Refinement for _ in range(iters): # Detach the gradients from the last iteration # (in my experience, not very important for performance) coords = coords.detach() fcorrs = fcorr_fn.corr_sample(track_feats, coords) corr_dim = fcorrs.shape[3] fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim) fcorrs_ = self.corr_mlp(fcorrs_) # Movement of current coords relative to query points flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) # (In my trials, it is also okay to just add the flows_emb instead of concat) flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1) track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) # Concatenate them as the input for the transformers transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) # 2D positional embed # TODO: this can be much simplified pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) x = transformer_input + sampled_pos_emb # Add the query ref token to the track feats query_ref_token = torch.cat( [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1 ) x = x + query_ref_token.to(x.device).to(x.dtype) # B, N, S, C x = rearrange(x, "(b n) s d -> b n s d", b=B) # Compute the delta coordinates and delta track features delta, _ = self.updateformer(x) # BN, S, C delta = rearrange(delta, " b n s d -> (b n) s d", b=B) delta_coords_ = delta[:, :, :2] delta_feats_ = delta[:, :, 2:] track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) # Update the track features track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC # B x S x N x 2 coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) # Force coord0 as query # because we assume the query points should not be changed coords[:, 0] = coords_backup[:, 0] # The predicted tracks are in the original image scale if down_ratio > 1: coord_preds.append(coords * self.stride * down_ratio) else: coord_preds.append(coords * self.stride) # B, S, N vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) if apply_sigmoid: vis_e = torch.sigmoid(vis_e) if self.predict_conf: conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) if apply_sigmoid: conf_e = torch.sigmoid(conf_e) else: conf_e = None if return_feat: return coord_preds, vis_e, track_feats, query_track_feat, conf_e else: return coord_preds, vis_e, conf_e ================================================ FILE: mvtracker/models/core/vggt/heads/track_modules/blocks.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Modified from https://github.com/facebookresearch/co-tracker/ import math import torch import torch.nn as nn import torch.nn.functional as F from .utils import bilinear_sampler from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock class EfficientUpdateFormer(nn.Module): """ Transformer model that updates track estimates. """ def __init__( self, space_depth=6, time_depth=6, input_dim=320, hidden_size=384, num_heads=8, output_dim=130, mlp_ratio=4.0, add_space_attn=True, num_virtual_tracks=64, ): super().__init__() self.out_channels = 2 self.num_heads = num_heads self.hidden_size = hidden_size self.add_space_attn = add_space_attn # Add input LayerNorm before linear projection self.input_norm = nn.LayerNorm(input_dim) self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) # Add output LayerNorm before final projection self.output_norm = nn.LayerNorm(hidden_size) self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) self.num_virtual_tracks = num_virtual_tracks if self.add_space_attn: self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) else: self.virual_tracks = None self.time_blocks = nn.ModuleList( [ AttnBlock( hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention, ) for _ in range(time_depth) ] ) if add_space_attn: self.space_virtual_blocks = nn.ModuleList( [ AttnBlock( hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention, ) for _ in range(space_depth) ] ) self.space_point2virtual_blocks = nn.ModuleList( [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] ) self.space_virtual2point_blocks = nn.ModuleList( [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] ) assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) self.initialize_weights() def initialize_weights(self): def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) self.apply(_basic_init) def forward(self, input_tensor, mask=None): # Apply input LayerNorm input_tensor = self.input_norm(input_tensor) tokens = self.input_transform(input_tensor) init_tokens = tokens B, _, T, _ = tokens.shape if self.add_space_attn: virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) tokens = torch.cat([tokens, virtual_tokens], dim=1) _, N, _, _ = tokens.shape j = 0 for i in range(len(self.time_blocks)): time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C time_tokens = self.time_blocks[i](time_tokens) tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C point_tokens = space_tokens[:, : N - self.num_virtual_tracks] virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C j += 1 if self.add_space_attn: tokens = tokens[:, : N - self.num_virtual_tracks] tokens = tokens + init_tokens # Apply output LayerNorm before final projection tokens = self.output_norm(tokens) flow = self.flow_head(tokens) return flow, None class CorrBlock: def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): """ Build a pyramid of feature maps from the input. fmaps: Tensor (B, S, C, H, W) num_levels: number of pyramid levels (each downsampled by factor 2) radius: search radius for sampling correlation multiple_track_feats: if True, split the target features per pyramid level padding_mode: passed to grid_sample / bilinear_sampler """ B, S, C, H, W = fmaps.shape self.S, self.C, self.H, self.W = S, C, H, W self.num_levels = num_levels self.radius = radius self.padding_mode = padding_mode self.multiple_track_feats = multiple_track_feats # Build pyramid: each level is half the spatial resolution of the previous self.fmaps_pyramid = [fmaps] # level 0 is full resolution current_fmaps = fmaps for i in range(num_levels - 1): B, S, C, H, W = current_fmaps.shape # Merge batch & sequence dimensions current_fmaps = current_fmaps.reshape(B * S, C, H, W) # Avg pool down by factor 2 current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2) _, _, H_new, W_new = current_fmaps.shape current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new) self.fmaps_pyramid.append(current_fmaps) # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling. # This grid is added to the (scaled) coordinate centroids. r = self.radius dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) # delta: for every (dy,dx) displacement (i.e. Δx, Δy) self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2) def corr_sample(self, targets, coords): """ Instead of storing the entire correlation pyramid, we compute each level's correlation volume, sample it immediately, then discard it. This saves GPU memory. Args: targets: Tensor (B, S, N, C) — features for the current targets. coords: Tensor (B, S, N, 2) — coordinates at full resolution. Returns: Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations) """ B, S, N, C = targets.shape # If you have multiple track features, split them per level. if self.multiple_track_feats: targets_split = torch.split(targets, C // self.num_levels, dim=-1) out_pyramid = [] for i, fmaps in enumerate(self.fmaps_pyramid): # Get current spatial resolution H, W for this pyramid level. B, S, C, H, W = fmaps.shape # Reshape feature maps for correlation computation: # fmap2s: (B, S, C, H*W) fmap2s = fmaps.view(B, S, C, H * W) # Choose appropriate target features. fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C) # Compute correlation directly corrs = compute_corr_level(fmap1, fmap2s, C) corrs = corrs.view(B, S, N, H, W) # Prepare sampling grid: # Scale down the coordinates for the current level. centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i) # Make sure our precomputed delta grid is on the same device/dtype. delta_lvl = self.delta.to(coords.device).to(coords.dtype) # Now the grid for grid_sample is: # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid) coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2) # Sample from the correlation volume using bilinear interpolation. # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target. corrs_sampled = bilinear_sampler( corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode ) # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims. corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2) out_pyramid.append(corrs_sampled) # Concatenate all levels along the last dimension. out = torch.cat(out_pyramid, dim=-1).contiguous() return out def compute_corr_level(fmap1, fmap2s, C): # fmap1: (B, S, N, C) # fmap2s: (B, S, C, H*W) corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W) corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W) return corrs / math.sqrt(C) ================================================ FILE: mvtracker/models/core/vggt/heads/track_modules/modules.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F from functools import partial from typing import Callable import collections from torch import Tensor from itertools import repeat # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) return parse def exists(val): return val is not None def default(val, d): return val if exists(val) else d to_2tuple = _ntuple(2) class ResidualBlock(nn.Module): """ ResidualBlock: construct a block of two conv layers with residual connections """ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros", ) self.conv2 = nn.Conv2d( planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros", ) self.relu = nn.ReLU(inplace=True) num_groups = planes // 8 if norm_fn == "group": self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) if not stride == 1: self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) elif norm_fn == "batch": self.norm1 = nn.BatchNorm2d(planes) self.norm2 = nn.BatchNorm2d(planes) if not stride == 1: self.norm3 = nn.BatchNorm2d(planes) elif norm_fn == "instance": self.norm1 = nn.InstanceNorm2d(planes) self.norm2 = nn.InstanceNorm2d(planes) if not stride == 1: self.norm3 = nn.InstanceNorm2d(planes) elif norm_fn == "none": self.norm1 = nn.Sequential() self.norm2 = nn.Sequential() if not stride == 1: self.norm3 = nn.Sequential() else: raise NotImplementedError if stride == 1: self.downsample = None else: self.downsample = nn.Sequential( nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3, ) def forward(self, x): y = x y = self.relu(self.norm1(self.conv1(y))) y = self.relu(self.norm2(self.conv2(y))) if self.downsample is not None: x = self.downsample(x) return self.relu(x + y) class Mlp(nn.Module): """MLP as used in Vision Transformer, MLP-Mixer and related networks""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, norm_layer=None, bias=True, drop=0.0, use_conv=False, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class AttnBlock(nn.Module): def __init__( self, hidden_size, num_heads, attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, mlp_ratio=4.0, **block_kwargs ): """ Self attention block """ super().__init__() self.norm1 = nn.LayerNorm(hidden_size) self.norm2 = nn.LayerNorm(hidden_size) self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) def forward(self, x, mask=None): # Prepare the mask for PyTorch's attention (it expects a different format) # attn_mask = mask if mask is not None else None # Normalize before attention x = self.norm1(x) # PyTorch's MultiheadAttention returns attn_output, attn_output_weights # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) attn_output, _ = self.attn(x, x, x) # Add & Norm x = x + attn_output x = x + self.mlp(self.norm2(x)) return x class CrossAttnBlock(nn.Module): def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): """ Cross attention block """ super().__init__() self.norm1 = nn.LayerNorm(hidden_size) self.norm_context = nn.LayerNorm(hidden_size) self.norm2 = nn.LayerNorm(hidden_size) self.cross_attn = nn.MultiheadAttention( embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs ) mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) def forward(self, x, context, mask=None): # Normalize inputs x = self.norm1(x) context = self.norm_context(context) # Apply cross attention # Note: nn.MultiheadAttention returns attn_output, attn_output_weights attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) # Add & Norm x = x + attn_output x = x + self.mlp(self.norm2(x)) return x ================================================ FILE: mvtracker/models/core/vggt/heads/track_modules/utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Modified from https://github.com/facebookresearch/vggsfm # and https://github.com/facebookresearch/co-tracker/tree/main import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Union def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: """ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. It is a wrapper of get_2d_sincos_pos_embed_from_grid. Args: - embed_dim: The embedding dimension. - grid_size: The grid size. Returns: - pos_embed: The generated 2D positional embedding. """ if isinstance(grid_size, tuple): grid_size_h, grid_size_w = grid_size else: grid_size_h = grid_size_w = grid_size grid_h = torch.arange(grid_size_h, dtype=torch.float) grid_w = torch.arange(grid_size_w, dtype=torch.float) grid = torch.meshgrid(grid_w, grid_h, indexing="xy") grid = torch.stack(grid, dim=0) grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if return_grid: return ( pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid, ) return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: """ This function generates a 2D positional embedding from a given grid using sine and cosine functions. Args: - embed_dim: The embedding dimension. - grid: The grid to generate the embedding from. Returns: - emb: The generated 2D positional embedding. """ assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: """ This function generates a 1D positional embedding from a given grid using sine and cosine functions. Args: - embed_dim: The embedding dimension. - pos: The position to generate the embedding from. Returns: - emb: The generated 1D positional embedding. """ assert embed_dim % 2 == 0 omega = torch.arange(embed_dim // 2, dtype=torch.double) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = torch.sin(out) # (M, D/2) emb_cos = torch.cos(out) # (M, D/2) emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) return emb[None].float() def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: """ This function generates a 2D positional embedding from given coordinates using sine and cosine functions. Args: - xy: The coordinates to generate the embedding from. - C: The size of the embedding. - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. Returns: - pe: The generated 2D positional embedding. """ B, N, D = xy.shape assert D == 2 x = xy[:, :, 0:1] y = xy[:, :, 1:2] div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) pe_x[:, :, 0::2] = torch.sin(x * div_term) pe_x[:, :, 1::2] = torch.cos(x * div_term) pe_y[:, :, 0::2] = torch.sin(y * div_term) pe_y[:, :, 1::2] = torch.cos(y * div_term) pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) if cat_coords: pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) return pe def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): r"""Sample a tensor using bilinear interpolation `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at coordinates :attr:`coords` using bilinear interpolation. It is the same as `torch.nn.functional.grid_sample()` but with a different coordinate convention. The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where :math:`B` is the batch size, :math:`C` is the number of channels, :math:`H` is the height of the image, and :math:`W` is the width of the image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note that in this case the order of the components is slightly different from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. If `align_corners` is `True`, the coordinate :math:`x` is assumed to be in the range :math:`[0,W-1]`, with 0 corresponding to the center of the left-most image pixel :math:`W-1` to the center of the right-most pixel. If `align_corners` is `False`, the coordinate :math:`x` is assumed to be in the range :math:`[0,W]`, with 0 corresponding to the left edge of the left-most pixel :math:`W` to the right edge of the right-most pixel. Similar conventions apply to the :math:`y` for the range :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range :math:`[0,T-1]` and :math:`[0,T]`. Args: input (Tensor): batch of input images. coords (Tensor): batch of coordinates. align_corners (bool, optional): Coordinate convention. Defaults to `True`. padding_mode (str, optional): Padding mode. Defaults to `"border"`. Returns: Tensor: sampled points. """ coords = coords.detach().clone() ############################################################ # IMPORTANT: coords = coords.to(input.device).to(input.dtype) ############################################################ sizes = input.shape[2:] assert len(sizes) in [2, 3] if len(sizes) == 3: # t x y -> x y t to match dimensions T H W in grid_sample coords = coords[..., [1, 2, 0]] if align_corners: scale = torch.tensor( [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype ) else: scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype) coords.mul_(scale) # coords = coords * scale coords.sub_(1) # coords = coords - 1 return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) def sample_features4d(input, coords): r"""Sample spatial features `sample_features4d(input, coords)` samples the spatial features :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. The field is sampled at coordinates :attr:`coords` using bilinear interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the same convention as :func:`bilinear_sampler` with `align_corners=True`. The output tensor has one feature per point, and has shape :math:`(B, R, C)`. Args: input (Tensor): spatial features. coords (Tensor): points. Returns: Tensor: sampled features. """ B, _, _, _ = input.shape # B R 2 -> B R 1 2 coords = coords.unsqueeze(2) # B C R 1 feats = bilinear_sampler(input, coords) return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C ================================================ FILE: mvtracker/models/core/vggt/heads/utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor: """ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) Args: pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates embed_dim: Output channel dimension for embeddings Returns: Tensor of shape (H, W, embed_dim) with positional embeddings """ H, W, grid_dim = pos_grid.shape assert grid_dim == 2 pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) # Process x and y coordinates separately emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] # Combine and reshape emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] return emb.view(H, W, embed_dim) # [H, W, D] def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: """ This function generates a 1D positional embedding from a given grid using sine and cosine functions. Args: - embed_dim: The embedding dimension. - pos: The position to generate the embedding from. Returns: - emb: The generated 1D positional embedding. """ assert embed_dim % 2 == 0 omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device) omega /= embed_dim / 2.0 omega = 1.0 / omega_0**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = torch.sin(out) # (M, D/2) emb_cos = torch.cos(out) # (M, D/2) emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) return emb.float() # Inspired by https://github.com/microsoft/moge def create_uv_grid( width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None ) -> torch.Tensor: """ Create a normalized UV grid of shape (width, height, 2). The grid spans horizontally and vertically according to an aspect ratio, ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right corner is at (x_span, y_span), normalized by the diagonal of the plane. Args: width (int): Number of points horizontally. height (int): Number of points vertically. aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. dtype (torch.dtype, optional): Data type of the resulting tensor. device (torch.device, optional): Device on which the tensor is created. Returns: torch.Tensor: A (width, height, 2) tensor of UV coordinates. """ # Derive aspect ratio if not explicitly provided if aspect_ratio is None: aspect_ratio = float(width) / float(height) # Compute normalized spans for X and Y diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 span_x = aspect_ratio / diag_factor span_y = 1.0 / diag_factor # Establish the linspace boundaries left_x = -span_x * (width - 1) / width right_x = span_x * (width - 1) / width top_y = -span_y * (height - 1) / height bottom_y = span_y * (height - 1) / height # Generate 1D coordinates x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) # Create 2D meshgrid (width x height) and stack into UV uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") uv_grid = torch.stack((uu, vv), dim=-1) return uv_grid ================================================ FILE: mvtracker/models/core/vggt/layers/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from .mlp import Mlp from .patch_embed import PatchEmbed from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused from .block import NestedTensorBlock from .attention import MemEffAttention ================================================ FILE: mvtracker/models/core/vggt/layers/attention.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py import logging import os import warnings from torch import Tensor from torch import nn import torch.nn.functional as F XFORMERS_AVAILABLE = False class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, proj_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, norm_layer: nn.Module = nn.LayerNorm, qk_norm: bool = False, fused_attn: bool = True, # use F.scaled_dot_product_attention or not rope=None, ) -> None: super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim**-0.5 self.fused_attn = fused_attn self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) self.rope = rope def forward(self, x: Tensor, pos=None) -> Tensor: B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) if self.rope is not None: q = self.rope(q, pos) k = self.rope(k, pos) if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, ) else: q = q * self.scale attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class MemEffAttention(Attention): def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor: assert pos is None if not XFORMERS_AVAILABLE: if attn_bias is not None: raise AssertionError("xFormers is required for using nested tensors") return super().forward(x) B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) q, k, v = unbind(qkv, 2) x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) x = x.reshape([B, N, C]) x = self.proj(x) x = self.proj_drop(x) return x ================================================ FILE: mvtracker/models/core/vggt/layers/block.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py import logging import os from typing import Callable, List, Any, Tuple, Dict import warnings import torch from torch import nn, Tensor from .attention import Attention from .drop_path import DropPath from .layer_scale import LayerScale from .mlp import Mlp XFORMERS_AVAILABLE = False class Block(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, proj_bias: bool = True, ffn_bias: bool = True, drop: float = 0.0, attn_drop: float = 0.0, init_values=None, drop_path: float = 0.0, act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_class: Callable[..., nn.Module] = Attention, ffn_layer: Callable[..., nn.Module] = Mlp, qk_norm: bool = False, fused_attn: bool = True, # use F.scaled_dot_product_attention or not rope=None, ) -> None: super().__init__() self.norm1 = norm_layer(dim) self.attn = attn_class( dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, qk_norm=qk_norm, fused_attn=fused_attn, rope=rope, ) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = ffn_layer( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias, ) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.sample_drop_ratio = drop_path def forward(self, x: Tensor, pos=None) -> Tensor: def attn_residual_func(x: Tensor, pos=None) -> Tensor: return self.ls1(self.attn(self.norm1(x), pos=pos)) def ffn_residual_func(x: Tensor) -> Tensor: return self.ls2(self.mlp(self.norm2(x))) if self.training and self.sample_drop_ratio > 0.1: # the overhead is compensated only for a drop path rate larger than 0.1 x = drop_add_residual_stochastic_depth( x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio, ) x = drop_add_residual_stochastic_depth( x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio, ) elif self.training and self.sample_drop_ratio > 0.0: x = x + self.drop_path1(attn_residual_func(x, pos=pos)) x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 else: x = x + attn_residual_func(x, pos=pos) x = x + ffn_residual_func(x) return x def drop_add_residual_stochastic_depth( x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None, ) -> Tensor: # 1) extract subset using permutation b, n, d = x.shape sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) brange = (torch.randperm(b, device=x.device))[:sample_subset_size] x_subset = x[brange] # 2) apply residual_func to get residual if pos is not None: # if necessary, apply rope to the subset pos = pos[brange] residual = residual_func(x_subset, pos=pos) else: residual = residual_func(x_subset) x_flat = x.flatten(1) residual = residual.flatten(1) residual_scale_factor = b / sample_subset_size # 3) add the residual x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) return x_plus_residual.view_as(x) def get_branges_scales(x, sample_drop_ratio=0.0): b, n, d = x.shape sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) brange = (torch.randperm(b, device=x.device))[:sample_subset_size] residual_scale_factor = b / sample_subset_size return brange, residual_scale_factor def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): if scaling_vector is None: x_flat = x.flatten(1) residual = residual.flatten(1) x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) else: x_plus_residual = scaled_index_add( x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor ) return x_plus_residual attn_bias_cache: Dict[Tuple, Any] = {} def get_attn_bias_and_cat(x_list, branges=None): """ this will perform the index select, cat the tensors, and provide the attn_bias from cache """ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) if all_shapes not in attn_bias_cache.keys(): seqlens = [] for b, x in zip(batch_sizes, x_list): for _ in range(b): seqlens.append(x.shape[1]) attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) attn_bias._batch_sizes = batch_sizes attn_bias_cache[all_shapes] = attn_bias if branges is not None: cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) else: tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) cat_tensors = torch.cat(tensors_bs1, dim=1) return attn_bias_cache[all_shapes], cat_tensors def drop_add_residual_stochastic_depth_list( x_list: List[Tensor], residual_func: Callable[[Tensor, Any], Tensor], sample_drop_ratio: float = 0.0, scaling_vector=None, ) -> Tensor: # 1) generate random set of indices for dropping samples in the batch branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] branges = [s[0] for s in branges_scales] residual_scale_factors = [s[1] for s in branges_scales] # 2) get attention bias and index+concat the tensors attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) # 3) apply residual_func to get residual, and split the result residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore outputs = [] for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) return outputs class NestedTensorBlock(Block): def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: """ x_list contains a list of tensors to nest together and run """ assert isinstance(self.attn, MemEffAttention) if self.training and self.sample_drop_ratio > 0.0: def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.attn(self.norm1(x), attn_bias=attn_bias) def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.mlp(self.norm2(x)) x_list = drop_add_residual_stochastic_depth_list( x_list, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio, scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, ) x_list = drop_add_residual_stochastic_depth_list( x_list, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio, scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, ) return x_list else: def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.ls2(self.mlp(self.norm2(x))) attn_bias, x = get_attn_bias_and_cat(x_list) x = x + attn_residual_func(x, attn_bias=attn_bias) x = x + ffn_residual_func(x) return attn_bias.split(x) def forward(self, x_or_x_list): if isinstance(x_or_x_list, Tensor): return super().forward(x_or_x_list) elif isinstance(x_or_x_list, list): if not XFORMERS_AVAILABLE: raise AssertionError("xFormers is required for using nested tensors") return self.forward_nested(x_or_x_list) else: raise AssertionError ================================================ FILE: mvtracker/models/core/vggt/layers/drop_path.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py from torch import nn def drop_path(x, drop_prob: float = 0.0, training: bool = False): if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0: random_tensor.div_(keep_prob) output = x * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) ================================================ FILE: mvtracker/models/core/vggt/layers/layer_scale.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 from typing import Union import torch from torch import Tensor from torch import nn class LayerScale(nn.Module): def __init__( self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False, ) -> None: super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: Tensor) -> Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma ================================================ FILE: mvtracker/models/core/vggt/layers/mlp.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py from typing import Callable, Optional from torch import Tensor, nn class Mlp(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = nn.GELU, drop: float = 0.0, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) self.drop = nn.Dropout(drop) def forward(self, x: Tensor) -> Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x ================================================ FILE: mvtracker/models/core/vggt/layers/patch_embed.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py from typing import Callable, Optional, Tuple, Union from torch import Tensor import torch.nn as nn def make_2tuple(x): if isinstance(x, tuple): assert len(x) == 2 return x assert isinstance(x, int) return (x, x) class PatchEmbed(nn.Module): """ 2D image to patch embedding: (B,C,H,W) -> (B,N,D) Args: img_size: Image size. patch_size: Patch token size. in_chans: Number of input image channels. embed_dim: Number of linear projection output channels. norm_layer: Normalization layer. """ def __init__( self, img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, embed_dim: int = 768, norm_layer: Optional[Callable] = None, flatten_embedding: bool = True, ) -> None: super().__init__() image_HW = make_2tuple(img_size) patch_HW = make_2tuple(patch_size) patch_grid_size = ( image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1], ) self.img_size = image_HW self.patch_size = patch_HW self.patches_resolution = patch_grid_size self.num_patches = patch_grid_size[0] * patch_grid_size[1] self.in_chans = in_chans self.embed_dim = embed_dim self.flatten_embedding = flatten_embedding self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x: Tensor) -> Tensor: _, _, H, W = x.shape patch_H, patch_W = self.patch_size assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" x = self.proj(x) # B C H W H, W = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) # B HW C x = self.norm(x) if not self.flatten_embedding: x = x.reshape(-1, H, W, self.embed_dim) # B H W C return x def flops(self) -> float: Ho, Wo = self.patches_resolution flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops ================================================ FILE: mvtracker/models/core/vggt/layers/rope.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # Implementation of 2D Rotary Position Embeddings (RoPE). # This module provides a clean implementation of 2D Rotary Position Embeddings, # which extends the original RoPE concept to handle 2D spatial positions. # Inspired by: # https://github.com/meta-llama/codellama/blob/main/llama/model.py # https://github.com/naver-ai/rope-vit import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, Tuple class PositionGetter: """Generates and caches 2D spatial positions for patches in a grid. This class efficiently manages the generation of spatial coordinates for patches in a 2D grid, caching results to avoid redundant computations. Attributes: position_cache: Dictionary storing precomputed position tensors for different grid dimensions. """ def __init__(self): """Initializes the position generator with an empty cache.""" self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor: """Generates spatial positions for a batch of patches. Args: batch_size: Number of samples in the batch. height: Height of the grid in patches. width: Width of the grid in patches. device: Target device for the position tensor. Returns: Tensor of shape (batch_size, height*width, 2) containing y,x coordinates for each position in the grid, repeated for each batch item. """ if (height, width) not in self.position_cache: y_coords = torch.arange(height, device=device) x_coords = torch.arange(width, device=device) positions = torch.cartesian_prod(y_coords, x_coords) self.position_cache[height, width] = positions cached_positions = self.position_cache[height, width] return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone() class RotaryPositionEmbedding2D(nn.Module): """2D Rotary Position Embedding implementation. This module applies rotary position embeddings to input tokens based on their 2D spatial positions. It handles the position-dependent rotation of features separately for vertical and horizontal dimensions. Args: frequency: Base frequency for the position embeddings. Default: 100.0 scaling_factor: Scaling factor for frequency computation. Default: 1.0 Attributes: base_frequency: Base frequency for computing position embeddings. scaling_factor: Factor to scale the computed frequencies. frequency_cache: Cache for storing precomputed frequency components. """ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): """Initializes the 2D RoPE module.""" super().__init__() self.base_frequency = frequency self.scaling_factor = scaling_factor self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} def _compute_frequency_components( self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype ) -> Tuple[torch.Tensor, torch.Tensor]: """Computes frequency components for rotary embeddings. Args: dim: Feature dimension (must be even). seq_len: Maximum sequence length. device: Target device for computations. dtype: Data type for the computed tensors. Returns: Tuple of (cosine, sine) tensors for frequency components. """ cache_key = (dim, seq_len, device, dtype) if cache_key not in self.frequency_cache: # Compute frequency bands exponents = torch.arange(0, dim, 2, device=device).float() / dim inv_freq = 1.0 / (self.base_frequency**exponents) # Generate position-dependent frequencies positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) angles = torch.einsum("i,j->ij", positions, inv_freq) # Compute and cache frequency components angles = angles.to(dtype) angles = torch.cat((angles, angles), dim=-1) cos_components = angles.cos().to(dtype) sin_components = angles.sin().to(dtype) self.frequency_cache[cache_key] = (cos_components, sin_components) return self.frequency_cache[cache_key] @staticmethod def _rotate_features(x: torch.Tensor) -> torch.Tensor: """Performs feature rotation by splitting and recombining feature dimensions. Args: x: Input tensor to rotate. Returns: Rotated feature tensor. """ feature_dim = x.shape[-1] x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] return torch.cat((-x2, x1), dim=-1) def _apply_1d_rope( self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor ) -> torch.Tensor: """Applies 1D rotary position embeddings along one dimension. Args: tokens: Input token features. positions: Position indices. cos_comp: Cosine components for rotation. sin_comp: Sine components for rotation. Returns: Tokens with applied rotary position embeddings. """ # Embed positions with frequency components cos = F.embedding(positions, cos_comp)[:, None, :, :] sin = F.embedding(positions, sin_comp)[:, None, :, :] # Apply rotation return (tokens * cos) + (self._rotate_features(tokens) * sin) def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: """Applies 2D rotary position embeddings to input tokens. Args: tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). The feature dimension (dim) must be divisible by 4. positions: Position tensor of shape (batch_size, n_tokens, 2) containing the y and x coordinates for each token. Returns: Tensor of same shape as input with applied 2D rotary position embeddings. Raises: AssertionError: If input dimensions are invalid or positions are malformed. """ # Validate inputs assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)" # Compute feature dimension for each spatial direction feature_dim = tokens.size(-1) // 2 # Get frequency components max_position = int(positions.max()) + 1 cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype) # Split features for vertical and horizontal processing vertical_features, horizontal_features = tokens.chunk(2, dim=-1) # Apply RoPE separately for each dimension vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp) horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp) # Combine processed features return torch.cat((vertical_features, horizontal_features), dim=-1) ================================================ FILE: mvtracker/models/core/vggt/layers/swiglu_ffn.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. import os from typing import Callable, Optional import warnings from torch import Tensor, nn import torch.nn.functional as F class SwiGLUFFN(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = None, drop: float = 0.0, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) self.w3 = nn.Linear(hidden_features, out_features, bias=bias) def forward(self, x: Tensor) -> Tensor: x12 = self.w12(x) x1, x2 = x12.chunk(2, dim=-1) hidden = F.silu(x1) * x2 return self.w3(hidden) XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None # try: # if XFORMERS_ENABLED: # from xformers.ops import SwiGLU # XFORMERS_AVAILABLE = True # warnings.warn("xFormers is available (SwiGLU)") # else: # warnings.warn("xFormers is disabled (SwiGLU)") # raise ImportError # except ImportError: SwiGLU = SwiGLUFFN XFORMERS_AVAILABLE = False # warnings.warn("xFormers is not available (SwiGLU)") class SwiGLUFFNFused(SwiGLU): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = None, drop: float = 0.0, bias: bool = True, ) -> None: out_features = out_features or in_features hidden_features = hidden_features or in_features hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 super().__init__( in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias, ) ================================================ FILE: mvtracker/models/core/vggt/layers/vision_transformer.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py from functools import partial import math import logging from typing import Sequence, Tuple, Union, Callable import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from torch.nn.init import trunc_normal_ from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block logger = logging.getLogger("dinov2") def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: if not depth_first and include_root: fn(module=module, name=name) for child_name, child_module in module.named_children(): child_name = ".".join((name, child_name)) if name else child_name named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) if depth_first and include_root: fn(module=module, name=name) return module class BlockChunk(nn.ModuleList): def forward(self, x): for b in self: x = b(x) return x class DinoVisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, ffn_bias=True, proj_bias=True, drop_path_rate=0.0, drop_path_uniform=False, init_values=None, # for layerscale: None or 0 => no layerscale embed_layer=PatchEmbed, act_layer=nn.GELU, block_fn=Block, ffn_layer="mlp", block_chunks=1, num_register_tokens=0, interpolate_antialias=False, interpolate_offset=0.1, qk_norm=False, ): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size in_chans (int): number of input channels embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True proj_bias (bool): enable bias for proj in attn if True ffn_bias (bool): enable bias for ffn if True drop_path_rate (float): stochastic depth rate drop_path_uniform (bool): apply uniform drop rate across blocks weight_init (str): weight init scheme init_values (float): layer-scale init values embed_layer (nn.Module): patch embedding layer act_layer (nn.Module): MLP activation layer block_fn (nn.Module): transformer block class ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" block_chunks: (int) split block sequence into block_chunks units for FSDP wrap num_register_tokens: (int) number of extra cls tokens (so-called "registers") interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings """ super().__init__() norm_layer = partial(nn.LayerNorm, eps=1e-6) # tricky but makes it work self.use_checkpoint = False # self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_tokens = 1 self.n_blocks = depth self.num_heads = num_heads self.patch_size = patch_size self.num_register_tokens = num_register_tokens self.interpolate_antialias = interpolate_antialias self.interpolate_offset = interpolate_offset self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) assert num_register_tokens >= 0 self.register_tokens = ( nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None ) if drop_path_uniform is True: dpr = [drop_path_rate] * depth else: dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule if ffn_layer == "mlp": logger.info("using MLP layer as FFN") ffn_layer = Mlp elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": logger.info("using SwiGLU layer as FFN") ffn_layer = SwiGLUFFNFused elif ffn_layer == "identity": logger.info("using Identity layer as FFN") def f(*args, **kwargs): return nn.Identity() ffn_layer = f else: raise NotImplementedError blocks_list = [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, ffn_layer=ffn_layer, init_values=init_values, qk_norm=qk_norm, ) for i in range(depth) ] if block_chunks > 0: self.chunked_blocks = True chunked_blocks = [] chunksize = depth // block_chunks for i in range(0, depth, chunksize): # this is to keep the block index consistent if we chunk the block list chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) else: self.chunked_blocks = False self.blocks = nn.ModuleList(blocks_list) self.norm = norm_layer(embed_dim) self.head = nn.Identity() self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) self.init_weights() def init_weights(self): trunc_normal_(self.pos_embed, std=0.02) nn.init.normal_(self.cls_token, std=1e-6) if self.register_tokens is not None: nn.init.normal_(self.register_tokens, std=1e-6) named_apply(init_weights_vit_timm, self) def interpolate_pos_encoding(self, x, w, h): previous_dtype = x.dtype npatch = x.shape[1] - 1 N = self.pos_embed.shape[1] - 1 if npatch == N and w == h: return self.pos_embed pos_embed = self.pos_embed.float() class_pos_embed = pos_embed[:, 0] patch_pos_embed = pos_embed[:, 1:] dim = x.shape[-1] w0 = w // self.patch_size h0 = h // self.patch_size M = int(math.sqrt(N)) # Recover the number of patches in each dimension assert N == M * M kwargs = {} if self.interpolate_offset: # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors sx = float(w0 + self.interpolate_offset) / M sy = float(h0 + self.interpolate_offset) / M kwargs["scale_factor"] = (sx, sy) else: # Simply specify an output size instead of a scale factor kwargs["size"] = (w0, h0) patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), mode="bicubic", antialias=self.interpolate_antialias, **kwargs, ) assert (w0, h0) == patch_pos_embed.shape[-2:] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) def prepare_tokens_with_masks(self, x, masks=None): B, nc, w, h = x.shape x = self.patch_embed(x) if masks is not None: x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.interpolate_pos_encoding(x, w, h) if self.register_tokens is not None: x = torch.cat( ( x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:], ), dim=1, ) return x def forward_features_list(self, x_list, masks_list): x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] for blk in self.blocks: if self.use_checkpoint: x = checkpoint(blk, x, use_reentrant=self.use_reentrant) else: x = blk(x) all_x = x output = [] for x, masks in zip(all_x, masks_list): x_norm = self.norm(x) output.append( { "x_norm_clstoken": x_norm[:, 0], "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], "x_prenorm": x, "masks": masks, } ) return output def forward_features(self, x, masks=None): if isinstance(x, list): return self.forward_features_list(x, masks) x = self.prepare_tokens_with_masks(x, masks) for blk in self.blocks: if self.use_checkpoint: x = checkpoint(blk, x, use_reentrant=self.use_reentrant) else: x = blk(x) x_norm = self.norm(x) return { "x_norm_clstoken": x_norm[:, 0], "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], "x_prenorm": x, "masks": masks, } def _get_intermediate_layers_not_chunked(self, x, n=1): x = self.prepare_tokens_with_masks(x) # If n is an int, take the n last blocks. If it's a list, take them output, total_block_len = [], len(self.blocks) blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n for i, blk in enumerate(self.blocks): x = blk(x) if i in blocks_to_take: output.append(x) assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" return output def _get_intermediate_layers_chunked(self, x, n=1): x = self.prepare_tokens_with_masks(x) output, i, total_block_len = [], 0, len(self.blocks[-1]) # If n is an int, take the n last blocks. If it's a list, take them blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n for block_chunk in self.blocks: for blk in block_chunk[i:]: # Passing the nn.Identity() x = blk(x) if i in blocks_to_take: output.append(x) i += 1 assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" return output def get_intermediate_layers( self, x: torch.Tensor, n: Union[int, Sequence] = 1, # Layers or n last layers to take reshape: bool = False, return_class_token: bool = False, norm=True, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: if self.chunked_blocks: outputs = self._get_intermediate_layers_chunked(x, n) else: outputs = self._get_intermediate_layers_not_chunked(x, n) if norm: outputs = [self.norm(out) for out in outputs] class_tokens = [out[:, 0] for out in outputs] outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] if reshape: B, _, w, h = x.shape outputs = [ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() for out in outputs ] if return_class_token: return tuple(zip(outputs, class_tokens)) return tuple(outputs) def forward(self, *args, is_training=True, **kwargs): ret = self.forward_features(*args, **kwargs) if is_training: return ret else: return self.head(ret["x_norm_clstoken"]) def init_weights_vit_timm(module: nn.Module, name: str = ""): """ViT weight initialization, original timm impl (for reproducibility)""" if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) def vit_small(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model def vit_base(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model def vit_large(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): """ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 """ model = DinoVisionTransformer( patch_size=patch_size, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model ================================================ FILE: mvtracker/models/core/vggt/models/aggregator.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Union, List, Dict, Any from ..layers import PatchEmbed from ..layers.block import Block from ..layers.rope import RotaryPositionEmbedding2D, PositionGetter from ..layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 logger = logging.getLogger(__name__) _RESNET_MEAN = [0.485, 0.456, 0.406] _RESNET_STD = [0.229, 0.224, 0.225] class Aggregator(nn.Module): """ The Aggregator applies alternating-attention over input frames, as described in VGGT: Visual Geometry Grounded Transformer. Args: img_size (int): Image size in pixels. patch_size (int): Size of each patch for PatchEmbed. embed_dim (int): Dimension of the token embeddings. depth (int): Number of blocks. num_heads (int): Number of attention heads. mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. num_register_tokens (int): Number of register tokens. block_fn (nn.Module): The block type used for attention (Block by default). qkv_bias (bool): Whether to include bias in QKV projections. proj_bias (bool): Whether to include bias in the output projection. ffn_bias (bool): Whether to include bias in MLP layers. patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg". aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"]. aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1. qk_norm (bool): Whether to apply QK normalization. rope_freq (int): Base frequency for rotary embedding. -1 to disable. init_values (float): Init scale for layer scale. """ def __init__( self, img_size=518, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4.0, num_register_tokens=4, block_fn=Block, qkv_bias=True, proj_bias=True, ffn_bias=True, patch_embed="dinov2_vitl14_reg", aa_order=["frame", "global"], aa_block_size=1, qk_norm=True, rope_freq=100, init_values=0.01, ): super().__init__() self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim) # Initialize rotary position embedding if frequency > 0 self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None self.position_getter = PositionGetter() if self.rope is not None else None self.frame_blocks = nn.ModuleList( [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, init_values=init_values, qk_norm=qk_norm, rope=self.rope, ) for _ in range(depth) ] ) self.global_blocks = nn.ModuleList( [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, init_values=init_values, qk_norm=qk_norm, rope=self.rope, ) for _ in range(depth) ] ) self.depth = depth self.aa_order = aa_order self.patch_size = patch_size self.aa_block_size = aa_block_size # Validate that depth is divisible by aa_block_size if self.depth % self.aa_block_size != 0: raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})") self.aa_block_num = self.depth // self.aa_block_size # Note: We have two camera tokens, one for the first frame and one for the rest # The same applies for register tokens self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim)) self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim)) # The patch tokens start after the camera and register tokens self.patch_start_idx = 1 + num_register_tokens # Initialize parameters with small values nn.init.normal_(self.camera_token, std=1e-6) nn.init.normal_(self.register_token, std=1e-6) # Register normalization constants as buffers for name, value in ( ("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD), ): self.register_buffer( name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False, ) def __build_patch_embed__( self, patch_embed, img_size, patch_size, num_register_tokens, interpolate_antialias=True, interpolate_offset=0.0, block_chunks=0, init_values=1.0, embed_dim=1024, ): """ Build the patch embed layer. If 'conv', we use a simple PatchEmbed conv layer. Otherwise, we use a vision transformer. """ if "conv" in patch_embed: self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim) else: vit_models = { "dinov2_vitl14_reg": vit_large, "dinov2_vitb14_reg": vit_base, "dinov2_vits14_reg": vit_small, "dinov2_vitg2_reg": vit_giant2, } self.patch_embed = vit_models[patch_embed]( img_size=img_size, patch_size=patch_size, num_register_tokens=num_register_tokens, interpolate_antialias=interpolate_antialias, interpolate_offset=interpolate_offset, block_chunks=block_chunks, init_values=init_values, ) # Disable gradient updates for mask token if hasattr(self.patch_embed, "mask_token"): self.patch_embed.mask_token.requires_grad_(False) def forward( self, images: torch.Tensor, ) -> Tuple[List[torch.Tensor], int]: """ Args: images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. B: batch size, S: sequence length, 3: RGB channels, H: height, W: width Returns: (list[torch.Tensor], int): The list of outputs from the attention blocks, and the patch_start_idx indicating where patch tokens begin. """ B, S, C_in, H, W = images.shape if C_in != 3: raise ValueError(f"Expected 3 input channels, got {C_in}") # Normalize images and reshape for patch embed images = (images - self._resnet_mean) / self._resnet_std # Reshape to [B*S, C, H, W] for patch embedding images = images.view(B * S, C_in, H, W) patch_tokens = self.patch_embed(images) if isinstance(patch_tokens, dict): patch_tokens = patch_tokens["x_norm_patchtokens"] _, P, C = patch_tokens.shape # Expand camera and register tokens to match batch size and sequence length camera_token = slice_expand_and_flatten(self.camera_token, B, S) register_token = slice_expand_and_flatten(self.register_token, B, S) # Concatenate special tokens with patch tokens tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) pos = None if self.rope is not None: pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device) if self.patch_start_idx > 0: # do not use position embedding for special tokens (camera and register tokens) # so set pos to 0 for the special tokens pos = pos + 1 pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype) pos = torch.cat([pos_special, pos], dim=1) # update P because we added special tokens _, P, C = tokens.shape frame_idx = 0 global_idx = 0 output_list = [] for _ in range(self.aa_block_num): for attn_type in self.aa_order: if attn_type == "frame": tokens, frame_idx, frame_intermediates = self._process_frame_attention( tokens, B, S, P, C, frame_idx, pos=pos ) elif attn_type == "global": tokens, global_idx, global_intermediates = self._process_global_attention( tokens, B, S, P, C, global_idx, pos=pos ) else: raise ValueError(f"Unknown attention type: {attn_type}") for i in range(len(frame_intermediates)): # concat frame and global intermediates, [B x S x P x 2C] concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1) output_list.append(concat_inter) del concat_inter del frame_intermediates del global_intermediates return output_list, self.patch_start_idx def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None): """ Process frame attention blocks. We keep tokens in shape (B*S, P, C). """ # If needed, reshape tokens or positions: if tokens.shape != (B * S, P, C): tokens = tokens.view(B, S, P, C).view(B * S, P, C) if pos is not None and pos.shape != (B * S, P, 2): pos = pos.view(B, S, P, 2).view(B * S, P, 2) intermediates = [] # by default, self.aa_block_size=1, which processes one block at a time for _ in range(self.aa_block_size): tokens = self.frame_blocks[frame_idx](tokens, pos=pos) frame_idx += 1 intermediates.append(tokens.view(B, S, P, C)) return tokens, frame_idx, intermediates def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None): """ Process global attention blocks. We keep tokens in shape (B, S*P, C). """ if tokens.shape != (B, S * P, C): tokens = tokens.view(B, S, P, C).view(B, S * P, C) if pos is not None and pos.shape != (B, S * P, 2): pos = pos.view(B, S, P, 2).view(B, S * P, 2) intermediates = [] # by default, self.aa_block_size=1, which processes one block at a time for _ in range(self.aa_block_size): tokens = self.global_blocks[global_idx](tokens, pos=pos) global_idx += 1 intermediates.append(tokens.view(B, S, P, C)) return tokens, global_idx, intermediates def slice_expand_and_flatten(token_tensor, B, S): """ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing: 1) Uses the first position (index=0) for the first frame only 2) Uses the second position (index=1) for all remaining frames (S-1 frames) 3) Expands both to match batch size B 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token followed by (S-1) second-position tokens 5) Flattens to (B*S, X, C) for processing Returns: torch.Tensor: Processed tokens with shape (B*S, X, C) """ # Slice out the "query" tokens => shape (1, 1, ...) query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:]) # Slice out the "other" tokens => shape (1, S-1, ...) others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:]) # Concatenate => shape (B, S, ...) combined = torch.cat([query, others], dim=1) # Finally flatten => shape (B*S, ...) combined = combined.view(B * S, *combined.shape[2:]) return combined ================================================ FILE: mvtracker/models/core/vggt/models/vggt.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin # used for model hub from ..models.aggregator import Aggregator from ..heads.camera_head import CameraHead from ..heads.dpt_head import DPTHead from ..heads.track_head import TrackHead class VGGT(nn.Module, PyTorchModelHubMixin): def __init__(self, img_size=518, patch_size=14, embed_dim=1024): super().__init__() self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) self.camera_head = CameraHead(dim_in=2 * embed_dim) self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1") self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1") self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size) def forward( self, images: torch.Tensor, query_points: torch.Tensor = None, ): """ Forward pass of the VGGT model. Args: images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]. B: batch size, S: sequence length, 3: RGB channels, H: height, W: width query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates. Shape: [N, 2] or [B, N, 2], where N is the number of query points. Default: None Returns: dict: A dictionary containing the following predictions: - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration) - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1] - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W] - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3] - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W] - images (torch.Tensor): Original input images, preserved for visualization If query_points is provided, also includes: - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N] - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N] """ # If without batch dimension, add it if len(images.shape) == 4: images = images.unsqueeze(0) if query_points is not None and len(query_points.shape) == 2: query_points = query_points.unsqueeze(0) aggregated_tokens_list, patch_start_idx = self.aggregator(images) predictions = {} with torch.cuda.amp.autocast(enabled=False): if self.camera_head is not None: pose_enc_list = self.camera_head(aggregated_tokens_list) predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration if self.depth_head is not None: depth, depth_conf = self.depth_head( aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx ) predictions["depth"] = depth predictions["depth_conf"] = depth_conf if self.point_head is not None: pts3d, pts3d_conf = self.point_head( aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx ) predictions["world_points"] = pts3d predictions["world_points_conf"] = pts3d_conf if self.track_head is not None and query_points is not None: track_list, vis, conf = self.track_head( aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points ) predictions["track"] = track_list[-1] # track of the last iteration predictions["vis"] = vis predictions["conf"] = conf predictions["images"] = images return predictions ================================================ FILE: mvtracker/models/core/vggt/utils/geometry.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import torch import numpy as np def unproject_depth_map_to_point_map( depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray ) -> np.ndarray: """ Unproject a batch of depth maps to 3D world coordinates. Args: depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W) extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4) intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3) Returns: np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3) """ if isinstance(depth_map, torch.Tensor): depth_map = depth_map.cpu().numpy() if isinstance(extrinsics_cam, torch.Tensor): extrinsics_cam = extrinsics_cam.cpu().numpy() if isinstance(intrinsics_cam, torch.Tensor): intrinsics_cam = intrinsics_cam.cpu().numpy() world_points_list = [] for frame_idx in range(depth_map.shape[0]): cur_world_points, _, _ = depth_to_world_coords_points( depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx] ) world_points_list.append(cur_world_points) world_points_array = np.stack(world_points_list, axis=0) return world_points_array def depth_to_world_coords_points( depth_map: np.ndarray, extrinsic: np.ndarray, intrinsic: np.ndarray, eps=1e-8, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Convert a depth map to world coordinates. Args: depth_map (np.ndarray): Depth map of shape (H, W). intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world. Returns: tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W). """ if depth_map is None: return None, None, None # Valid depth mask point_mask = depth_map > eps # Convert depth map to camera coordinates cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) # Multiply with the inverse of extrinsic matrix to transform to world coordinates # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4)) cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] R_cam_to_world = cam_to_world_extrinsic[:3, :3] t_cam_to_world = cam_to_world_extrinsic[:3, 3] # Apply the rotation and translation to the camera coordinates world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3 # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world return world_coords_points, cam_coords_points, point_mask def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ Convert a depth map to camera coordinates. Args: depth_map (np.ndarray): Depth map of shape (H, W). intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). Returns: tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3) """ H, W = depth_map.shape assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew" # Intrinsic parameters fu, fv = intrinsic[0, 0], intrinsic[1, 1] cu, cv = intrinsic[0, 2], intrinsic[1, 2] # Generate grid of pixel coordinates u, v = np.meshgrid(np.arange(W), np.arange(H)) # Unproject to camera coordinates x_cam = (u - cu) * depth_map / fu y_cam = (v - cv) * depth_map / fv z_cam = depth_map # Stack to form camera coordinates cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) return cam_coords def closed_form_inverse_se3(se3, R=None, T=None): """ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. If `R` and `T` are provided, they must correspond to the rotation and translation components of `se3`. Otherwise, they will be extracted from `se3`. Args: se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. R (optional): Nx3x3 array or tensor of rotation matrices. T (optional): Nx3x1 array or tensor of translation vectors. Returns: Inverted SE3 matrices with the same type and device as `se3`. Shapes: se3: (N, 4, 4) R: (N, 3, 3) T: (N, 3, 1) """ # Check if se3 is a numpy array or a torch tensor is_numpy = isinstance(se3, np.ndarray) # Validate shapes if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") # Extract R and T if not provided if R is None: R = se3[:, :3, :3] # (N,3,3) if T is None: T = se3[:, :3, 3:] # (N,3,1) # Transpose R if is_numpy: # Compute the transpose of the rotation for NumPy R_transposed = np.transpose(R, (0, 2, 1)) # -R^T t for NumPy top_right = -np.matmul(R_transposed, T) inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) else: R_transposed = R.transpose(1, 2) # (N,3,3) top_right = -torch.bmm(R_transposed, T) # (N,3,1) inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) inverted_matrix[:, :3, :3] = R_transposed inverted_matrix[:, :3, 3:] = top_right return inverted_matrix ================================================ FILE: mvtracker/models/core/vggt/utils/load_fn.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from PIL import Image from torchvision import transforms as TF def load_and_preprocess_images(image_path_list, mode="crop"): """ A quick start function to load and preprocess images for model input. This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. Args: image_path_list (list): List of paths to image files mode (str, optional): Preprocessing mode, either "crop" or "pad". - "crop" (default): Sets width to 518px and center crops height if needed. - "pad": Preserves all pixels by making the largest dimension 518px and padding the smaller dimension to reach a square shape. Returns: torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) Raises: ValueError: If the input list is empty or if mode is invalid Notes: - Images with different dimensions will be padded with white (value=1.0) - A warning is printed when images have different shapes - When mode="crop": The function ensures width=518px while maintaining aspect ratio and height is center-cropped if larger than 518px - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio and the smaller dimension is padded to reach a square shape (518x518) - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements """ # Check for empty list if len(image_path_list) == 0: raise ValueError("At least 1 image is required") # Validate mode if mode not in ["crop", "pad"]: raise ValueError("Mode must be either 'crop' or 'pad'") images = [] shapes = set() to_tensor = TF.ToTensor() target_size = 518 # First process all images and collect their shapes for image_path in image_path_list: # Open image img = Image.open(image_path) # If there's an alpha channel, blend onto white background: if img.mode == "RGBA": # Create white background background = Image.new("RGBA", img.size, (255, 255, 255, 255)) # Alpha composite onto the white background img = Image.alpha_composite(background, img) # Now convert to "RGB" (this step assigns white for transparent areas) img = img.convert("RGB") width, height = img.size if mode == "pad": # Make the largest dimension 518px while maintaining aspect ratio if width >= height: new_width = target_size new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14 else: new_height = target_size new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14 else: # mode == "crop" # Original behavior: set width to 518px new_width = target_size # Calculate height maintaining aspect ratio, divisible by 14 new_height = round(height * (new_width / width) / 14) * 14 # Resize with new dimensions (width, height) img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) img = to_tensor(img) # Convert to tensor (0, 1) # Center crop height if it's larger than 518 (only in crop mode) if mode == "crop" and new_height > target_size: start_y = (new_height - target_size) // 2 img = img[:, start_y : start_y + target_size, :] # For pad mode, pad to make a square of target_size x target_size if mode == "pad": h_padding = target_size - img.shape[1] w_padding = target_size - img.shape[2] if h_padding > 0 or w_padding > 0: pad_top = h_padding // 2 pad_bottom = h_padding - pad_top pad_left = w_padding // 2 pad_right = w_padding - pad_left # Pad with white (value=1.0) img = torch.nn.functional.pad( img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 ) shapes.add((img.shape[1], img.shape[2])) images.append(img) # Check if we have different shapes # In theory our model can also work well with different shapes if len(shapes) > 1: print(f"Warning: Found images with different shapes: {shapes}") # Find maximum dimensions max_height = max(shape[0] for shape in shapes) max_width = max(shape[1] for shape in shapes) # Pad images if necessary padded_images = [] for img in images: h_padding = max_height - img.shape[1] w_padding = max_width - img.shape[2] if h_padding > 0 or w_padding > 0: pad_top = h_padding // 2 pad_bottom = h_padding - pad_top pad_left = w_padding // 2 pad_right = w_padding - pad_left img = torch.nn.functional.pad( img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 ) padded_images.append(img) images = padded_images images = torch.stack(images) # concatenate images # Ensure correct shape when single image if len(image_path_list) == 1: # Verify shape is (1, C, H, W) if images.dim() == 3: images = images.unsqueeze(0) return images ================================================ FILE: mvtracker/models/core/vggt/utils/pose_enc.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from .rotation import quat_to_mat, mat_to_quat def extri_intri_to_pose_encoding( extrinsics, intrinsics, image_size_hw=None, # e.g., (256, 512) pose_encoding_type="absT_quaR_FoV", ): """Convert camera extrinsics and intrinsics to a compact pose encoding. This function transforms camera parameters into a unified pose encoding format, which can be used for various downstream tasks like pose prediction or representation. Args: extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4, where B is batch size and S is sequence length. In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3. Defined in pixels, with format: [[fx, 0, cx], [0, fy, cy], [0, 0, 1]] where fx, fy are focal lengths and (cx, cy) is the principal point image_size_hw (tuple): Tuple of (height, width) of the image in pixels. Required for computing field of view values. For example: (256, 512). pose_encoding_type (str): Type of pose encoding to use. Currently only supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). Returns: torch.Tensor: Encoded camera pose parameters with shape BxSx9. For "absT_quaR_FoV" type, the 9 dimensions are: - [:3] = absolute translation vector T (3D) - [3:7] = rotation as quaternion quat (4D) - [7:] = field of view (2D) """ # extrinsics: BxSx3x4 # intrinsics: BxSx3x3 if pose_encoding_type == "absT_quaR_FoV": R = extrinsics[:, :, :3, :3] # BxSx3x3 T = extrinsics[:, :, :3, 3] # BxSx3 quat = mat_to_quat(R) # Note the order of h and w here H, W = image_size_hw fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() else: raise NotImplementedError return pose_encoding def pose_encoding_to_extri_intri( pose_encoding, image_size_hw=None, # e.g., (256, 512) pose_encoding_type="absT_quaR_FoV", build_intrinsics=True, ): """Convert a pose encoding back to camera extrinsics and intrinsics. This function performs the inverse operation of extri_intri_to_pose_encoding, reconstructing the full camera parameters from the compact encoding. Args: pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9, where B is batch size and S is sequence length. For "absT_quaR_FoV" type, the 9 dimensions are: - [:3] = absolute translation vector T (3D) - [3:7] = rotation as quaternion quat (4D) - [7:] = field of view (2D) image_size_hw (tuple): Tuple of (height, width) of the image in pixels. Required for reconstructing intrinsics from field of view values. For example: (256, 512). pose_encoding_type (str): Type of pose encoding used. Currently only supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). build_intrinsics (bool): Whether to reconstruct the intrinsics matrix. If False, only extrinsics are returned and intrinsics will be None. Returns: tuple: (extrinsics, intrinsics) - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4. In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3, or None if build_intrinsics is False. Defined in pixels, with format: [[fx, 0, cx], [0, fy, cy], [0, 0, 1]] where fx, fy are focal lengths and (cx, cy) is the principal point, assumed to be at the center of the image (W/2, H/2). """ intrinsics = None if pose_encoding_type == "absT_quaR_FoV": T = pose_encoding[..., :3] quat = pose_encoding[..., 3:7] fov_h = pose_encoding[..., 7] fov_w = pose_encoding[..., 8] R = quat_to_mat(quat) extrinsics = torch.cat([R, T[..., None]], dim=-1) if build_intrinsics: H, W = image_size_hw fy = (H / 2.0) / torch.tan(fov_h / 2.0) fx = (W / 2.0) / torch.tan(fov_w / 2.0) intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device) intrinsics[..., 0, 0] = fx intrinsics[..., 1, 1] = fy intrinsics[..., 0, 2] = W / 2 intrinsics[..., 1, 2] = H / 2 intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 else: raise NotImplementedError return extrinsics, intrinsics ================================================ FILE: mvtracker/models/core/vggt/utils/rotation.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d import torch import numpy as np import torch.nn.functional as F def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: """ Quaternion Order: XYZW or say ijkr, scalar-last Convert rotations given as quaternions to rotation matrices. Args: quaternions: quaternions with real part last, as tensor of shape (..., 4). Returns: Rotation matrices as tensor of shape (..., 3, 3). """ i, j, k, r = torch.unbind(quaternions, -1) # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. two_s = 2.0 / (quaternions * quaternions).sum(-1) o = torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), -1, ) return o.reshape(quaternions.shape[:-1] + (3, 3)) def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: """ Convert rotations given as rotation matrices to quaternions. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). Returns: quaternions with real part last, as tensor of shape (..., 4). Quaternion Order: XYZW or say ijkr, scalar-last """ if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") batch_dim = matrix.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) q_abs = _sqrt_positive_part( torch.stack( [ 1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22, ], dim=-1, ) ) # we produce the desired quaternion multiplied by each of r, i, j, k quat_by_rijk = torch.stack( [ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and # `int`. torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), ], dim=-2, ) # We floor here at 0.1 but the exact level is not important; if q_abs is small, # the candidate won't be picked. flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) # if not for numerical problems, quat_candidates[i] should be same (up to a sign), # forall i; we pick the best-conditioned one (with the largest denominator) out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) # Convert from rijk to ijkr out = out[..., [1, 2, 3, 0]] out = standardize_quaternion(out) return out def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: """ Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. """ ret = torch.zeros_like(x) positive_mask = x > 0 if torch.is_grad_enabled(): ret[positive_mask] = torch.sqrt(x[positive_mask]) else: ret = torch.where(positive_mask, torch.sqrt(x), ret) return ret def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: """ Convert a unit quaternion to a standard form: one in which the real part is non negative. Args: quaternions: Quaternions with real part last, as tensor of shape (..., 4). Returns: Standardized quaternions as tensor of shape (..., 4). """ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) ================================================ FILE: mvtracker/models/core/vggt/utils/visual_track.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import cv2 import torch import numpy as np import os def color_from_xy(x, y, W, H, cmap_name="hsv"): """ Map (x, y) -> color in (R, G, B). 1) Normalize x,y to [0,1]. 2) Combine them into a single scalar c in [0,1]. 3) Use matplotlib's colormap to convert c -> (R,G,B). You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y). """ import matplotlib.cm import matplotlib.colors x_norm = x / max(W - 1, 1) y_norm = y / max(H - 1, 1) # Simple combination: c = (x_norm + y_norm) / 2.0 cmap = matplotlib.cm.get_cmap(cmap_name) # cmap(c) -> (r,g,b,a) in [0,1] rgba = cmap(c) r, g, b = rgba[0], rgba[1], rgba[2] return (r, g, b) # in [0,1], RGB order def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"): """ Given all tracks in one sample (b), compute a (N,3) array of RGB color values in [0,255]. The color is determined by the (x,y) position in the first visible frame for each track. Args: tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame. vis_mask_b: (S, N) boolean mask; if None, assume all are visible. image_width, image_height: used for normalizing (x, y). cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet'). Returns: track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255]. """ S, N, _ = tracks_b.shape track_colors = np.zeros((N, 3), dtype=np.uint8) if vis_mask_b is None: # treat all as visible vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device) for i in range(N): # Find first visible frame for track i visible_frames = torch.where(vis_mask_b[:, i])[0] if len(visible_frames) == 0: # track is never visible; just assign black or something track_colors[i] = (0, 0, 0) continue first_s = int(visible_frames[0].item()) # use that frame's (x,y) x, y = tracks_b[first_s, i].tolist() # map (x,y) -> (R,G,B) in [0,1] r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name) # scale to [0,255] r, g, b = int(r * 255), int(g * 255), int(b * 255) track_colors[i] = (r, g, b) return track_colors def visualize_tracks_on_images( images, tracks, track_vis_mask=None, out_dir="track_visuals_concat_by_xy", image_format="CHW", # "CHW" or "HWC" normalize_mode="[0,1]", cmap_name="hsv", # e.g. "hsv", "rainbow", "jet" frames_per_row=4, # New parameter for grid layout save_grid=True, # Flag to control whether to save the grid image ): """ Visualizes frames in a grid layout with specified frames per row. Each track's color is determined by its (x,y) position in the first visible frame (or frame 0 if always visible). Finally convert the BGR result to RGB before saving. Also saves each individual frame as a separate PNG file. Args: images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC. tracks: torch.Tensor (S, N, 2), last dim = (x, y). track_vis_mask: torch.Tensor (S, N) or None. out_dir: folder to save visualizations. image_format: "CHW" or "HWC". normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255 cmap_name: a matplotlib colormap name for color_from_xy. frames_per_row: number of frames to display in each row of the grid. save_grid: whether to save all frames in one grid image. Returns: None (saves images in out_dir). """ if len(tracks.shape) == 4: tracks = tracks.squeeze(0) images = images.squeeze(0) if track_vis_mask is not None: track_vis_mask = track_vis_mask.squeeze(0) import matplotlib matplotlib.use("Agg") # for non-interactive (optional) os.makedirs(out_dir, exist_ok=True) S = images.shape[0] _, N, _ = tracks.shape # (S, N, 2) # Move to CPU images = images.cpu().clone() tracks = tracks.cpu().clone() if track_vis_mask is not None: track_vis_mask = track_vis_mask.cpu().clone() # Infer H, W from images shape if image_format == "CHW": # e.g. images[s].shape = (3, H, W) H, W = images.shape[2], images.shape[3] else: # e.g. images[s].shape = (H, W, 3) H, W = images.shape[1], images.shape[2] # Pre-compute the color for each track i based on first visible position track_colors_rgb = get_track_colors_by_position( tracks, # shape (S, N, 2) vis_mask_b=track_vis_mask if track_vis_mask is not None else None, image_width=W, image_height=H, cmap_name=cmap_name, ) # We'll accumulate each frame's drawn image in a list frame_images = [] for s in range(S): # shape => either (3, H, W) or (H, W, 3) img = images[s] # Convert to (H, W, 3) if image_format == "CHW": img = img.permute(1, 2, 0) # (H, W, 3) # else "HWC", do nothing img = img.numpy().astype(np.float32) # Scale to [0,255] if needed if normalize_mode == "[0,1]": img = np.clip(img, 0, 1) * 255.0 elif normalize_mode == "[-1,1]": img = (img + 1.0) * 0.5 * 255.0 img = np.clip(img, 0, 255.0) # else no normalization # Convert to uint8 img = img.astype(np.uint8) # For drawing in OpenCV, convert to BGR img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Draw each visible track cur_tracks = tracks[s] # shape (N, 2) if track_vis_mask is not None: valid_indices = torch.where(track_vis_mask[s])[0] else: valid_indices = range(N) cur_tracks_np = cur_tracks.numpy() for i in valid_indices: x, y = cur_tracks_np[i] pt = (int(round(x)), int(round(y))) # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR R, G, B = track_colors_rgb[i] color_bgr = (int(B), int(G), int(R)) cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1) # Convert back to RGB for consistent final saving: img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) # Save individual frame frame_path = os.path.join(out_dir, f"frame_{s:04d}.png") # Convert to BGR for OpenCV imwrite frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) cv2.imwrite(frame_path, frame_bgr) frame_images.append(img_rgb) # Only create and save the grid image if save_grid is True if save_grid: # Calculate grid dimensions num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division # Create a grid of images grid_img = None for row in range(num_rows): start_idx = row * frames_per_row end_idx = min(start_idx + frames_per_row, S) # Concatenate this row horizontally row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1) # If this row has fewer than frames_per_row images, pad with black if end_idx - start_idx < frames_per_row: padding_width = (frames_per_row - (end_idx - start_idx)) * W padding = np.zeros((H, padding_width, 3), dtype=np.uint8) row_img = np.concatenate([row_img, padding], axis=1) # Add this row to the grid if grid_img is None: grid_img = row_img else: grid_img = np.concatenate([grid_img, row_img], axis=0) out_path = os.path.join(out_dir, "tracks_grid.png") # Convert back to BGR for OpenCV imwrite grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR) cv2.imwrite(out_path, grid_img_bgr) print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}") print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png") ================================================ FILE: mvtracker/models/core/vit/__init__.py ================================================ ================================================ FILE: mvtracker/models/core/vit/common.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Type import torch import torch.nn as nn class MLPBlock(nn.Module): def __init__( self, embedding_dim: int, mlp_dim: int, act: Type[nn.Module] = nn.GELU, ) -> None: super().__init__() self.lin1 = nn.Linear(embedding_dim, mlp_dim) self.lin2 = nn.Linear(mlp_dim, embedding_dim) self.act = act() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.lin2(self.act(self.lin1(x))) # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa class LayerNorm2d(nn.Module): def __init__(self, num_channels: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x ================================================ FILE: mvtracker/models/core/vit/encoder.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Optional, Tuple, Type import torch import torch.nn as nn import torch.nn.functional as F from mvtracker.models.core.vit.common import ( LayerNorm2d, MLPBlock ) # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa class ImageEncoderViT(nn.Module): def __init__( self, img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, ...] = (), ) -> None: """ Args: img_size (int): Input image size. patch_size (int): Patch size. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. depth (int): Depth of ViT. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_abs_pos (bool): If True, use absolute positional embeddings. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. global_attn_indexes (list): Indexes for blocks using global attention. """ super().__init__() self.img_size = img_size self.patch_embed = PatchEmbed( kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), in_chans=in_chans, embed_dim=embed_dim, ) self.pos_embed: Optional[nn.Parameter] = None if use_abs_pos: # Initialize absolute positional embedding with pretrain image size. self.pos_embed = nn.Parameter( torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) ) self.blocks = nn.ModuleList() for i in range(depth): block = Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer, act_layer=act_layer, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, window_size=window_size if i not in global_attn_indexes else 0, input_size=(img_size // patch_size, img_size // patch_size), ) self.blocks.append(block) self.neck = nn.Sequential( nn.Conv2d( embed_dim, out_chans, kernel_size=1, bias=False, ), LayerNorm2d(out_chans), nn.Conv2d( out_chans, out_chans, kernel_size=3, padding=1, bias=False, ), LayerNorm2d(out_chans), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) if self.pos_embed is not None: x = x + self.pos_embed for blk in self.blocks: x = blk(x) x = self.neck(x.permute(0, 3, 1, 2)) return x class Block(nn.Module): """Transformer blocks with support of window attention and residual propagation blocks""" def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, input_size: Optional[Tuple[int, int]] = None, ) -> None: """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. If it equals 0, then use global attention. input_size (tuple(int, int) or None): Input resolution for calculating the relative positional parameter size. """ super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, input_size=input_size if window_size == 0 else (window_size, window_size), ) self.norm2 = norm_layer(dim) self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) self.window_size = window_size def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x x = self.norm1(x) # Window partition if self.window_size > 0: H, W = x.shape[1], x.shape[2] x, pad_hw = window_partition(x, self.window_size) x = self.attn(x) # Reverse window partition if self.window_size > 0: x = window_unpartition(x, self.window_size, pad_hw, (H, W)) x = shortcut + x x = x + self.mlp(self.norm2(x)) return x class Attention(nn.Module): """Multi-head Attention block with relative position embeddings.""" def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, input_size: Optional[Tuple[int, int]] = None, ) -> None: """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. qkv_bias (bool): If True, add a learnable bias to query, key, value. rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. input_size (tuple(int, int) or None): Input resolution for calculating the relative positional parameter size. """ super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.use_rel_pos = use_rel_pos if self.use_rel_pos: assert ( input_size is not None ), "Input size must be provided if using relative positional encoding." # initialize relative positional embeddings self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: B, H, W, _ = x.shape # qkv with shape (3, B, nHead, H * W, C) qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # q, k, v with shape (B * nHead, H * W, C) q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) attn = (q * self.scale) @ k.transpose(-2, -1) if self.use_rel_pos: attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) attn = attn.softmax(dim=-1) x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) x = self.proj(x) return x def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: """ Partition into non-overlapping windows with padding if needed. Args: x (tensor): input tokens with [B, H, W, C]. window_size (int): window size. Returns: windows: windows after partition with [B * num_windows, window_size, window_size, C]. (Hp, Wp): padded height and width before partition """ B, H, W, C = x.shape pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows, (Hp, Wp) def window_unpartition( windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] ) -> torch.Tensor: """ Window unpartition into original sequences and removing padding. Args: windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. window_size (int): window size. pad_hw (Tuple): padded height and width (Hp, Wp). hw (Tuple): original height and width (H, W) before padding. Returns: x: unpartitioned sequences with [B, H, W, C]. """ Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) if Hp > H or Wp > W: x = x[:, :H, :W, :].contiguous() return x def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: """ Get relative positional embeddings according to the relative positions of query and key sizes. Args: q_size (int): size of query q. k_size (int): size of key k. rel_pos (Tensor): relative position embeddings (L, C). Returns: Extracted positional embeddings according to relative positions. """ max_rel_dist = int(2 * max(q_size, k_size) - 1) # Interpolate rel pos if needed. if rel_pos.shape[0] != max_rel_dist: # Interpolate rel pos. rel_pos_resized = F.interpolate( rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear", ) rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) else: rel_pos_resized = rel_pos # Scale the coords with short length if shapes for q and k are different. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) return rel_pos_resized[relative_coords.long()] def add_decomposed_rel_pos( attn: torch.Tensor, q: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, q_size: Tuple[int, int], k_size: Tuple[int, int], ) -> torch.Tensor: """ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 Args: attn (Tensor): attention map. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. q_size (Tuple): spatial sequence size of query q with (q_h, q_w). k_size (Tuple): spatial sequence size of key k with (k_h, k_w). Returns: attn (Tensor): attention map with added relative positional embeddings. """ q_h, q_w = q_size k_h, k_w = k_size Rh = get_rel_pos(q_h, k_h, rel_pos_h) Rw = get_rel_pos(q_w, k_w, rel_pos_w) B, _, dim = q.shape r_q = q.reshape(B, q_h, q_w, dim) rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) attn = ( attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] ).view(B, q_h * q_w, k_h * k_w) return attn class PatchEmbed(nn.Module): """ Image to Patch Embedding. """ def __init__( self, kernel_size: Tuple[int, int] = (16, 16), stride: Tuple[int, int] = (16, 16), padding: Tuple[int, int] = (0, 0), in_chans: int = 3, embed_dim: int = 768, ) -> None: """ Args: kernel_size (Tuple): kernel size of the projection layer. stride (Tuple): stride of the projection layer. padding (Tuple): padding size of the projection layer. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. """ super().__init__() self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) # B C H W -> B H W C x = x.permute(0, 2, 3, 1) return x ================================================ FILE: mvtracker/models/evaluation_predictor_3dpt.py ================================================ import os import random from typing import Optional, Tuple import numpy as np import torch import torch.nn.functional as F from tqdm import tqdm from mvtracker.models.core.model_utils import bilinear_sample2d, get_points_on_a_grid from mvtracker.models.core.model_utils import world_space_to_pixel_xy_and_camera_z from mvtracker.models.core.mvtracker.mvtracker import save_pointcloud_to_ply from mvtracker.utils.basic import to_homogeneous, from_homogeneous, time_now from mvtracker.utils.visualizer_mp4 import MultiViewVisualizer class EvaluationPredictor(torch.nn.Module): def __init__( self, multiview_model: torch.nn.Module, interp_shape: Optional[Tuple[int, int]] = (384, 512), visibility_threshold=0.5, grid_size: int = 5, n_grids_per_view: int = 1, local_grid_size: int = 8, local_extent: int = 50, single_point: bool = False, sift_size: int = 0, num_uniformly_sampled_pts: int = 0, n_iters: int = 6, ) -> None: super(EvaluationPredictor, self).__init__() self.model = multiview_model self.interp_shape = interp_shape self.visibility_threshold = visibility_threshold self.grid_size = grid_size self.n_grids_per_view = n_grids_per_view self.local_grid_size = local_grid_size self.local_extent = local_extent self.single_point = single_point self.sift_size = sift_size self.num_uniformly_sampled_pts = num_uniformly_sampled_pts self.n_iters = n_iters self.model.eval() def forward( self, rgbs, depths, query_points_3d, intrs, extrs, save_debug_logs=False, debug_logs_path="", query_points_view=None, **kwargs, ): batch_size, num_views, num_frames, _, height_raw, width_raw = rgbs.shape _, num_points, _ = query_points_3d.shape assert rgbs.shape == (batch_size, num_views, num_frames, 3, height_raw, width_raw) assert depths.shape == (batch_size, num_views, num_frames, 1, height_raw, width_raw) assert query_points_3d.shape == (batch_size, num_points, 4) assert intrs.shape == (batch_size, num_views, num_frames, 3, 3) assert extrs.shape == (batch_size, num_views, num_frames, 3, 4) if batch_size != 1: raise NotImplementedError # Interpolate the inputs to the desired resolution, if needed if self.interp_shape is None: height, width = height_raw, width_raw else: height, width = self.interp_shape rgbs = rgbs.reshape(-1, 3, height_raw, width_raw) rgbs = F.interpolate(rgbs, (height, width), mode="nearest") rgbs = rgbs.reshape(batch_size, num_views, num_frames, 3, height, width) depths = depths.reshape(-1, 1, height_raw, width_raw) depths = F.interpolate(depths, (height, width), mode="nearest") depths = depths.reshape(batch_size, num_views, num_frames, 1, height, width) intrs_resize_transform = torch.tensor([ [width / width_raw, 0, 0], [0, height / height_raw, 0], [0, 0, 1], ], device=intrs.device, dtype=intrs.dtype) intrs = torch.einsum("ij,BVTjk->BVTik", intrs_resize_transform, intrs) # Unpack the query points query_points_t = query_points_3d[:, :, :1].long() query_points_xyz_worldspace = query_points_3d[:, :, 1:] # Invert intrinsics and extrinsics intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype) extrs_square = torch.eye(4).to(extrs.device)[None].repeat(batch_size, num_views, num_frames, 1, 1) extrs_square[:, :, :, :3, :] = extrs extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype) support_points = torch.zeros((batch_size, 0, 4), device=rgbs.device) grid_points = [] if self.grid_size > 0: pixel_xy = get_points_on_a_grid(self.grid_size, (height, width), device=rgbs.device) pixel_xy_homo = to_homogeneous(pixel_xy) for t in range(0, num_frames, max(1, num_frames // self.n_grids_per_view)): for view_idx in range(num_views): camera_z = bilinear_sample2d( depths[0, view_idx, t][None], pixel_xy[..., 0], pixel_xy[..., 1], ).permute(0, 2, 1) camera_xyz = torch.einsum('Bij,BNj->BNi', intrs_inv[:, view_idx, t, :, :], pixel_xy_homo) camera_xyz = camera_xyz * camera_z camera_xyz_homo = to_homogeneous(camera_xyz) world_xyz_homo = torch.einsum('Bij,BNj->BNi', extrs_inv[:, view_idx, t, :, :], camera_xyz_homo) world_xyz = from_homogeneous(world_xyz_homo) grid_points_i = torch.cat([torch.ones_like(world_xyz[:, :, :1]) * t, world_xyz], dim=2) grid_points.append(grid_points_i) grid_points = torch.cat(grid_points, dim=1) support_points = torch.concat([support_points, grid_points], dim=1) if save_debug_logs: os.makedirs(debug_logs_path, exist_ok=True) save_pointcloud_to_ply( filename=os.path.join(debug_logs_path, time_now() + "__predictor__query_points.ply"), points=query_points_xyz_worldspace[0].cpu().numpy(), colors=np.ones_like(query_points_xyz_worldspace[0].cpu().numpy(), dtype=int) * np.array( [255, 30, 60]), ) save_pointcloud_to_ply( filename=os.path.join(debug_logs_path, time_now() + "__predictor__support_grid_points.ply"), points=grid_points[0, :, 1:].cpu().numpy(), colors=np.ones_like(grid_points[0, :, 1:].cpu().numpy(), dtype=int) * np.array([45, 255, 60]), ) sift_points = [] if self.sift_size > 0: raise NotImplementedError # xy = get_sift_sampled_pts(video, sift_size, T, [H, W], device=device) # if xy.shape[1] == sift_size: # queries = torch.cat([queries, xy], dim=1) # # else: # sift_size = 0 sift_points = torch.cat(sift_points, dim=1) support_points = torch.concat([support_points, sift_points], dim=1) support_uniform_pts = [] if self.num_uniformly_sampled_pts > 0: sampled_pts = get_uniformly_sampled_pts( self.num_uniformly_sampled_pts, num_frames, (height, width), device=rgbs.device, )[0] # shape: (N, 3) where each row is (t, y, x) t_samples = sampled_pts[:, 0].long() y_samples = sampled_pts[:, 1].float() x_samples = sampled_pts[:, 2].float() pixel_xy = torch.stack([x_samples, y_samples], dim=-1)[None] # (1, N, 2) pixel_xy_homo = to_homogeneous(pixel_xy) for idx in range(sampled_pts.shape[0]): t = t_samples[idx].item() x = x_samples[idx].item() y = y_samples[idx].item() for view_idx in range(num_views): depth_val = bilinear_sample2d( depths[0, view_idx, t][None], # shape (1, 1, H, W) torch.tensor([[x]], device=rgbs.device), torch.tensor([[y]], device=rgbs.device), ).item() cam_xy_h = torch.tensor([[x, y, 1.0]], device=rgbs.device).T K_inv = intrs_inv[0, view_idx, t] extr_inv = extrs_inv[0, view_idx, t] cam_xyz = (K_inv @ cam_xy_h).squeeze() * depth_val cam_xyz_h = to_homogeneous(cam_xyz[None])[0] world_xyz_h = extr_inv @ cam_xyz_h world_xyz = from_homogeneous(world_xyz_h[None])[0] support_point = torch.cat([torch.tensor([t], device=rgbs.device), world_xyz]) support_uniform_pts.append(support_point) if support_uniform_pts: support_uniform_pts = torch.stack(support_uniform_pts, dim=0)[None] # (1, N, 4) support_points = torch.cat([support_points, support_uniform_pts], dim=1) if self.single_point: # Project the queries to each view # This will be needed if adding local grid points query_points_xyz_worldspace_homo = to_homogeneous(query_points_xyz_worldspace) query_points_perview_camera_xyz = torch.einsum('BVTij,BNj->BVTNi', extrs, query_points_xyz_worldspace_homo) query_points_perview_pixel_xy_homo = torch.einsum('BVTij,BVTNj->BVTNi', intrs, query_points_perview_camera_xyz) query_points_perview_pixel_xy = from_homogeneous(query_points_perview_pixel_xy_homo) query_points_perview_camera_xyz = query_points_perview_camera_xyz[ # Extract at the correct per-query timestep torch.arange(batch_size)[:, None, None], torch.arange(num_views)[None, :, None], query_points_t[:, None, :, 0], torch.arange(num_points)[None, None, :], ] query_points_perview_pixel_xy = query_points_perview_pixel_xy[ # Extract at the correct per-query timestep torch.arange(batch_size)[:, None, None], torch.arange(num_views)[None, :, None], query_points_t[:, None, :, 0], torch.arange(num_points)[None, None, :], ] query_points_perview_camera_z = query_points_perview_camera_xyz[..., -1:] traj_e = torch.zeros((batch_size, num_frames, num_points, 3), device=rgbs.device) vis_e = torch.zeros((batch_size, num_frames, num_points), device=rgbs.device) for point_idx in tqdm(range(num_points), desc="Single point evaluation"): # Support points for this query point support_points_i = torch.zeros((batch_size, 0, 4), device=rgbs.device) # Add the local support points if self.local_grid_size > 0: t = query_points_t[0, point_idx, 0].item() local_grid_points = torch.zeros((batch_size, 0, 4), device=rgbs.device) for view_idx in range(num_views): pixel_xy = get_points_on_a_grid( size=self.local_grid_size, extent=(self.local_extent, self.local_extent), center=(query_points_perview_pixel_xy[0, view_idx, point_idx, 1].item(), query_points_perview_pixel_xy[0, view_idx, point_idx, 0].item()), device=rgbs.device, ) inside_frame = ((pixel_xy[0, :, 0] >= 0) & (pixel_xy[0, :, 0] < width) & (pixel_xy[0, :, 1] >= 0) & (pixel_xy[0, :, 1] < height)) if not inside_frame.any(): continue pixel_xy = pixel_xy[:, inside_frame, :] pixel_xy_homo = to_homogeneous(pixel_xy) camera_z = bilinear_sample2d( depths[0, view_idx, t][None], pixel_xy[..., 0], pixel_xy[..., 1], ).permute(0, 2, 1) camera_xyz = torch.einsum('Bij,BNj->BNi', intrs_inv[:, view_idx, t, :, :], pixel_xy_homo) camera_xyz = camera_xyz * camera_z camera_xyz_homo = to_homogeneous(camera_xyz) world_xyz_homo = torch.einsum('Bij,BNj->BNi', extrs_inv[:, view_idx, t, :, :], camera_xyz_homo) world_xyz = from_homogeneous(world_xyz_homo) local_grid_points_i = torch.cat([torch.ones_like(world_xyz[:, :, :1]) * t, world_xyz], dim=2) local_grid_points = torch.cat([local_grid_points, local_grid_points_i], dim=1) support_points_i = torch.cat([support_points_i, local_grid_points], dim=1) # Add the global support points support_points_i = torch.cat([support_points_i, support_points], dim=1) # Forward pass for this query point query_points_i = torch.cat([query_points_3d[:, point_idx: point_idx + 1, :], support_points_i], dim=1) if query_points_view is not None: query_points_view = torch.cat([ query_points_view[:, point_idx: point_idx + 1], query_points_view.new_zeros(support_points_i[:, :, 0].shape), ], dim=1) results_i = self.model( rgbs, depths=depths, query_points=query_points_i, intrs=intrs, extrs=extrs, iters=self.n_iters, save_debug_logs=save_debug_logs and point_idx == 0, debug_logs_path=debug_logs_path, query_points_view=query_points_view, **kwargs, ) traj_e[:, :, point_idx: point_idx + 1] = results_i["traj_e"][:, :, :1] vis_e[:, :, point_idx: point_idx + 1] = results_i["vis_e"][:, :, :1] if save_debug_logs and (point_idx in [0, 1, 2, 3, 4] or point_idx % 100 == 0): visualizer = MultiViewVisualizer( save_dir=debug_logs_path, pad_value=16, fps=12, show_first_frame=0, tracks_leave_trace=0, ) # filename, pred_trajectories, pred_visibilities, qps tuples_to_process = [] tuples_to_process += [( f"predictor__pidx={point_idx}__viz_A_pred", results_i["traj_e"][:, :, :1], results_i["vis_e"][:, :, :1], query_points_i[:, :1, :], )] tuples_to_process += [( f"predictor__pidx={point_idx}__viz_B_pred_w_support", results_i["traj_e"], results_i["vis_e"], query_points_i[:, :, :], )] if self.local_grid_size > 0 and local_grid_points.shape[1] > 0: num_local_support_points = local_grid_points.shape[1] tuples_to_process += [( f"predictor__pidx={point_idx}__viz_C_local_support_grid", results_i["traj_e"][:, :, 1:1 + num_local_support_points, :], results_i["vis_e"][:, :, 1:1 + num_local_support_points], query_points_i[:, 1:1 + num_local_support_points, :], )] if self.grid_size > 0: num_global_support_points = support_points.shape[1] tuples_to_process += [( f"predictor__pidx={point_idx}__viz_D_global_support_grid", results_i["traj_e"][:, :, -num_global_support_points:, :], results_i["vis_e"][:, :, -num_global_support_points:], query_points_i[:, -num_global_support_points:, :], )] for filename, pred_trajectories, pred_visibilities, qps in tuples_to_process: filename = time_now() + "__" + filename # Project the predictions to pixel space for visualization pred_trajectories_pixel_xy_camera_z_per_view = torch.stack([ torch.cat(world_space_to_pixel_xy_and_camera_z( world_xyz=pred_trajectories[0], intrs=intrs[0, view_idx], extrs=extrs[0, view_idx], ), dim=-1) for view_idx in range(num_views) ], dim=0)[None] pred_viz, _ = visualizer.visualize( video=rgbs, video_depth=depths, tracks=pred_trajectories_pixel_xy_camera_z_per_view, visibility=pred_visibilities > 0.5, query_frame=qps[..., 0].long().clone(), filename=filename, writer=None, step=0, save_video=True, ) else: query_points_3d = torch.cat([query_points_3d, support_points], dim=1) if query_points_view is not None: query_points_view = torch.cat([ query_points_view, query_points_view.new_zeros(support_points[:, :, 0].shape) ], dim=1) results = self.model( rgbs, depths=depths, query_points=query_points_3d, intrs=intrs, extrs=extrs, iters=self.n_iters, save_debug_logs=save_debug_logs, debug_logs_path=debug_logs_path, query_points_view=query_points_view, **kwargs, ) traj_e = results["traj_e"][:, :, :num_points, :] vis_e = results["vis_e"][:, :, :num_points] if save_debug_logs: visualizer = MultiViewVisualizer( save_dir=debug_logs_path, pad_value=16, fps=12, show_first_frame=0, tracks_leave_trace=0, ) num_support_grid_points = grid_points.shape[1] if self.grid_size > 0 else 0 view_pts_all_timesteps = num_support_grid_points // num_views view_pts = view_pts_all_timesteps // self.n_grids_per_view if self.grid_size > 0 else 0 for filename, pred_trajectories, pred_visibilities, qps in [ ("predictor__viz_A_pred", traj_e, vis_e, query_points_3d[:, :num_points, :]), ("predictor__viz_B_pred_w_support_grid", results["traj_e"], results["vis_e"], query_points_3d), ("predictor__viz_C_support_grid_only", results["traj_e"][:, :, num_points:, :], results["vis_e"][:, :, num_points:], query_points_3d[:, num_points:, :]), *[( f"predictor__viz_D_support_grid_only__t-0_view-{view_idx}", results["traj_e"][:, :, num_points + view_pts * view_idx:num_points + view_pts * (view_idx + 1), :], results["vis_e"][:, :, num_points + view_pts * view_idx:num_points + view_pts * (view_idx + 1)], query_points_3d[:, num_points + view_pts * view_idx:num_points + view_pts * (view_idx + 1), :], ) for view_idx in range(num_views)], ]: filename = time_now() + "__" + filename # Project the predictions to pixel space for visualization pred_trajectories_pixel_xy_camera_z_per_view = torch.stack([ torch.cat(world_space_to_pixel_xy_and_camera_z( world_xyz=pred_trajectories[0], intrs=intrs[0, view_idx], extrs=extrs[0, view_idx], ), dim=-1) for view_idx in range(num_views) ], dim=0)[None] pred_viz, _ = visualizer.visualize( video=rgbs, video_depth=depths, tracks=pred_trajectories_pixel_xy_camera_z_per_view, visibility=pred_visibilities > 0.5, query_frame=qps[..., 0].long().clone(), filename=filename, writer=None, step=0, save_video=True, ) return { "traj_e": traj_e, "vis_e": vis_e > self.visibility_threshold, "vis_e_as_prob": vis_e, } def get_uniformly_sampled_pts( size: int, num_frames: int, extent: Tuple[float, ...], device: Optional[torch.device] = torch.device("cpu"), ): time_points = torch.randint(low=0, high=num_frames, size=(size, 1), device=device) space_points = torch.rand(size, 2, device=device) * torch.tensor( [extent[1], extent[0]], device=device ) points = torch.cat((time_points, space_points), dim=1) return points[None] def get_superpoint_sampled_pts( video, size: int, num_frames: int, extent: Tuple[float, ...], device: Optional[torch.device] = torch.device("cpu"), ): extractor = SuperPoint(max_num_keypoints=48).eval().cuda() points = list() for _ in range(8): frame_num = random.randint(0, int(num_frames * 0.25)) key_points = extractor.extract( video[0, frame_num, :, :, :] / 255.0, resize=None )["keypoints"] frame_tensor = torch.full((1, key_points.shape[1], 1), frame_num).cuda() points.append(torch.cat([frame_tensor.cuda(), key_points], dim=2)) return torch.cat(points, dim=1)[:, :size, :] def get_sift_sampled_pts( video, size: int, num_frames: int, extent: Tuple[float, ...], device: Optional[torch.device] = torch.device("cpu"), num_sampled_frames: int = 8, sampling_length_percent: float = 0.25, ): import cv2 # assert size == 384, "hardcoded for experiment" sift = cv2.SIFT_create(nfeatures=size // num_sampled_frames) points = list() for _ in range(num_sampled_frames): frame_num = random.randint(0, int(num_frames * sampling_length_percent)) key_points, _ = sift.detectAndCompute( video[0, frame_num, :, :, :] .cpu() .permute(1, 2, 0) .numpy() .astype(np.uint8), None, ) for kp in key_points: points.append([frame_num, int(kp.pt[0]), int(kp.pt[1])]) return torch.tensor(points[:size], device=device)[None] ================================================ FILE: mvtracker/utils/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: mvtracker/utils/basic.py ================================================ import os from datetime import datetime import numpy as np import torch import torch.nn.functional as F EPS = 1e-6 def sub2ind(height, width, y, x): return y * width + x def ind2sub(height, width, ind): y = ind // width x = ind % width return y, x def get_lr_str(lr): lrn = "%.1e" % lr # e.g., 5.0e-04 lrn = lrn[0] + lrn[3:5] + lrn[-1] # e.g., 5e-4 return lrn def strnum(x): s = '%g' % x if '.' in s: if x < 1.0: s = s[s.index('.'):] s = s[:min(len(s), 4)] return s def assert_same_shape(t1, t2): for (x, y) in zip(list(t1.shape), list(t2.shape)): assert (x == y) def print_stats(name, tensor): shape = tensor.shape tensor = tensor.detach().cpu().numpy() print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % ( name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape) def print_stats_py(name, tensor): shape = tensor.shape print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % ( name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape) def print_(name, tensor): tensor = tensor.detach().cpu().numpy() print(name, tensor, tensor.shape) def mkdir(path): if not os.path.exists(path): os.makedirs(path) def normalize_single(d): # d is a whatever shape torch tensor dmin = torch.min(d) dmax = torch.max(d) d = (d - dmin) / (EPS + (dmax - dmin)) return d def normalize(d): # d is B x whatever. normalize within each element of the batch out = torch.zeros(d.size()) if d.is_cuda: out = out.cuda() B = list(d.size())[0] for b in list(range(B)): out[b] = normalize_single(d[b]) return out def hard_argmax2d(tensor): B, C, Y, X = list(tensor.shape) assert (C == 1) # flatten the Tensor along the height and width axes flat_tensor = tensor.reshape(B, -1) # argmax of the flat tensor argmax = torch.argmax(flat_tensor, dim=1) # convert the indices into 2d coordinates argmax_y = torch.floor(argmax / X) # row argmax_x = argmax % X # col argmax_y = argmax_y.reshape(B) argmax_x = argmax_x.reshape(B) return argmax_y, argmax_x def argmax2d(heat, hard=True): B, C, Y, X = list(heat.shape) assert (C == 1) if hard: # hard argmax loc_y, loc_x = hard_argmax2d(heat) loc_y = loc_y.float() loc_x = loc_x.float() else: heat = heat.reshape(B, Y * X) prob = torch.nn.functional.softmax(heat, dim=1) grid_y, grid_x = meshgrid2d(B, Y, X) grid_y = grid_y.reshape(B, -1) grid_x = grid_x.reshape(B, -1) loc_y = torch.sum(grid_y * prob, dim=1) loc_x = torch.sum(grid_x * prob, dim=1) # these are B return loc_y, loc_x def reduce_masked_mean(x, mask, dim=None, keepdim=False): # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting # returns shape-1 # axis can be a list of axes for (a, b) in zip(x.size(), mask.size()): # if not b==1: assert (a == b) # some shape mismatch! # assert(x.size() == mask.size()) prod = x * mask if dim is None: numer = torch.sum(prod) denom = EPS + torch.sum(mask) else: numer = torch.sum(prod, dim=dim, keepdim=keepdim) denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim) mean = numer / denom return mean def reduce_masked_median(x, mask, keep_batch=False): # x and mask are the same shape assert (x.size() == mask.size()) device = x.device B = list(x.shape)[0] x = x.detach().cpu().numpy() mask = mask.detach().cpu().numpy() if keep_batch: x = np.reshape(x, [B, -1]) mask = np.reshape(mask, [B, -1]) meds = np.zeros([B], np.float32) for b in list(range(B)): xb = x[b] mb = mask[b] if np.sum(mb) > 0: xb = xb[mb > 0] meds[b] = np.median(xb) else: meds[b] = np.nan meds = torch.from_numpy(meds).to(device) return meds.float() else: x = np.reshape(x, [-1]) mask = np.reshape(mask, [-1]) if np.sum(mask) > 0: x = x[mask > 0] med = np.median(x) else: med = np.nan med = np.array([med], np.float32) med = torch.from_numpy(med).to(device) return med.float() def pack_seqdim(tensor, B): shapelist = list(tensor.shape) B_, S = shapelist[:2] assert (B == B_) otherdims = shapelist[2:] tensor = torch.reshape(tensor, [B * S] + otherdims) return tensor def unpack_seqdim(tensor, B): shapelist = list(tensor.shape) BS = shapelist[0] assert (BS % B == 0) otherdims = shapelist[1:] S = int(BS / B) tensor = torch.reshape(tensor, [B, S] + otherdims) return tensor def meshgrid2d(B, Y, X, stack=False, norm=False, device='cuda', on_chans=False): # returns a meshgrid sized B x Y x X grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device)) grid_y = torch.reshape(grid_y, [1, Y, 1]) grid_y = grid_y.repeat(B, 1, X) grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device)) grid_x = torch.reshape(grid_x, [1, 1, X]) grid_x = grid_x.repeat(B, Y, 1) if norm: grid_y, grid_x = normalize_grid2d( grid_y, grid_x, Y, X) if stack: # note we stack in xy order # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) if on_chans: grid = torch.stack([grid_x, grid_y], dim=1) else: grid = torch.stack([grid_x, grid_y], dim=-1) return grid else: return grid_y, grid_x def meshgrid3d(B, Z, Y, X, stack=False, norm=False, device='cuda'): # returns a meshgrid sized B x Z x Y x X grid_z = torch.linspace(0.0, Z - 1, Z, device=device) grid_z = torch.reshape(grid_z, [1, Z, 1, 1]) grid_z = grid_z.repeat(B, 1, Y, X) grid_y = torch.linspace(0.0, Y - 1, Y, device=device) grid_y = torch.reshape(grid_y, [1, 1, Y, 1]) grid_y = grid_y.repeat(B, Z, 1, X) grid_x = torch.linspace(0.0, X - 1, X, device=device) grid_x = torch.reshape(grid_x, [1, 1, 1, X]) grid_x = grid_x.repeat(B, Z, Y, 1) # if cuda: # grid_z = grid_z.cuda() # grid_y = grid_y.cuda() # grid_x = grid_x.cuda() if norm: grid_z, grid_y, grid_x = normalize_grid3d( grid_z, grid_y, grid_x, Z, Y, X) if stack: # note we stack in xyz order # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) grid = torch.stack([grid_x, grid_y, grid_z], dim=-1) return grid else: return grid_z, grid_y, grid_x def normalize_grid2d(grid_y, grid_x, Y, X, clamp_extreme=True): # make things in [-1,1] grid_y = 2.0 * (grid_y / float(Y - 1)) - 1.0 grid_x = 2.0 * (grid_x / float(X - 1)) - 1.0 if clamp_extreme: grid_y = torch.clamp(grid_y, min=-2.0, max=2.0) grid_x = torch.clamp(grid_x, min=-2.0, max=2.0) return grid_y, grid_x def normalize_grid3d(grid_z, grid_y, grid_x, Z, Y, X, clamp_extreme=True): # make things in [-1,1] grid_z = 2.0 * (grid_z / float(Z - 1)) - 1.0 grid_y = 2.0 * (grid_y / float(Y - 1)) - 1.0 grid_x = 2.0 * (grid_x / float(X - 1)) - 1.0 if clamp_extreme: grid_z = torch.clamp(grid_z, min=-2.0, max=2.0) grid_y = torch.clamp(grid_y, min=-2.0, max=2.0) grid_x = torch.clamp(grid_x, min=-2.0, max=2.0) return grid_z, grid_y, grid_x def gridcloud2d(B, Y, X, norm=False, device='cuda'): # we want to sample for each location in the grid grid_y, grid_x = meshgrid2d(B, Y, X, norm=norm, device=device) x = torch.reshape(grid_x, [B, -1]) y = torch.reshape(grid_y, [B, -1]) # these are B x N xy = torch.stack([x, y], dim=2) # this is B x N x 2 return xy def gridcloud3d(B, Z, Y, X, norm=False, device='cuda'): # we want to sample for each location in the grid grid_z, grid_y, grid_x = meshgrid3d(B, Z, Y, X, norm=norm, device=device) x = torch.reshape(grid_x, [B, -1]) y = torch.reshape(grid_y, [B, -1]) z = torch.reshape(grid_z, [B, -1]) # these are B x N xyz = torch.stack([x, y, z], dim=2) # this is B x N x 3 return xyz import re def readPFM(file): file = open(file, 'rb') color = None width = None height = None scale = None endian = None header = file.readline().rstrip() if header == b'PF': color = True elif header == b'Pf': color = False else: raise Exception('Not a PFM file.') dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) if dim_match: width, height = map(int, dim_match.groups()) else: raise Exception('Malformed PFM header.') scale = float(file.readline().rstrip()) if scale < 0: # little-endian endian = '<' scale = -scale else: endian = '>' # big-endian data = np.fromfile(file, endian + 'f') shape = (height, width, 3) if color else (height, width) data = np.reshape(data, shape) data = np.flipud(data) return data def normalize_boxlist2d(boxlist2d, H, W): boxlist2d = boxlist2d.clone() ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2) ymin = ymin / float(H) ymax = ymax / float(H) xmin = xmin / float(W) xmax = xmax / float(W) boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2) return boxlist2d def unnormalize_boxlist2d(boxlist2d, H, W): boxlist2d = boxlist2d.clone() ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2) ymin = ymin * float(H) ymax = ymax * float(H) xmin = xmin * float(W) xmax = xmax * float(W) boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2) return boxlist2d def unnormalize_box2d(box2d, H, W): return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1) def normalize_box2d(box2d, H, W): return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1) def get_gaussian_kernel_2d(channels, kernel_size=3, sigma=2.0, mid_one=False): C = channels xy_grid = gridcloud2d(C, kernel_size, kernel_size) # C x N x 2 mean = (kernel_size - 1) / 2.0 variance = sigma ** 2.0 gaussian_kernel = (1.0 / (2.0 * np.pi * variance) ** 1.5) * torch.exp( -torch.sum((xy_grid - mean) ** 2.0, dim=-1) / (2.0 * variance)) # C X N gaussian_kernel = gaussian_kernel.view(C, 1, kernel_size, kernel_size) # C x 1 x 3 x 3 kernel_sum = torch.sum(gaussian_kernel, dim=(2, 3), keepdim=True) gaussian_kernel = gaussian_kernel / kernel_sum # normalize if mid_one: # normalize so that the middle element is 1 maxval = gaussian_kernel[:, :, (kernel_size // 2), (kernel_size // 2)].reshape(C, 1, 1, 1) gaussian_kernel = gaussian_kernel / maxval return gaussian_kernel def gaussian_blur_2d(input, kernel_size=3, sigma=2.0, reflect_pad=False, mid_one=False): B, C, Z, X = input.shape kernel = get_gaussian_kernel_2d(C, kernel_size, sigma, mid_one=mid_one) if reflect_pad: pad = (kernel_size - 1) // 2 out = F.pad(input, (pad, pad, pad, pad), mode='reflect') out = F.conv2d(out, kernel, padding=0, groups=C) else: out = F.conv2d(input, kernel, padding=(kernel_size - 1) // 2, groups=C) return out def gradient2d(x, absolute=False, square=False, return_sum=False): # x should be B x C x H x W dh = x[:, :, 1:, :] - x[:, :, :-1, :] dw = x[:, :, :, 1:] - x[:, :, :, :-1] zeros = torch.zeros_like(x) zero_h = zeros[:, :, 0:1, :] zero_w = zeros[:, :, :, 0:1] dh = torch.cat([dh, zero_h], axis=2) dw = torch.cat([dw, zero_w], axis=3) if absolute: dh = torch.abs(dh) dw = torch.abs(dw) if square: dh = dh ** 2 dw = dw ** 2 if return_sum: return dh + dw else: return dh, dw def to_homogeneous(x): return torch.cat([x, x.new_ones(x[..., :1].shape)], -1) def from_homogeneous(x, assert_homogeneous_part_is_equal_to_1=False, eps=0.1): if assert_homogeneous_part_is_equal_to_1: assert torch.allclose(x[..., -1], x.new_ones(x[..., -1].shape), atol=eps) return x[..., :-1] / x[..., -1:] def time_now(): return datetime.now().strftime("%Y%m%d_%H%M%S_%f") ================================================ FILE: mvtracker/utils/eval_utils.py ================================================ import os import matplotlib import numpy as np import rerun as rr import json from tqdm import tqdm from scipy.stats import multivariate_normal def medianTrajError(output, target): diff = np.linalg.norm(target - output, axis = 1) orderedDiff = np.sort(diff) return orderedDiff[len(orderedDiff)//2] def averageTrajError(output, target): diff = np.linalg.norm(target - output, axis = 1) return np.mean(diff, axis = 0) def pointTrack(queryPoint, anchorPos, anchorRot): R = qToRot(anchorRot[0]) t0 = R.T@(queryPoint - anchorPos[0]) track = [] for idx in tqdm(range(len(anchorPos)), 'Track', position = 1, leave = False): track.append(anchorPos[idx] + qToRot(anchorRot[idx])@t0) return np.array(track) def qToRot(q): norm = np.linalg.norm(q) r = q[0]/norm x = q[1]/norm y = q[2]/norm z = q[3]/norm R = np.array( [[1.0 - 2.0 * (y * y + z * z), 2.0 * (x * y - r * z), 2.0 * (x * z + r * y)], [2.0 * (x * y + r * z), 1.0 - 2.0 * (x * x + z * z), 2.0 * (y * z - r * x)], [2.0 * (x * z - r * y), 2.0 * (y * z + r * x), 1.0 - 2.0 * (x * x + y * y)]] ) return R def get3DCov(scale, rotation, scale_mod = 1): S = np.zeros((3,3)) S[0][0] = scale_mod * scale[0] S[1][1] = scale_mod * scale[1] S[2][2] = scale_mod * scale[2] R = qToRot(rotation) M = S * R sigma = np.transpose(M) * M return sigma def getAll3DCov(scales, rotations, scale_mod = 1): cov3Ds = [] for idx in tqdm(range(len(scales)), 'Cov'): cov3Ds.append(get3DCov(scales[idx], rotations[idx], scale_mod)) return np.array(cov3Ds) def getContributions(mean3Ds, cov3Ds, query): assert len(mean3Ds) == len(cov3Ds), f'{mean3Ds.shape} {cov3Ds.shape}' PDFs = [] for idx in tqdm(range(len(mean3Ds)),'PDF', position = 1, leave = False): try: pdf = multivariate_normal.pdf(query, mean = mean3Ds[idx], cov = cov3Ds[idx]) PDFs.append(pdf) except: PDFs.append(-1) return np.array(PDFs) ================================================ FILE: mvtracker/utils/geom.py ================================================ import numpy as np import torch import torchvision.ops as ops def matmul2(mat1, mat2): return torch.matmul(mat1, mat2) def matmul3(mat1, mat2, mat3): return torch.matmul(mat1, torch.matmul(mat2, mat3)) def eye_3x3(B, device='cuda'): rt = torch.eye(3, device=torch.device(device)).view(1, 3, 3).repeat([B, 1, 1]) return rt def eye_4x4(B, device='cuda'): rt = torch.eye(4, device=torch.device(device)).view(1, 4, 4).repeat([B, 1, 1]) return rt def safe_inverse(a): # parallel version B, _, _ = list(a.shape) inv = a.clone() r_transpose = a[:, :3, :3].transpose(1, 2) # inverse of rotation matrix inv[:, :3, :3] = r_transpose inv[:, :3, 3:4] = -torch.matmul(r_transpose, a[:, :3, 3:4]) return inv def safe_inverse_single(a): r, t = split_rt_single(a) t = t.view(3, 1) r_transpose = r.t() inv = torch.cat([r_transpose, -torch.matmul(r_transpose, t)], 1) bottom_row = a[3:4, :] # this is [0, 0, 0, 1] # bottom_row = torch.tensor([0.,0.,0.,1.]).view(1,4) inv = torch.cat([inv, bottom_row], 0) return inv def split_intrinsics(K): # K is B x 3 x 3 or B x 4 x 4 fx = K[:, 0, 0] fy = K[:, 1, 1] x0 = K[:, 0, 2] y0 = K[:, 1, 2] return fx, fy, x0, y0 def apply_pix_T_cam(pix_T_cam, xyz): fx, fy, x0, y0 = split_intrinsics(pix_T_cam) # xyz is shaped B x H*W x 3 # returns xy, shaped B x H*W x 2 B, N, C = list(xyz.shape) assert (C == 3) x, y, z = torch.unbind(xyz, axis=-1) fx = torch.reshape(fx, [B, 1]) fy = torch.reshape(fy, [B, 1]) x0 = torch.reshape(x0, [B, 1]) y0 = torch.reshape(y0, [B, 1]) EPS = 1e-4 z = torch.clamp(z, min=EPS) x = (x * fx) / (z) + x0 y = (y * fy) / (z) + y0 xy = torch.stack([x, y], axis=-1) return xy def apply_pix_T_cam_py(pix_T_cam, xyz): fx, fy, x0, y0 = split_intrinsics(pix_T_cam) # xyz is shaped B x H*W x 3 # returns xy, shaped B x H*W x 2 B, N, C = list(xyz.shape) assert (C == 3) x, y, z = xyz[:, :, 0], xyz[:, :, 1], xyz[:, :, 2] fx = np.reshape(fx, [B, 1]) fy = np.reshape(fy, [B, 1]) x0 = np.reshape(x0, [B, 1]) y0 = np.reshape(y0, [B, 1]) EPS = 1e-4 z = np.clip(z, EPS, None) x = (x * fx) / (z) + x0 y = (y * fy) / (z) + y0 xy = np.stack([x, y], axis=-1) return xy def get_camM_T_camXs(origin_T_camXs, ind=0): B, S = list(origin_T_camXs.shape)[0:2] camM_T_camXs = torch.zeros_like(origin_T_camXs) for b in list(range(B)): camM_T_origin = safe_inverse_single(origin_T_camXs[b, ind]) for s in list(range(S)): camM_T_camXs[b, s] = torch.matmul(camM_T_origin, origin_T_camXs[b, s]) return camM_T_camXs def apply_4x4(RT, xyz): B, N, _ = list(xyz.shape) ones = torch.ones_like(xyz[:, :, 0:1]) xyz1 = torch.cat([xyz, ones], 2) xyz1_t = torch.transpose(xyz1, 1, 2) # this is B x 4 x N xyz2_t = torch.matmul(RT, xyz1_t) xyz2 = torch.transpose(xyz2_t, 1, 2) xyz2 = xyz2[:, :, :3] return xyz2 def apply_4x4_py(RT, xyz): # print('RT', RT.shape) B, N, _ = list(xyz.shape) ones = np.ones_like(xyz[:, :, 0:1]) xyz1 = np.concatenate([xyz, ones], 2) # print('xyz1', xyz1.shape) xyz1_t = xyz1.transpose(0, 2, 1) # print('xyz1_t', xyz1_t.shape) # this is B x 4 x N xyz2_t = np.matmul(RT, xyz1_t) # print('xyz2_t', xyz2_t.shape) xyz2 = xyz2_t.transpose(0, 2, 1) # print('xyz2', xyz2.shape) xyz2 = xyz2[:, :, :3] return xyz2 def apply_3x3(RT, xy): B, N, _ = list(xy.shape) ones = torch.ones_like(xy[:, :, 0:1]) xy1 = torch.cat([xy, ones], 2) xy1_t = torch.transpose(xy1, 1, 2) # this is B x 4 x N xy2_t = torch.matmul(RT, xy1_t) xy2 = torch.transpose(xy2_t, 1, 2) xy2 = xy2[:, :, :2] return xy2 def generate_polygon(ctr_x, ctr_y, avg_r, irregularity, spikiness, num_verts): ''' Start with the center of the polygon at ctr_x, ctr_y, Then creates the polygon by sampling points on a circle around the center. Random noise is added by varying the angular spacing between sequential points, and by varying the radial distance of each point from the centre. Params: ctr_x, ctr_y - coordinates of the "centre" of the polygon avg_r - in px, the average radius of this polygon, this roughly controls how large the polygon is, really only useful for order of magnitude. irregularity - [0,1] indicating how much variance there is in the angular spacing of vertices. [0,1] will map to [0, 2pi/numberOfVerts] spikiness - [0,1] indicating how much variance there is in each vertex from the circle of radius avg_r. [0,1] will map to [0, avg_r] pp num_verts Returns: np.array [num_verts, 2] - CCW order. ''' # spikiness spikiness = np.clip(spikiness, 0, 1) * avg_r # generate n angle steps irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / num_verts lower = (2 * np.pi / num_verts) - irregularity upper = (2 * np.pi / num_verts) + irregularity # angle steps angle_steps = np.random.uniform(lower, upper, num_verts) sc = (2 * np.pi) / angle_steps.sum() angle_steps *= sc # get all radii angle = np.random.uniform(0, 2 * np.pi) radii = np.clip(np.random.normal(avg_r, spikiness, num_verts), 0, 2 * avg_r) # compute all points points = [] for i in range(num_verts): x = ctr_x + radii[i] * np.cos(angle) y = ctr_y + radii[i] * np.sin(angle) points.append([x, y]) angle += angle_steps[i] return np.array(points).astype(int) def get_random_affine_2d(B, rot_min=-5.0, rot_max=5.0, tx_min=-0.1, tx_max=0.1, ty_min=-0.1, ty_max=0.1, sx_min=-0.05, sx_max=0.05, sy_min=-0.05, sy_max=0.05, shx_min=-0.05, shx_max=0.05, shy_min=-0.05, shy_max=0.05): ''' Params: rot_min: rotation amount min rot_max: rotation amount max tx_min: translation x min tx_max: translation x max ty_min: translation y min ty_max: translation y max sx_min: scaling x min sx_max: scaling x max sy_min: scaling y min sy_max: scaling y max shx_min: shear x min shx_max: shear x max shy_min: shear y min shy_max: shear y max Returns: transformation matrix: (B, 3, 3) ''' # rotation if rot_max - rot_min != 0: rot_amount = np.random.uniform(low=rot_min, high=rot_max, size=B) rot_amount = np.pi / 180.0 * rot_amount else: rot_amount = rot_min rotation = np.zeros((B, 3, 3)) # B, 3, 3 rotation[:, 2, 2] = 1 rotation[:, 0, 0] = np.cos(rot_amount) rotation[:, 0, 1] = -np.sin(rot_amount) rotation[:, 1, 0] = np.sin(rot_amount) rotation[:, 1, 1] = np.cos(rot_amount) # translation translation = np.zeros((B, 3, 3)) # B, 3, 3 translation[:, [0, 1, 2], [0, 1, 2]] = 1 if (tx_max - tx_min) > 0: trans_x = np.random.uniform(low=tx_min, high=tx_max, size=B) translation[:, 0, 2] = trans_x # else: # translation[:, 0, 2] = tx_max if ty_max - ty_min != 0: trans_y = np.random.uniform(low=ty_min, high=ty_max, size=B) translation[:, 1, 2] = trans_y # else: # translation[:, 1, 2] = ty_max # scaling scaling = np.zeros((B, 3, 3)) # B, 3, 3 scaling[:, [0, 1, 2], [0, 1, 2]] = 1 if (sx_max - sx_min) > 0: scale_x = 1 + np.random.uniform(low=sx_min, high=sx_max, size=B) scaling[:, 0, 0] = scale_x # else: # scaling[:, 0, 0] = sx_max if (sy_max - sy_min) > 0: scale_y = 1 + np.random.uniform(low=sy_min, high=sy_max, size=B) scaling[:, 1, 1] = scale_y # else: # scaling[:, 1, 1] = sy_max # shear shear = np.zeros((B, 3, 3)) # B, 3, 3 shear[:, [0, 1, 2], [0, 1, 2]] = 1 if (shx_max - shx_min) > 0: shear_x = np.random.uniform(low=shx_min, high=shx_max, size=B) shear[:, 0, 1] = shear_x # else: # shear[:, 0, 1] = shx_max if (shy_max - shy_min) > 0: shear_y = np.random.uniform(low=shy_min, high=shy_max, size=B) shear[:, 1, 0] = shear_y # else: # shear[:, 1, 0] = shy_max # compose all those rt = np.einsum("ijk,ikl->ijl", rotation, translation) ss = np.einsum("ijk,ikl->ijl", scaling, shear) trans = np.einsum("ijk,ikl->ijl", rt, ss) return trans def get_centroid_from_box2d(box2d): ymin = box2d[:, 0] xmin = box2d[:, 1] ymax = box2d[:, 2] xmax = box2d[:, 3] x = (xmin + xmax) / 2.0 y = (ymin + ymax) / 2.0 return y, x def normalize_boxlist2d(boxlist2d, H, W): boxlist2d = boxlist2d.clone() ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2) ymin = ymin / float(H) ymax = ymax / float(H) xmin = xmin / float(W) xmax = xmax / float(W) boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2) return boxlist2d def unnormalize_boxlist2d(boxlist2d, H, W): boxlist2d = boxlist2d.clone() ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2) ymin = ymin * float(H) ymax = ymax * float(H) xmin = xmin * float(W) xmax = xmax * float(W) boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2) return boxlist2d def unnormalize_box2d(box2d, H, W): return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1) def normalize_box2d(box2d, H, W): return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1) def get_size_from_box2d(box2d): ymin = box2d[:, 0] xmin = box2d[:, 1] ymax = box2d[:, 2] xmax = box2d[:, 3] height = ymax - ymin width = xmax - xmin return height, width def crop_and_resize(im, boxlist, PH, PW, boxlist_is_normalized=False): B, C, H, W = im.shape B2, N, D = boxlist.shape assert (B == B2) assert (D == 4) # PH, PW is the size to resize to # output is B,N,C,PH,PW # pt wants xy xy, unnormalized if boxlist_is_normalized: boxlist_unnorm = unnormalize_boxlist2d(boxlist, H, W) else: boxlist_unnorm = boxlist ymin, xmin, ymax, xmax = boxlist_unnorm.unbind(2) # boxlist_pt = torch.stack([boxlist_unnorm[:,1], boxlist_unnorm[:,0], boxlist_unnorm[:,3], boxlist_unnorm[:,2]], dim=1) boxlist_pt = torch.stack([xmin, ymin, xmax, ymax], dim=2) # we want a B-len list of K x 4 arrays # print('im', im.shape) # print('boxlist', boxlist.shape) # print('boxlist_pt', boxlist_pt.shape) # boxlist_pt = list(boxlist_pt.unbind(0)) crops = [] for b in range(B): crops_b = ops.roi_align(im[b:b + 1], [boxlist_pt[b]], output_size=(PH, PW)) crops.append(crops_b) # # crops = im # print('crops', crops.shape) # crops = crops.reshape(B,N,C,PH,PW) # crops = [] # for b in range(B): # crop_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW)) # print('crop_b', crop_b.shape) # crops.append(crop_b) crops = torch.stack(crops, dim=0) # print('crops', crops.shape) # boxlist_list = boxlist_pt.unbind(0) # print('rgb_crop', rgb_crop.shape) return crops # def get_boxlist_from_centroid_and_size(cy, cx, h, w, clip=True): # # cy,cx are both B,N # ymin = cy - h/2 # ymax = cy + h/2 # xmin = cx - w/2 # xmax = cx + w/2 # box = torch.stack([ymin, xmin, ymax, xmax], dim=-1) # if clip: # box = torch.clamp(box, 0, 1) # return box def get_boxlist_from_centroid_and_size(cy, cx, h, w): # , clip=False): # cy,cx are the same shape ymin = cy - h / 2 ymax = cy + h / 2 xmin = cx - w / 2 xmax = cx + w / 2 # if clip: # ymin = torch.clamp(ymin, 0, H-1) # ymax = torch.clamp(ymax, 0, H-1) # xmin = torch.clamp(xmin, 0, W-1) # xmax = torch.clamp(xmax, 0, W-1) box = torch.stack([ymin, xmin, ymax, xmax], dim=-1) return box def get_box2d_from_mask(mask, normalize=False): # mask is B, 1, H, W B, C, H, W = mask.shape assert (C == 1) xy = utils.basic.gridcloud2d(B, H, W, norm=False, device=mask.device) # B, H*W, 2 box = torch.zeros((B, 4), dtype=torch.float32, device=mask.device) for b in range(B): xy_b = xy[b] # H*W, 2 mask_b = mask[b].reshape(H * W) xy_ = xy_b[mask_b > 0] x_ = xy_[:, 0] y_ = xy_[:, 1] ymin = torch.min(y_) ymax = torch.max(y_) xmin = torch.min(x_) xmax = torch.max(x_) box[b] = torch.stack([ymin, xmin, ymax, xmax], dim=0) if normalize: box = normalize_boxlist2d(box.unsqueeze(1), H, W).squeeze(1) return box def convert_box2d_to_intrinsics(box2d, pix_T_cam, H, W, use_image_aspect_ratio=True, mult_padding=1.0): # box2d is B x 4, with ymin, xmin, ymax, xmax in normalized coords # ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1) # H, W is the original size of the image # mult_padding is relative to object size in pixels # i assume we're rendering an image the same size as the original (H, W) if not mult_padding == 1.0: y, x = get_centroid_from_box2d(box2d) h, w = get_size_from_box2d(box2d) box2d = get_box2d_from_centroid_and_size( y, x, h * mult_padding, w * mult_padding, clip=False) if use_image_aspect_ratio: h, w = get_size_from_box2d(box2d) y, x = get_centroid_from_box2d(box2d) # note h,w are relative right now # we need to undo this, to see the real ratio h = h * float(H) w = w * float(W) box_ratio = h / w im_ratio = H / float(W) # print('box_ratio:', box_ratio) # print('im_ratio:', im_ratio) if box_ratio >= im_ratio: w = h / im_ratio # print('setting w:', h/im_ratio) else: h = w * im_ratio # print('setting h:', w*im_ratio) box2d = get_box2d_from_centroid_and_size( y, x, h / float(H), w / float(W), clip=False) assert (h > 1e-4) assert (w > 1e-4) ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1) fx, fy, x0, y0 = split_intrinsics(pix_T_cam) # the topleft of the new image will now have a different offset from the center of projection new_x0 = x0 - xmin * W new_y0 = y0 - ymin * H pix_T_cam = pack_intrinsics(fx, fy, new_x0, new_y0) # this alone will give me an image in original resolution, # with its topleft at the box corner box_h, box_w = get_size_from_box2d(box2d) # these are normalized, and shaped B. (e.g., [0.4], [0.3]) # we are going to scale the image by the inverse of this, # since we are zooming into this area sy = 1. / box_h sx = 1. / box_w pix_T_cam = scale_intrinsics(pix_T_cam, sx, sy) return pix_T_cam, box2d def pixels2camera(x, y, z, fx, fy, x0, y0): # x and y are locations in pixel coordinates, z is a depth in meters # they can be images or pointclouds # fx, fy, x0, y0 are camera intrinsics # returns xyz, sized B x N x 3 B = x.shape[0] fx = torch.reshape(fx, [B, 1]) fy = torch.reshape(fy, [B, 1]) x0 = torch.reshape(x0, [B, 1]) y0 = torch.reshape(y0, [B, 1]) x = torch.reshape(x, [B, -1]) y = torch.reshape(y, [B, -1]) z = torch.reshape(z, [B, -1]) # unproject x = (z / fx) * (x - x0) y = (z / fy) * (y - y0) xyz = torch.stack([x, y, z], dim=2) # B x N x 3 return xyz def camera2pixels(xyz, pix_T_cam): # xyz is shaped B x H*W x 3 # returns xy, shaped B x H*W x 2 fx, fy, x0, y0 = split_intrinsics(pix_T_cam) x, y, z = torch.unbind(xyz, dim=-1) B = list(z.shape)[0] fx = torch.reshape(fx, [B, 1]) fy = torch.reshape(fy, [B, 1]) x0 = torch.reshape(x0, [B, 1]) y0 = torch.reshape(y0, [B, 1]) x = torch.reshape(x, [B, -1]) y = torch.reshape(y, [B, -1]) z = torch.reshape(z, [B, -1]) EPS = 1e-4 z = torch.clamp(z, min=EPS) x = (x * fx) / z + x0 y = (y * fy) / z + y0 xy = torch.stack([x, y], dim=-1) return xy def depth2pointcloud(z, pix_T_cam): B, C, H, W = list(z.shape) device = z.device y, x = utils.basic.meshgrid2d(B, H, W, device=device) z = torch.reshape(z, [B, H, W]) fx, fy, x0, y0 = split_intrinsics(pix_T_cam) xyz = pixels2camera(x, y, z, fx, fy, x0, y0) return xyz ================================================ FILE: mvtracker/utils/improc.py ================================================ import cv2 import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F import torchvision from matplotlib import cm from sklearn.decomposition import PCA EPS = 1e-6 from skimage.color import ( hsv2rgb) def _convert(input_, type_): return { 'float': input_.float(), 'double': input_.double(), }.get(type_, input_) def _generic_transform_sk_3d(transform, in_type='', out_type=''): def apply_transform_individual(input_): device = input_.device input_ = input_.cpu() input_ = _convert(input_, in_type) input_ = input_.permute(1, 2, 0).detach().numpy() transformed = transform(input_) output = torch.from_numpy(transformed).float().permute(2, 0, 1) output = _convert(output, out_type) return output.to(device) def apply_transform(input_): to_stack = [] for image in input_: to_stack.append(apply_transform_individual(image)) return torch.stack(to_stack) return apply_transform hsv_to_rgb = _generic_transform_sk_3d(hsv2rgb) def preprocess_color_tf(x): import tensorflow as tf return tf.cast(x, tf.float32) * 1. / 255 - 0.5 def preprocess_color(x): if isinstance(x, np.ndarray): return x.astype(np.float32) * 1. / 255 - 0.5 else: return x.float() * 1. / 255 - 0.5 def pca_embed(emb, keep, valid=None): ## emb -- [S,H/2,W/2,C] ## keep is the number of principal components to keep ## Helper function for reduce_emb. emb = emb + EPS # emb is B x C x H x W emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() # this is B x H x W x C if valid: valid = valid.cpu().detach().numpy().reshape((H * W)) emb_reduced = list() B, H, W, C = np.shape(emb) for img in emb: if np.isnan(img).any(): emb_reduced.append(np.zeros([H, W, keep])) continue pixels_kd = np.reshape(img, (H * W, C)) if valid: pixels_kd_pca = pixels_kd[valid] else: pixels_kd_pca = pixels_kd P = PCA(keep) P.fit(pixels_kd_pca) if valid: pixels3d = P.transform(pixels_kd) * valid else: pixels3d = P.transform(pixels_kd) out_img = np.reshape(pixels3d, [H, W, keep]).astype(np.float32) if np.isnan(out_img).any(): emb_reduced.append(np.zeros([H, W, keep])) continue emb_reduced.append(out_img) emb_reduced = np.stack(emb_reduced, axis=0).astype(np.float32) return torch.from_numpy(emb_reduced).permute(0, 3, 1, 2) def pca_embed_together(emb, keep): ## emb -- [S,H/2,W/2,C] ## keep is the number of principal components to keep ## Helper function for reduce_emb. emb = emb + EPS # emb is B x C x H x W emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() # this is B x H x W x C B, H, W, C = np.shape(emb) if np.isnan(emb).any(): return torch.zeros(B, keep, H, W) pixelskd = np.reshape(emb, (B * H * W, C)) P = PCA(keep) P.fit(pixelskd) pixels3d = P.transform(pixelskd) out_img = np.reshape(pixels3d, [B, H, W, keep]).astype(np.float32) if np.isnan(out_img).any(): return torch.zeros(B, keep, H, W) return torch.from_numpy(out_img).permute(0, 3, 1, 2) def reduce_emb(emb, valid=None, inbound=None, together=False): ## emb -- [S,C,H/2,W/2], inbound -- [S,1,H/2,W/2] ## Reduce number of chans to 3 with PCA. For vis. # S,H,W,C = emb.shape.as_list() S, C, H, W = list(emb.size()) keep = 3 if together: reduced_emb = pca_embed_together(emb, keep) else: reduced_emb = pca_embed(emb, keep, valid) # not im reduced_emb = utils.basic.normalize(reduced_emb) - 0.5 if inbound is not None: emb_inbound = emb * inbound else: emb_inbound = None return reduced_emb, emb_inbound def get_feat_pca(feat, valid=None): B, C, D, W = list(feat.size()) # feat is B x C x D x W. If 3D input, average it through Height dimension before passing into this function. pca, _ = reduce_emb(feat, valid=valid, inbound=None, together=True) # pca is B x 3 x W x D return pca def gif_and_tile(ims, just_gif=False): S = len(ims) # each im is B x H x W x C # i want a gif in the left, and the tiled frames on the right # for the gif tool, this means making a B x S x H x W tensor # where the leftmost part is sequential and the rest is tiled gif = torch.stack(ims, dim=1) if just_gif: return gif til = torch.cat(ims, dim=2) til = til.unsqueeze(dim=1).repeat(1, S, 1, 1, 1) im = torch.cat([gif, til], dim=3) return im def back2color(i, blacken_zeros=False): if blacken_zeros: const = torch.tensor([-0.5]) i = torch.where(i == 0.0, const.cuda() if i.is_cuda else const, i) return back2color(i) else: return ((i + 0.5) * 255).type(torch.ByteTensor) def convert_occ_to_height(occ, reduce_axis=3): B, C, D, H, W = list(occ.shape) assert (C == 1) # note that height increases DOWNWARD in the tensor # (like pixel/camera coordinates) G = list(occ.shape)[reduce_axis] values = torch.linspace(float(G), 1.0, steps=G, dtype=torch.float32, device=occ.device) if reduce_axis == 2: # fro view values = values.view(1, 1, G, 1, 1) elif reduce_axis == 3: # top view values = values.view(1, 1, 1, G, 1) elif reduce_axis == 4: # lateral view values = values.view(1, 1, 1, 1, G) else: assert (False) # you have to reduce one of the spatial dims (2-4) values = torch.max(occ * values, dim=reduce_axis)[0] / float(G) # values = values.view([B, C, D, W]) return values def xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=False): # xy is B x N x 2, containing float x and y coordinates of N things # grid_xs and grid_ys are B x N x Y x X B, N, Y, X = list(grid_xs.shape) mu_x = xy[:, :, 0].clone() mu_y = xy[:, :, 1].clone() x_valid = (mu_x > -0.5) & (mu_x < float(X + 0.5)) y_valid = (mu_y > -0.5) & (mu_y < float(Y + 0.5)) not_valid = ~(x_valid & y_valid) mu_x[not_valid] = -10000 mu_y[not_valid] = -10000 mu_x = mu_x.reshape(B, N, 1, 1).repeat(1, 1, Y, X) mu_y = mu_y.reshape(B, N, 1, 1).repeat(1, 1, Y, X) sigma_sq = sigma * sigma # sigma_sq = (sigma*sigma).reshape(B, N, 1, 1) sq_diff_x = (grid_xs - mu_x) ** 2 sq_diff_y = (grid_ys - mu_y) ** 2 term1 = 1. / 2. * np.pi * sigma_sq term2 = torch.exp(-(sq_diff_x + sq_diff_y) / (2. * sigma_sq)) gauss = term1 * term2 if norm: # normalize so each gaussian peaks at 1 gauss_ = gauss.reshape(B * N, Y, X) gauss_ = utils.basic.normalize(gauss_) gauss = gauss_.reshape(B, N, Y, X) return gauss def xy2heatmaps(xy, Y, X, sigma=30.0, norm=True): # xy is B x N x 2 B, N, D = list(xy.shape) assert (D == 2) device = xy.device grid_y, grid_x = utils.basic.meshgrid2d(B, Y, X, device=device) # grid_x and grid_y are B x Y x X grid_xs = grid_x.unsqueeze(1).repeat(1, N, 1, 1) grid_ys = grid_y.unsqueeze(1).repeat(1, N, 1, 1) heat = xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=norm) return heat def draw_circles_at_xy(xy, Y, X, sigma=12.5, round=False): B, N, D = list(xy.shape) assert (D == 2) prior = xy2heatmaps(xy, Y, X, sigma=sigma) # prior is B x N x Y x X if round: prior = (prior > 0.5).float() return prior def seq2color(im, norm=True, colormap='coolwarm'): B, S, H, W = list(im.shape) # S is sequential # prep a mask of the valid pixels, so we can blacken the invalids later mask = torch.max(im, dim=1, keepdim=True)[0] # turn the S dim into an explicit sequence coeffs = np.linspace(1.0, float(S), S).astype(np.float32) / float(S) # # increase the spacing from the center # coeffs[:int(S/2)] -= 2.0 # coeffs[int(S/2)+1:] += 2.0 coeffs = torch.from_numpy(coeffs).float().cuda() coeffs = coeffs.reshape(1, S, 1, 1).repeat(B, 1, H, W) # scale each channel by the right coeff im = im * coeffs # now im is in [1/S, 1], except for the invalid parts which are 0 # keep the highest valid coeff at each pixel im = torch.max(im, dim=1, keepdim=True)[0] out = [] for b in range(B): im_ = im[b] # move channels out to last dim_ im_ = im_.detach().cpu().numpy() im_ = np.squeeze(im_) # im_ is H x W if colormap == 'coolwarm': im_ = cm.coolwarm(im_)[:, :, :3] elif colormap == 'PiYG': im_ = cm.PiYG(im_)[:, :, :3] elif colormap == 'winter': im_ = cm.winter(im_)[:, :, :3] elif colormap == 'spring': im_ = cm.spring(im_)[:, :, :3] elif colormap == 'onediff': im_ = np.reshape(im_, (-1)) im0_ = cm.spring(im_)[:, :3] im1_ = cm.winter(im_)[:, :3] im1_[im_ == 1 / float(S)] = im0_[im_ == 1 / float(S)] im_ = np.reshape(im1_, (H, W, 3)) else: assert (False) # invalid colormap # move channels into dim 0 im_ = np.transpose(im_, [2, 0, 1]) im_ = torch.from_numpy(im_).float().cuda() out.append(im_) out = torch.stack(out, dim=0) # blacken the invalid pixels, instead of using the 0-color out = out * mask # out = out*255.0 # put it in [-0.5, 0.5] out = out - 0.5 return out def colorize(d): # this is actually just grayscale right now if d.ndim == 2: d = d.unsqueeze(dim=0) else: assert (d.ndim == 3) # color_map = cm.get_cmap('plasma') color_map = cm.get_cmap('inferno') # S1, D = traj.shape # print('d1', d.shape) C, H, W = d.shape assert (C == 1) d = d.reshape(-1) d = d.detach().cpu().numpy() # print('d2', d.shape) color = np.array(color_map(d)) * 255 # rgba # print('color1', color.shape) color = np.reshape(color[:, :3], [H * W, 3]) # print('color2', color.shape) color = torch.from_numpy(color).permute(1, 0).reshape(3, H, W) # # gather # cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray') # if cmap=='RdBu' or cmap=='RdYlGn': # colors = cm(np.arange(256))[:, :3] # else: # colors = cm.colors # colors = np.array(colors).astype(np.float32) # colors = np.reshape(colors, [-1, 3]) # colors = tf.constant(colors, dtype=tf.float32) # value = tf.gather(colors, indices) # colorize(value, normalize=True, vmin=None, vmax=None, cmap=None, vals=255) # copy to the three chans # d = d.repeat(3, 1, 1) return color def oned2inferno(d, norm=True, do_colorize=False): # convert a 1chan input to a 3chan image output # if it's just B x H x W, add a C dim if d.ndim == 3: d = d.unsqueeze(dim=1) # d should be B x C x H x W, where C=1 B, C, H, W = list(d.shape) assert (C == 1) if norm: d = utils.basic.normalize(d) if do_colorize: rgb = torch.zeros(B, 3, H, W) for b in list(range(B)): rgb[b] = colorize(d[b]) else: rgb = d.repeat(1, 3, 1, 1) * 255.0 # rgb = (255.0*rgb).type(torch.ByteTensor) rgb = rgb.type(torch.ByteTensor) # rgb = tf.cast(255.0*rgb, tf.uint8) # rgb = tf.reshape(rgb, [-1, hyp.H, hyp.W, 3]) # rgb = tf.expand_dims(rgb, axis=0) return rgb def oned2gray(d, norm=True): # convert a 1chan input to a 3chan image output # if it's just B x H x W, add a C dim if d.ndim == 3: d = d.unsqueeze(dim=1) # d should be B x C x H x W, where C=1 B, C, H, W = list(d.shape) assert (C == 1) if norm: d = utils.basic.normalize(d) rgb = d.repeat(1, 3, 1, 1) rgb = (255.0 * rgb).type(torch.ByteTensor) return rgb def draw_frame_id_on_vis(vis, frame_id, scale=0.5, left=5, top=20): rgb = vis.detach().cpu().numpy()[0] rgb = np.transpose(rgb, [1, 2, 0]) # put channels last rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) color = (255, 255, 255) # print('putting frame id', frame_id) frame_str = utils.basic.strnum(frame_id) text_color_bg = (0, 0, 0) font = cv2.FONT_HERSHEY_SIMPLEX text_size, _ = cv2.getTextSize(frame_str, font, scale, 1) text_w, text_h = text_size cv2.rectangle(rgb, (left, top - text_h), (left + text_w, top + 1), text_color_bg, -1) cv2.putText( rgb, frame_str, (left, top), # from left, from top font, scale, # font scale (float) color, 1) # font thickness (int) rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB) vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) return vis COLORMAP_FILE = "./utils/bremm.png" class ColorMap2d: def __init__(self, filename=None): self._colormap_file = filename or COLORMAP_FILE self._img = plt.imread(self._colormap_file) self._height = self._img.shape[0] self._width = self._img.shape[1] def __call__(self, X): assert len(X.shape) == 2 output = np.zeros((X.shape[0], 3)) for i in range(X.shape[0]): x, y = X[i, :] xp = int((self._width - 1) * x) yp = int((self._height - 1) * y) xp = np.clip(xp, 0, self._width - 1) yp = np.clip(yp, 0, self._height - 1) output[i, :] = self._img[yp, xp] return output def get_n_colors(N, sequential=False): label_colors = [] for ii in range(N): if sequential: rgb = cm.winter(ii / (N - 1)) rgb = (np.array(rgb) * 255).astype(np.uint8)[:3] else: rgb = np.zeros(3) while np.sum(rgb) < 128: # ensure min brightness rgb = np.random.randint(0, 256, 3) label_colors.append(rgb) return label_colors class Summ_writer(object): def __init__(self, writer, global_step, log_freq=10, fps=8, scalar_freq=100, just_gif=False): self.writer = writer self.global_step = global_step self.log_freq = log_freq self.fps = fps self.just_gif = just_gif self.maxwidth = 10000 self.save_this = (self.global_step % self.log_freq == 0) self.scalar_freq = max(scalar_freq, 1) def summ_gif(self, name, tensor, blacken_zeros=False): # tensor should be in B x S x C x H x W assert tensor.dtype in {torch.uint8, torch.float32} shape = list(tensor.shape) if tensor.dtype == torch.float32: tensor = back2color(tensor, blacken_zeros=blacken_zeros) video_to_write = tensor[0:1] S = video_to_write.shape[1] if S == 1: # video_to_write is 1 x 1 x C x H x W self.writer.add_image(name, video_to_write[0, 0], global_step=self.global_step) else: self.writer.add_video(name, video_to_write, fps=self.fps, global_step=self.global_step) return video_to_write def draw_boxlist2d_on_image(self, rgb, boxlist, scores=None, tids=None, linewidth=1): B, C, H, W = list(rgb.shape) assert (C == 3) B2, N, D = list(boxlist.shape) assert (B2 == B) assert (D == 4) # ymin, xmin, ymax, xmax rgb = back2color(rgb) if scores is None: scores = torch.ones(B2, N).float() if tids is None: tids = torch.arange(N).reshape(1, N).repeat(B2, N).long() # tids = torch.zeros(B2, N).long() out = self.draw_boxlist2d_on_image_py( rgb[0].cpu().detach().numpy(), boxlist[0].cpu().detach().numpy(), scores[0].cpu().detach().numpy(), tids[0].cpu().detach().numpy(), linewidth=linewidth) out = torch.from_numpy(out).type(torch.ByteTensor).permute(2, 0, 1) out = torch.unsqueeze(out, dim=0) out = preprocess_color(out) out = torch.reshape(out, [1, C, H, W]) return out def draw_boxlist2d_on_image_py(self, rgb, boxlist, scores, tids, linewidth=1): # all inputs are numpy tensors # rgb is H x W x 3 # boxlist is N x 4 # scores is N # tids is N rgb = np.transpose(rgb, [1, 2, 0]) # put channels last # rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) rgb = rgb.astype(np.uint8).copy() H, W, C = rgb.shape assert (C == 3) N, D = boxlist.shape assert (D == 4) # color_map = cm.get_cmap('tab20') # color_map = cm.get_cmap('set1') color_map = cm.get_cmap('Accent') color_map = color_map.colors # print('color_map', color_map) # draw for ind, box in enumerate(boxlist): # box is 4 if not np.isclose(scores[ind], 0.0): # box = utils.geom.scale_box2d(box, H, W) ymin, xmin, ymax, xmax = box # ymin, ymax = ymin*H, ymax*H # xmin, xmax = xmin*W, xmax*W # print 'score = %.2f' % scores[ind] # color_id = tids[ind] % 20 color_id = tids[ind] color = color_map[color_id] color = np.array(color) * 255.0 color = color.round() # color = color.astype(np.uint8) # color = color[::-1] # print('color', color) # print 'tid = %d; score = %.3f' % (tids[ind], scores[ind]) # if False: if scores[ind] < 1.0: # not gt cv2.putText(rgb, # '%d (%.2f)' % (tids[ind], scores[ind]), '%.2f' % (scores[ind]), (int(xmin), int(ymin)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, # font size color), # 1) # font weight xmin = np.clip(int(xmin), 0, W - 1) xmax = np.clip(int(xmax), 0, W - 1) ymin = np.clip(int(ymin), 0, H - 1) ymax = np.clip(int(ymax), 0, H - 1) cv2.line(rgb, (xmin, ymin), (xmin, ymax), color, linewidth, cv2.LINE_AA) cv2.line(rgb, (xmin, ymin), (xmax, ymin), color, linewidth, cv2.LINE_AA) cv2.line(rgb, (xmax, ymin), (xmax, ymax), color, linewidth, cv2.LINE_AA) cv2.line(rgb, (xmax, ymax), (xmin, ymax), color, linewidth, cv2.LINE_AA) # rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB) return rgb def summ_boxlist2d(self, name, rgb, boxlist, scores=None, tids=None, frame_id=None, only_return=False, linewidth=2): B, C, H, W = list(rgb.shape) boxlist_vis = self.draw_boxlist2d_on_image(rgb, boxlist, scores=scores, tids=tids, linewidth=linewidth) return self.summ_rgb(name, boxlist_vis, frame_id=frame_id, only_return=only_return) def summ_rgbs(self, name, ims, frame_ids=None, blacken_zeros=False, only_return=False): if self.save_this: ims = gif_and_tile(ims, just_gif=self.just_gif) vis = ims assert vis.dtype in {torch.uint8, torch.float32} if vis.dtype == torch.float32: vis = back2color(vis, blacken_zeros) B, S, C, H, W = list(vis.shape) if frame_ids is not None: assert (len(frame_ids) == S) for s in range(S): vis[:, s] = draw_frame_id_on_vis(vis[:, s], frame_ids[s]) if int(W) > self.maxwidth: vis = vis[:, :, :, :self.maxwidth] if only_return: return vis else: return self.summ_gif(name, vis, blacken_zeros) def summ_rgb(self, name, ims, blacken_zeros=False, frame_id=None, only_return=False, halfres=False): if self.save_this: assert ims.dtype in {torch.uint8, torch.float32} if ims.dtype == torch.float32: ims = back2color(ims, blacken_zeros) # ims is B x C x H x W vis = ims[0:1] # just the first one B, C, H, W = list(vis.shape) if halfres: vis = F.interpolate(vis, scale_factor=0.5) if frame_id is not None: vis = draw_frame_id_on_vis(vis, frame_id) if int(W) > self.maxwidth: vis = vis[:, :, :, :self.maxwidth] if only_return: return vis else: return self.summ_gif(name, vis.unsqueeze(1), blacken_zeros) def flow2color(self, flow, clip=50.0): """ :param flow: Optical flow tensor. :return: RGB image normalized between 0 and 1. """ # flow is B x C x H x W B, C, H, W = list(flow.size()) flow = flow.clone().detach() abs_image = torch.abs(flow) flow_mean = abs_image.mean(dim=[1, 2, 3]) flow_std = abs_image.std(dim=[1, 2, 3]) if clip: flow = torch.clamp(flow, -clip, clip) / clip else: # Apply some kind of normalization. Divide by the perceived maximum (mean + std*2) flow_max = flow_mean + flow_std * 2 + 1e-10 for b in range(B): flow[b] = flow[b].clamp(-flow_max[b].item(), flow_max[b].item()) / flow_max[b].clamp(min=1) radius = torch.sqrt(torch.sum(flow ** 2, dim=1, keepdim=True)) # B x 1 x H x W radius_clipped = torch.clamp(radius, 0.0, 1.0) angle = torch.atan2(flow[:, 1:], flow[:, 0:1]) / np.pi # B x 1 x H x W hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0) saturation = torch.ones_like(hue) * 0.75 value = radius_clipped hsv = torch.cat([hue, saturation, value], dim=1) # B x 3 x H x W # flow = tf.image.hsv_to_rgb(hsv) flow = hsv_to_rgb(hsv) flow = (flow * 255.0).type(torch.ByteTensor) return flow def summ_flow(self, name, im, clip=0.0, only_return=False, frame_id=None): # flow is B x C x D x W if self.save_this: return self.summ_rgb(name, self.flow2color(im, clip=clip), only_return=only_return, frame_id=frame_id) else: return None def summ_oneds(self, name, ims, frame_ids=None, bev=False, fro=False, logvis=False, reduce_max=False, max_val=0.0, norm=True, only_return=False, do_colorize=False): if self.save_this: if bev: B, C, H, _, W = list(ims[0].shape) if reduce_max: ims = [torch.max(im, dim=3)[0] for im in ims] else: ims = [torch.mean(im, dim=3) for im in ims] elif fro: B, C, _, H, W = list(ims[0].shape) if reduce_max: ims = [torch.max(im, dim=2)[0] for im in ims] else: ims = [torch.mean(im, dim=2) for im in ims] if len(ims) != 1: # sequence im = gif_and_tile(ims, just_gif=self.just_gif) else: im = torch.stack(ims, dim=1) # single frame B, S, C, H, W = list(im.shape) if logvis and max_val: max_val = np.log(max_val) im = torch.log(torch.clamp(im, 0) + 1.0) im = torch.clamp(im, 0, max_val) im = im / max_val norm = False elif max_val: im = torch.clamp(im, 0, max_val) im = im / max_val norm = False if norm: # normalize before oned2inferno, # so that the ranges are similar within B across S im = utils.basic.normalize(im) im = im.view(B * S, C, H, W) vis = oned2inferno(im, norm=norm, do_colorize=do_colorize) vis = vis.view(B, S, 3, H, W) if frame_ids is not None: assert (len(frame_ids) == S) for s in range(S): vis[:, s] = draw_frame_id_on_vis(vis[:, s], frame_ids[s]) if W > self.maxwidth: vis = vis[..., :self.maxwidth] if only_return: return vis else: self.summ_gif(name, vis) def summ_oned(self, name, im, bev=False, fro=False, logvis=False, max_val=0, max_along_y=False, norm=True, frame_id=None, only_return=False): if self.save_this: if bev: B, C, H, _, W = list(im.shape) if max_along_y: im = torch.max(im, dim=3)[0] else: im = torch.mean(im, dim=3) elif fro: B, C, _, H, W = list(im.shape) if max_along_y: im = torch.max(im, dim=2)[0] else: im = torch.mean(im, dim=2) else: B, C, H, W = list(im.shape) im = im[0:1] # just the first one assert (C == 1) if logvis and max_val: max_val = np.log(max_val) im = torch.log(im) im = torch.clamp(im, 0, max_val) im = im / max_val norm = False elif max_val: im = torch.clamp(im, 0, max_val) / max_val norm = False vis = oned2inferno(im, norm=norm) if W > self.maxwidth: vis = vis[..., :self.maxwidth] return self.summ_rgb(name, vis, blacken_zeros=False, frame_id=frame_id, only_return=only_return) def summ_feats(self, name, feats, valids=None, pca=True, fro=False, only_return=False, frame_ids=None): if self.save_this: if valids is not None: valids = torch.stack(valids, dim=1) feats = torch.stack(feats, dim=1) # feats leads with B x S x C if feats.ndim == 6: # feats is B x S x C x D x H x W if fro: reduce_dim = 3 else: reduce_dim = 4 if valids is None: feats = torch.mean(feats, dim=reduce_dim) else: valids = valids.repeat(1, 1, feats.size()[2], 1, 1, 1) feats = utils.basic.reduce_masked_mean(feats, valids, dim=reduce_dim) B, S, C, D, W = list(feats.size()) if not pca: # feats leads with B x S x C feats = torch.mean(torch.abs(feats), dim=2, keepdims=True) # feats leads with B x S x 1 feats = torch.unbind(feats, dim=1) return self.summ_oneds(name=name, ims=feats, norm=True, only_return=only_return, frame_ids=frame_ids) else: __p = lambda x: utils.basic.pack_seqdim(x, B) __u = lambda x: utils.basic.unpack_seqdim(x, B) feats_ = __p(feats) if valids is None: feats_pca_ = get_feat_pca(feats_) else: valids_ = __p(valids) feats_pca_ = get_feat_pca(feats_, valids) feats_pca = __u(feats_pca_) return self.summ_rgbs(name=name, ims=torch.unbind(feats_pca, dim=1), only_return=only_return, frame_ids=frame_ids) def summ_feat(self, name, feat, valid=None, pca=True, only_return=False, bev=False, fro=False, frame_id=None): if self.save_this: if feat.ndim == 5: # B x C x D x H x W if bev: reduce_axis = 3 elif fro: reduce_axis = 2 else: # default to bev reduce_axis = 3 if valid is None: feat = torch.mean(feat, dim=reduce_axis) else: valid = valid.repeat(1, feat.size()[1], 1, 1, 1) feat = utils.basic.reduce_masked_mean(feat, valid, dim=reduce_axis) B, C, D, W = list(feat.shape) if not pca: feat = torch.mean(torch.abs(feat), dim=1, keepdims=True) # feat is B x 1 x D x W return self.summ_oned(name=name, im=feat, norm=True, only_return=only_return, frame_id=frame_id) else: feat_pca = get_feat_pca(feat, valid) return self.summ_rgb(name, feat_pca, only_return=only_return, frame_id=frame_id) def summ_scalar(self, name, value): if (not (isinstance(value, int) or isinstance(value, float) or isinstance(value, np.float32))) and ( 'torch' in value.type()): value = value.detach().cpu().numpy() if not np.isnan(value): if (self.log_freq == 1): self.writer.add_scalar(name, value, global_step=self.global_step) elif self.save_this or np.mod(self.global_step, self.scalar_freq) == 0: self.writer.add_scalar(name, value, global_step=self.global_step) def summ_seg(self, name, seg, only_return=False, frame_id=None, colormap='tab20', label_colors=None): if not self.save_this: return B, H, W = seg.shape if label_colors is None: custom_label_colors = False # label_colors = get_n_colors(int(torch.max(seg).item()), sequential=True) label_colors = cm.get_cmap(colormap).colors label_colors = [[int(i * 255) for i in l] for l in label_colors] else: custom_label_colors = True # label_colors = matplotlib.cm.get_cmap(colormap).colors # label_colors = [[int(i*255) for i in l] for l in label_colors] # print('label_colors', label_colors) # label_colors = [ # (0, 0, 0), # None # (70, 70, 70), # Buildings # (190, 153, 153), # Fences # (72, 0, 90), # Other # (220, 20, 60), # Pedestrians # (153, 153, 153), # Poles # (157, 234, 50), # RoadLines # (128, 64, 128), # Roads # (244, 35, 232), # Sidewalks # (107, 142, 35), # Vegetation # (0, 0, 255), # Vehicles # (102, 102, 156), # Walls # (220, 220, 0) # TrafficSigns # ] r = torch.zeros_like(seg, dtype=torch.uint8) g = torch.zeros_like(seg, dtype=torch.uint8) b = torch.zeros_like(seg, dtype=torch.uint8) for label in range(0, len(label_colors)): if (not custom_label_colors): # and (N > 20): label_ = label % 20 else: label_ = label idx = (seg == label + 1) r[idx] = label_colors[label_][0] g[idx] = label_colors[label_][1] b[idx] = label_colors[label_][2] rgb = torch.stack([r, g, b], axis=1) return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id) def summ_pts_on_rgb(self, name, trajs, rgb, valids=None, frame_id=None, only_return=False, show_dots=True, cmap='coolwarm', linewidth=1): # trajs is B, S, N, 2 # rgbs is B, S, C, H, W B, C, H, W = rgb.shape B, S, N, D = trajs.shape rgb = rgb[0] # C, H, W trajs = trajs[0] # S, N, 2 if valids is None: valids = torch.ones_like(trajs[:, :, 0]) # S, N else: valids = valids[0] # print('trajs', trajs.shape) # print('valids', valids.shape) rgb = back2color(rgb).detach().cpu().numpy() rgb = np.transpose(rgb, [1, 2, 0]) # put channels last trajs = trajs.long().detach().cpu().numpy() # S, N, 2 valids = valids.long().detach().cpu().numpy() # S, N rgb = rgb.astype(np.uint8).copy() for i in range(N): if cmap == 'onediff' and i == 0: cmap_ = 'spring' elif cmap == 'onediff': cmap_ = 'winter' else: cmap_ = cmap traj = trajs[:, i] # S,2 valid = valids[:, i] # S color_map = cm.get_cmap(cmap) color = np.array(color_map(i)[:3]) * 255 # rgb for s in range(S): if valid[s]: cv2.circle(rgb, (int(traj[s, 0]), int(traj[s, 1])), linewidth, color, -1) rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) rgb = preprocess_color(rgb) return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id) def summ_pts_on_rgbs(self, name, trajs, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap='coolwarm', linewidth=1): # trajs is B, S, N, 2 # rgbs is B, S, C, H, W B, S, C, H, W = rgbs.shape B, S2, N, D = trajs.shape assert (S == S2) rgbs = rgbs[0] # S, C, H, W trajs = trajs[0] # S, N, 2 if valids is None: valids = torch.ones_like(trajs[:, :, 0]) # S, N else: valids = valids[0] # print('trajs', trajs.shape) # print('valids', valids.shape) rgbs_color = [] for rgb in rgbs: rgb = back2color(rgb).detach().cpu().numpy() rgb = np.transpose(rgb, [1, 2, 0]) # put channels last rgbs_color.append(rgb) # each element 3 x H x W trajs = trajs.long().detach().cpu().numpy() # S, N, 2 valids = valids.long().detach().cpu().numpy() # S, N rgbs_color = [rgb.astype(np.uint8).copy() for rgb in rgbs_color] for i in range(N): traj = trajs[:, i] # S,2 valid = valids[:, i] # S color_map = cm.get_cmap(cmap) color = np.array(color_map(0)[:3]) * 255 # rgb for s in range(S): if valid[s]: cv2.circle(rgbs_color[s], (traj[s, 0], traj[s, 1]), linewidth, color, -1) rgbs = [] for rgb in rgbs_color: rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) rgbs.append(preprocess_color(rgb)) return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) def summ_traj2ds_on_rgbs(self, name, trajs, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=False, cmap='coolwarm', vals=None, linewidth=1): # trajs is B, S, N, 2 # rgbs is B, S, C, H, W B, S, C, H, W = rgbs.shape B, S2, N, D = trajs.shape assert (S == S2) rgbs = rgbs[0] # S, C, H, W trajs = trajs[0] # S, N, 2 if valids is None: valids = torch.ones_like(trajs[:, :, 0]) # S, N else: valids = valids[0] # print('trajs', trajs.shape) # print('valids', valids.shape) if vals is not None: vals = vals[0] # N # print('vals', vals.shape) rgbs_color = [] for rgb in rgbs: rgb = back2color(rgb).detach().cpu().numpy() rgb = np.transpose(rgb, [1, 2, 0]) # put channels last rgbs_color.append(rgb) # each element 3 x H x W for i in range(N): if cmap == 'onediff' and i == 0: cmap_ = 'spring' elif cmap == 'onediff': cmap_ = 'winter' else: cmap_ = cmap traj = trajs[:, i].long().detach().cpu().numpy() # S, 2 valid = valids[:, i].long().detach().cpu().numpy() # S # print('traj', traj.shape) # print('valid', valid.shape) if vals is not None: # val = vals[:,i].float().detach().cpu().numpy() # [] val = vals[i].float().detach().cpu().numpy() # [] # print('val', val.shape) else: val = None for t in range(S): # if valid[t]: # traj_seq = traj[max(t-16,0):t+1] traj_seq = traj[max(t - 8, 0):t + 1] val_seq = np.linspace(0, 1, len(traj_seq)) # if t<2: # val_seq = np.zeros_like(val_seq) # print('val_seq', val_seq) # val_seq = 1.0 # val_seq = np.arange(8)/8.0 # val_seq = val_seq[-len(traj_seq):] # rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth) rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth) # input() for i in range(N): if cmap == 'onediff' and i == 0: cmap_ = 'spring' elif cmap == 'onediff': cmap_ = 'winter' else: cmap_ = cmap traj = trajs[:, i] # S,2 # vis = visibles[:,i] # S vis = torch.ones_like(traj[:, 0]) # S valid = valids[:, i] # S rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=0, show_dots=show_dots, cmap=cmap_, linewidth=linewidth) rgbs = [] for rgb in rgbs_color: rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) rgbs.append(preprocess_color(rgb)) return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) def summ_traj2ds_on_rgbs2(self, name, trajs, visibles, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap=None, linewidth=1): # trajs is B, S, N, 2 # rgbs is B, S, C, H, W B, S, C, H, W = rgbs.shape B, S2, N, D = trajs.shape assert (S == S2) rgbs = rgbs[0] # S, C, H, W trajs = trajs[0] # S, N, 2 visibles = visibles[0] # S, N if valids is None: valids = torch.ones_like(trajs[:, :, 0]) # S, N else: valids = valids[0] # print('trajs', trajs.shape) # print('valids', valids.shape) rgbs_color = [] for rgb in rgbs: rgb = back2color(rgb).detach().cpu().numpy() rgb = np.transpose(rgb, [1, 2, 0]) # put channels last rgbs_color.append(rgb) # each element 3 x H x W trajs = trajs.long().detach().cpu().numpy() # S, N, 2 visibles = visibles.float().detach().cpu().numpy() # S, N valids = valids.long().detach().cpu().numpy() # S, N for i in range(N): if cmap == 'onediff' and i == 0: cmap_ = 'spring' elif cmap == 'onediff': cmap_ = 'winter' else: cmap_ = cmap traj = trajs[:, i] # S,2 vis = visibles[:, i] # S valid = valids[:, i] # S rgbs_color = self.draw_traj_on_images_py(rgbs_color, traj, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth) for i in range(N): if cmap == 'onediff' and i == 0: cmap_ = 'spring' elif cmap == 'onediff': cmap_ = 'winter' else: cmap_ = cmap traj = trajs[:, i] # S,2 vis = visibles[:, i] # S valid = valids[:, i] # S if valid[0]: rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=None, linewidth=linewidth) rgbs = [] for rgb in rgbs_color: rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) rgbs.append(preprocess_color(rgb)) return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) def summ_traj2ds_on_rgb(self, name, trajs, rgb, valids=None, show_dots=False, show_lines=True, frame_id=None, only_return=False, cmap='coolwarm', linewidth=1): # trajs is B, S, N, 2 # rgb is B, C, H, W B, C, H, W = rgb.shape B, S, N, D = trajs.shape rgb = rgb[0] # S, C, H, W trajs = trajs[0] # S, N, 2 if valids is None: valids = torch.ones_like(trajs[:, :, 0]) else: valids = valids[0] rgb_color = back2color(rgb).detach().cpu().numpy() rgb_color = np.transpose(rgb_color, [1, 2, 0]) # put channels last # using maxdist will dampen the colors for short motions norms = torch.sqrt(1e-4 + torch.sum((trajs[-1] - trajs[0]) ** 2, dim=1)) # N maxdist = torch.quantile(norms, 0.95).detach().cpu().numpy() maxdist = None trajs = trajs.long().detach().cpu().numpy() # S, N, 2 valids = valids.long().detach().cpu().numpy() # S, N for i in range(N): if cmap == 'onediff' and i == 0: cmap_ = 'spring' elif cmap == 'onediff': cmap_ = 'winter' else: cmap_ = cmap traj = trajs[:, i] # S, 2 valid = valids[:, i] # S if valid[0] == 1: traj = traj[valid > 0] rgb_color = self.draw_traj_on_image_py( rgb_color, traj, S=S, show_dots=show_dots, show_lines=show_lines, cmap=cmap_, maxdist=maxdist, linewidth=linewidth) rgb_color = torch.from_numpy(rgb_color).permute(2, 0, 1).unsqueeze(0) rgb = preprocess_color(rgb_color) return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id) def draw_traj_on_image_py(self, rgb, traj, S=50, linewidth=1, show_dots=False, show_lines=True, cmap='coolwarm', val=None, maxdist=None): # all inputs are numpy tensors # rgb is 3 x H x W # traj is S x 2 H, W, C = rgb.shape assert (C == 3) rgb = rgb.astype(np.uint8).copy() S1, D = traj.shape assert (D == 2) color_map = cm.get_cmap(cmap) S1, D = traj.shape for s in range(S1): if val is not None: # if len(val) == S1: color = np.array(color_map(val[s])[:3]) * 255 # rgb # else: # color = np.array(color_map(val)[:3]) * 255 # rgb else: if maxdist is not None: val = (np.sqrt(np.sum((traj[s] - traj[0]) ** 2)) / maxdist).clip(0, 1) color = np.array(color_map(val)[:3]) * 255 # rgb else: color = np.array(color_map((s) / max(1, float(S - 2)))[:3]) * 255 # rgb if show_lines and s < (S1 - 1): cv2.line(rgb, (int(traj[s, 0]), int(traj[s, 1])), (int(traj[s + 1, 0]), int(traj[s + 1, 1])), color, linewidth, cv2.LINE_AA) if show_dots: cv2.circle(rgb, (int(traj[s, 0]), int(traj[s, 1])), linewidth, np.array(color_map(1)[:3]) * 255, -1) # if maxdist is not None: # val = (np.sqrt(np.sum((traj[-1]-traj[0])**2))/maxdist).clip(0,1) # color = np.array(color_map(val)[:3]) * 255 # rgb # else: # # draw the endpoint of traj, using the next color (which may be the last color) # color = np.array(color_map((S1-1)/max(1,float(S-2)))[:3]) * 255 # rgb # # emphasize endpoint # cv2.circle(rgb, (traj[-1,0], traj[-1,1]), linewidth*2, color, -1) return rgb def draw_traj_on_images_py(self, rgbs, traj, S=50, linewidth=1, show_dots=False, cmap='coolwarm', maxdist=None): # all inputs are numpy tensors # rgbs is a list of H,W,3 # traj is S,2 H, W, C = rgbs[0].shape assert (C == 3) rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs] S1, D = traj.shape assert (D == 2) x = int(np.clip(traj[0, 0], 0, W - 1)) y = int(np.clip(traj[0, 1], 0, H - 1)) color = rgbs[0][y, x] color = (int(color[0]), int(color[1]), int(color[2])) for s in range(S): # bak_color = np.array(color_map(1.0)[:3]) * 255 # rgb # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth*4, bak_color, -1) cv2.polylines(rgbs[s], [traj[:s + 1]], False, color, linewidth, cv2.LINE_AA) return rgbs def draw_circs_on_image_py(self, rgb, xy, colors=None, linewidth=10, radius=3, show_dots=False, maxdist=None): # all inputs are numpy tensors # rgbs is a list of 3,H,W # xy is N,2 H, W, C = rgb.shape assert (C == 3) rgb = rgb.astype(np.uint8).copy() N, D = xy.shape assert (D == 2) xy = xy.astype(np.float32) xy[:, 0] = np.clip(xy[:, 0], 0, W - 1) xy[:, 1] = np.clip(xy[:, 1], 0, H - 1) xy = xy.astype(np.int32) if colors is None: colors = get_n_colors(N) for n in range(N): color = colors[n] # print('color', color) # color = (color[0]*255).astype(np.uint8) color = (int(color[0]), int(color[1]), int(color[2])) # x = int(np.clip(xy[0,0], 0, W-1)) # y = int(np.clip(xy[0,1], 0, H-1)) # color_ = rgbs[0][y,x] # color_ = (int(color_[0]),int(color_[1]),int(color_[2])) # color_ = (int(color_[0]),int(color_[1]),int(color_[2])) cv2.circle(rgb, (xy[n, 0], xy[n, 1]), linewidth, color, 3) # vis_color = int(np.squeeze(vis[s])*255) # vis_color = (vis_color,vis_color,vis_color) # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth+1, vis_color, -1) return rgb def draw_circ_on_images_py(self, rgbs, traj, vis, S=50, linewidth=1, show_dots=False, cmap=None, maxdist=None): # all inputs are numpy tensors # rgbs is a list of 3,H,W # traj is S,2 H, W, C = rgbs[0].shape assert (C == 3) rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs] S1, D = traj.shape assert (D == 2) if cmap is None: bremm = ColorMap2d() traj_ = traj[0:1].astype(np.float32) traj_[:, 0] /= float(W) traj_[:, 1] /= float(H) color = bremm(traj_) # print('color', color) color = (color[0] * 255).astype(np.uint8) # color = (int(color[0]),int(color[1]),int(color[2])) color = (int(color[2]), int(color[1]), int(color[0])) for s in range(S1): if cmap is not None: color_map = cm.get_cmap(cmap) # color = np.array(color_map(s/(S-1))[:3]) * 255 # rgb color = np.array(color_map((s + 1) / max(1, float(S - 1)))[:3]) * 255 # rgb # color = color.astype(np.uint8) # color = (color[0], color[1], color[2]) # print('color', color) # import ipdb; ipdb.set_trace() cv2.circle(rgbs[s], (int(traj[s, 0]), int(traj[s, 1])), linewidth + 1, color, -1) # vis_color = int(np.squeeze(vis[s])*255) # vis_color = (vis_color,vis_color,vis_color) # cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, vis_color, -1) return rgbs def summ_traj_as_crops(self, name, trajs_e, rgbs, frame_id=None, only_return=False, show_circ=False, trajs_g=None, is_g=False): B, S, N, D = trajs_e.shape assert (N == 1) assert (D == 2) rgbs_vis = [] n = 0 pad_amount = 100 trajs_e_py = trajs_e[0].detach().cpu().numpy() # trajs_e_py = np.clip(trajs_e_py, min=pad_amount/2, max=pad_amoun trajs_e_py = trajs_e_py + pad_amount if trajs_g is not None: trajs_g_py = trajs_g[0].detach().cpu().numpy() trajs_g_py = trajs_g_py + pad_amount for s in range(S): rgb = rgbs[0, s].detach().cpu().numpy() # print('orig rgb', rgb.shape) rgb = np.transpose(rgb, (1, 2, 0)) # H, W, 3 rgb = np.pad(rgb, ((pad_amount, pad_amount), (pad_amount, pad_amount), (0, 0))) # print('pad rgb', rgb.shape) H, W, C = rgb.shape if trajs_g is not None: xy_g = trajs_g_py[s, n] xy_g[0] = np.clip(xy_g[0], pad_amount, W - pad_amount) xy_g[1] = np.clip(xy_g[1], pad_amount, H - pad_amount) rgb = self.draw_circs_on_image_py(rgb, xy_g.reshape(1, 2), colors=[(0, 255, 0)], linewidth=2, radius=3) xy_e = trajs_e_py[s, n] xy_e[0] = np.clip(xy_e[0], pad_amount, W - pad_amount) xy_e[1] = np.clip(xy_e[1], pad_amount, H - pad_amount) if show_circ: if is_g: rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1, 2), colors=[(0, 255, 0)], linewidth=2, radius=3) else: rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1, 2), colors=[(255, 0, 255)], linewidth=2, radius=3) xmin = int(xy_e[0]) - pad_amount // 2 xmax = int(xy_e[0]) + pad_amount // 2 ymin = int(xy_e[1]) - pad_amount // 2 ymax = int(xy_e[1]) + pad_amount // 2 rgb_ = rgb[ymin:ymax, xmin:xmax] H_, W_ = rgb_.shape[:2] # if np.any(rgb_.shape==0): # input() if H_ == 0 or W_ == 0: import ipdb; ipdb.set_trace() rgb_ = rgb_.transpose(2, 0, 1) rgb_ = torch.from_numpy(rgb_) rgbs_vis.append(rgb_) # nrow = int(np.sqrt(S)*(16.0/9)/2.0) nrow = int(np.sqrt(S) * 1.5) grid_img = torchvision.utils.make_grid(torch.stack(rgbs_vis, dim=0), nrow=nrow).unsqueeze(0) # print('grid_img', grid_img.shape) return self.summ_rgb(name, grid_img.byte(), frame_id=frame_id, only_return=only_return) def summ_occ(self, name, occ, reduce_axes=[3], bev=False, fro=False, pro=False, frame_id=None, only_return=False): if self.save_this: B, C, D, H, W = list(occ.shape) if bev: reduce_axes = [3] elif fro: reduce_axes = [2] elif pro: reduce_axes = [4] for reduce_axis in reduce_axes: height = convert_occ_to_height(occ, reduce_axis=reduce_axis) if reduce_axis == reduce_axes[-1]: return self.summ_oned(name=('%s_ax%d' % (name, reduce_axis)), im=height, norm=False, frame_id=frame_id, only_return=only_return) else: self.summ_oned(name=('%s_ax%d' % (name, reduce_axis)), im=height, norm=False, frame_id=frame_id, only_return=only_return) def erode2d(im, times=1, device='cuda'): weights2d = torch.ones(1, 1, 3, 3, device=device) for time in range(times): im = 1.0 - F.conv2d(1.0 - im, weights2d, padding=1).clamp(0, 1) return im def dilate2d(im, times=1, device='cuda', mode='square'): weights2d = torch.ones(1, 1, 3, 3, device=device) if mode == 'cross': weights2d[:, :, 0, 0] = 0.0 weights2d[:, :, 0, 2] = 0.0 weights2d[:, :, 2, 0] = 0.0 weights2d[:, :, 2, 2] = 0.0 for time in range(times): im = F.conv2d(im, weights2d, padding=1).clamp(0, 1) return im ================================================ FILE: mvtracker/utils/misc.py ================================================ import numpy as np import torch from prettytable import PrettyTable def count_parameters(model): table = PrettyTable(["Modules", "Parameters"]) total_params = 0 for name, parameter in model.named_parameters(): if not parameter.requires_grad: continue param = parameter.numel() if param > 100000: table.add_row([name, param]) total_params += param print(table) print('total params: %.2f M' % (total_params / 1000000.0)) return total_params def posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False): device = xy.device dtype = xy.dtype B, S, D = xy.shape assert (D == 2) x = xy[:, :, 0] y = xy[:, :, 1] assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' omega = torch.arange(C // 4, device=device) / (C // 4 - 1) omega = 1. / (temperature ** omega) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) pe = pe.reshape(B, S, C).type(dtype) if cat_coords: pe = torch.cat([pe, xy], dim=2) # B,N,C+2 return pe class SimplePool(): def __init__(self, pool_size, version='pt'): self.pool_size = pool_size self.version = version self.items = [] if not (version == 'pt' or version == 'np'): print('version = %s; please choose pt or np') assert (False) # please choose pt or np def __len__(self): return len(self.items) def mean(self, min_size=1): if min_size == 'half': pool_size_thresh = self.pool_size / 2 else: pool_size_thresh = min_size if self.version == 'np': if len(self.items) >= pool_size_thresh: return np.sum(self.items) / float(len(self.items)) else: return np.nan if self.version == 'pt': if len(self.items) >= pool_size_thresh: return torch.sum(self.items) / float(len(self.items)) else: return torch.from_numpy(np.nan) def sample(self, with_replacement=True): idx = np.random.randint(len(self.items)) if with_replacement: return self.items[idx] else: return self.items.pop(idx) def fetch(self, num=None): if self.version == 'pt': item_array = torch.stack(self.items) elif self.version == 'np': item_array = np.stack(self.items) if num is not None: # there better be some items assert (len(self.items) >= num) # if there are not that many elements just return however many there are if len(self.items) < num: return item_array else: idxs = np.random.randint(len(self.items), size=num) return item_array[idxs] else: return item_array def is_full(self): full = len(self.items) == self.pool_size return full def empty(self): self.items = [] def update(self, items): for item in items: if len(self.items) < self.pool_size: # the pool is not full, so let's add this in self.items.append(item) else: # the pool is full # pop from the front self.items.pop(0) # add to the back self.items.append(item) return self.items def farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False): """ Input: xyz: pointcloud data, [B, N, C], where C is probably 3 npoint: number of samples Return: inds: sampled pointcloud index, [B, npoint] """ device = xyz.device B, N, C = xyz.shape xyz = xyz.float() inds = torch.zeros(B, npoint, dtype=torch.long).to(device) distance = torch.ones(B, N).to(device) * 1e10 if deterministic: farthest = torch.randint(0, 1, (B,), dtype=torch.long).to(device) else: farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) batch_indices = torch.arange(B, dtype=torch.long).to(device) for i in range(npoint): if include_ends: if i == 0: farthest = 0 elif i == 1: farthest = N - 1 inds[:, i] = farthest centroid = xyz[batch_indices, farthest, :].view(B, 1, C) dist = torch.sum((xyz - centroid) ** 2, -1) mask = dist < distance distance[mask] = dist[mask] farthest = torch.max(distance, -1)[1] if npoint > N: # if we need more samples, make them random distance += torch.randn_like(distance) return inds def farthest_point_sample_py(xyz, npoint): N, C = xyz.shape inds = np.zeros(npoint, dtype=np.int32) distance = np.ones(N) * 1e10 farthest = np.random.randint(0, N, dtype=np.int32) for i in range(npoint): inds[i] = farthest centroid = xyz[farthest, :].reshape(1, C) dist = np.sum((xyz - centroid) ** 2, -1) mask = dist < distance distance[mask] = dist[mask] farthest = np.argmax(distance, -1) if npoint > N: # if we need more samples, make them random distance += np.random.randn(*distance.shape) return inds ================================================ FILE: mvtracker/utils/visualizer_mp4.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import os import threading from typing import Tuple import cv2 import flow_vis import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F import torchvision.transforms as transforms from matplotlib import cm from moviepy.editor import ImageSequenceClip from mvtracker.models.core.model_utils import world_space_to_pixel_xy_and_camera_z def read_video_from_path(path): cap = cv2.VideoCapture(path) if not cap.isOpened(): raise ValueError(f"Unable to open video file: {path}") frames = [] while cap.isOpened(): ret, frame = cap.read() if ret == True: frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) else: break cap.release() return np.stack(frames) class Visualizer: def __init__( self, save_dir: str = "./results", grayscale: bool = False, pad_value: int = 0, fps: int = 10, mode: str = "rainbow", # 'cool', 'optical_flow' linewidth: int = 2, show_first_frame: int = 10, tracks_leave_trace: int = 0, # -1 for infinite tracks_use_alpha: bool = False, print_debug_info: bool = False, ): self.mode = mode self.save_dir = save_dir if mode == "rainbow": self.color_map = cm.get_cmap("gist_rainbow") elif mode == "cool": self.color_map = cm.get_cmap(mode) self.show_first_frame = show_first_frame self.grayscale = grayscale self.tracks_leave_trace = tracks_leave_trace self.tracks_use_alpha = tracks_use_alpha self.print_debug_info = print_debug_info self.pad_value = pad_value self.linewidth = linewidth self.fps = fps def visualize( self, video: torch.Tensor, # (B,T,C,H,W) tracks: torch.Tensor, # (B,T,N,2) visibility: torch.Tensor = None, # (B, T, N) bool gt_tracks: torch.Tensor = None, # (B,T,N,2) segm_mask: torch.Tensor = None, # (B,1,H,W) filename: str = "video", writer=None, # tensorboard Summary Writer, used for visualization during training step: int = 0, query_frame: torch.Tensor = None, # (B,N) save_video: bool = True, compensate_for_camera_motion: bool = False, rigid_part=None, video_depth=None, # (B,T,C,H,W) vector_colors=None, ): batch_size, num_frames, _, height, width = video.shape num_points = tracks.shape[-2] num_dims = tracks.shape[-1] assert video.shape == (batch_size, num_frames, 3, height, width) assert tracks.shape == (batch_size, num_frames, num_points, num_dims) if visibility is not None: assert visibility.shape == (batch_size, num_frames, num_points) if gt_tracks is not None: assert gt_tracks.shape == (batch_size, num_frames, num_points, num_dims) if query_frame is not None: assert query_frame.shape == (batch_size, num_points) if compensate_for_camera_motion: assert segm_mask is not None if segm_mask is not None: assert (query_frame == 0).all().item() coords = tracks[0, 0].round().long() segm_mask = segm_mask[0, 0][coords[:, 1], coords[:, 0]].long() video = F.pad( video, (self.pad_value, self.pad_value, self.pad_value, self.pad_value), "constant", 255, ) if video_depth is not None: video_depth = video_depth.squeeze(2) video_depth = video_depth.cpu().numpy() highest_depth_value = max(video_depth.max(), 100) video_depth = plt.cm.Spectral(video_depth / highest_depth_value) * 255 video_depth = video_depth[..., :3] video_depth = video_depth.astype(np.uint8) video_depth = torch.from_numpy(video_depth) video_depth = video_depth.permute(0, 1, 4, 2, 3) video_depth = F.pad( video_depth, (self.pad_value, self.pad_value, self.pad_value, self.pad_value), "constant", 255, ) tracks = tracks + self.pad_value if self.grayscale: transform = transforms.Grayscale() video = transform(video) video = video.repeat(1, 1, 3, 1, 1) res_video, vector_colors = self.draw_tracks_on_video( video=video, tracks=tracks[..., :2], visibility=visibility, segm_mask=segm_mask, gt_tracks=gt_tracks, query_frame=query_frame, compensate_for_camera_motion=compensate_for_camera_motion, rigid_part=rigid_part, vector_colors=vector_colors, ) if video_depth is not None: res_video_depth, _ = self.draw_tracks_on_video( video=video_depth, tracks=tracks[..., :2], visibility=visibility, segm_mask=segm_mask, gt_tracks=gt_tracks, query_frame=query_frame, compensate_for_camera_motion=compensate_for_camera_motion, vector_colors=vector_colors, ) res_video = torch.cat([res_video, res_video_depth], dim=4) # B, T, 3, H, [W] if save_video: # self.save_video(res_video, filename=filename, writer=writer, step=step) thread = threading.Thread( target=Visualizer.save_video, args=(res_video, self.save_dir, filename, writer, self.fps, step) ) thread.start() return res_video, vector_colors @staticmethod def save_video(video, save_dir, filename, writer=None, fps=12, step=0): if writer is not None: writer.add_video(f"{filename}", video.to(torch.uint8), global_step=step, fps=fps) writer.flush() logging.info(f"Video {filename} saved to tensorboard") if save_dir is not None: os.makedirs(save_dir, exist_ok=True) wide_list = list(video.unbind(1)) wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list] clip = ImageSequenceClip(wide_list, fps=fps) # Write the video file save_path = os.path.join(save_dir, f"{filename}_step_{step}.mp4") clip.write_videofile(save_path, codec="libx264", fps=fps, logger=None) logging.info(f"Video saved to {save_path}") def draw_tracks_on_video( self, video: torch.Tensor, tracks: torch.Tensor, visibility: torch.Tensor = None, segm_mask: torch.Tensor = None, gt_tracks=None, query_frame: torch.Tensor = None, compensate_for_camera_motion=False, vector_colors=None, rigid_part=None, ): B, T, C, H, W = video.shape _, _, N, D = tracks.shape assert D == 2 assert C == 3 video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2 if query_frame is not None: query_frame = query_frame[0].long().detach().cpu().numpy() # N if gt_tracks is not None: gt_tracks = gt_tracks[0].detach().cpu().numpy() res_video = [] # process input video for rgb in video: res_video.append(rgb.copy()) if vector_colors is None: vector_colors = np.zeros((T, N, 3)) if self.mode == "optical_flow": vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame, torch.arange(N)][None]) elif segm_mask is None: if self.mode == "rainbow": # y_min, y_max = ( # tracks[query_frame, :, 1].min(), # tracks[query_frame, :, 1].max(), # ) y_min, y_max = 0, H norm = plt.Normalize(y_min, y_max) for n in range(N): color = self.color_map(norm(tracks[query_frame[n], n, 1])) color = np.array(color[:3])[None] * 255 vector_colors[:, n] = np.repeat(color, T, axis=0) else: # color changes with time for t in range(T): color = np.array(self.color_map(t / T)[:3])[None] * 255 vector_colors[t] = np.repeat(color, N, axis=0) else: if self.mode == "rainbow": vector_colors[:, segm_mask <= 0, :] = 255 # y_min, y_max = ( # tracks[0, segm_mask > 0, 1].min(), # tracks[0, segm_mask > 0, 1].max(), # ) y_min, y_max = 0, H norm = plt.Normalize(y_min, y_max) for n in range(N): if segm_mask[n] > 0: color = self.color_map(norm(tracks[0, n, 1])) color = np.array(color[:3])[None] * 255 vector_colors[:, n] = np.repeat(color, T, axis=0) else: # color changes with segm class segm_mask = segm_mask.cpu() color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32) color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0 color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0 vector_colors = np.repeat(color[None], T, axis=0) # draw tracks if self.tracks_leave_trace != 0: for t in range(1, T): first_ind = ( max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0 ) curr_tracks = tracks[first_ind: t + 1] curr_colors = vector_colors[first_ind: t + 1] if compensate_for_camera_motion: diff = ( tracks[first_ind: t + 1, segm_mask <= 0] - tracks[t: t + 1, segm_mask <= 0] ).mean(1)[:, None] curr_tracks = curr_tracks - diff curr_tracks = curr_tracks[:, segm_mask > 0] curr_colors = curr_colors[:, segm_mask > 0] res_video[t] = self._draw_pred_tracks( res_video[t], curr_tracks, curr_colors, query_frame - first_ind, use_alpha=self.tracks_use_alpha, ) if gt_tracks is not None: res_video[t] = self._draw_gt_tracks( res_video[t], gt_tracks[first_ind: t + 1] ) # Add frame number if self.print_debug_info: for t in range(T): min_x = tracks[t].min(0)[0] min_y = tracks[t].min(0)[1] min_xy = f"{min_x:6.1f}, {min_y:6.1f}" median_x = np.median(tracks[t], axis=0)[0] median_y = np.median(tracks[t], axis=0)[1] median_xy = f"{median_x:6.1f}, {median_y:6.1f}" max_x = tracks[t].max(0)[0] max_y = tracks[t].max(0)[1] max_xy = f"{max_x:6.1f}, {max_y:6.1f}" text = ( f"Frame {t}" f"\nH,W={H},{W}" f"\nT,N={T},{N}" f"\nmin_xy = {min_xy} " f"\nmedian_xy = {median_xy} " f"\nmax_xy = {max_xy} " ) res_video[t] = put_debug_text_onto_image(res_video[t], text) if rigid_part is not None: cls_label = torch.unique(rigid_part) cls_num = len(torch.unique(rigid_part)) # visualize the clustering results cmap = plt.get_cmap('jet') # get the color mapping colors = cmap(np.linspace(0, 1, cls_num)) colors = (colors[:, :3] * 255) color_map = {label.item(): color for label, color in zip(cls_label, colors)} # draw points for t in range(T): for i in range(N): if query_frame is not None and query_frame[i] > t: continue coord = (tracks[t, i, 0], tracks[t, i, 1]) visibile = True if visibility is not None: visibile = visibility[0, t, i] # Check for NaN or Inf in coordinates if np.isnan(coord).any() or np.isinf(coord).any(): logging.info(f"Warning: Skipping track {i} at t={t} due to NaN or Inf coord={coord}.") continue # Skip plotting this point if coord[0] != 0 and coord[1] != 0: if not compensate_for_camera_motion or ( compensate_for_camera_motion and segm_mask[i] > 0 ): if rigid_part is not None: color = color_map[rigid_part.squeeze()[i].item()] cv2.circle( res_video[t], coord, int(self.linewidth * 2), color.tolist(), thickness=-1 if visibile else 2 - 1, ) else: cv2.circle( res_video[t], coord, int(self.linewidth * 2), vector_colors[t, i].tolist(), thickness=-1 if visibile else 2 - 1, ) # construct the final rgb sequence if self.show_first_frame > 0: res_video = [res_video[0]] * self.show_first_frame + res_video[1:] return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte(), vector_colors def _draw_pred_tracks( self, rgb: np.ndarray, # H x W x 3 tracks: np.ndarray, # shape: [T, N, 2] vector_colors: np.ndarray, # shape: [T, N, 3] query_frame: np.ndarray, # shape: [N], each entry = birth frame for track i use_alpha: bool = False, ) -> np.ndarray: """ Draws trajectory lines from frame s to s+1, but only if s >= query_frame[i]. That is, no lines are drawn before the track 'appears' at query_frame[i]. """ T, N, _ = tracks.shape for s in range(T - 1): # We'll blend older lines more lightly (alpha) if desired: original_rgb = rgb.copy() if use_alpha: alpha = (s / T) ** 2 # or pick some function of s, T else: alpha = 1 for i in range(N): # If the query/birth frame for track i is after s, skip drawing if query_frame is not None and s < query_frame[i]: continue pt_s = (int(tracks[s, i, 0]), int(tracks[s, i, 1])) pt_sp1 = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1])) # Skip if the points are 0 or invalid if pt_s[0] == 0 and pt_s[1] == 0: continue if pt_sp1[0] == 0 and pt_sp1[1] == 0: continue color = vector_colors[s, i].tolist() cv2.line(rgb, pt_s, pt_sp1, color, self.linewidth, cv2.LINE_AA) # Optionally alpha-blend older lines if you want them to fade out: rgb = cv2.addWeighted(rgb, alpha, original_rgb, 1 - alpha, 0) return rgb def _draw_gt_tracks( self, rgb: np.ndarray, # H x W x 3, gt_tracks: np.ndarray, # T x 2 ): T, N, _ = gt_tracks.shape color = np.array((211.0, 0.0, 0.0)) for t in range(T): for i in range(N): gt_tracks = gt_tracks[t][i] # draw a red cross if gt_tracks[0] > 0 and gt_tracks[1] > 0: length = self.linewidth * 3 coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length) coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length) cv2.line( rgb, coord_y, coord_x, color, self.linewidth, cv2.LINE_AA, ) coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length) coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length) cv2.line( rgb, coord_y, coord_x, color, self.linewidth, cv2.LINE_AA, ) return rgb def put_debug_text_onto_image(img: np.ndarray, text: str, font_scale: float = 0.5, left: int = 5, top: int = 20, font_thickness: int = 1, text_color_bg: Tuple[int, int, int] = (0, 0, 0)) -> np.ndarray: """ Overlay debug text on the provided image. Parameters ---------- img : np.ndarray A 3D numpy array representing the input image. The image is expected to have three color channels. text : str The debug text to overlay on the image. The text can include newline characters ('\n') to create multi-line text. font_scale : float, default 0.5 The scale factor that is multiplied by the font-specific base size. left : int, default 5 The left-most coordinate where the text is to be put. top : int, default 20 The top-most coordinate where the text is to be put. font_thickness : int, default 1 Thickness of the lines used to draw the text. text_color_bg : Tuple[int, int, int], default (0, 0, 0) The color of the text background in BGR format. Returns ------- img : np.ndarray A 3D numpy array representing the image with the debug text overlaid. """ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) font_color = (255, 255, 255) # Write each line of text in a new row (_, label_height), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness) if text_color_bg is not None: for i, line in enumerate(text.split('\n')): (line_width, _), _ = cv2.getTextSize(line, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness) top_i = top + i * label_height cv2.rectangle(img, (left, top_i - label_height), (left + line_width, top_i), text_color_bg, -1) for i, line in enumerate(text.split('\n')): top_i = top + i * label_height cv2.putText(img, line, (left, top_i), cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_color, font_thickness) img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB) return img class MultiViewVisualizer(Visualizer): def __init__(self, **kwargs): super().__init__(**kwargs) def visualize( self, video: torch.Tensor, # (B,V,T,C,H,W) tracks: torch.Tensor, # (B,V,T,N,2) visibility: torch.Tensor = None, # (B,V,T,N) bool gt_tracks: torch.Tensor = None, # (B,V,T,N,2) segm_mask: torch.Tensor = None, # (B,V,1,H,W) filename: str = "video", writer=None, # tensorboard Summary Writer, used for visualization during training step: int = 0, query_frame: torch.Tensor = None, # (B,N) save_video: bool = True, compensate_for_camera_motion: bool = False, rigid_part=None, video_depth=None, # (B,V,T,C,H,W) vector_colors=None, ): # Replace NaN and Inf values with 0 tracks = tracks.detach().clone().clip(-1e4, 1e4) tracks[torch.isnan(tracks)] = 0 gt_tracks = gt_tracks.detach().clone().clip(-1e4, 1e4) if gt_tracks is not None else None batch_size, num_views, num_frames, _, height, width = video.shape num_points = tracks.shape[-2] num_dims = tracks.shape[-1] # Repeat visibility for each view if only global visibility is provided if visibility is not None and visibility.dim() == 3: visibility = visibility[:, None, :, :].repeat(1, num_views, 1, 1) # Assert shapes of per-view data assert video.shape == (batch_size, num_views, num_frames, 3, height, width) assert tracks.shape == (batch_size, num_views, num_frames, num_points, num_dims) assert num_dims in [2, 3] if gt_tracks is not None: assert gt_tracks.shape == (batch_size, num_views, num_frames, num_points, num_dims) if visibility is not None: assert visibility.shape == (batch_size, num_views, num_frames, num_points) if segm_mask is not None: assert segm_mask.shape == (batch_size, num_views, 1, height, width) if video_depth is not None: assert video_depth.shape == (batch_size, num_views, num_frames, 1, height, width) res_video_list = [] for view_idx in range(num_views): res_video, vector_colors = super(MultiViewVisualizer, self).visualize( # Extract view-specific data video=video[:, view_idx], tracks=tracks[:, view_idx], visibility=visibility[:, view_idx], gt_tracks=gt_tracks[:, view_idx] if gt_tracks is not None else None, segm_mask=segm_mask[:, view_idx] if segm_mask is not None else None, video_depth=video_depth[:, view_idx] if video_depth is not None else None, # Pass-through arguments step=step, query_frame=query_frame, compensate_for_camera_motion=compensate_for_camera_motion, rigid_part=rigid_part, vector_colors=vector_colors, # Disable saving video for individual views as we will save the merged videos filename=None, writer=None, save_video=False ) res_video_list.append(res_video) res_video = torch.cat(res_video_list, dim=3) if save_video: # Visualizer.save_video(res_video, self.save_dir, filename, writer, self.fps, step) thread = threading.Thread( target=Visualizer.save_video, args=(res_video, self.save_dir, filename, writer, self.fps, step) ) thread.start() return res_video, vector_colors def log_mp4_track_viz( log_dir, dataset_name, datapoint_idx, rgbs, intrs, extrs, gt_trajectories, gt_visibilities, pred_trajectories, pred_visibilities, query_points_3d, step=0, prefix="comparison__", max_tracks_to_visualize=36, max_individual_tracks_to_visualize=6, ): batch_size, num_frames, num_points, _ = gt_trajectories.shape num_views = rgbs.shape[1] intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype) extrs_square = torch.eye(4).to(extrs.device)[None].repeat(batch_size, num_views, num_frames, 1, 1) extrs_square[:, :, :, :3, :] = extrs extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype) assert intrs_inv.shape == (batch_size, num_views, num_frames, 3, 3) assert extrs_inv.shape == (batch_size, num_views, num_frames, 4, 4) gt_pix_xy_cam_z = torch.stack([ torch.cat(world_space_to_pixel_xy_and_camera_z( world_xyz=gt_trajectories[0], intrs=intrs[0, view_idx], extrs=extrs[0, view_idx], ), dim=-1) for view_idx in range(num_views) ], dim=0)[None] pred_pix_xy_cam_z = torch.stack([ torch.cat(world_space_to_pixel_xy_and_camera_z( world_xyz=pred_trajectories[0], intrs=intrs[0, view_idx], extrs=extrs[0, view_idx], ), dim=-1) for view_idx in range(num_views) ], dim=0)[None] visualizer = MultiViewVisualizer( save_dir=log_dir, pad_value=0, fps=30 if "panoptic" in dataset_name else 12, show_first_frame=0, tracks_leave_trace=-1, ) seq_name = f"seq-{datapoint_idx}" # Plot all tracks at the same time gt_viz, vector_colors = visualizer.visualize( video=rgbs.cpu(), video_depth=None, tracks=gt_pix_xy_cam_z[:, :, :, :max_tracks_to_visualize].cpu(), visibility=gt_visibilities.clone()[:, :, :max_tracks_to_visualize].cpu(), query_frame=query_points_3d[..., 0].long().clone()[:, :max_tracks_to_visualize].cpu(), filename=f"eval_{dataset_name}_gt_traj_{seq_name}_any_visib", save_video=False, ) pred_viz, _ = visualizer.visualize( video=rgbs.cpu(), video_depth=None, tracks=pred_pix_xy_cam_z[:, :, :, :max_tracks_to_visualize].cpu(), visibility=pred_visibilities[:, :, :max_tracks_to_visualize].cpu(), query_frame=query_points_3d[..., 0].long().clone()[:, :max_tracks_to_visualize].cpu(), filename=f"eval_{dataset_name}_pred_traj_{seq_name}", save_video=False, vector_colors=vector_colors, ) viz = torch.cat([gt_viz, pred_viz], dim=-1) thread = threading.Thread( target=Visualizer.save_video, args=(viz, visualizer.save_dir, f"{prefix}{seq_name}", None, visualizer.fps, step) ) thread.start() thread.join() # Plot individual tracks for track_idx in range(min(num_points, max_individual_tracks_to_visualize)): seq_name_i = f"seq-{datapoint_idx}-point-{track_idx:02d}" gt_viz, vector_colors_i = visualizer.visualize( video=rgbs.cpu(), video_depth=None, tracks=gt_pix_xy_cam_z[:, :, :, track_idx:track_idx + 1].cpu(), visibility=gt_visibilities.clone()[:, :, track_idx:track_idx + 1].cpu(), query_frame=query_points_3d[..., 0].long().clone()[:, track_idx:track_idx + 1].cpu(), filename=f"eval_{dataset_name}_gt_traj_{seq_name_i}_any_visib", step=step, save_video=False, ) pred_viz, _ = visualizer.visualize( video=rgbs.cpu(), video_depth=None, tracks=pred_pix_xy_cam_z[:, :, :, track_idx:track_idx + 1].cpu(), visibility=pred_visibilities[:, :, track_idx:track_idx + 1].cpu(), query_frame=query_points_3d[..., 0].long().clone()[:, track_idx:track_idx + 1].cpu(), filename=f"eval_{dataset_name}_pred_traj_{seq_name_i}", save_video=False, vector_colors=vector_colors_i, ) viz = torch.cat([gt_viz, pred_viz], dim=-1) thread = threading.Thread( target=Visualizer.save_video, args=(viz, visualizer.save_dir, f"{prefix}{seq_name_i}", None, visualizer.fps, step) ) thread.start() thread.join() ================================================ FILE: mvtracker/utils/visualizer_rerun.py ================================================ from typing import Union, Optional, List, Dict, Any import matplotlib import numpy as np import pandas as pd import rerun as rr import seaborn as sns import torch from matplotlib import pyplot as plt, colors as mcolors, cm as cm from sklearn.decomposition import PCA def setup_libs(latex=False): pd.set_option('display.max_rows', 500) pd.set_option('display.max_columns', 500) pd.set_option('display.width', 1000) sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) sns.set_style("ticks") sns.set_palette("flare") if latex: plt.rc('font', **{'family': 'serif', 'serif': ['Computer Modern Roman']}) plt.rc('text', usetex=True) plt.rcParams.update({ 'figure.titlesize': '28', 'axes.titlesize': '22', 'axes.titlepad': '10', 'legend.title_fontsize': '16', 'legend.fontsize': '14', 'axes.labelsize': '18', 'xtick.labelsize': '16', 'ytick.labelsize': '16', 'figure.dpi': 200, }) def log_pointclouds_to_rerun( dataset_name: str, datapoint_idx: Union[int, str], rgbs: torch.Tensor, depths: torch.Tensor, intrs: torch.Tensor, extrs: torch.Tensor, depths_conf: Optional[torch.Tensor] = None, conf_thrs: Optional[List[float]] = None, log_only_confident_pc: bool = False, radii: float = -2.45, fps: float = 30.0, bbox_crop: Optional[torch.Tensor] = None, # e.g., np.array([[-4, 4], [-3, 3.7], [1.2, 5.2]]) sphere_radius_crop: Optional[float] = None, # e.g., 6.0 sphere_center_crop: Optional[np.ndarray] = np.array([0, 0, 0]), log_rgb_image: bool = False, log_depthmap_as_image_v1: bool = False, log_depthmap_as_image_v2: bool = False, log_camera_frustrum: bool = True, log_rgb_pointcloud: bool = True, timesteps_to_log: Optional[List[int]] = None, ): # Set the up-axis for the world # Log coordinate axes for reference rr.set_time_seconds("frame", 0) B, V, T, _, H, W = rgbs.shape assert rgbs.shape == (B, V, T, 3, H, W) assert depths.shape == (B, V, T, 1, H, W) assert depths_conf is None or depths_conf.shape == (B, V, T, 1, H, W) assert intrs.shape == (B, V, T, 3, 3) assert extrs.shape == (B, V, T, 3, 4) assert B == 1 # Compute inverse intrinsics and extrinsics intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype) extrs_square = torch.eye(4).to(extrs.device)[None].repeat(B, V, T, 1, 1) extrs_square[:, :, :, :3, :] = extrs extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype) assert intrs_inv.shape == (B, V, T, 3, 3) assert extrs_inv.shape == (B, V, T, 4, 4) for v in range(V): # Iterate over views for t in range(T): # Iterate over frames if timesteps_to_log is not None and t not in timesteps_to_log: continue rr.set_time_seconds("frame", t / fps) # Log RGB image rgb_image = rgbs[0, v, t].permute(1, 2, 0).cpu().numpy() if log_rgb_image: rr.log(f"sequence-{datapoint_idx}/{dataset_name}/image/view-{v}/rgb", rr.Image(rgb_image)) # Log Depth map depth_map = depths[0, v, t, 0].cpu().numpy() if log_depthmap_as_image_v1: rr.log(f"sequence-{datapoint_idx}/{dataset_name}/image/view-{v}/depth", rr.DepthImage(depth_map, point_fill_ratio=0.2)) # Log Depth map as RGB d_min, d_max = depth_map.min(), depth_map.max() norm = mcolors.Normalize(vmin=d_min, vmax=d_max) turbo_cmap = cm.get_cmap("turbo") # "viridis", "plasma", etc. depth_color_rgba = turbo_cmap(norm(depth_map)) depth_color_rgb = (depth_color_rgba[..., :3] * 255).astype(np.uint8) if log_depthmap_as_image_v2: rr.log(f"sequence-{datapoint_idx}/{dataset_name}/image/view-{v}/deptha-as-rgb", rr.Image(depth_color_rgb)) # Log Camera K = intrs[0, v, t].cpu().numpy() world_T_cam = np.eye(4) world_T_cam[:3, :3] = extrs_inv[0, v, t, :3, :3].cpu().numpy() world_T_cam[:3, 3] = extrs_inv[0, v, t, :3, 3].cpu().numpy() if log_camera_frustrum: rr.log(f"sequence-{datapoint_idx}/{dataset_name}/image/view-{v}", rr.Pinhole(image_from_camera=K, width=W, height=H)) rr.log(f"sequence-{datapoint_idx}/{dataset_name}/image/view-{v}", rr.Transform3D(translation=world_T_cam[:3, 3], mat3x3=world_T_cam[:3, :3])) # Generate and log point cloud colored by RGB values # Compute 3D points from depth map y, x = np.indices((H, W)) homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T depth_values = depth_map.ravel() cam_coords = (intrs_inv[0, v, t].cpu().numpy() @ homo_pixel_coords) * depth_values cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1])))) world_coords = (world_T_cam @ cam_coords)[:3].T rgb_colors = rgb_image.reshape(-1, 3).astype(np.uint8) # Log point clouds if log_rgb_pointcloud: # Filter out points with zero depth valid_mask = depth_values > 0 # Filter out points outside this bbox # bbox_crop = np.array([[-4, 4], [-3, 3.7], [1.2, 5.2]]) if bbox_crop is not None: bbox_mask = ( (world_coords[..., 0] > bbox_crop[0, 0]) & (world_coords[..., 0] < bbox_crop[0, 1]) & (world_coords[..., 1] > bbox_crop[1, 0]) & (world_coords[..., 1] < bbox_crop[1, 1]) & (world_coords[..., 2] > bbox_crop[2, 0]) & (world_coords[..., 2] < bbox_crop[2, 1]) ) valid_mask = valid_mask & bbox_mask # Lightweight Kubric and DexYCB if sphere_radius_crop is not None: assert sphere_center_crop is not None sphere_mask = ((world_coords - sphere_center_crop) ** 2).sum(-1) < sphere_radius_crop ** 2 valid_mask = valid_mask & sphere_mask # Filter out points with confidence below threshold pc_name__mask__tuples = [] if not (log_only_confident_pc and depths_conf is not None): pc_name__mask__tuples += [("point_cloud", valid_mask)] if depths_conf is not None: confs = depths_conf[0, v, t, 0].cpu().numpy() assert conf_thrs is not None for thr in conf_thrs: name = f"point_cloud__conf-{thr}" mask = valid_mask & (confs.ravel() > thr) if (valid_mask == mask).all(): continue pc_name__mask__tuples += [(name, mask)] for pc_name, mask in pc_name__mask__tuples: rr.log(f"sequence-{datapoint_idx}/{dataset_name}/{pc_name}/view-{v}", rr.Points3D(world_coords[mask], colors=rgb_colors[mask], radii=radii)) def _log_tracks_to_rerun( tracks: np.ndarray, visibles: np.ndarray, query_timestep: np.ndarray, colors: np.ndarray, track_names=None, fps=30.0, entity_format_str="{}", log_points=True, points_radii=-3.6, log_line_strips=True, max_strip_length_past=10, max_strip_length_future=0, strips_radii=-1.8, log_error_lines=False, error_lines_radii=0.0042, error_lines_color=[1., 0., 0.], gt_for_error_lines=None, ) -> None: """ Log tracks to Rerun. Parameters: tracks: Shape (T, N, 3), the 3D trajectories of points. visibles: Shape (T, N), boolean visibility mask for each point at each timestep. query_timestep: Shape (T, N), the frame index after which the tracks start. colors: Shape (N, 4), RGBA colors for each point. """ T, N, _ = tracks.shape assert tracks.shape == (T, N, 3) assert visibles.shape == (T, N) assert query_timestep.shape == (N,) assert query_timestep.min() >= 0 assert query_timestep.max() < T assert colors.shape == (N, 4) for n in range(N): track_name = track_names[n] if track_names is not None else f"track-{n}" rr.log(entity_format_str.format(track_name), rr.Clear(recursive=True)) for t in range(query_timestep[n], T): # if t not in [0] + [T * (x + 1) // 3 - 1 for x in range(3)]: # if t not in [T - 1]: # continue rr.set_time_seconds("frame", t / fps) # Log the point (special handling for invisible points) if log_points: rr.log( entity_format_str.format(f"{track_name}/point"), rr.Points3D( positions=[tracks[t, n]], colors=[colors[n, :3]] if visibles[t, n] else [colors[n, :3] * 0.7], radii=points_radii, ), ) # Log line segments for visible tracks if log_line_strips and t > query_timestep[n]: strip_t_start = max(t - max_strip_length_past, query_timestep[n].item()) strip_t_end = min(t + max_strip_length_future, T - 1) strips = np.stack([ tracks[strip_t_start:strip_t_end, n], tracks[strip_t_start + 1:strip_t_end + 1, n], ], axis=-2) strips_visibility = visibles[strip_t_start + 1:strip_t_end + 1, n] strips_colors = np.where( strips_visibility[:, None], colors[None, n, :3], colors[None, n, :3] * 0.7, ) rr.log( entity_format_str.format(f"{track_name}/line"), rr.LineStrips3D(strips=strips, colors=strips_colors, radii=strips_radii), ) if log_error_lines: assert gt_for_error_lines is not None strips = np.stack([ tracks[t, n], gt_for_error_lines[t, n], ], axis=-2) rr.log( entity_format_str.format(f"{track_name}/error"), rr.LineStrips3D(strips=strips, colors=error_lines_color, radii=error_lines_radii), ) def _log_tracks_to_rerun_lightweight( tracks: np.ndarray, visibles: np.ndarray, query_timestep: np.ndarray, colors: np.ndarray, track_names=None, fps=30.0, entity_format_str="{}", log_points=True, points_radii=0.01, log_line_strips=True, max_strip_length_past=24, max_strip_length_future=0, strips_radii=0.0042, log_error_lines=False, error_lines_radii=0.0010, error_lines_color=[1., 0., 0.], gt_for_error_lines=None, ) -> None: """ Log tracks to Rerun. Parameters: tracks: Shape (T, N, 3), the 3D trajectories of points. visibles: Shape (T, N), boolean visibility mask for each point at each timestep. query_timestep: Shape (T, N), the frame index after which the tracks start. colors: Shape (N, 4), RGBA colors for each point. """ T, N, _ = tracks.shape assert tracks.shape == (T, N, 3) assert visibles.shape == (T, N) assert query_timestep.shape == (N,) assert query_timestep.min() >= 0 assert query_timestep.max() < T assert colors.shape == (N, 4) for t in range(T): rr.set_time_seconds("frame", t / fps) points_list, points_colors = [], [] strips_list, strips_colors_list = [], [] errors_list = [] for n in range(N): if t > query_timestep[n]: strip_t_start = max(t - max_strip_length_past, query_timestep[n].item()) strip_t_end = min(t + max_strip_length_future, T - 1) strips = np.stack([ tracks[strip_t_start:strip_t_end, n], tracks[strip_t_start + 1:strip_t_end + 1, n], ], axis=-2) strips_visibility = visibles[strip_t_start + 1:strip_t_end + 1, n] strips_colors = np.where( strips_visibility[:, None], colors[None, n, :3], colors[None, n, :3] * 0.7, ) if log_line_strips: strips_list.append(strips) strips_colors_list.append(strips_colors) for t_ in range(strip_t_start, strip_t_end + 1): if log_points: points_list += [tracks[t_, n]] points_colors += [colors[n, :3]] if visibles[t_, n] else [colors[n, :3] * 0.7] if log_error_lines: assert gt_for_error_lines is not None error_lines = np.stack([ tracks[t_, n], gt_for_error_lines[t_, n], ], axis=-2) errors_list.append(error_lines) if log_points and len(points_list) > 0: rr.log( entity_format_str.format(f"points"), rr.Points3D( positions=points_list, colors=points_colors, radii=points_radii, ), ) if log_line_strips and len(strips_list) > 0: rr.log( entity_format_str.format(f"trajectories"), rr.LineStrips3D( strips=np.concatenate(strips_list, axis=0), colors=np.concatenate(strips_colors_list, axis=0), radii=strips_radii, ), ) if log_error_lines and len(errors_list) > 0: rr.log( entity_format_str.format(f"errors"), rr.LineStrips3D( strips=np.stack(errors_list), colors=error_lines_color, radii=error_lines_radii, ), ) def log_tracks_to_rerun( dataset_name: str, datapoint_idx: Union[int, str], predictor_name: str, gt_trajectories_3d_worldspace: Optional[torch.Tensor], gt_visibilities_any_view: Optional[torch.Tensor], query_points_3d: torch.Tensor, pred_trajectories: torch.Tensor, pred_visibilities: torch.Tensor, per_track_results: Optional[Dict[str, Any]] = None, radii_scale: float = 1.0, fps: float = 30.0, sphere_radius_crop: Optional[float] = None, # e.g., 6.0 sphere_center_crop: Optional[np.ndarray] = np.array([0, 0, 0]), log_per_interval_results: bool = False, max_tracks_to_log: Optional[int] = None, track_batch_size: int = 100, method_id: Optional[int] = None, color_per_method_id: Optional[Dict[int, tuple]] = None, # { 0: (46, 204, 113), ... } memory_lightweight_logging: bool = True, ): # Prepare track data gt_tracks = gt_trajectories_3d_worldspace[0].cpu().numpy() if gt_trajectories_3d_worldspace is not None else None gt_vis = gt_visibilities_any_view[0].cpu().numpy() if gt_visibilities_any_view is not None else None pred_tracks = pred_trajectories[0].cpu().numpy() pred_vis = pred_visibilities[0].cpu().numpy() query_timestep = query_points_3d[0, :, 0].cpu().numpy().astype(int) T, N, _ = pred_tracks.shape assert gt_tracks is None or gt_tracks.shape == (T, N, 3) assert gt_vis is None or gt_vis.shape == (T, N) assert pred_tracks.shape == (T, N, 3) assert pred_vis.shape == (T, N) assert query_timestep.shape == (N,) if sphere_radius_crop is not None: pred_tracks = pred_tracks.copy() assert sphere_center_crop is not None dist = np.linalg.norm(pred_tracks - sphere_center_crop, axis=-1, keepdims=True) mask = dist > sphere_radius_crop pred_tracks[mask[..., 0]] = ( sphere_center_crop + sphere_radius_crop * (pred_tracks[mask[..., 0]] - sphere_center_crop) / dist[mask][..., None] ) if gt_tracks is not None: gt_tracks = gt_tracks.copy() assert sphere_center_crop is not None dist = np.linalg.norm(gt_tracks - sphere_center_crop, axis=-1, keepdims=True) mask = dist > sphere_radius_crop gt_tracks[mask[..., 0]] = ( sphere_center_crop + sphere_radius_crop * (gt_tracks[mask[..., 0]] - sphere_center_crop) / dist[mask][..., None] ) # Last timestamp determines track color (unless method_id is specified) final_xyz = gt_tracks[-1] if gt_tracks is not None else pred_tracks[-1] # (N, 3) pca = PCA(n_components=1).fit_transform(final_xyz) # Apply PCA to spread values across 1D axis pca_normalized = (pca - pca.min()) / (pca.max() - pca.min() + 1e-8) # Normalize to [0, 1] cmap = matplotlib.colormaps["gist_rainbow"] colors = cmap(pca_normalized[:, 0]) # Map to colormap assert colors.shape == (N, 4) # If method_id is specified, use fixed colors # Fixed color mapping per method if color_per_method_id is None: color_per_method_id = { 0: (46, 204, 113), 1: (52, 152, 219), 2: (241, 196, 15), 3: (155, 89, 182), 4: (230, 126, 34), 5: (26, 188, 156), } if method_id is not None: assert method_id in color_per_method_id base_rgb = np.array(color_per_method_id[method_id]) / 255.0 colors = np.tile(np.append(base_rgb, 1.0), (N, 1)) assert colors.shape == (N, 4) # Log the tracks common_kwargs = { "points_radii": -3.6 * radii_scale, "strips_radii": -1.8 * radii_scale, "error_lines_radii": 0.0042 * radii_scale, "fps": fps, } if max_tracks_to_log: N = min(N, max_tracks_to_log) for tracks_batch_start in range(0, N, track_batch_size): tracks_batch_end = min(tracks_batch_start + track_batch_size, N) entity_format_strs = [] entity_format_strs += [ f"sequence-{datapoint_idx}/tracks/{{track_name}}/{tracks_batch_start:02d}-{tracks_batch_end:02d}/{{{{}}}}" ] if not memory_lightweight_logging: entity_format_strs += [ f"sequence-{datapoint_idx}/tracks/all/{tracks_batch_start:02d}-{tracks_batch_end:02d}/{{{{}}}}/{{track_name}}" ] for entity_format_str in entity_format_strs: log_tracks_fn = _log_tracks_to_rerun if not memory_lightweight_logging else _log_tracks_to_rerun_lightweight # Log the GT tracks if gt_tracks is not None and (method_id is None or method_id == 0): log_tracks_fn( tracks=gt_tracks[:, tracks_batch_start:tracks_batch_end], visibles=gt_vis[:, tracks_batch_start:tracks_batch_end], query_timestep=query_timestep[tracks_batch_start:tracks_batch_end], colors=colors[tracks_batch_start:tracks_batch_end] * 0 + np.array([1, 1, 1, 1]), track_names=[f"track-{i:02d}" for i in range(tracks_batch_start, tracks_batch_end)], entity_format_str=entity_format_str.format(track_name=f"gt"), **common_kwargs, ) # Log the predicted tracks log_tracks_fn( tracks=pred_tracks[:, tracks_batch_start:tracks_batch_end], visibles=pred_vis[:, tracks_batch_start:tracks_batch_end], query_timestep=query_timestep[tracks_batch_start:tracks_batch_end], colors=colors[tracks_batch_start:tracks_batch_end], track_names=[f"track-{i:02d}" for i in range(tracks_batch_start, tracks_batch_end)], entity_format_str=entity_format_str.format(track_name=f"pred--{predictor_name}"), log_error_lines=gt_tracks is not None, gt_for_error_lines=gt_tracks[:, tracks_batch_start:tracks_batch_end] if gt_tracks is not None else None, **common_kwargs, ) if log_per_interval_results and per_track_results is not None: intervals = [(i / 10 * 100, (i + 1) / 10 * 100) for i in range(10)] # Intervals for 0-10%, ..., 90-100% intervals += [(0, 33), (33, 66), (66, 100)] # Intervals for lower, middle, upper third else: intervals = [] for lower, upper in intervals: for point_type in ["dynamic", "very_dynamic", "static", "any"]: if f"all_{point_type}" not in per_track_results: continue if lower == 0: # Special case to include 0 track_indices = per_track_results[f"all_{point_type}"].indices[ (per_track_results[f"all_{point_type}"].average_pts_within_thresh_per_track >= lower) & (per_track_results[f"all_{point_type}"].average_pts_within_thresh_per_track <= upper) ] else: track_indices = per_track_results[f"all_{point_type}"].indices[ (per_track_results[f"all_{point_type}"].average_pts_within_thresh_per_track > lower) & (per_track_results[f"all_{point_type}"].average_pts_within_thresh_per_track <= upper) ] if len(track_indices) == 0: continue entity_format_str = f"sequence-{datapoint_idx}/tracks/location-accuracy-for-{point_type}/{int(lower)}-{int(upper)}-percent-{{track_name}}/{{{{}}}}" # Log the GT tracks _log_tracks_to_rerun( tracks=gt_tracks[:, track_indices], visibles=gt_vis[:, track_indices], query_timestep=query_timestep[track_indices], colors=colors[track_indices] * 0 + np.array([1, 1, 1, 1]), track_names=[f"track-{i:02d}" for i in track_indices], entity_format_str=entity_format_str.format(track_name=f"gt"), **common_kwargs, ) # Log the predicted tracks _log_tracks_to_rerun( tracks=pred_tracks[:, track_indices], visibles=pred_vis[:, track_indices], query_timestep=query_timestep[track_indices], colors=colors[track_indices], track_names=[f"track-{i:02d}" for i in track_indices], entity_format_str=entity_format_str.format(track_name=f"pred-{dataset_name}"), log_error_lines=True, gt_for_error_lines=gt_tracks[:, track_indices], **common_kwargs, ) ================================================ FILE: requirements.full.txt ================================================ # Minimal runtime numpy==1.24.3 huggingface-hub==0.30.2 easydict==1.13 pandas==2.2.2 einops==0.7.0 opencv-python==4.11.0.86 matplotlib==3.8.3 seaborn==0.13.2 scikit-image==0.22.0 scikit-learn==1.4.1.post1 pypng==0.20220715.0 kornia==0.7.3 flow-vis==0.1 moviepy==1.0.3 mediapy==1.2.0 rerun-sdk==0.21.0 # Training / baselines torchdata==0.11.0 lightning==2.4.0 timm==0.6.7 prettytable==3.10.0 # tensorflow==2.12.1 # tensorflow-datasets==4.9.8 # tensorflow-graphics==2021.12.3 tensorboard==2.12.3 tqdm==4.67.1 gpustat==1.1.1 hydra-core==1.3.2 wandb==0.19.9 rich==14.0.0 ================================================ FILE: requirements.txt ================================================ # Minimal dependencies numpy==1.24.3 huggingface-hub==0.30.2 easydict==1.13 pandas==2.2.2 einops==0.7.0 opencv-python==4.11.0.86 matplotlib==3.8.3 seaborn==0.13.2 scikit-image==0.22.0 scikit-learn==1.4.1.post1 pypng==0.20220715.0 kornia==0.7.3 flow-vis==0.1 moviepy==1.0.3 mediapy==1.2.0 rerun-sdk==0.21.0 ================================================ FILE: scripts/4ddress_preprocessing.py ================================================ """ First download the dataset. You'll have to fill in an online ETH form and then wait for a few days to get a temporary access code over email. I used the following sequence of commands to download and unpack the data into the expected structure. You can probably replace the `dt=...` with your access token that you can probably find in the access URL (or otherwise in the page source of the download page that will be linked). Note that you don't need to download all the data if you don't need it, e.g., maybe you just want to download a small sample. Note also that in the commands below, I didn't delete the `*.tar.gz` and `*.zip` files, but you can do so if you'd like. Note also that the extraction of 00135 had some unexpected structure in that some takes were in the root of 00135 instead of subfolders, but I ignored that. ```bash wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00122_Inner.tar.gz' -O 00122_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00122_Outer.tar.gz' -O 00122_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00123_Inner.tar.gz' -O 00123_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00123_Outer.tar.gz' -O 00123_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00127_Inner.tar.gz' -O 00127_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00127_Outer.tar.gz' -O 00127_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00129_Inner.tar.gz' -O 00129_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00129_Outer.tar.gz' -O 00129_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00134_Inner.tar.gz' -O 00134_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00134_Outer.tar.gz' -O 00134_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00135_Inner.tar.gz' -O 00135_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00135_Outer_1.tar.gz' -O 00135_Outer_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00135_Outer_2.tar.gz' -O 00135_Outer_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00136_Inner.tar.gz' -O 00136_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00136_Outer_1.tar.gz' -O 00136_Outer_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00136_Outer_2.tar.gz' -O 00136_Outer_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00137_Inner_1.tar.gz' -O 00137_Inner_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00137_Inner_2.tar.gz' -O 00137_Inner_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00137_Outer_1.tar.gz' -O 00137_Outer_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00137_Outer_2.tar.gz' -O 00137_Outer_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00140_Inner_1.tar.gz' -O 00140_Inner_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00140_Inner_2.tar.gz' -O 00140_Inner_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00140_Outer_1.tar.gz' -O 00140_Outer_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00140_Outer_2.tar.gz' -O 00140_Outer_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00147_Inner.tar.gz' -O 00147_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00147_Outer.tar.gz' -O 00147_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00148_Inner.tar.gz' -O 00148_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00148_Outer.tar.gz' -O 00148_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00149_Inner_1.tar.gz' -O 00149_Inner_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00149_Inner_2.tar.gz' -O 00149_Inner_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00149_Outer_1.tar.gz' -O 00149_Outer_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00149_Outer_2.tar.gz' -O 00149_Outer_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00151_Inner.tar.gz' -O 00151_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00151_Outer.tar.gz' -O 00151_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00152_Inner.tar.gz' -O 00152_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00152_Outer_1.tar.gz' -O 00152_Outer_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00152_Outer_2.tar.gz' -O 00152_Outer_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00154_Inner.tar.gz' -O 00154_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00154_Outer_1.tar.gz' -O 00154_Outer_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00154_Outer_2.tar.gz' -O 00154_Outer_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00156_Inner.tar.gz' -O 00156_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00156_Outer.tar.gz' -O 00156_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00160_Inner.tar.gz' -O 00160_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00160_Outer.tar.gz' -O 00160_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00163_Inner_1.tar.gz' -O 00163_Inner_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00163_Inner_2.tar.gz' -O 00163_Inner_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00163_Outer.tar.gz' -O 00163_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00167_Inner.tar.gz' -O 00167_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00167_Outer.tar.gz' -O 00167_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00168_Inner.tar.gz' -O 00168_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00168_Outer_1.tar.gz' -O 00168_Outer_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00168_Outer_2.tar.gz' -O 00168_Outer_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00169_Inner.tar.gz' -O 00169_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00169_Outer.tar.gz' -O 00169_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00170_Inner_1.tar.gz' -O 00170_Inner_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00170_Inner_2.tar.gz' -O 00170_Inner_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00170_Outer.tar.gz' -O 00170_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00174_Inner.tar.gz' -O 00174_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00174_Outer.tar.gz' -O 00174_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00175_Inner_1.tar.gz' -O 00175_Inner_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00175_Inner_2.tar.gz' -O 00175_Inner_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00175_Outer_1.tar.gz' -O 00175_Outer_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00175_Outer_2.tar.gz' -O 00175_Outer_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00176_Inner.tar.gz' -O 00176_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00176_Outer.tar.gz' -O 00176_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00179_Inner.tar.gz' -O 00179_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00179_Outer.tar.gz' -O 00179_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00180_Inner.tar.gz' -O 00180_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00180_Outer.tar.gz' -O 00180_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00185_Inner_1.tar.gz' -O 00185_Inner_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00185_Inner_2.tar.gz' -O 00185_Inner_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00185_Outer_1.tar.gz' -O 00185_Outer_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00185_Outer_2.tar.gz' -O 00185_Outer_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00187_Inner_1.tar.gz' -O 00187_Inner_1.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00187_Inner_2.tar.gz' -O 00187_Inner_2.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00187_Outer.tar.gz' -O 00187_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00188_Inner.tar.gz' -O 00188_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00188_Outer.tar.gz' -O 00188_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00190_Inner.tar.gz' -O 00190_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00190_Outer.tar.gz' -O 00190_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00191_Inner.tar.gz' -O 00191_Inner.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00191_Outer.tar.gz' -O 00191_Outer.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/Overview.tar.gz' -O Overview.tar.gz wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/README.md' -O README.md wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/Template.tar.gz' -O Template.tar.gz mkdir benchmark wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/Benchmark/Clothing_Recon_inner.zip' -O benchmark/Clothing_Recon_inner.zip wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/Benchmark/Clothing_Recon_outer.zip' -O benchmark/Clothing_Recon_outer.zip wget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/Benchmark/Human_Recon.zip' -O benchmark/Human_Recon.zip mkdir -p 00122 00123 00127 00129 00134 00135 00136 00137 00140 00147 00148 00149 00151 00152 00154 00156 00160 00163 00167 00168 00169 00170 00174 00175 00176 00179 00180 00185 00187 00188 00190 00191 tar -xvzf 00122_Inner.tar.gz -C 00122 tar -xvzf 00122_Outer.tar.gz -C 00122 tar -xvzf 00123_Inner.tar.gz -C 00123 tar -xvzf 00123_Outer.tar.gz -C 00123 tar -xvzf 00127_Inner.tar.gz -C 00127 tar -xvzf 00127_Outer.tar.gz -C 00127 tar -xvzf 00129_Inner.tar.gz -C 00129 tar -xvzf 00129_Outer.tar.gz -C 00129 tar -xvzf 00134_Inner.tar.gz -C 00134 tar -xvzf 00134_Outer.tar.gz -C 00134 tar -xvzf 00135_Inner.tar.gz -C 00135 tar -xvzf 00135_Outer_1.tar.gz -C 00135 tar -xvzf 00135_Outer_2.tar.gz -C 00135 tar -xvzf 00136_Inner.tar.gz -C 00136 tar -xvzf 00136_Outer_1.tar.gz -C 00136 tar -xvzf 00136_Outer_2.tar.gz -C 00136 tar -xvzf 00137_Inner_1.tar.gz -C 00137 tar -xvzf 00137_Inner_2.tar.gz -C 00137 tar -xvzf 00137_Outer_1.tar.gz -C 00137 tar -xvzf 00137_Outer_2.tar.gz -C 00137 tar -xvzf 00140_Inner_1.tar.gz -C 00140 tar -xvzf 00140_Inner_2.tar.gz -C 00140 tar -xvzf 00140_Outer_1.tar.gz -C 00140 tar -xvzf 00140_Outer_2.tar.gz -C 00140 tar -xvzf 00147_Inner.tar.gz -C 00147 tar -xvzf 00147_Outer.tar.gz -C 00147 tar -xvzf 00148_Inner.tar.gz -C 00148 tar -xvzf 00148_Outer.tar.gz -C 00148 tar -xvzf 00149_Inner_1.tar.gz -C 00149 tar -xvzf 00149_Inner_2.tar.gz -C 00149 tar -xvzf 00149_Outer_1.tar.gz -C 00149 tar -xvzf 00149_Outer_2.tar.gz -C 00149 tar -xvzf 00151_Inner.tar.gz -C 00151 tar -xvzf 00151_Outer.tar.gz -C 00151 tar -xvzf 00152_Inner.tar.gz -C 00152 tar -xvzf 00152_Outer_1.tar.gz -C 00152 tar -xvzf 00152_Outer_2.tar.gz -C 00152 tar -xvzf 00154_Inner.tar.gz -C 00154 tar -xvzf 00154_Outer_1.tar.gz -C 00154 tar -xvzf 00154_Outer_2.tar.gz -C 00154 tar -xvzf 00156_Inner.tar.gz -C 00156 tar -xvzf 00156_Outer.tar.gz -C 00156 tar -xvzf 00160_Inner.tar.gz -C 00160 tar -xvzf 00160_Outer.tar.gz -C 00160 tar -xvzf 00163_Inner_1.tar.gz -C 00163 tar -xvzf 00163_Inner_2.tar.gz -C 00163 tar -xvzf 00163_Outer.tar.gz -C 00163 tar -xvzf 00167_Inner.tar.gz -C 00167 tar -xvzf 00167_Outer.tar.gz -C 00167 tar -xvzf 00168_Inner.tar.gz -C 00168 tar -xvzf 00168_Outer_1.tar.gz -C 00168 tar -xvzf 00168_Outer_2.tar.gz -C 00168 tar -xvzf 00169_Inner.tar.gz -C 00169 tar -xvzf 00169_Outer.tar.gz -C 00169 tar -xvzf 00170_Inner_1.tar.gz -C 00170 tar -xvzf 00170_Inner_2.tar.gz -C 00170 tar -xvzf 00170_Outer.tar.gz -C 00170 tar -xvzf 00174_Inner.tar.gz -C 00174 tar -xvzf 00174_Outer.tar.gz -C 00174 tar -xvzf 00175_Inner_1.tar.gz -C 00175 tar -xvzf 00175_Inner_2.tar.gz -C 00175 tar -xvzf 00175_Outer_1.tar.gz -C 00175 tar -xvzf 00175_Outer_2.tar.gz -C 00175 tar -xvzf 00176_Inner.tar.gz -C 00176 tar -xvzf 00176_Outer.tar.gz -C 00176 tar -xvzf 00179_Inner.tar.gz -C 00179 tar -xvzf 00179_Outer.tar.gz -C 00179 tar -xvzf 00180_Inner.tar.gz -C 00180 tar -xvzf 00180_Outer.tar.gz -C 00180 tar -xvzf 00185_Inner_1.tar.gz -C 00185 tar -xvzf 00185_Inner_2.tar.gz -C 00185 tar -xvzf 00185_Outer_1.tar.gz -C 00185 tar -xvzf 00185_Outer_2.tar.gz -C 00185 tar -xvzf 00187_Inner_1.tar.gz -C 00187 tar -xvzf 00187_Inner_2.tar.gz -C 00187 tar -xvzf 00187_Outer.tar.gz -C 00187 tar -xvzf 00188_Inner.tar.gz -C 00188 tar -xvzf 00188_Outer.tar.gz -C 00188 tar -xvzf 00190_Inner.tar.gz -C 00190 tar -xvzf 00190_Outer.tar.gz -C 00190 tar -xvzf 00191_Inner.tar.gz -C 00191 tar -xvzf 00191_Outer.tar.gz -C 00191 tar -xvzf Overview.tar.gz tar -xvzf Template.tar.gz cd benchmark unzip Clothing_Recon_inner.zip unzip Clothing_Recon_outer.zip unzip Human_Recon.zip ``` With the data downloaded, you can run the script: `python -m scripts.4ddress_preprocessing`. I create a subselection of the sequences as: ```bash SRC=datasets/4d-dress-processed-resized-512 DST=datasets/4d-dress-processed-resized-512-selection mkdir ${DST} cp ${SRC}/00129_Inner_Take3.pkl ${DST}/00129_Inner_Take3_happy.pkl cp ${SRC}/00129_Inner_Take4.pkl ${DST}/00129_Inner_Take4_stretch.pkl cp ${SRC}/00129_Inner_Take5.pkl ${DST}/00129_Inner_Take5_balerina.pkl cp ${SRC}/00129_Outer_Take13.pkl ${DST}/00129_Outer_Take13_kolo.pkl cp ${SRC}/00140_Inner_Take8.pkl ${DST}/00140_Inner_Take8_football.pkl cp ${SRC}/00140_Outer_Take13.pkl ${DST}/00140_Outer_Take13_stretch.pkl cp ${SRC}/00140_Outer_Take15.pkl ${DST}/00140_Outer_Take15_kicks.pkl cp ${SRC}/00147_Inner_Take10.pkl ${DST}/00147_Inner_Take10_basketball.pkl cp ${SRC}/00147_Inner_Take11.pkl ${DST}/00147_Inner_Take11_football.pkl cp ${SRC}/00147_Outer_Take16.pkl ${DST}/00147_Outer_Take16_dance.pkl cp ${SRC}/00147_Outer_Take17.pkl ${DST}/00147_Outer_Take17_avatar.pkl cp ${SRC}/00174_Inner_Take9.pkl ${DST}/00174_Inner_Take9_stretching.pkl cp ${SRC}/00175_Inner_Take6.pkl ${DST}/00175_Inner_Take6_basketball.pkl ``` """ import os import pickle from typing import Optional import cv2 import numpy as np import rerun as rr import torch import tqdm from PIL import Image from pytorch3d.renderer import ( PerspectiveCameras, MeshRasterizer, RasterizationSettings, ) from pytorch3d.structures import Meshes from scipy.spatial.transform import Rotation from mvtracker.datasets.utils import transform_scene def load_pickle(p): with open(p, "rb") as f: return pickle.load(f) def save_pickle(p, data): with open(p, "wb") as f: pickle.dump(data, f) def load_image(path): return np.array(Image.open(path)) def extract_4d_dress_data( dataset_root: str, subject_name: str, outfit_name: str, take_name, save_pkl_path, downscaled_longerside: Optional[int] = None, save_rerun_viz: bool = True, stream_rerun_viz: bool = False, skip_if_output_exists: bool = False, ): # Skip if output exists if skip_if_output_exists and os.path.exists(save_pkl_path): print(f"Skipping {save_pkl_path} since it already exists") print() return save_pkl_path else: print(f"Processing {save_pkl_path}...") base_dir = os.path.join(dataset_root, subject_name, outfit_name, take_name) capture_dir = os.path.join(base_dir, "Capture") mesh_dir = os.path.join(base_dir, "Meshes_pkl") basic_info = load_pickle(os.path.join(base_dir, "basic_info.pkl")) scan_frames = basic_info['scan_frames'] cameras = load_pickle(os.path.join(capture_dir, "cameras.pkl")) cam_names = sorted(list(cameras.keys())) # Prepare final structure rgbs, intrs, extrs, depths = {}, {}, {}, {} for cam_name in cam_names: rgbs[cam_name] = [] depths[cam_name] = [] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") for frame in tqdm.tqdm(scan_frames, desc="Extracting frame data"): mesh_path = os.path.join(mesh_dir, f"mesh-f{frame}.pkl") mesh_data = load_pickle(mesh_path) vertices = mesh_data["vertices"] faces = mesh_data["faces"] verts = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0) faces = torch.tensor(faces, dtype=torch.int64, device=device).unsqueeze(0) mesh = Meshes(verts=verts, faces=faces) for cam_name in cam_names: cam_path = os.path.join(capture_dir, cam_name) img_path = os.path.join(cam_path, "images", f"capture-f{frame}.png") if not os.path.exists(img_path): continue image = load_image(img_path) h, w = image.shape[:2] intr = cameras[cam_name]['intrinsics'].copy() extr = cameras[cam_name]['extrinsics'].copy() if downscaled_longerside is not None: scale = downscaled_longerside / max(h, w) h, w = int(h * scale), int(w * scale) image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA) intr[:2] *= scale if cam_name not in intrs: intrs[cam_name] = intr extrs[cam_name] = extr rgbs[cam_name].append(image) # Convert intrinsics to normalized device coords fx, fy = intr[0, 0], intr[1, 1] cx, cy = intr[0, 2], intr[1, 2] R = extr[:3, :3] T = extr[:3, 3] R = R.T R = R @ np.diag(np.array([-1, -1, 1.])) # Flip the x and y axes (or multiply f by -1) T = T @ np.diag(np.array([-1, -1, 1.])) # Flip the x and y axes (or multiply f by -1) cameras_p3d = PerspectiveCameras( focal_length=torch.tensor([[fx, fy]], dtype=torch.float32, device=device), principal_point=torch.tensor([[cx, cy]], dtype=torch.float32, device=device), R=torch.tensor(R, dtype=torch.float32, device=device).unsqueeze(0), T=torch.tensor(T, dtype=torch.float32, device=device).unsqueeze(0), image_size=torch.tensor([[h, w]], dtype=torch.float32, device=device), in_ndc=False, device=device, ) raster_settings = RasterizationSettings( image_size=(h, w), blur_radius=0.0, faces_per_pixel=1, bin_size=0 ) rasterizer = MeshRasterizer(cameras=cameras_p3d, raster_settings=raster_settings) fragments = rasterizer(mesh) zbuf = fragments.zbuf.squeeze().cpu().numpy() zbuf[np.isnan(zbuf)] = 0.0 depths[cam_name].append(zbuf) for cam_name in cam_names: if rgbs[cam_name]: rgbs[cam_name] = np.stack(rgbs[cam_name]).transpose(0, 3, 1, 2) # T, C, H, W depths[cam_name] = np.stack(depths[cam_name]) # T, H, W # Rotate the scene to have the ground at z=0 rot_x = Rotation.from_euler('x', 90, degrees=True).as_matrix() rot_y = Rotation.from_euler('y', 0, degrees=True).as_matrix() rot_z = Rotation.from_euler('z', 0, degrees=True).as_matrix() rot = torch.from_numpy(rot_z @ rot_y @ rot_x) translation = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) for cam_name in cam_names: extrs[cam_name] = transform_scene( 1, rot, translation, None, torch.from_numpy(extrs[cam_name][None, None]), )[1][0, 0].numpy() # Check shapes n_frames, _, h, w = rgbs[cam_names[0]].shape for cam_name in cam_names: assert rgbs[cam_name].shape == (n_frames, 3, h, w) assert intrs[cam_name].shape == (3, 3) assert extrs[cam_name].shape == (3, 4) # Save processed output to a pickle file save_pickle(save_pkl_path, dict( rgbs=rgbs, intrs=intrs, extrs=extrs, depths=depths, ego_cam_name=None, )) # Visualize the data sample using rerun rerun_modes = [] if stream_rerun_viz: rerun_modes += ["stream"] if save_rerun_viz: rerun_modes += ["save"] for rerun_mode in rerun_modes: rr.init(f"3dpt", recording_id="v0.16") if rerun_mode == "stream": rr.connect_tcp() rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True) rr.set_time_seconds("frame", 0) rr.log( "world/xyz", rr.Arrows3D( vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]], colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]], ), ) rr.log( "mesh", rr.Mesh3D( vertex_positions=vertices.astype(np.float32), # (N, 3) triangle_indices=faces.cpu().numpy().reshape(-1, 3).astype(np.int32), # (M, 3) albedo_factor=[200, 200, 255], # Optional color ), ) fps = 30 for frame_idx in range(n_frames): rr.set_time_seconds("frame", frame_idx / fps) for cam_name in cam_names: extr = extrs[cam_name] intr = intrs[cam_name] img = rgbs[cam_name][frame_idx].transpose(1, 2, 0).astype(np.uint8) depth = depths[cam_name][frame_idx] h, w = img.shape[:2] fx, fy = intr[0, 0], intr[1, 1] cx, cy = intr[0, 2], intr[1, 2] # Camera pose T = np.eye(4) T[:3, :] = extr world_T_cam = np.linalg.inv(T) rr.log(f"{cam_name}/image", rr.Transform3D( translation=world_T_cam[:3, 3], mat3x3=world_T_cam[:3, :3], )) rr.log(f"{cam_name}/image", rr.Pinhole( image_from_camera=intr, width=w, height=h )) rr.log(f"{cam_name}/image", rr.Image(img)) rr.log(f"{cam_name}/depth", rr.Transform3D( translation=world_T_cam[:3, 3], mat3x3=world_T_cam[:3, :3], )) rr.log(f"{cam_name}/depth", rr.Pinhole( image_from_camera=intr, width=w, height=h )) rr.log(f"{cam_name}/depth", rr.DepthImage(depth, meter=1.0, colormap="viridis")) # Unproject depth to point cloud y, x = np.meshgrid(np.arange(h), np.arange(w), indexing="ij") z = depth valid = z > 0 x = x[valid] y = y[valid] z = z[valid] X = (x - cx) * z / fx Y = (y - cy) * z / fy pts_cam = np.stack([X, Y, z], axis=-1) # Transform to world R = world_T_cam[:3, :3] t = world_T_cam[:3, 3] pts_world = pts_cam @ R.T + t # Color colors = img[y, x] rr.log(f"point_cloud/{cam_name}", rr.Points3D(positions=pts_world, colors=colors)) if rerun_mode == "save": base, name = os.path.split(save_pkl_path) name_no_ext = os.path.splitext(name)[0] save_rrd_path = os.path.join(base, f"rerun__{name_no_ext}.rrd") rr.save(save_rrd_path) print(f"Saved rerun viz to {os.path.abspath(save_rrd_path)}") print(f"Done with {save_pkl_path}.") print() def crete_overview_pngs(dataset_root, subject_names, overview_dir): os.makedirs(overview_dir, exist_ok=True) for subject_name in tqdm.tqdm(subject_names): if "." in subject_name: continue for outfit_name in os.listdir(os.path.join(dataset_root, subject_name)): if outfit_name not in ["Inner", "Outer"]: continue for take_name in os.listdir(os.path.join(dataset_root, subject_name, outfit_name)): if "." in take_name: continue cam_dir = os.path.join(dataset_root, subject_name, outfit_name, take_name, "Capture") cam_names = sorted([name for name in os.listdir(cam_dir) if "." not in name]) first_cam = cam_names[0] img_folder = os.path.join(cam_dir, first_cam, "images") images = sorted(os.listdir(img_folder)) last_img = images[-1] img_path = os.path.join(img_folder, last_img) # Load image and overlay info from PIL import Image, ImageDraw, ImageFont img = Image.open(img_path).convert("RGB") draw = ImageDraw.Draw(img) text = ( f"{subject_name} / {outfit_name} / {take_name}\n" f"Frame: {last_img.split('-')[-1].split('.')[0]}\n" f"Cams: {cam_names}" ) try: font = ImageFont.truetype("DejaVuSans-Bold.ttf", 16) except: font = ImageFont.load_default() draw.text((10, 10), text, fill="white", font=font) # Save image overview_path = os.path.join(overview_dir, f"{subject_name}__{outfit_name}__{take_name}.png") img.save(overview_path) print(f"Saved overview to {overview_path}") def crete_overview_mp4s(dataset_root, subject_names, overview_dir, fps=30): os.makedirs(overview_dir, exist_ok=True) for subject_name in tqdm.tqdm(subject_names): if "." in subject_name: continue for outfit_name in os.listdir(os.path.join(dataset_root, subject_name)): if outfit_name not in ["Inner", "Outer"]: continue for take_name in os.listdir(os.path.join(dataset_root, subject_name, outfit_name)): if "." in take_name: continue cam_dir = os.path.join(dataset_root, subject_name, outfit_name, take_name, "Capture") cam_names = sorted([name for name in os.listdir(cam_dir) if "." not in name]) first_cam = cam_names[0] img_folder = os.path.join(cam_dir, first_cam, "images") images = sorted(os.listdir(img_folder)) # Load first frame to get size first_img = cv2.imread(os.path.join(img_folder, images[0])) height, width = first_img.shape[:2] video_path = os.path.join( overview_dir, f"{subject_name}__{outfit_name}__{take_name}.mp4" ) writer = cv2.VideoWriter( video_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height) ) for img_name in images: img_path = os.path.join(img_folder, img_name) img = cv2.imread(img_path) overlay_text = ( f"{subject_name} / {outfit_name} / {take_name} | " f"Frame: {img_name.split('-')[-1].split('.')[0]} | " f"Cams: {', '.join(cam_names)}" ) cv2.putText( img, overlay_text, (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, lineType=cv2.LINE_AA ) writer.write(img) writer.release() print(f"Saved video to {video_path}") if __name__ == "__main__": dataset_root = "datasets/4d-dress" output_root = "datasets/4d-dress-processed" create_overviews = True # Creates an overview folder with a png/mp4 summary of each subject-outfit-take longside_resolution: Optional[int] = 512 if longside_resolution is not None: output_root += f"-resized-{longside_resolution}" os.makedirs(output_root, exist_ok=True) subject_names = [ "00122", "00123", "00127", "00129", "00134", "00135", "00136", "00137", "00140", "00147", "00148", "00149", "00151", "00152", "00154", "00156", "00160", "00163", "00167", "00168", "00169", "00170", "00174", "00175", "00176", "00179", "00180", "00185", "00187", "00188", "00190", "00191", ] if create_overviews: crete_overview_pngs(dataset_root, subject_names, os.path.join(dataset_root, "overview-pngs")) crete_overview_mp4s(dataset_root, subject_names, os.path.join(dataset_root, "overview-mp4s")) for subject_name in tqdm.tqdm(subject_names): if "." in subject_name: continue for outfit_name in os.listdir(os.path.join(dataset_root, subject_name)): if outfit_name not in ["Inner", "Outer"]: continue for take_name in os.listdir(os.path.join(dataset_root, subject_name, outfit_name)): if "." in take_name: continue pkl_path = os.path.join(output_root, f"{subject_name}_{outfit_name}_{take_name}.pkl") extract_4d_dress_data( dataset_root=dataset_root, subject_name=subject_name, outfit_name=outfit_name, take_name=take_name, downscaled_longerside=longside_resolution, save_pkl_path=pkl_path, save_rerun_viz=True, stream_rerun_viz=False, skip_if_output_exists=True, ) ================================================ FILE: scripts/__init__.py ================================================ ================================================ FILE: scripts/compare_cdist-topk_against_pointops-knn.py ================================================ import time import torch from pointops import knn_query B, N, M, D, K = 12, 49152, 928, 3, 16 def knn_torch(k: int, xyz_ref: torch.Tensor, xyz_query: torch.Tensor): dists = torch.cdist(xyz_query, xyz_ref, p=2) # shape: (B, M, N) sorted_dists, indices = torch.topk(dists, k, dim=-1, largest=False, sorted=True) return sorted_dists, indices def knn_pointops(k: int, xyz_ref: torch.Tensor, xyz_query: torch.Tensor): B, N, _ = xyz_ref.shape _, M, _ = xyz_query.shape orig_dtype = xyz_ref.dtype xyz_ref_flat = xyz_ref.contiguous().view(B * N, 3).to(torch.float32) xyz_query_flat = xyz_query.contiguous().view(B * M, 3).to(torch.float32) offset = torch.arange(1, B + 1, device=xyz_ref.device) * N new_offset = torch.arange(1, B + 1, device=xyz_query.device) * M idx, dists = knn_query(k, xyz_ref_flat, offset, xyz_query_flat, new_offset) # Remap global indices to local per-batch idx = idx.view(B, M, k) idx = idx - (torch.arange(B, device=idx.device).view(B, 1, 1) * N) dists = dists.view(B, M, k).to(orig_dtype) return dists, idx def benchmark(fn, name, HALF_PRECISION=False, iters=100): total_time = 0.0 peak_memories = [] for _ in range(iters): xyz_ref = torch.randn(B, N, D, device="cuda") xyz_query = torch.randn(B, M, D, device="cuda") if HALF_PRECISION: xyz_ref = xyz_ref.half() xyz_query = xyz_query.half() fn(K, xyz_ref, xyz_query) # warm up torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() start = time.time() fn(K, xyz_ref, xyz_query) torch.cuda.synchronize() total_time += time.time() - start peak_memories.append(torch.cuda.max_memory_allocated() / 1e6) # MB avg_time = total_time / iters peak_memory_min = min(peak_memories) peak_memory_avg = sum(peak_memories) / len(peak_memories) peak_memory_max = max(peak_memories) print(f"{name:<24} | " f"Avg Time: {avg_time:.6f} s | " f"Peak Memory: {peak_memory_avg:>6.2f} MB (min: {peak_memory_min:>6.2f}, max: {peak_memory_max:>6.2f})") print("Benchmarking KNN with different methods (HALF_PRECISION=True):") benchmark(knn_torch, "torch.cdist+torch.topk", True) benchmark(knn_pointops, "pointops.knn_query", True) print("\nBenchmarking KNN with different methods (HALF_PRECISION=False):") benchmark(knn_torch, "torch.cdist+torch.topk", False) benchmark(knn_pointops, "pointops.knn_query", False) ================================================ FILE: scripts/dex_ycb_to_neus_format.py ================================================ """ Before running the script, you need to install the toolkit and other dependencies, as well as download the data and necessary MANO checkpoints/models. Install the toolkit and dependencies: ```sh # Create a new conda environment conda create -n dexycb python=3.9 conda activate dexycb conda install pytorch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 pytorch-cuda=12.1 -c pytorch -c nvidia conda install -c iopath iopath pip install --upgrade setuptools wheel pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py39_cu121_pyt241/download.html conda install ninja scipy matplotlib -c conda-forge pip install numpy==1.21.6 matplotlib==3.6 pandas==2.0 scikit-image scipy==1.11 rerun-sdk pyembree rtree --no-deps # Install dex-ycb-toolkit cd /home/frrajic/xode/03-macos/ git clone --recursive git@github.com:NVlabs/dex-ycb-toolkit.git cd dex-ycb-toolkit pip install -e . # Install bop_toolkit dependencies cd bop_toolkit pip install -r requirements.txt cd .. # Install manopth cd manopth pip install -e . cd .. # Make sure numpy version is not too high (so that np.bool is not deprecated) pip install numpy==1.21.6 matplotlib==3.6 pandas==2.0 scikit-image scipy==1.11 rerun-sdk pyembree rtree --no-deps ``` Download the DexYCB dataset from the [project site](https://dex-ycb.github.io): ```sh export DEX_YCB_DIR=/home/frrajic/xode/00-data/dex-january-2025 cd $DEX_YCB_DIR # 20200709-subject-01.tar.gz (12G) # 20200813-subject-02.tar.gz (12G) # 20200820-subject-03.tar.gz (12G) # 20200903-subject-04.tar.gz (12G) # 20200908-subject-05.tar.gz (12G) # 20200918-subject-06.tar.gz (12G) # 20200928-subject-07.tar.gz (12G) # 20201002-subject-08.tar.gz (12G) # 20201015-subject-09.tar.gz (12G) # 20201022-subject-10.tar.gz (12G) gdown --fuzzy https://drive.google.com/file/d/1Ehh92wDE3CWAiKG7E9E73HjN2Xk2XfEk/view?usp=sharing & gdown --fuzzy https://drive.google.com/file/d/1Uo7MLqTbXEa-8s7YQZ3duugJ1nXFEo62/view?usp=sharing & gdown --fuzzy https://drive.google.com/file/d/1FkUxas8sv8UcVGgAzmSZlJw1eI5W5CXq/view?usp=sharing & gdown --fuzzy https://drive.google.com/file/d/14up6qsTpvgEyqOQ5hir-QbjMB_dHfdpA/view?usp=sharing & gdown --fuzzy https://drive.google.com/file/d/1NBA_FPyGWOQF5-X9ueAat5g8lDMz-EmS/view?usp=sharing & gdown --fuzzy https://drive.google.com/file/d/1UWIN2-wOBZX2T0dkAi4ctAAW8KffkXMQ/view?usp=sharing & gdown --fuzzy https://drive.google.com/file/d/1oWEYD_o3PVh39pLzMlJcArkDtMj4nzI0/view?usp=sharing & gdown --fuzzy https://drive.google.com/file/d/1GTNZwhWbs7Mfez0krTgXwLPndvrw1Ztv/view?usp=sharing & gdown --fuzzy https://drive.google.com/file/d/1j0BLkaCjIuwjakmywKdOO9vynHTWR0UH/view?usp=sharing & gdown --fuzzy https://drive.google.com/file/d/1FvFlRfX-p5a5sAWoKEGc17zKJWwKaSB-/view?usp=sharing & # bop.tar.gz (1.2G) # calibration.tar.gz (16K) # models.tar.gz (1.4G) gdown --fuzzy https://drive.google.com/file/d/1CPqLjsaYNjE3xSJbuWmqaMsGvyGIxiKL/view?usp=sharing & gdown --fuzzy https://drive.google.com/file/d/1UAwVKT4Rgb1fLcFoa1o71_-0NtSvvLAQ/view?usp=sharing & gdown --fuzzy https://drive.google.com/file/d/1cAzlQBpcTatI5ykYQ8ziQiHLUG_a_UpM/view?usp=sharing & tar xvf 20200709-subject-01.tar.gz & tar xvf 20200813-subject-02.tar.gz & tar xvf 20200820-subject-03.tar.gz & tar xvf 20200903-subject-04.tar.gz & tar xvf 20200908-subject-05.tar.gz & tar xvf 20200918-subject-06.tar.gz & tar xvf 20200928-subject-07.tar.gz & tar xvf 20201002-subject-08.tar.gz & tar xvf 20201015-subject-09.tar.gz & tar xvf 20201022-subject-10.tar.gz & tar xvf bop.tar.gz & tar xvf calibration.tar.gz & tar xvf models.tar.gz & rm 20200709-subject-01.tar.gz rm 20200813-subject-02.tar.gz rm 20200820-subject-03.tar.gz rm 20200903-subject-04.tar.gz rm 20200908-subject-05.tar.gz rm 20200918-subject-06.tar.gz rm 20200928-subject-07.tar.gz rm 20201002-subject-08.tar.gz rm 20201015-subject-09.tar.gz rm 20201022-subject-10.tar.gz rm bop.tar.gz rm calibration.tar.gz rm models.tar.gz ``` The structure of the dataset should look like this: ```sh tree -L 1 $DEX_YCB_DIR # /home/frrajic/xode/00-data/dex-january-2025 # ├── 20200709-subject-01 # ├── 20200813-subject-02 # ├── 20200820-subject-03 # ├── 20200903-subject-04 # ├── 20200908-subject-05 # ├── 20200918-subject-06 # ├── 20200928-subject-07 # ├── 20201002-subject-08 # ├── 20201015-subject-09 # ├── 20201022-subject-10 # ├── bop # ├── calibration # └── models du -sch $DEX_YCB_DIR/* # 13G /home/frrajic/xode/00-data/dex-january-2025/20200709-subject-01 # 13G /home/frrajic/xode/00-data/dex-january-2025/20200813-subject-02 # 13G /home/frrajic/xode/00-data/dex-january-2025/20200820-subject-03 # 13G /home/frrajic/xode/00-data/dex-january-2025/20200903-subject-04 # 13G /home/frrajic/xode/00-data/dex-january-2025/20200908-subject-05 # 13G /home/frrajic/xode/00-data/dex-january-2025/20200918-subject-06 # 13G /home/frrajic/xode/00-data/dex-january-2025/20200928-subject-07 # 13G /home/frrajic/xode/00-data/dex-january-2025/20201002-subject-08 # 13G /home/frrajic/xode/00-data/dex-january-2025/20201015-subject-09 # 13G /home/frrajic/xode/00-data/dex-january-2025/20201022-subject-10 # 24G /home/frrajic/xode/00-data/dex-january-2025/bop # 200K /home/frrajic/xode/00-data/dex-january-2025/calibration # 3.5G /home/frrajic/xode/00-data/dex-january-2025/models # 154G total ``` Download MANO models and code (`mano_v1_2.zip`) from the [MANO website](https://mano.is.tue.mpg.de) and place the file under `manopath`. Unzip the file and create symlink: ```sh cd /home/frrajic/xode/03-macos/dex-ycb-toolkit cd manopth unzip mano_v1_2.zip cd mano ln -s ../mano_v1_2/models models cd ../.. ``` Finally, run the script: ```sh conda activate dexycb export DEX_YCB_DIR=/home/frrajic/xode/00-data/dex-january-2025 cd /home/frrajic/xode/03-macos/dex-ycb-toolkit python /home/frrajic/xode/03-macos/spatialtracker/scripts/dex_ycb_to_neus_format.py ``` """ import os import cv2 import imageio import math import matplotlib import matplotlib.pyplot as plt import numpy as np import open3d as o3d import open3d.visualization as vis import rerun as rr import torch import trimesh import yaml from dex_ycb_toolkit.layers.mano_group_layer import MANOGroupLayer from dex_ycb_toolkit.layers.ycb_group_layer import YCBGroupLayer from dex_ycb_toolkit.layers.ycb_layer import dcm2rv, rv2dcm from matplotlib import cm from matplotlib.cm import get_cmap from pytorch3d.renderer import ( MeshRasterizer, MeshRendererWithFragments, RasterizationSettings, SoftPhongShader, PointLights, ) from pytorch3d.renderer import TexturesVertex from pytorch3d.structures import Meshes from pytorch3d.utils.camera_conversions import cameras_from_opencv_projection from scipy.spatial.transform import Rotation as Rot from tqdm import tqdm def sample_surface(mesh: trimesh.Trimesh, count, face_weight=None, seed=None): """ Sample the surface of a mesh, returning the specified number of points For individual triangle sampling uses this method: http://mathworld.wolfram.com/TrianglePointPicking.html Adapted from: https://github.com/mikedh/trimesh/blob/a47b66d2d18404bc044aa9fcb983a80b1287919a/trimesh/sample.py#L23 Parameters ----------- mesh : trimesh.Trimesh Geometry to sample the surface of count : int Number of points to return face_weight : None or len(mesh.faces) float Weight faces by a factor other than face area. If None will be the same as face_weight=mesh.area seed : None or int If passed as an integer will provide deterministic results otherwise pulls the seed from operating system entropy. Returns --------- samples : (count, 3) float Points in space on the surface of mesh face_index : (count,) int Indices of faces for each sampled point colors : (count, 4) float Colors of each sampled point Returns only when the sample_color is True """ if face_weight is None: # len(mesh.faces) float, array of the areas # of each face of the mesh face_weight = mesh.area_faces # cumulative sum of weights (len(mesh.faces)) # cumulative sum of weights (len(mesh.faces)) weight_cum = np.cumsum(face_weight) # seed the random number generator as requested default_rng = np.random.default_rng random = default_rng(seed).random # last value of cumulative sum is total summed weight/area face_pick = random(count) * weight_cum[-1] # get the index of the selected faces picked_faces = np.searchsorted(weight_cum, face_pick) # pull triangles into the form of an origin + 2 vectors tri_origins = mesh.vertices[mesh.faces[:, 0]] tri_vectors = mesh.vertices[mesh.faces[:, 1:]].copy() tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3)) # pull the vectors for the faces we are going to sample from tri_origins = tri_origins[picked_faces] tri_vectors = tri_vectors[picked_faces] # randomly generate two 0-1 scalar components to multiply edge vectors b picked_weights = random((len(tri_vectors), 2, 1)) # points will be distributed on a quadrilateral if we use 2 0-1 samples # if the two scalar components sum less than 1.0 the point will be # inside the triangle, so we find vectors longer than 1.0 and # transform them to be inside the triangle outside_triangle = picked_weights.sum(axis=1).reshape(-1) > 1.0 picked_weights[outside_triangle] -= 1.0 picked_weights = np.abs(picked_weights) # multiply triangle edge vectors by the random lengths and sum sample_vector = (tri_vectors * picked_weights).sum(axis=1) # finally, offset by the origin to generate # (n,3) points in space on the triangle picked_points = sample_vector + tri_origins return picked_faces, picked_weights, picked_points def pick_points_from_mesh(mesh, picked_faces, picked_weights, reference_mesh): if reference_mesh is not None: # Number of vertices must match, but the 3D location of vertices can change assert reference_mesh.vertices.shape == mesh.vertices.shape, "Number of vertices must match" # The faces must be the same assert np.allclose(reference_mesh.faces, mesh.faces), "Faces must be the same" # pull triangles into the form of an origin + 2 vectors tri_origins = mesh.vertices[mesh.faces[:, 0]] tri_vectors = mesh.vertices[mesh.faces[:, 1:]].copy() tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3)) # pull the vectors for the faces we are going to sample from tri_origins = tri_origins[picked_faces] tri_vectors = tri_vectors[picked_faces] # multiply triangle edge vectors by the random lengths and sum sample_vector = (tri_vectors * picked_weights).sum(axis=1) picked_points = sample_vector + tri_origins return picked_points class SequenceLoader(): """DexYCB sequence loader.""" def __init__( self, name, device='cuda:0', preload=True, app='viewer', **kwargs, ): """Constructor. Args: name: Sequence name. device: A torch.device string argument. The specified device is used only for certain data loading computations, but not storing the loaded data. Currently the loaded data is always stored as numpy arrays on CPU. preload: Whether to preload the point cloud or load it online. app: 'viewer' or 'renderer'. """ assert device in ('cuda', 'cpu') or device.split(':')[0] == 'cuda' assert app in ('viewer', 'renderer', 'convert_to_neus') self._name = name self._device = torch.device(device) self._preload = preload self._app = app assert 'DEX_YCB_DIR' in os.environ, "environment variable 'DEX_YCB_DIR' is not set" self._dex_ycb_dir = os.environ['DEX_YCB_DIR'] # Load meta. meta_file = self._dex_ycb_dir + '/' + self._name + "/meta.yml" with open(meta_file, 'r') as f: meta = yaml.load(f, Loader=yaml.FullLoader) self._serials = meta['serials'] self._h = 480 self._w = 640 self._num_cameras = len(self._serials) self._data_dir = [ self._dex_ycb_dir + '/' + self._name + '/' + s for s in self._serials ] self._color_prefix = "color_" self._depth_prefix = "aligned_depth_to_color_" self._label_prefix = "labels_" self._num_frames = meta['num_frames'] self._ycb_ids = meta['ycb_ids'] self._mano_sides = meta['mano_sides'] # Load intrinsics. def intr_to_K(x): return torch.tensor( [[x['fx'], 0.0, x['ppx']], [0.0, x['fy'], x['ppy']], [0.0, 0.0, 1.0]], dtype=torch.float32, device=self._device) self._K = [] for s in self._serials: intr_file = self._dex_ycb_dir + "/calibration/intrinsics/" + s + '_' + str( self._w) + 'x' + str(self._h) + ".yml" with open(intr_file, 'r') as f: intr = yaml.load(f, Loader=yaml.FullLoader) K = intr_to_K(intr['color']) self._K.append(K) self._K_inv = [torch.inverse(k) for k in self._K] # Load extrinsics. extr_file = self._dex_ycb_dir + "/calibration/extrinsics_" + meta[ 'extrinsics'] + "/extrinsics.yml" with open(extr_file, 'r') as f: extr = yaml.load(f, Loader=yaml.FullLoader) T = extr['extrinsics'] T = { s: torch.tensor(T[s], dtype=torch.float32, device=self._device).view(3, 4) for s in T } self._R = [T[s][:, :3] for s in self._serials] self._t = [T[s][:, 3] for s in self._serials] self._R_inv = [torch.inverse(r) for r in self._R] self._t_inv = [torch.mv(r, -t) for r, t in zip(self._R_inv, self._t)] self._master_intrinsics = self._K[[ i for i, s in enumerate(self._serials) if s == extr['master'] ][0]].cpu().numpy() self._tag_R = T['apriltag'][:, :3] self._tag_t = T['apriltag'][:, 3] self._tag_R_inv = torch.inverse(self._tag_R) self._tag_t_inv = torch.mv(self._tag_R_inv, -self._tag_t) self._tag_lim = [-0.00, +1.20, -0.10, +0.70, -0.10, +0.70] # Compute texture coordinates. y, x = torch.meshgrid(torch.arange(self._h), torch.arange(self._w), indexing="ij") x = x.float() y = y.float() s = torch.stack((x / (self._w - 1), y / (self._h - 1)), dim=2) self._pcd_tex_coord = [s.numpy()] * self._num_cameras # Compute rays. self._p = [] ones = torch.ones((self._h, self._w), dtype=torch.float32) xy1s = torch.stack((x, y, ones), dim=2).view(self._w * self._h, 3).t() xy1s = xy1s.to(self._device) for c in range(self._num_cameras): p = torch.mm(self._K_inv[c], xy1s) self._p.append(p) # Load point cloud. if self._preload: print('Preloading point cloud') self._color = [] self._depth = [] for c in range(self._num_cameras): color = [] depth = [] for i in range(self._num_frames): rgb, d = self._load_frame_rgbd(c, i) color.append(rgb) depth.append(d) self._color.append(color) self._depth.append(depth) self._color = np.array(self._color, dtype=np.uint8) self._depth = np.array(self._depth, dtype=np.uint16) self._pcd_rgb = [x for x in self._color] self._pcd_vert = [] self._pcd_mask = [] for c in range(self._num_cameras): p, m = self._deproject_depth_and_filter_points(self._depth[c], c) self._pcd_vert.append(p) self._pcd_mask.append(m) else: print('Loading point cloud online') self._pcd_rgb = [ np.zeros((self._h, self._w, 3), dtype=np.uint8) for _ in range(self._num_cameras) ] self._pcd_vert = [ np.zeros((self._h, self._w, 3), dtype=np.float32) for _ in range(self._num_cameras) ] self._pcd_mask = [ np.zeros((self._h, self._w), dtype=np.bool) for _ in range(self._num_cameras) ] # Create YCB group layer. self._ycb_group_layer = YCBGroupLayer(self._ycb_ids).to(self._device) self._ycb_model_dir = self._dex_ycb_dir + "/models" self._ycb_count = self._ycb_group_layer.count self._ycb_material = self._ycb_group_layer.material self._ycb_tex_coords = self._ycb_group_layer.tex_coords # Create MANO group layer. mano_betas = [] for m in meta['mano_calib']: mano_calib_file = self._dex_ycb_dir + "/calibration/mano_" + m + "/mano.yml" with open(mano_calib_file, 'r') as f: mano_calib = yaml.load(f, Loader=yaml.FullLoader) betas = np.array(mano_calib['betas'], dtype=np.float32) mano_betas.append(betas) self._mano_group_layer = MANOGroupLayer(self._mano_sides, mano_betas).to(self._device) # Prepare data for viewer. if app == 'viewer': s = np.cumsum([0] + self._ycb_group_layer.count[:-1]) e = np.cumsum(self._ycb_group_layer.count) self._ycb_seg = list(zip(s, e)) ycb_file = self._dex_ycb_dir + '/' + self._name + "/pose.npz" data = np.load(ycb_file) ycb_pose = data['pose_y'] i = np.any(ycb_pose != [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], axis=2) pose = ycb_pose.reshape(-1, 7) v, n = self.transform_ycb(pose) self._ycb_vert = [ np.zeros((self._num_frames, n, 3), dtype=np.float32) for n in self._ycb_count ] self._ycb_norm = [ np.zeros((self._num_frames, n, 3), dtype=np.float32) for n in self._ycb_count ] for o in range(self._ycb_group_layer.num_obj): io = i[:, o] self._ycb_vert[o][io] = v[io, self._ycb_seg[o][0]:self._ycb_seg[o][1]] self._ycb_norm[o][io] = n[io, self._ycb_seg[o][0]:self._ycb_seg[o][1]] mano_file = self._dex_ycb_dir + '/' + self._name + "/pose.npz" data = np.load(mano_file) mano_pose = data['pose_m'] i = np.any(mano_pose != 0.0, axis=2) pose = torch.from_numpy(mano_pose).to(self._device) pose = pose.view(-1, self._mano_group_layer.num_obj * 51) verts, _ = self._mano_group_layer(pose) # Numpy array is faster than PyTorch Tensor here. verts = verts.cpu().numpy() f = self._mano_group_layer.f.cpu().numpy() v = verts[:, f.ravel()] n = np.cross(v[:, 1::3, :] - v[:, 0::3, :], v[:, 2::3, :] - v[:, 1::3, :]) n = np.repeat(n, 3, axis=1) l = verts[:, f[:, [0, 1, 1, 2, 2, 0]].ravel(), :] self._mano_vert = [ np.zeros((self._num_frames, 4614, 3), dtype=np.float32) for _ in range(self._mano_group_layer.num_obj) ] self._mano_norm = [ np.zeros((self._num_frames, 4614, 3), dtype=np.float32) for _ in range(self._mano_group_layer.num_obj) ] self._mano_line = [ np.zeros((self._num_frames, 9228, 3), dtype=np.float32) for _ in range(self._mano_group_layer.num_obj) ] for o in range(self._mano_group_layer.num_obj): io = i[:, o] self._mano_vert[o][io] = v[io, 4614 * o:4614 * (o + 1), :] self._mano_norm[o][io] = n[io, 4614 * o:4614 * (o + 1), :] self._mano_line[o][io] = l[io, 9228 * o:9228 * (o + 1), :] # Prepare data for renderer. if app == 'renderer': self._ycb_pose = [] self._mano_vert = [] self._mano_joint_3d = [] for c in range(self._num_cameras): ycb_pose = [] mano_pose = [] mano_joint_3d = [] for i in range(self._num_frames): label_file = self._data_dir[ c] + '/' + self._label_prefix + "{:06d}.npz".format(i) label = np.load(label_file) pose_y = np.hstack((label['pose_y'], np.array([[[0, 0, 0, 1]]] * len(label['pose_y']), dtype=np.float32))) pose_m = label['pose_m'] joint_3d = label['joint_3d'] ycb_pose.append(pose_y) mano_pose.append(pose_m) mano_joint_3d.append(joint_3d) ycb_pose = np.array(ycb_pose, dtype=np.float32) mano_pose = np.array(mano_pose, dtype=np.float32) mano_joint_3d = np.array(mano_joint_3d, dtype=np.float32) self._ycb_pose.append(ycb_pose) self._mano_joint_3d.append(mano_joint_3d) i = np.any(mano_pose != 0.0, axis=2) pose = torch.from_numpy(mano_pose).to(self._device) pose = pose.view(-1, self._mano_group_layer.num_obj * 51) verts, _ = self._mano_group_layer(pose) verts = verts.cpu().numpy() mano_vert = [ np.zeros((self._num_frames, 778, 3), dtype=np.float32) for _ in range(self._mano_group_layer.num_obj) ] for o in range(self._mano_group_layer.num_obj): io = i[:, o] mano_vert[o][io] = verts[io, 778 * o:778 * (o + 1), :] self._mano_vert.append(mano_vert) # Convert to Neus format. if app == "convert_to_neus": output_dataset_path = kwargs.get("output_dataset_path", "output_dataset") downscaling_factor = kwargs.get("downscaling_factor", 1) n_points = kwargs.get("n_points", 3_600) n_subsample = kwargs.get("n_subsample", 1) seed = kwargs.get("seed", 72) stream_rerun_viz = kwargs.get("stream_rerun_viz", False) save_rerun_viz = kwargs.get("save_rerun_viz", False) np.random.seed(seed) torch.manual_seed(seed) # Save camera centers as a .ply pointcloud, for debugging purposes. t_centered = torch.stack(self._t) - torch.tensor([0., 0., 1.]) # Move along z axis by -1 colors = cm.get_cmap('tab10')(np.linspace(0, 1, self._num_cameras))[:, :3] pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(t_centered.cpu().numpy()) pcd.colors = o3d.utility.Vector3dVector(colors) pcd_file = os.path.join(output_dataset_path, f"camera_center__{c:02d}_cameras.ply") o3d.io.write_point_cloud(pcd_file, pcd) # Create the view folders. for c in range(self._num_cameras): view_folder = os.path.join(output_dataset_path, f"view_{c:02d}") os.makedirs(view_folder, exist_ok=True) # Save the intrinsics.txt file. intrinsics_file = os.path.join(view_folder, "intrinsics.txt") intrinsics = np.zeros((4, 4), dtype=np.float32) intrinsics[:3, :3] = self._K[c].cpu().numpy() intrinsics[3, 3] = 1 intrinsics_str = '\n'.join([' '.join([str(x) for x in row]) for row in intrinsics]) with open(intrinsics_file, "w") as f: f.write(intrinsics_str) # Save the cameras_sphere.npz file. R = self._R t = self._t t_centered = [t_ - torch.tensor([0., 0., 1.]) for t_ in t] # Move along z axis by -1 R_inv = [torch.inverse(r) for r in R] t_centered_inv = [torch.mv(r, -t) for r, t in zip(R_inv, t_centered)] extrinsics = np.zeros((4, 4), dtype=np.float32) extrinsics[:3, :3] = R_inv[c].cpu().numpy() extrinsics[:3, 3] = t_centered_inv[c].cpu().numpy() extrinsics[3, 3] = 1 cameras_sphere_file = os.path.join(view_folder, "cameras_sphere.npz") cameras_sphere = { **{f'world_mat_{output_frame_id}': intrinsics @ extrinsics for output_frame_id in range(math.ceil(self._num_frames / n_subsample))}, **{f'scale_mat_{output_frame_id}': np.diag( [downscaling_factor, downscaling_factor, downscaling_factor, 1.0]) for output_frame_id in range(math.ceil(self._num_frames / n_subsample))} } np.savez_compressed(cameras_sphere_file, **cameras_sphere) # Also, save the intrinsics and extrinsics directly into a .npz file. camera_params_path = os.path.join(view_folder, "intrinsics_extrinsics.npz") np.savez_compressed(camera_params_path, intrinsics=intrinsics, extrinsics=extrinsics) # Save the rgb and depth images. And dummy masks. rgb_folder = os.path.join(view_folder, "rgb") depth_folder = os.path.join(view_folder, "depth") mask_folder = os.path.join(view_folder, "mask") rgb_with_valid_depth_folder = os.path.join(view_folder, "rgb_with_valid_depth") os.makedirs(rgb_folder, exist_ok=True) os.makedirs(depth_folder, exist_ok=True) os.makedirs(mask_folder, exist_ok=True) os.makedirs(rgb_with_valid_depth_folder, exist_ok=True) for output_frame_id in range(math.ceil(self._num_frames / n_subsample)): input_frame_id = output_frame_id * n_subsample rgb = self._color[c][input_frame_id][:, :, ::-1] rgb_file = os.path.join(rgb_folder, f"{output_frame_id:05d}.png") cv2.imwrite(rgb_file, rgb) depth = self._depth[c][input_frame_id] depth_file = os.path.join(depth_folder, f"{output_frame_id:05d}.png") cv2.imwrite(depth_file, depth) rgb_plot = rgb.copy() rgb_plot[depth == 0] = 255 cv2.imwrite(os.path.join(rgb_with_valid_depth_folder, f"{output_frame_id:05d}.png"), rgb_plot) label_file = self._data_dir[c] + '/' + self._label_prefix + "{:06d}.npz".format(input_frame_id) label = np.load(label_file) seg_mask = label["seg"] mask = seg_mask != 0 # Everything that is not background mask = mask[:, :, None].astype(np.uint8).repeat(3, 2) * 255 # dummy_mask = np.ones((self._h, self._w, 3)).astype(np.uint8) * 255 # mask = dummy_mask mask_file = os.path.join(mask_folder, f"{output_frame_id:05d}.png") imageio.imwrite(mask_file, mask) # Backproject the depth image to 3D points for visualization purposes. if output_frame_id in [0, math.ceil(self._num_frames / n_subsample) - 1] and c in range( self._num_cameras): d = self._depth[c][input_frame_id] d = d.astype(np.float32) / 1000 d = torch.from_numpy(d).to(self._device) p = torch.mul( d.view(1, -1, self._w * self._h).expand(3, -1, -1), self._p[c].unsqueeze(1)) p = torch.addmm(self._t[c].unsqueeze(1), self._R[c], p.view(3, -1)) p = p.t().view(self._h, self._w, 3) p = p.cpu().numpy() m = d > 0 p = p[m] colors = self._color[c][input_frame_id][m] / 255 pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(p) pcd.colors = o3d.utility.Vector3dVector(colors) pcd_file = os.path.join(view_folder, f"pcd_for_t{output_frame_id:03d}.ply") o3d.io.write_point_cloud(pcd_file, pcd) # Compute meshes for each frame. s = np.cumsum([0] + self._ycb_group_layer.count[:-1]) e = np.cumsum(self._ycb_group_layer.count) self._ycb_seg = list(zip(s, e)) ycb_file = self._dex_ycb_dir + '/' + self._name + "/pose.npz" data = np.load(ycb_file) ycb_pose = data['pose_y'][::n_subsample] i = np.any(ycb_pose != [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], axis=2) pose = ycb_pose.reshape(-1, 7) v, n = self.transform_ycb(pose) self._ycb_vert = [ np.zeros((math.ceil(self._num_frames / n_subsample), n, 3), dtype=np.float32) for n in self._ycb_count ] self._ycb_norm = [ np.zeros((math.ceil(self._num_frames / n_subsample), n, 3), dtype=np.float32) for n in self._ycb_count ] for o in range(self._ycb_group_layer.num_obj): io = i[:, o] self._ycb_vert[o][io] = v[io, self._ycb_seg[o][0]:self._ycb_seg[o][1]] self._ycb_norm[o][io] = n[io, self._ycb_seg[o][0]:self._ycb_seg[o][1]] self._ycb_faces = [ np.arange(n).reshape(-1, 3) for n in self._ycb_count ] mano_file = self._dex_ycb_dir + '/' + self._name + "/pose.npz" data = np.load(mano_file) mano_pose = data['pose_m'][::n_subsample] i = np.any(mano_pose != 0.0, axis=2) pose = torch.from_numpy(mano_pose).to(self._device) pose = pose.view(-1, self._mano_group_layer.num_obj * 51) verts, _ = self._mano_group_layer(pose) # Numpy array is faster than PyTorch Tensor here. verts = verts.cpu().numpy() f = self._mano_group_layer.f.cpu().numpy() v = verts[:, f.ravel()] n = np.cross(v[:, 1::3, :] - v[:, 0::3, :], v[:, 2::3, :] - v[:, 1::3, :]) n = np.repeat(n, 3, axis=1) l = verts[:, f[:, [0, 1, 1, 2, 2, 0]].ravel(), :] self._mano_vert = [ np.zeros((math.ceil(self._num_frames / n_subsample), 4614, 3), dtype=np.float32) for _ in range(self._mano_group_layer.num_obj) ] self._mano_norm = [ np.zeros((math.ceil(self._num_frames / n_subsample), 4614, 3), dtype=np.float32) for _ in range(self._mano_group_layer.num_obj) ] self._mano_line = [ np.zeros((math.ceil(self._num_frames / n_subsample), 9228, 3), dtype=np.float32) for _ in range(self._mano_group_layer.num_obj) ] self._mano_faces = [ np.arange(4614).reshape(-1, 3) for _ in range(self._mano_group_layer.num_obj) ] for o in range(self._mano_group_layer.num_obj): io = i[:, o] self._mano_vert[o][io] = v[io, 4614 * o:4614 * (o + 1), :] self._mano_norm[o][io] = n[io, 4614 * o:4614 * (o + 1), :] self._mano_line[o][io] = l[io, 9228 * o:9228 * (o + 1), :] vert = [] vert += self._ycb_vert vert += self._mano_vert norm = [] norm += self._ycb_norm norm += self._mano_norm ids = [] ids += self._ycb_group_layer._ids ids += [255 for _ in self._mano_group_layer._sides] names = [] names += ["ycb-" + layer._class_name for layer in self._ycb_group_layer._layers] names += [f"mano-{side}-hand" for side in self._mano_group_layer._sides] faces = [] faces += self._ycb_faces faces += self._mano_faces print(f"Number of meshes: {len(vert)}") assert len(vert) == len(norm) == len(faces) == len(ids) == len(names) print(f"Mesh names: {names}") print(f"Mesh IDS: {ids}") all_vertices = np.concatenate(vert, axis=1) all_normals = np.concatenate(norm, axis=1) all_faces = np.concatenate([ f + np.sum([v.shape[1] for v in vert[:i]]).astype(np.uint32) for i, f in enumerate(faces) ]) all_ids = np.concatenate([np.full(v.shape[1], i) for i, v in enumerate(vert)]) assert all_vertices.shape[0] == all_normals.shape[0] assert all_vertices.shape[1] == all_normals.shape[1] == all_faces.shape[0] * 3 == all_ids.shape[0] assert all_faces.max() + 1 == all_vertices.shape[1] print(f"all_vertices.shape: {all_vertices.shape}") print(f"all_normals.shape: {all_normals.shape}") print(f"all_faces.shape: {all_faces.shape}") print(f"all_ids.shape: {all_ids.shape}") n_frames = all_vertices.shape[0] meshes = [ trimesh.Trimesh( vertices=all_vertices[frame_idx], faces=all_faces, vertex_normals=all_normals[frame_idx], process=False, ) for frame_idx in range(n_frames) ] # Put the query points onto the frame where the hand is first visible hands_visible = np.any(mano_pose != 0.0, axis=2).all(axis=1) assert np.any(hands_visible), "Hands must be visible in at least one frame" t0 = np.argmax(hands_visible, axis=0) objects_visible = np.any(ycb_pose != [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], axis=2).all(axis=1) assert objects_visible[t0], "Objects must be visible in the first frame where the hands are visible" picked_faces, picked_weights, picked_points = sample_surface(meshes[t0], n_points, seed=seed) assert np.allclose(picked_points, pick_points_from_mesh(meshes[t0], picked_faces, picked_weights, meshes[t0])) picked_vertices = meshes[t0].faces[:, 0][picked_faces] picked_ids = all_ids[picked_vertices] # Track the points tracks_3d = [] for frame_idx in range(n_frames): points = pick_points_from_mesh(meshes[frame_idx], picked_faces, picked_weights, meshes[t0]) tracks_3d.append(points) if frame_idx == t0: assert np.allclose(points, picked_points) tracks_3d = np.stack(tracks_3d) # (n_frames, n_points, 3) # Project the points to the camera tracks_2d = [] tracks_2d_z = [] for c in range(self._num_cameras): p = torch.from_numpy(tracks_3d).to(self._device).T.reshape(3, -1) p = self._R_inv[c].double() @ p + self._t_inv[c][:, None] p = self._K[c].double() @ p z = p[2] p = p[:2] / z p = p.cpu().numpy().reshape(2, n_points, math.ceil(self._num_frames / n_subsample)).T z = z.cpu().numpy().reshape(n_points, math.ceil(self._num_frames / n_subsample)).T tracks_2d.append(p) tracks_2d_z.append(z) tracks_2d = np.stack(tracks_2d) tracks_2d_z = np.stack(tracks_2d_z) # --- Estimate occlusion rendered_depth = [] for c in range(self._num_cameras): rendered_depth_camera = [] for frame_idx in range(n_frames): rgb = self._color[c][0] depth = (self._depth[c][0] / 1000).clip(0, 2) h, w = self._h, self._w K = self._K[c].cpu().numpy() w2c = np.eye(4, dtype=float) w2c[:3, :3] = self._R_inv[c].cpu().numpy() w2c[:3, 3] = self._t_inv[c].cpu().numpy() c2w = np.linalg.inv(w2c) # Render depth device = "cuda" vertices = torch.tensor(all_vertices[frame_idx], dtype=torch.float32).to(device) faces = torch.tensor(all_faces, dtype=torch.int64).to(device) vertex_colors = torch.ones_like(vertices).unsqueeze(0).to(device) textures = TexturesVertex(verts_features=vertex_colors) mesh = Meshes(verts=[vertices], faces=[faces], textures=textures) intrinsics = torch.eye(4, dtype=torch.float32).to(device) intrinsics[:3, :3] = torch.from_numpy(K) cameras = cameras_from_opencv_projection( R=torch.from_numpy(w2c[:3, :3]).to(device)[None].float(), tvec=torch.from_numpy(w2c[:3, 3]).to(device)[None].float(), camera_matrix=self._K[c].to(device)[None].float(), image_size=torch.tensor([self._h, self._w], dtype=torch.int32).to(device)[None].float(), ) raster_settings = RasterizationSettings( image_size=(self._h, self._w), blur_radius=0.0, faces_per_pixel=1, bin_size=0, ) renderer = MeshRendererWithFragments( rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), shader=SoftPhongShader(device=device, cameras=cameras, lights=PointLights(device=device)), ) images, fragments = renderer(mesh) depth_map = fragments.zbuf rendered_depth_camera.append(depth_map.cpu().numpy()[0, :, :, 0]) rendered_depth.append(rendered_depth_camera) rendered_depth = np.stack(rendered_depth) assert rendered_depth.shape == (self._num_cameras, n_frames, self._h, self._w) seg_masks = [] for c in range(self._num_cameras): seg_masks_camera = [] for frame_idx in range(n_frames): input_frame_id = frame_idx * n_subsample label_file = self._data_dir[c] + '/' + self._label_prefix + "{:06d}.npz".format(input_frame_id) label = np.load(label_file) seg_masks_camera.append(label["seg"]) seg_masks.append(seg_masks_camera) seg_masks = np.stack(seg_masks) assert seg_masks.shape == (self._num_cameras, n_frames, self._h, self._w) seg_unique = np.unique(seg_masks) cmap = get_cmap("tab10") seg_masks_rgb = np.zeros((*seg_masks.shape, 3), dtype=np.uint8) for idx, val in enumerate(seg_unique): seg_masks_rgb[seg_masks == val] = (np.array(cmap(idx / len(seg_unique))[:3]) * 255).astype(np.uint8) assert seg_masks_rgb.shape == (self._num_cameras, n_frames, self._h, self._w, 3) def estimate_occlusion_by_depth_and_segment( depth_map, x, y, num_frames, thresh, seg_id=None, segments=None, min_or_max_reduce="max", convert_to_pixel_coords=True, occlude_if_depth_larger_than_xxx=None, ): # need to convert from raster to pixel coordinates if convert_to_pixel_coords: x = x - 0.5 y = y - 0.5 x0 = np.floor(x).astype(np.int32) x1 = x0 + 1 y0 = np.floor(y).astype(np.int32) y1 = y0 + 1 shp = depth_map.shape assert len(depth_map.shape) == 3 x0 = np.clip(x0, 0, shp[2] - 1) x1 = np.clip(x1, 0, shp[2] - 1) y0 = np.clip(y0, 0, shp[1] - 1) y1 = np.clip(y1, 0, shp[1] - 1) depth_map = depth_map.reshape(-1) rng = np.arange(num_frames)[:, np.newaxis] assert x.shape[0] == y.shape[0] == num_frames i1 = np.take(depth_map, rng * shp[1] * shp[2] + y0 * shp[2] + x0) i2 = np.take(depth_map, rng * shp[1] * shp[2] + y1 * shp[2] + x0) i3 = np.take(depth_map, rng * shp[1] * shp[2] + y0 * shp[2] + x1) i4 = np.take(depth_map, rng * shp[1] * shp[2] + y1 * shp[2] + x1) if min_or_max_reduce == "max": depth = np.maximum(np.maximum(np.maximum(i1, i2), i3), i4) elif min_or_max_reduce == "min": depth = np.minimum(np.minimum(np.minimum(i1, i2), i3), i4) else: raise ValueError(f"Unknown min_or_max_reduce: {min_or_max_reduce}") if occlude_if_depth_larger_than_xxx is not None: depth[depth >= occlude_if_depth_larger_than_xxx] = 0 depth_occluded = depth < thresh print("┌ Depth occlusion: ", depth_occluded.sum(), "/", depth_occluded.size) occluded = depth_occluded if segments is not None: segments = segments.reshape(-1) i1 = np.take(segments, rng * shp[1] * shp[2] + y0 * shp[2] + x0) i2 = np.take(segments, rng * shp[1] * shp[2] + y1 * shp[2] + x0) i3 = np.take(segments, rng * shp[1] * shp[2] + y0 * shp[2] + x1) i4 = np.take(segments, rng * shp[1] * shp[2] + y1 * shp[2] + x1) seg_occluded = np.ones_like(depth_occluded, dtype=bool) for i in [i1, i2, i3, i4]: i = i.astype(np.int32) seg_occluded = np.logical_and(seg_occluded, seg_id != i) print("| Segmentation occlusion: ", seg_occluded.sum(), "/", seg_occluded.size) occluded = np.logical_or(occluded, seg_occluded) return occluded tracks_2d_visibilities = [] for c in range(self._num_cameras): occlusion = np.zeros((tracks_2d[c].shape[0], tracks_2d[c].shape[1]), dtype=bool) print(f"N occluded: {occlusion.sum()} / {occlusion.size}") occlusion = np.logical_or(occlusion, (tracks_2d_z[c] <= 0) | (tracks_2d_z[c] >= (65535 / 1000))) print(f"N occluded (after Z): {occlusion.sum()} / {occlusion.size}") occlusion = np.logical_or(occlusion, tracks_2d[c][:, :, 0] <= 0) occlusion = np.logical_or(occlusion, tracks_2d[c][:, :, 1] <= 0) occlusion = np.logical_or(occlusion, tracks_2d[c][:, :, 0] >= self._w - 1) occlusion = np.logical_or(occlusion, tracks_2d[c][:, :, 1] >= self._h - 1) print(f"N occluded (& out-of-frame): {occlusion.sum()} / {occlusion.size}") # # V1: Use the depth map to estimate occlusion # depth_map_for_occlusion = self._depth[c][::n_subsample].copy() # depth_map_for_occlusion[depth_map_for_occlusion == 0] = 65535 # depth_map_for_occlusion = depth_map_for_occlusion / 1000.0 # # V2: Make the depth for occlussion be the depth from projected predicted points, taking the minimum z over all points at a pixel # depth_map_for_occlusion = np.ones((tracks_2d_z.shape[1], self._h, self._w), # dtype=np.float32) * 65535 / 1000 # for frame_idx in range(math.ceil(self._num_frames / n_subsample)): # for point_idx in range(n_points): # if np.isnan(tracks_2d[c][frame_idx, point_idx]).any(): # continue # x = int(tracks_2d[c][frame_idx, point_idx, 0]) # y = int(tracks_2d[c][frame_idx, point_idx, 1]) # z = tracks_2d_z[c][frame_idx, point_idx] # if 0 <= x < self._w and 0 <= y < self._h: # depth_map_for_occlusion[frame_idx, y - 3:y + 3, x - 3:x + 3] = np.minimum( # depth_map_for_occlusion[frame_idx, y - 3:y + 3, x - 3:x + 3], # z, # ) # # Visualize it side by side with GT depth # if False: # for frame_idx in range(math.ceil(self._num_frames / n_subsample)): # if frame_idx not in [0, math.ceil(self._num_frames / n_subsample) - 1]: # continue # d1 = self._depth[c][frame_idx * n_subsample] / 1000 # d2 = depth_map_for_occlusion[frame_idx] # d12 = np.concatenate([d1, d2], axis=1) # plt.figure(dpi=150, figsize=(d12.shape[1] / 100, d12.shape[0] / 100)) # plt.title(f"Depth GT (left) vs Depth used for occlusion (right), frame {frame_idx}") # plt.imshow(d12.clip(0.5, 1)) # plt.axis('off') # plt.tight_layout(pad=0) # plt.savefig(os.path.join(output_dataset_path, # f"depth_used_for_occlussion_view_{c:02d}_frame_{frame_idx:05d}.png")) # # plt.show() # # seg_mask = [] # for output_frame_id in range(math.ceil(self._num_frames / n_subsample)): # input_frame_id = output_frame_id * n_subsample # label_file = self._data_dir[c] + '/' + self._label_prefix + "{:06d}.npz".format(input_frame_id) # label = np.load(label_file) # seg_mask.append(label["seg"]) # seg_mask = np.stack(seg_mask) # depth_or_segment_occluded = estimate_occlusion_by_depth_and_segment( # depth_map=depth_map_for_occlusion, # segments=seg_mask, # x=tracks_2d[c][:, :, 0], # y=tracks_2d[c][:, :, 1], # num_frames=tracks_2d[c].shape[0], # thresh=tracks_2d_z[c] * 0.995, # seg_id=np.array(ids)[picked_ids], # ) # occlusion = np.logical_or(occlusion, depth_or_segment_occluded) # print(f"N occluded (& obscured by other objects): {occlusion.sum()} / {occlusion.size}") # print() # tracks_2d_visibilities.append(~occlusion) # # V3.a: Neither the GT depth nor the segmentation mask are reliable for occlusion estimation. # # Instead, we will use the rendered depth map, with a little help from the GT depth. # # First, the rendered depth needs to match the point depth, if not, the point is occluded. # # This will work perfectly for all the objects that have a full 3D mesh over time. So all # # the objects on the table, plus the MANO hand (but not the arm). This is susceptible to # # errors in estimating the mesh location, but it should be less problematic than the other # # segmentation mask and GT depth in that it will be less noisy and more consistent over time. # rendered_depth_for_occlusion = rendered_depth[c].copy() # rendered_depth_for_occlusion[rendered_depth_for_occlusion <= 0] = 65535 / 1000 # depth_or_segment_occluded = estimate_occlusion_by_depth_and_segment( # depth_map=rendered_depth_for_occlusion, # x=tracks_2d[c, :, :, 0], # y=tracks_2d[c, :, :, 1], # num_frames=n_frames, # thresh=tracks_2d_z[c, :, :] - 0.01, # min_or_max_reduce="min", # convert_to_pixel_coords=False, # occlude_if_depth_larger_than_xxx=65535 / 1000, # ) # occlusion = np.logical_or(occlusion, depth_or_segment_occluded) # print(f"N occluded (& obscured in rendered depth): {occlusion.sum()} / {occlusion.size}") # print() # # # # V3.b: Second, to avoid occlusion by the arm, we will use the GT depth map but with a high threshold. # # # This will avoid the arm occluding the points as the arm is not in the rendered depth map. # # depth_or_segment_occluded = estimate_occlusion_by_depth_and_segment( # # depth_map=self._depth[c][::n_subsample] / 1000, # # x=tracks_2d[c, :, :, 0], # # y=tracks_2d[c, :, :, 1], # # num_frames=n_frames, # # thresh=tracks_2d_z[c, :, :] * 0.995, # # ) # # occlusion = np.logical_or(occlusion, depth_or_segment_occluded) # # print(f"N occluded (& obscured in GT depth): {occlusion.sum()} / {occlusion.size}") # # print() # V4.a: Forget the rendered depths, it's still difficult because the depth map is pixelized. # Instead, let's shoot rays from the camera onto the scene mesh and see where they intersect. # It is very very slow but most accurate. camera_center = self._t[c].cpu().numpy() for frame_idx in range(n_frames): for track_idx in tqdm(range(n_points), desc=f"Ray casting for camera {c} frame {frame_idx}"): if occlusion[frame_idx, track_idx]: continue ray_direction = tracks_3d[frame_idx, track_idx] - camera_center ray_direction /= np.linalg.norm(ray_direction) intersections = meshes[frame_idx].ray.intersects_location(camera_center[None], ray_direction[None]) if len(intersections[0]) == 0: occlusion[frame_idx, track_idx] = True continue intersection_depth = np.inf for intersection in intersections[0]: intersection_depth = min(intersection_depth, np.linalg.norm(intersection - camera_center)) track_depth = np.linalg.norm(tracks_3d[frame_idx, track_idx] - camera_center) occlusion[frame_idx, track_idx] = not np.isclose(intersection_depth, track_depth, atol=0.001) print(f"N occluded (& obscured in scene mesh): {occlusion.sum()} / {occlusion.size}") print() # V4.b: The arm is not in the scene mesh and it is causing problems for 1/2 cameras. Let's use the # GT depths to figure out if the arm is occluding the points. Unfortunately, this will also not # work perfectly because the GT depths are missing around silhouette edges. depth_map_for_occlusion = self._depth[c][::n_subsample].copy() depth_map_for_occlusion[depth_map_for_occlusion <= 0] = 65535 depth_map_for_occlusion = depth_map_for_occlusion / 1000.0 depth_or_segment_occluded = estimate_occlusion_by_depth_and_segment( depth_map=depth_map_for_occlusion, x=tracks_2d[c, :, :, 0], y=tracks_2d[c, :, :, 1], num_frames=n_frames, thresh=tracks_2d_z[c, :, :] - 0.12, min_or_max_reduce="min", convert_to_pixel_coords=False, ) occlusion = np.logical_or(occlusion, depth_or_segment_occluded) print(f"N occluded (& obscured in GT depth): {occlusion.sum()} / {occlusion.size}") print() # Idea for V5: Do V4 and additionally try looking at if the RGB changed a lot for the point. # If it did, then it is likely that the point is occluded by the arm/person. # However, this might suffer from the same problem as in V3: edges would be noisy # and might quickly jump from visible to occluded to visible again. ... tracks_2d_visibilities.append(~occlusion) tracks_2d_visibilities = np.stack(tracks_2d_visibilities) tracks_3d_visibilities = tracks_2d_visibilities.any(axis=0) if stream_rerun_viz or save_rerun_viz: assert not (stream_rerun_viz and save_rerun_viz), ("Stream and save rerun at the same time not " "supported. But you can save what was streamed " "within the rerun viewer. Or run again. Or impl it.") rr.init("dexycb_preprocessing", recording_id="v0.1") if stream_rerun_viz: rr.connect_tcp() rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True) rr.set_time_seconds("frame", 0) rr.log( "world/xyz", rr.Arrows3D( vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]], colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]], ), ) entity_prefix = f"{os.path.basename(output_dataset_path)}/" radii_scale = 0.1 for t in range(n_frames): t_input = t * n_subsample rr.set_time_seconds("frame", t / 12) rr.log(f"{entity_prefix}mesh", rr.Mesh3D( vertex_positions=np.asarray(meshes[frame_idx].vertices), triangle_indices=np.asarray(meshes[frame_idx].faces), )) for c in range(self._num_cameras): rgb = self._color[c, t_input] depth = (self._depth[c, t_input] / 1000).clip(0, 2) rend_depth = rendered_depth[c, t].clip(0, 2) seg_mask = seg_masks_rgb[c, t] seg_rgb = seg_masks_rgb[c, t] h, w = self._h, self._w K = self._K[c].cpu().numpy() K_inv = np.linalg.inv(K) w2c = np.eye(4, dtype=float) w2c[:3, :3] = self._R_inv[c].cpu().numpy() w2c[:3, 3] = self._t_inv[c].cpu().numpy() c2w = np.linalg.inv(w2c) cam_pinhole = rr.Pinhole(image_from_camera=K, width=w, height=h) cam_transform = rr.Transform3D(translation=c2w[:3, 3], mat3x3=c2w[:3, :3]) for name, archetype in [ ("rgb", rr.Image(rgb)), ("seg", rr.Image(seg_rgb)), ("depth-gt", rr.DepthImage(depth, point_fill_ratio=0.2)), ("depth-rendered", rr.DepthImage(rend_depth, point_fill_ratio=0.2)), ]: rr.log(f"{entity_prefix}/image/{name}/view-{c:02d}", cam_pinhole) rr.log(f"{entity_prefix}/image/{name}/view-{c:02d}", cam_transform) rr.log(f"{entity_prefix}/image/{name}/view-{c:02d}/{name}", archetype) # Compute 3D points from GT depth map y, x = np.indices((self._h, self._w)) homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T cam_coords = (K_inv @ homo_pixel_coords) * depth.ravel() cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1])))) world_coords = (c2w @ cam_coords)[:3].T valid_mask = depth.ravel() > 0 rr.log(f"{entity_prefix}point_cloud/rgb-gt/view-{c}", rr.Points3D( positions=world_coords[valid_mask], colors=rgb.reshape(-1, 3)[valid_mask].astype(np.uint8), radii=0.01 * radii_scale, )) rr.log(f"{entity_prefix}point_cloud/seg-gt/view-{c}", rr.Points3D( positions=world_coords[valid_mask], colors=seg_rgb.reshape(-1, 3)[valid_mask].astype(np.uint8), radii=0.01 * radii_scale, )) # Compute 3D points from GT depth map y, x = np.indices((self._h, self._w)) homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T cam_coords = (K_inv @ homo_pixel_coords) * rend_depth.ravel() cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1])))) world_coords = (c2w @ cam_coords)[:3].T valid_mask = rend_depth.ravel() > 0 rr.log(f"{entity_prefix}point_cloud/rgb-rend/view-{c}", rr.Points3D( positions=world_coords[valid_mask], colors=rgb.reshape(-1, 3)[valid_mask].astype(np.uint8), radii=0.01 * radii_scale, )) rr.log(f"{entity_prefix}point_cloud/seg-rend/view-{c}", rr.Points3D( positions=world_coords[valid_mask], colors=seg_rgb.reshape(-1, 3)[valid_mask].astype(np.uint8), radii=0.01 * radii_scale, )) def log_tracks( tracks: np.ndarray, visibles: np.ndarray, query_timestep: np.ndarray, colors: np.ndarray, entity_format_str="{}", log_points=True, points_radii=0.03 * radii_scale, invisible_color=[0., 0., 0.], log_line_strips=True, max_strip_length_past=12, max_strip_length_future=1, strips_radii=0.0042 * radii_scale, log_error_lines=False, error_lines_radii=0.0072 * radii_scale, error_lines_color=[1., 0., 0.], gt_for_error_lines=None, ) -> None: """ Log tracks to Rerun. Parameters: tracks: Shape (T, N, 3), the 3D trajectories of points. visibles: Shape (T, N), boolean visibility mask for each point at each timestep. query_timestep: Shape (T, N), the frame index after which the tracks start. colors: Shape (N, 4), RGBA colors for each point. entity_prefix: String prefix for entity hierarchy in Rerun. entity_suffix: String suffix for entity hierarchy in Rerun. """ T, N, _ = tracks.shape assert tracks.shape == (T, N, 3) assert visibles.shape == (T, N) assert query_timestep.shape == (N,) assert query_timestep.min() >= 0 assert query_timestep.max() < T assert colors.shape == (N, 4) for n in range(N): rr.log(entity_format_str.format(f"track-{n}"), rr.Clear(recursive=True)) for t in range(query_timestep[n], T): rr.set_time_seconds("frame", t / 12) # Log the point (special handling for invisible points) if log_points: rr.log( entity_format_str.format(f"track-{n}/point"), rr.Points3D( positions=[tracks[t, n]], colors=[colors[n, :3]] if visibles[t, n] else [invisible_color], radii=points_radii, ), ) # Log line segments for visible tracks if log_line_strips and t > query_timestep[n]: strip_t_start = max(t - max_strip_length_past, query_timestep[n].item()) strip_t_end = min(t + max_strip_length_future, T - 1) strips = np.stack([ tracks[strip_t_start:strip_t_end, n], tracks[strip_t_start + 1:strip_t_end + 1, n], ], axis=-2) strips_visibility = visibles[strip_t_start + 1:strip_t_end + 1, n] strips_colors = np.where( strips_visibility[:, None], colors[None, n, :3], [invisible_color], ) rr.log( entity_format_str.format(f"track-{n}/line"), rr.LineStrips3D(strips=strips, colors=strips_colors, radii=strips_radii), ) if log_error_lines: assert gt_for_error_lines is not None strips = np.stack([ tracks[t, n], gt_for_error_lines[t, n], ], axis=-2) rr.log( entity_format_str.format(f"track-{n}/error"), rr.LineStrips3D(strips=strips, colors=error_lines_color, radii=error_lines_radii), ) # Log the tracks cmap = matplotlib.colormaps["gist_rainbow"] norm = matplotlib.colors.Normalize(vmin=tracks_3d[..., 0].min(), vmax=tracks_3d[..., 0].max()) track_color = cmap(norm(tracks_3d[-1, :, 0])) track_color = track_color * 0 + 1 # Just make all tracks white N = 800 B = 200 for tracks_batch_start in range(0, N, B): tracks_batch_end = min(tracks_batch_start + B, N) for name, visibles in [ ("tracks/c01234567-visibility", tracks_2d_visibilities.any(0)[:, tracks_batch_start:tracks_batch_end]), ("tracks/c0123-visibility", tracks_2d_visibilities.any(0)[:, tracks_batch_start:tracks_batch_end]), ("tracks/c2345-visibility", tracks_2d_visibilities.any(0)[:, tracks_batch_start:tracks_batch_end]), ("tracks/c0-visibility", tracks_2d_visibilities[0, :, tracks_batch_start:tracks_batch_end]), ("tracks/c1-visibility", tracks_2d_visibilities[1, :, tracks_batch_start:tracks_batch_end]), ("tracks/c2-visibility", tracks_2d_visibilities[2, :, tracks_batch_start:tracks_batch_end]), ("tracks/c3-visibility", tracks_2d_visibilities[3, :, tracks_batch_start:tracks_batch_end]), ("tracks/c4-visibility", tracks_2d_visibilities[4, :, tracks_batch_start:tracks_batch_end]), ("tracks/c5-visibility", tracks_2d_visibilities[5, :, tracks_batch_start:tracks_batch_end]), ("tracks/c6-visibility", tracks_2d_visibilities[6, :, tracks_batch_start:tracks_batch_end]), ("tracks/c7-visibility", tracks_2d_visibilities[7, :, tracks_batch_start:tracks_batch_end]), ]: log_tracks( tracks=tracks_3d[:, tracks_batch_start:tracks_batch_end], visibles=visibles, query_timestep=visibles.argmax(axis=0), colors=track_color[tracks_batch_start:tracks_batch_end], entity_format_str=f"{entity_prefix}/{name}/{tracks_batch_start:02d}-{tracks_batch_end:02d}/{{}}", max_strip_length_future=0, ) if save_rerun_viz: rr_rrd_path = os.path.join(output_dataset_path, f"rerun_viz.rrd") rr.save(rr_rrd_path) print(f"Saved Rerun recording to: {os.path.abspath(rr_rrd_path)}") # import pydevd_pycharm # pydevd_pycharm.settrace('localhost', port=51234, stdoutToServer=True, stderrToServer=True) # Save the tracks tracks_3d_file = os.path.join(output_dataset_path, "tracks_3d.npz") np.savez( tracks_3d_file, tracks_3d=(tracks_3d - np.array([0., 0., 1.])) / DOWNSCALING_FACTOR, tracks_3d_visibilities=tracks_3d_visibilities, object_ids=np.array(ids)[picked_ids], object_id_to_name={i: name for i, name in zip(ids, names)}, tracks_2d=tracks_2d, tracks_2d_z=tracks_2d_z, tracks_2d_visibilities=tracks_2d_visibilities, ) # Save some .ply files of the trajectories for debugging colors = plt.cm.viridis(tracks_3d[t0, :, 2] / tracks_3d[t0, :, 2].max())[:, :3] for frame_idx in [0, t0, t0 + 1, n_frames // 3, (2 * n_frames) // 3, n_frames - 1]: pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(tracks_3d[frame_idx]) pcd.colors = o3d.utility.Vector3dVector(colors) pcd_file = os.path.join(output_dataset_path, f"tracks_3d_{frame_idx}.ply") o3d.io.write_point_cloud(pcd_file, pcd) pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(tracks_3d[frame_idx][tracks_3d_visibilities[frame_idx]]) pcd.colors = o3d.utility.Vector3dVector(colors[tracks_3d_visibilities[frame_idx]]) pcd_file = os.path.join(output_dataset_path, f"tracks_3d_{frame_idx}_visible.ply") o3d.io.write_point_cloud(pcd_file, pcd) # Also save the first frame trimesh as a mesh meshes[0].export(os.path.join(output_dataset_path, "first_frame_mesh.obj")) self._frame = -1 def _load_frame_rgbd(self, c, i): """Loads an RGB-D frame. Args: c: Camera index. i: Frame index. Returns: color: A unit8 numpy array of shape [H, W, 3] containing the color image. depth: A uint16 numpy array of shape [H, W] containing the depth image. """ color_file = self._data_dir[ c] + '/' + self._color_prefix + "{:06d}.jpg".format(i) color = cv2.imread(color_file) color = color[:, :, ::-1] depth_file = self._data_dir[ c] + '/' + self._depth_prefix + "{:06d}.png".format(i) depth = cv2.imread(depth_file, cv2.IMREAD_ANYDEPTH) return color, depth def _deproject_depth_and_filter_points(self, d, c): """Deprojects a depth image to point cloud and filters points. Args: d: A uint16 numpy array of shape [F, H, W] or [H, W] containing the depth image in millimeters. c: Camera index. Returns: p: A float32 numpy array of shape [F, H, W, 3] or [H, W, 3] containing the point cloud. m: A bool numpy array of shape [F, H, W] or [H, W] containing the mask for points within the tag cooridnate limit. """ nd = d.ndim d = d.astype(np.float32) / 1000 d = torch.from_numpy(d).to(self._device) p = torch.mul( d.view(1, -1, self._w * self._h).expand(3, -1, -1), self._p[c].unsqueeze(1)) p = torch.addmm(self._t[c].unsqueeze(1), self._R[c], p.view(3, -1)) p_tag = torch.addmm(self._tag_t_inv.unsqueeze(1), self._tag_R_inv, p) mx1 = p_tag[0, :] > self._tag_lim[0] mx2 = p_tag[0, :] < self._tag_lim[1] my1 = p_tag[1, :] > self._tag_lim[2] my2 = p_tag[1, :] < self._tag_lim[3] mz1 = p_tag[2, :] > self._tag_lim[4] mz2 = p_tag[2, :] < self._tag_lim[5] m = mx1 & mx2 & my1 & my2 & mz1 & mz2 p = p.t().view(-1, self._h, self._w, 3) m = m.view(-1, self._h, self._w) if nd == 2: p = p.squeeze(0) m = m.squeeze(0) p = p.cpu().numpy() m = m.cpu().numpy() return p, m def transform_ycb(self, pose, c=None, camera_to_world=True, run_ycb_group_layer=True, return_trans_mat=False): """Transforms poses in SE3 between world and camera frames. Args: pose: A float32 numpy array of shape [N, 7] or [N, 6] containing the poses. Each row contains one pose represented by rotation in quaternion (x, y, z, w) or rotation vector and translation. c: Camera index. camera_to_world: Whether from camera to world or from world to camera. run_ycb_group_layer: Whether to return vertices and normals by running the YCB group layer or to return poses. return_trans_mat: Whether to return poses in transformation matrices. Returns: If run_ycb_group_layer is True: v: A float32 numpy array of shape [F, V, 3] containing the vertices. n: A float32 numpy array of shape [F, V, 3] containing the normals. else: A float32 numpy array of shape [N, 6] containing the transformed poses. """ if pose.shape[1] == 7: q = pose[:, :4] t = pose[:, 4:] R = Rot.from_quat(q).as_matrix().astype(np.float32) R = torch.from_numpy(R).to(self._device) t = torch.from_numpy(t).to(self._device) if pose.shape[1] == 6: r = pose[:, :3] t = pose[:, 3:] r = torch.from_numpy(r).to(self._device) t = torch.from_numpy(t).to(self._device) R = rv2dcm(r) if c is not None: if camera_to_world: R_c = self._R[c] t_c = self._t[c] else: R_c = self._R_inv[c] t_c = self._t_inv[c] R = torch.bmm(R_c.expand(R.size(0), -1, -1), R) t = torch.addmm(t_c, t, R_c.t()) if run_ycb_group_layer or not return_trans_mat: r = dcm2rv(R) p = torch.cat([r, t], dim=1) else: p = torch.cat([R, t.unsqueeze(2)], dim=2) p = torch.cat([ p, torch.tensor([[[0, 0, 0, 1]]] * R.size(0), dtype=torch.float32, device=self._device) ], dim=1) if run_ycb_group_layer: p = p.view(-1, self._ycb_group_layer.num_obj * 6) v, n = self._ycb_group_layer(p) v = v[:, self._ycb_group_layer.f.view(-1)] n = n[:, self._ycb_group_layer.f.view(-1)] v = v.cpu().numpy() n = n.cpu().numpy() return v, n else: p = p.cpu().numpy() return p @property def serials(self): return self._serials @property def num_cameras(self): return self._num_cameras @property def num_frames(self): return self._num_frames @property def dimensions(self): return self._w, self._h @property def ycb_ids(self): return self._ycb_ids @property def K(self): return self._K @property def master_intrinsics(self): return self._master_intrinsics def step(self): """Steps the frame.""" self._frame = (self._frame + 1) % self._num_frames if not self._preload: self._update_pcd() def _update_pcd(self): """Updates the point cloud.""" for c in range(self._num_cameras): rgb, d = self._load_frame_rgbd(c, self._frame) p, m = self._deproject_depth_and_filter_points(d, c) self._pcd_rgb[c][:] = rgb self._pcd_vert[c][:] = p self._pcd_mask[c][:] = m @property def pcd_rgb(self): if self._preload: return [x[self._frame] for x in self._pcd_rgb] else: return self._pcd_rgb @property def pcd_vert(self): if self._preload: return [x[self._frame] for x in self._pcd_vert] else: return self._pcd_vert @property def pcd_tex_coord(self): return self._pcd_tex_coord @property def pcd_mask(self): if self._preload: return [x[self._frame] for x in self._pcd_mask] else: return self._pcd_mask @property def ycb_group_layer(self): return self._ycb_group_layer @property def num_ycb(self): return self._ycb_group_layer.num_obj @property def ycb_model_dir(self): return self._ycb_model_dir @property def ycb_count(self): return self._ycb_count @property def ycb_material(self): return self._ycb_material @property def ycb_pose(self): if self._app == 'viewer': return None if self._app == 'renderer': return [x[self._frame] for x in self._ycb_pose] @property def ycb_vert(self): if self._app == 'viewer': return [x[self._frame] for x in self._ycb_vert] if self._app == 'renderer': return None @property def ycb_norm(self): if self._app == 'viewer': return [x[self._frame] for x in self._ycb_norm] if self._app == 'renderer': return None @property def ycb_tex_coords(self): return self._ycb_tex_coords @property def mano_group_layer(self): return self._mano_group_layer @property def num_mano(self): return self._mano_group_layer.num_obj @property def mano_vert(self): if self._app == 'viewer': return [x[self._frame] for x in self._mano_vert] if self._app == 'renderer': return [[y[self._frame] for y in x] for x in self._mano_vert] @property def mano_norm(self): if self._app == 'viewer': return [x[self._frame] for x in self._mano_norm] if self._app == 'renderer': return None @property def mano_line(self): if self._app == 'viewer': return [x[self._frame] for x in self._mano_line] if self._app == 'renderer': return None @property def mano_joint_3d(self): if self._app == 'viewer': return None if self._app == 'renderer': return [x[self._frame] for x in self._mano_joint_3d] # Some hacking with global variables to make the visualization work first_frame_seen = False ready_to_close = False def visualize_3dpt_tracks(tracks_path, output_video_path): global first_frame_seen, ready_to_close print(f"Visualizing 3D point tracks from {tracks_path} to {output_video_path}...") tracks = np.load(tracks_path)["tracks_3d"] + np.array([0, 0, 1]) n_frames, n_points, _ = tracks.shape frames_path = f"{output_video_path}__frames" os.makedirs(frames_path, exist_ok=True) first_frame_seen = False ready_to_close = False # images = [imageio.imread(f"{frames_path}/frame_{i:04d}.png") for i in range(n_frames)] # video_writer = imageio.get_writer(output_video_path, fps=10) # for img in images: # video_writer.append_data(img) # video_writer.close() # ready_to_close = True # return z = tracks[2 * n_frames // 3, :, 2] point_colors = np.zeros((n_points, 3)) point_colors[:, 0] = np.sin(z) point_colors[:, 1] = np.sin(z + 2 * np.pi / 3) point_colors[:, 2] = np.sin(z + 4 * np.pi / 3) point_colors = cm.jet(z / np.percentile(z, 99.9))[:, :3] print("Preparing clouds...") pointclouds = [] for frame_idx in tqdm(range(n_frames)): pc = o3d.geometry.PointCloud() pc.points = o3d.utility.Vector3dVector(tracks[frame_idx]) pc.colors = o3d.utility.Vector3dVector(point_colors) pointclouds += [{ "name": f"cloud t={frame_idx}", "geometry": pc, "time": frame_idx / 4, }] def start_animation(w: o3d.cpu.pybind.visualization.O3DVisualizer) -> None: w.is_animating = True frames_path = f"{output_video_path}__frames" os.makedirs(frames_path, exist_ok=True) first_frame_seen = False ready_to_close = False def create_video(w: o3d.cpu.pybind.visualization.O3DVisualizer, t: float) -> None: global first_frame_seen, ready_to_close if ready_to_close: print("Please close the window to finish the video export.") return if t == 0 and not first_frame_seen: first_frame_seen = True elif t == 0 and first_frame_seen: images = [imageio.imread(f"{frames_path}/frame_{i:04d}.png") for i in range(n_frames)] video_writer = imageio.get_writer(output_video_path, fps=10) for img in images: video_writer.append_data(img) video_writer.close() ready_to_close = True return w.export_current_image(f"{frames_path}/frame_{int(t * 4):04d}.png") vis.draw( title=tracks_path, width=1920, height=1080, point_size=4, geometry=pointclouds, animation_time_step=1 / 4, # ibl="crossroads", eye=np.array([0, 0, 0]), lookat=np.array([0, 0, 1]), up=np.array([0, -1, 0]), field_of_view=60.0, on_init=start_animation, on_animation_frame=create_video, on_animation_tick=None, ) DOWNSCALING_FACTOR = 1.0 SEQUENCES = [ # Each sequence has a different target object and a different (human) subject performing an action. "20200709-subject-01/20200709_141754", "20200813-subject-02/20200813_145653", "20200820-subject-03/20200820_135841", "20200903-subject-04/20200903_104428", "20200908-subject-05/20200908_144409", "20200918-subject-06/20200918_114117", "20200928-subject-07/20200928_144906", "20201002-subject-08/20201002_110227", "20201015-subject-09/20201015_144721", "20201022-subject-10/20201022_112651", ] def main(): assert os.environ['DEX_YCB_DIR'] for n_subsample in [3]: for sequence in tqdm(SEQUENCES): print(f"Processing sequence: {sequence}") SequenceLoader( sequence, device="cpu", preload=True, app="convert_to_neus", output_dataset_path=os.path.join(os.environ['DEX_YCB_DIR'], f"neus_nsubsample-{n_subsample}/{sequence.replace('/', '__')}"), downscaling_factor=DOWNSCALING_FACTOR, n_subsample=n_subsample, seed=72, stream_rerun_viz=False, save_rerun_viz=True, ) print("Done converting the dataset.") if __name__ == '__main__': main() ================================================ FILE: scripts/egoexo4d_preprocessing.py ================================================ """ Environment setup: ```bash cd .. # Clone the projectaria_tools repository git clone -b 1.5.0 https://github.com/facebookresearch/projectaria_tools cd projectaria_tools/ # Install required libraries using Conda conda install -c conda-forge cmake fmt xxhash libjpeg-turbo gcc_linux-64 gxx_linux-64 conda install -c conda-forge boost-cpp=1.82.0 boost=1.82.0 # Set compiler environment variables export BOOST_ROOT=$CONDA_PREFIX export BOOST_INCLUDEDIR=$CONDA_PREFIX/include export BOOST_LIBRARYDIR=$CONDA_PREFIX/lib # Clean previous builds and install projectaria_tools rm -rf build/ dist/ *.egg-info cmake -S . -B build \ -DBOOST_ROOT=$BOOST_ROOT \ -DBoost_NO_SYSTEM_PATHS=ON \ -DBoost_INCLUDE_DIR=$BOOST_INCLUDEDIR \ -DBoost_LIBRARY_DIR=$BOOST_LIBRARYDIR \ -DBUILD_PYTHON_BINDINGS=ON cmake --build build -j pip install . # Additional packages (if required) pip install av cd ../mvtracker ``` Download a subset of the data: ```bash # Install CLI for downloading the data pip install ego4d --upgrade # Get an access id and key after filling a form at https://ego4ddataset.com/egoexo-license/ ... # Install AWS CLI from https://aws.amazon.com/cli/ (assuming no sudo) cd .. curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" unzip awscliv2.zip ./aws/install -i ~/local/aws-cli -b ~/local/bin # Add to ~/.bashrc: export PATH=$HOME/.local/bin:$PATH source ~/.bashrc aws --version # aws-cli/2.27.49 Python/3.13.4 Linux/6.8.0-57-generic exe/x86_64.ubuntu.24 aws configure # Now you can enter the access id and key... # Download a small subset of the data (around 100 GB) egoexo -o ./datasets/egoexo4d --parts metadata egoexo -o ./datasets/egoexo4d --parts take_trajectory take_vrs_noimagestream captures annotations metadata takes downscaled_takes/448 --uids ed3ec638-8363-4e1d-9851-c7936cbfad8c 51fc36b3-e769-4617-b087-3826b280cad3 f179e1a2-3265-464a-a106-a08c30d0a2ae 43dca3b5-21d9-4ebf-856e-515a5c417699 c3915dd7-3ac0-40b7-a69b-73b7326bd15c e08856e4-a1c7-4e36-96b6-a233efb27bfd 425d8f94-ed65-49d5-86e7-174f555fda5d ed698f62-ccdb-4601-8a0a-ee89a0a7e1c0 4e5aa06a-7a60-4e23-9853-d55260a9e6e9 001ae9a5-9c8a-4710-9f7f-7dc67597a02f 2aaaca24-108a-437e-ab9a-bc3e8d65fcdf 503cc92d-7052-44ff-a21d-da6c4a5d6927 f32dc6d9-0eb8-4c85-8ab9-7d47b8c4c660 2423c2ff-c85d-4998-afab-29de8d26d263 0e5d13c6-87ba-4c9b-ab2f-1aaac4e0aacb 1a9a21ab-9023-402f-ac64-df08feaabb5b 811ad284-702f-4d38-af99-a2c4006fa298 a261cc1d-7a45-479f-81a9-7c73eb379e6c c2fb62e3-8894-4101-9923-5eedeb1b4282 egoexo -o ./datasets/egoexo4d --parts captures egoexo -o ./datasets/egoexo4d --parts annotations --benchmarks egopose ``` Running the script: `PYTHONPATH=/local/home/frrajic/xode/duster:$PYTHONPATH python -m scripts.egoexo4d_preprocessing` Note that you need to set up dust3r first, see docstring of `scripts/estimate_depth_with_duster.py`. """ import json import os import pickle import time from typing import Optional import av import cv2 import math import numpy as np import pandas as pd import rerun as rr import torch from projectaria_tools.core import calibration from projectaria_tools.core import data_provider from projectaria_tools.core import mps from projectaria_tools.core.calibration import CameraCalibration, KANNALA_BRANDT_K3 from projectaria_tools.core.stream_id import StreamId from tqdm import tqdm from scripts.estimate_depth_with_duster import run_duster def main_preprocess_egoexo4d( release_dir: str, take_name: str, outputs_dir: str, max_frames: Optional[int] = None, frames_downsampling_factor: Optional[int] = None, downscaled_longerside: Optional[int] = None, save_rerun_viz: bool = True, stream_rerun_viz: bool = False, skip_if_output_exists: bool = True, ): # Skip if output exists save_pkl_path = os.path.join(outputs_dir, f"{take_name}.pkl") if skip_if_output_exists and os.path.exists(save_pkl_path): print(f"Skipping {save_pkl_path} since it already exists") print() return else: print(f"Processing {take_name}...") # Load necessary metadata files egoexo = { "takes": os.path.join(release_dir, "takes.json"), "captures": os.path.join(release_dir, "captures.json") } for k, v in egoexo.items(): egoexo[k] = json.load(open(v)) takes = egoexo["takes"] captures = egoexo["captures"] takes_by_name = {x["take_name"]: x for x in takes} # Take the take take = takes_by_name[take_name] # Initialize exo cameras from calibration file traj_dir = os.path.join(release_dir, take["root_dir"], "trajectory") exo_traj_path = os.path.join(traj_dir, "gopro_calibs.csv") exo_traj_df = pd.read_csv(exo_traj_path) exo_cam_names = list(exo_traj_df["cam_uid"]) ego_cam_names = [x["cam_id"] for x in take["capture"]["cameras"] if x["is_ego"] and x["cam_id"].startswith("aria")] all_cams = ego_cam_names + exo_cam_names ego_cam_name = ego_cam_names[0] print("exo cameras: ", exo_cam_names) print(" ego camera: ", ego_cam_name) go_pro_proxy = {} static_calibrations = mps.read_static_camera_calibrations(exo_traj_path) for static_calibration in static_calibrations: # assert the GoPro was correctly localized if static_calibration.quality != 1.0: print(f"Camera: {static_calibration.camera_uid} was not localized, ignoring this camera.") continue proxy = {} proxy["name"] = static_calibration.camera_uid proxy["pose"] = static_calibration.transform_world_cam proxy["camera"] = CameraCalibration( static_calibration.camera_uid, KANNALA_BRANDT_K3, static_calibration.intrinsics, static_calibration.transform_world_cam, # probably extrinsics static_calibration.width, static_calibration.height, None, math.pi, "") go_pro_proxy[static_calibration.camera_uid] = proxy # Configure the VRSDataProvider (interface used to retrieve Trajectory data) ego_exo_project_path = os.path.join(release_dir, 'takes', take['take_name']) aria_dir = os.path.join(release_dir, take["root_dir"]) aria_path = os.path.join(aria_dir, f"{ego_cam_name}.vrs") vrs_data_provider = data_provider.create_vrs_data_provider(aria_path) device_calibration = vrs_data_provider.get_device_calibration() ego_stream_name = "214-1" rgb_stream_id = StreamId(ego_stream_name) rgb_stream_label = vrs_data_provider.get_label_from_stream_id(rgb_stream_id) rgb_camera_calibration = device_calibration.get_camera_calib(rgb_stream_label) mps_data_paths_provider = mps.MpsDataPathsProvider(ego_exo_project_path) mps_data_paths = mps_data_paths_provider.get_data_paths() mps_data_provider = mps.MpsDataProvider(mps_data_paths) # Extract ego extrinsics capture_name = take["capture"]["capture_name"] timesync = pd.read_csv(os.path.join(release_dir, f"captures/{capture_name}/timesync.csv")) start_idx = take["timesync_start_idx"] + 1 end_idx = take["timesync_end_idx"] take_timestamps = [] for idx in range(start_idx, end_idx): ts = timesync.iloc[idx][f"{ego_cam_name}_{ego_stream_name}_capture_timestamp_ns"] take_timestamps.append(ts) if frames_downsampling_factor is not None: take_timestamps = take_timestamps[::frames_downsampling_factor] if max_frames is not None: take_timestamps = take_timestamps[:max_frames] valid_frames = np.array([not np.isnan(ts) for ts in take_timestamps]) if not valid_frames.all(): print(f"Number of invalid frames (with nan ego timesync): {(~valid_frames).sum()}") take_timestamps = np.array(take_timestamps)[valid_frames].astype(int) ego_closed_loop_poses = [mps_data_provider.get_closed_loop_pose(t) for t in take_timestamps] ego_extrs = [] T_device_camera = rgb_camera_calibration.get_transform_device_camera() for pose in ego_closed_loop_poses: assert pose is not None T_world_device = pose.transform_world_device T_world_camera = T_world_device @ T_device_camera extrinsic_matrix = T_world_camera.inverse().to_matrix()[:3, :] # Rotate camera 90° clockwise around Z R_z_90 = np.array([ [0, -1, 0], [1, 0, 0], [0, 0, 1] ]) extrinsic_matrix[:3, :] = R_z_90 @ extrinsic_matrix[:3, :] ego_extrs.append(extrinsic_matrix) # Extract videos base_directory = os.path.join(release_dir, take["root_dir"]) videos = {} for cam_name in all_cams: if cam_name in exo_cam_names: stream_name = '0' else: stream_name = 'rgb' local_path = os.path.join(base_directory, take['frame_aligned_videos'][cam_name][stream_name]['relative_path']) container = av.open(local_path) frames = [] for frame_idx, frame in enumerate(tqdm(container.decode(video=0))): if frame_idx % frames_downsampling_factor != 0: continue if max_frames is not None and len(frames) >= max_frames: break frames.append(np.array(frame.to_image())) frames = np.stack(frames)[valid_frames] videos[cam_name] = frames # Undistorted videos rgbs = {} intrs = {} extrs = {} for cam_name in all_cams: frames = videos[cam_name] h, w = frames[0].shape[:2] if cam_name in exo_cam_names: calib = exo_traj_df[exo_traj_df.cam_uid == cam_name].iloc[0].to_dict() D = np.array([calib[f"intrinsics_{i}"] for i in range(4, 8)]) K = np.array([ [calib["intrinsics_0"], 0, calib["intrinsics_2"]], [0, calib["intrinsics_1"], calib["intrinsics_3"]], [0, 0, 1] ]) width, height = calib["image_width"], calib["image_height"] scaled_K = K * w / width scaled_K[2][2] = 1.0 new_K = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(scaled_K, D, (w, h), np.eye(3), balance=0.0) map1, map2 = cv2.fisheye.initUndistortRectifyMap(scaled_K, D, np.eye(3), new_K, (w, h), cv2.CV_16SC2) undistorted = [] for img in tqdm(frames, desc=f"Undistorting {cam_name}"): ud = cv2.remap(img, map1, map2, interpolation=cv2.INTER_LINEAR) undistorted.append(ud) intrs[cam_name] = new_K extrs[cam_name] = go_pro_proxy[cam_name]["pose"].inverse().to_matrix()[:3, :] rgbs[cam_name] = np.stack([f.transpose(2, 0, 1) for f in undistorted]) else: src_calib = rgb_camera_calibration dst_calib = calibration.get_linear_camera_calibration(w, h, 450) fx, fy = dst_calib.get_focal_lengths() cx, cy = dst_calib.get_principal_point() K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) undistorted = [] for img in tqdm(frames, desc=f"Undistorting {cam_name}"): img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) ud = calibration.distort_by_calibration(img, dst_calib, src_calib) ud = cv2.rotate(ud, cv2.ROTATE_90_CLOCKWISE) undistorted.append(ud) undistorted = [ud.transpose(2, 0, 1) for ud in undistorted] intrs[cam_name] = K extrs[cam_name] = np.stack(ego_extrs) rgbs[cam_name] = np.stack(undistorted) # Check shapes n_frames, _, h_exo, w_exo = rgbs[exo_cam_names[0]].shape _, _, h_ego, w_ego = rgbs[ego_cam_name].shape for cam_name in all_cams: if cam_name in exo_cam_names: assert rgbs[cam_name].shape == (n_frames, 3, h_exo, w_exo) assert intrs[cam_name].shape == (3, 3) assert extrs[cam_name].shape == (3, 4) else: assert rgbs[cam_name].shape == (n_frames, 3, h_ego, w_ego) assert intrs[cam_name].shape == (3, 3) assert extrs[cam_name].shape == (n_frames, 3, 4) # Save downsized version if downscaled_longerside is not None: print(f"Downscaling to longer side {downscaled_longerside}") for cam_name in rgbs: _, _, h, w = rgbs[cam_name].shape scale = downscaled_longerside / max(h, w) new_h, new_w = int(h * scale), int(w * scale) resized = [] for img in rgbs[cam_name]: img = img.transpose(1, 2, 0) # CHW -> HWC img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) resized.append(img.transpose(2, 0, 1)) # HWC -> CHW rgbs[cam_name] = np.stack(resized) # scale intrinsics intrs[cam_name][:2] *= scale # Save processed output to a pickle file os.makedirs(outputs_dir, exist_ok=True) with open(save_pkl_path, "wb") as f: pickle.dump( dict( rgbs=rgbs, intrs=intrs, extrs=extrs, ego_cam_name=ego_cam_name, ), f, protocol=pickle.HIGHEST_PROTOCOL, ) print(f"Saved {save_pkl_path}") # Visualize the data sample using rerun rerun_modes = [] if stream_rerun_viz: rerun_modes += ["stream"] if save_rerun_viz: rerun_modes += ["save"] for rerun_mode in rerun_modes: rr.init(f"3dpt", recording_id="v0.16") if rerun_mode == "stream": rr.connect_tcp() rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True) rr.set_time_seconds("frame", 0) rr.log( "world/xyz", rr.Arrows3D( vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]], colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]], ), ) fps = 30 for frame_idx in range(n_frames): rr.set_time_seconds("frame", frame_idx / fps) for cam_name in all_cams: extr = extrs[cam_name] if cam_name in exo_cam_names else extrs[cam_name][frame_idx] intr = intrs[cam_name] img = rgbs[cam_name][frame_idx].transpose(1, 2, 0).astype(np.uint8) # Camera pose logging E = extr if extr.shape == (3, 4) else extr[0] T = np.eye(4) T[:3, :] = E T_world_cam = np.linalg.inv(T) rr.log(f"{cam_name}/image", rr.Transform3D( translation=T_world_cam[:3, 3], mat3x3=T_world_cam[:3, :3], )) # Intrinsics and image rr.log(f"{cam_name}/image", rr.Pinhole( image_from_camera=intr, width=img.shape[1], height=img.shape[0] )) rr.log(f"{cam_name}/image", rr.Image(img)) if rerun_mode == "save": save_rrd_path = os.path.join(outputs_dir, f"rerun__{take_name}.rrd") rr.save(save_rrd_path) print(f"Saved rerun viz to {os.path.abspath(save_rrd_path)}") def main_estimate_duster_depth( pkl_scene_file, depths_output_dir, save_rerun_viz=False, skip_if_output_already_exists=True, ): duster_kwargs = { "model_name_or_path": "../duster/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", "silent": True, "output_2d_matches": False, "dump_exhaustive_data": False, "save_ply": False, "save_png_viz": False, "show_debug_plots": False, "save_rerun_viz": save_rerun_viz, "skip_if_output_already_exists": skip_if_output_already_exists, } print(f"Generating DUSt3R depths to {os.path.abspath(depths_output_dir)}") assert os.path.exists(pkl_scene_file) with open(pkl_scene_file, "rb") as f: scene = pickle.load(f) rgbs = scene["rgbs"] intrs = scene["intrs"] extrs = scene["extrs"] ego_cam_name = scene["ego_cam_name"] exo_cam_names = sorted([cam_name for cam_name in rgbs.keys() if cam_name != ego_cam_name]) n_frames, _, h, w = rgbs[exo_cam_names[0]].shape fx, fy, cx, cy, extrinsics = [], [], [], [], [] for cam_name in exo_cam_names: intrinsics = intrs[cam_name] extrinsics_view = np.eye(4) extrinsics_view[:3, :4] = extrs[cam_name] assert np.isclose(intrinsics[0, 1], 0) assert np.isclose(intrinsics[1, 0], 0) assert np.isclose(intrinsics[2, 0], 0) assert np.isclose(intrinsics[2, 1], 0) assert np.isclose(intrinsics[2, 2], 1) fx.append(intrinsics[0, 0]) fy.append(intrinsics[1, 1]) cx.append(intrinsics[0, 2]) cy.append(intrinsics[1, 2]) extrinsics.append(extrinsics_view) fx = torch.tensor(fx).float() fy = torch.tensor(fy).float() cx = torch.tensor(cx).float() cy = torch.tensor(cy).float() extrinsics = torch.from_numpy(np.stack(extrinsics)).float() start = time.time() images_tensor = torch.from_numpy(np.stack([rgbs[cam_name] for cam_name in exo_cam_names])) run_duster(images_tensor, depths_output_dir, fx, fy, cx, cy, extrinsics, **duster_kwargs) time_elapsed = time.time() - start print(f"Time elapsed for DUST3R: {time_elapsed:.2f} seconds") if __name__ == '__main__': release_dir = "datasets/egoexo4d/" outputs_dir = "datasets/egoexo4d-processed/" num_devices = 1 device_id = int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")) device_id = device_id % num_devices print(f"Device ID: {device_id} (out of {num_devices}). The devices split the work.") for i, take_name in enumerate([ "fair_cooking_06_4", # take_uid = "a261cc1d-7a45-479f-81a9-7c73eb379e6c" "cmu_bike01_2", # take_uid = "ed3ec638-8363-4e1d-9851-c7936cbfad8c" "georgiatech_cooking_01_01_2", # take_uid = "51fc36b3-e769-4617-b087-3826b280cad3" "iiith_cooking_49_2", # take_uid = "f179e1a2-3265-464a-a106-a08c30d0a2ae" "indiana_bike_12_5", # take_uid = "43dca3b5-21d9-4ebf-856e-515a5c417699" "minnesota_rockclimbing_033_20", # take_uid = "c3915dd7-3ac0-40b7-a69b-73b7326bd15c" "sfu_basketball_09_21", # take_uid = "425d8f94-ed65-49d5-86e7-174f555fda5d" "unc_basketball_03-09-23_02_11", # take_uid = "ed698f62-ccdb-4601-8a0a-ee89a0a7e1c0" "unc_music_04-26-23_02_7", # take_uid = "4e5aa06a-7a60-4e23-9853-d55260a9e6e9" "uniandes_dance_017_57", # take_uid = "0e5d13c6-87ba-4c9b-ab2f-1aaac4e0aacb" "upenn_0331_Guitar_2_4", # take_uid = "1a9a21ab-9023-402f-ac64-df08feaabb5b" "unc_basketball_02-24-23_01_12", # take_uid = "c2fb62e3-8894-4101-9923-5eedeb1b4282" ]): if i % num_devices != device_id: continue for max_frames, frames_downsampling_factor, downscaled_longerside in [(300, 1, 512), (300, 1, 518)]: # Extract rgbs, intrs, extrs from EgoExo4D dataset outputs_subdir = os.path.join(outputs_dir, f"maxframes-{max_frames}_" f"downsample-{frames_downsampling_factor}_" f"downscale-{downscaled_longerside}") main_preprocess_egoexo4d(release_dir, take_name, outputs_subdir, max_frames, frames_downsampling_factor, downscaled_longerside) # Run Dust3r to estimate depths from rgbs, fix the known intrs and extrs during multi-view stereo optim take_pkl = os.path.join(outputs_subdir, f"{take_name}.pkl") depth_subdir = os.path.join(outputs_subdir, f"duster_depths__{take_name}") main_estimate_duster_depth( pkl_scene_file=take_pkl, depths_output_dir=depth_subdir, ) # Run VGGT to estimate depths from rgbs, align with the known extrs afterward ... ================================================ FILE: scripts/estimate_depth_with_duster.py ================================================ """ Set up the environment: ```sh cd /local/home/frrajic/xode git clone --recursive git@github.com:ethz-vlg/duster.git cd duster # Fix models path, since there are two in the project sed -i 's/from models/from croco.models/g' croco/*.py sed -i 's/from models/from croco.models/g' croco/*/*.py sed -i 's/from models/from croco.models/g' dust3r/*.py sed -i 's/from models/from croco.models/g' dust3r/*/*.py # Download the checkpoint wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth -P checkpoints md5sum checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth # c3fab9b455b03f23d20e6bf77f2607bb checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth # You should be able to use the same environment as for # the rest of the project, just install missing packages: pip install roma==1.5.1 ``` Running the script: ```sh cd /local/home/frrajic/xode/mvtracker export PYTHONPATH=/local/home/frrajic/xode/duster:$PYTHONPATH python scripts/estimate_depth_with_duster.py --dataset dexycb python scripts/estimate_depth_with_duster.py --dataset kubric-val python scripts/estimate_depth_with_duster.py --dataset kubric-train ``` Running the script on Panoptic Sports from Dynamic 3DGS: ```sh # Download the data cd datasets wget https://omnomnom.vision.rwth-aachen.de/data/Dynamic3DGaussians/data.zip unzip data.zip mv data panoptic_d3dgs cd - # Run the script cd /local/home/frrajic/xode/duster/mvtracker export PYTHONPATH=/local/home/frrajic/xode/duster:$PYTHONPATH python scripts/estimate_depth_with_duster.py --dataset panoptic_d3dgs ``` """ import argparse import json import os import random import time import warnings from copy import deepcopy from pathlib import Path import cv2 import matplotlib.pyplot as plt import numpy as np import rerun as rr import torch import torch.nn.functional as F import trimesh from PIL import Image from PIL.ImageOps import exif_transpose from dust3r.cloud_opt import PointCloudOptimizer from dust3r.image_pairs import make_pairs from dust3r.inference import inference from dust3r.model import AsymmetricCroCo3DStereo from dust3r.utils.device import to_numpy from dust3r.utils.geometry import find_reciprocal_matches, xy_grid from dust3r.utils.image import load_images from dust3r.utils.image import rgb, heif_support_enabled, _resize_pil_image, ImgNorm from mvtracker.datasets import KubricMultiViewDataset torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 def seed_all(seed): """ Seed all random number generators. Parameters ---------- seed : int The seed to use. Returns ------- None """ random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def get_view_visibility(scene, pts): vis = np.zeros((len(scene.imgs), len(pts)), dtype=bool) poses = scene.get_im_poses().detach().cpu().numpy() extrinsics = np.linalg.inv(poses) focals = scene.get_focals().squeeze(-1).detach().cpu().numpy() pps = scene.get_principal_points().detach().cpu().numpy() depths = [d.detach().cpu().numpy() for d in scene.get_depthmaps(raw=False)] # Apply masks to the depthmaps as to not consider points that have low confidence per_view_masks = [m.detach().cpu().numpy() for m in scene.get_masks()] for view_idx, mask in enumerate(per_view_masks): depths[view_idx] = depths[view_idx] * mask for view_idx in range(len(scene.imgs)): p_world = pts p_world = np.concatenate([p_world, np.ones((len(p_world), 1))], axis=1) p_cam = extrinsics[view_idx] @ p_world.T z = p_cam[2] x = p_cam[0, :] / z[:] * focals[view_idx, 0] + pps[view_idx, 0] y = p_cam[1, :] / z[:] * focals[view_idx, 1] + pps[view_idx, 1] x_floor = np.floor(x).astype(int) y_floor = np.floor(y).astype(int) x_ceil = np.ceil(x).astype(int) y_ceil = np.ceil(y).astype(int) h, w = depths[view_idx].shape[:2] out_of_view = ( (x_floor < 0) | (x_ceil >= w) | (y_floor < 0) | (y_ceil >= h) | (z < 0) ) z_from_depthmap_1 = depths[view_idx][y_floor[~out_of_view], x_floor[~out_of_view]] z_from_depthmap_2 = depths[view_idx][y_floor[~out_of_view], x_ceil[~out_of_view]] z_from_depthmap_3 = depths[view_idx][y_ceil[~out_of_view], x_floor[~out_of_view]] z_from_depthmap_4 = depths[view_idx][y_ceil[~out_of_view], x_ceil[~out_of_view]] z_from_depthmap = np.stack([z_from_depthmap_1, z_from_depthmap_2, z_from_depthmap_3, z_from_depthmap_4], axis=0) vis[view_idx] = ~out_of_view vis[view_idx][~out_of_view] = np.isclose(z[~out_of_view], z_from_depthmap.min(axis=0), rtol=0.001, atol=0.1) # import pandas as pd # x = pd.Series(np.abs(z[~out_of_view] - z_from_depthmap.min(axis=0))) # quantiles_to_print = [0.001, 0.01, 0.05, 0.1, 0.5, 0.9, 0.95, 0.99, 0.999] # print(f"Quantiles of the difference between the depthmap and the z coordinate of the point in the camera frame") # for q in quantiles_to_print: # print(f"{q=}: {x.quantile(q)}") return vis def get_3D_model_from_scene( output_file_prefix, silent, scene, min_conf_thr=3, mask_sky=False, clean_depth=False, feats=None, dump_exhaustive_data=False, save_ply=False, save_png_viz=False, save_rerun_viz=False, rerun_radii=0.01, rerun_viz_timestamp=0, ): scene = deepcopy(scene) if clean_depth: scene = scene.clean_pointcloud() if mask_sky: scene = scene.mask_sky() rgbimg = scene.imgs pts3d = to_numpy(scene.get_pts3d()) scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr))) msk = to_numpy(scene.get_masks()) if not silent: print(f'Exporting 3D scene to prefix={output_file_prefix}') assert len(pts3d) == len(msk) <= len(rgbimg) pts3d = to_numpy(pts3d) pts3d_view_idx = [view_idx * np.ones_like(p[:, :, 0]) for view_idx, p in enumerate(pts3d)] imgs = to_numpy(rgbimg) pts_view_idx = np.concatenate([pvi[m] for pvi, m in zip(pts3d_view_idx, msk)]) pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)]) col = np.concatenate([p[m] for p, m in zip(imgs, msk)]) # get_view_visibility(scene, np.stack(pts3d).reshape(-1, 3)[:10], np.stack(pts3d_view_idx).reshape(-1)[:10]) # debug vis = get_view_visibility(scene, pts) msk = np.stack([m for m in msk]) depths = to_numpy(scene.get_depthmaps()) depths = np.stack([d for d in depths]) confs = to_numpy([c for c in scene.im_conf]) confs = np.stack([c for c in confs]) output_dict = { "depths": depths, "confs": confs, "cleaned_mask": msk, "min_conf_thr": min_conf_thr, "mask_sky": mask_sky, "clean_depth": clean_depth, } if dump_exhaustive_data: output_dict.update({ "pts": pts, "pts_view": pts_view_idx, "col": col, "vis": vis, "rgbs": imgs, }) if feats is not None: output_dict["feats"] = feats np.savez(f"{output_file_prefix}__scene.npz", **output_dict) if save_ply: pcd = trimesh.PointCloud(vertices=pts, colors=col) pcd.export(f"{output_file_prefix}__pc.ply") if rerun_viz_timestamp == 0: init_pt_cld = np.concatenate([pts, col, np.ones_like(pts[:, :1])], axis=1) np.savez(f"{output_file_prefix}__init_pt_cld.npz", data=init_pt_cld) if save_png_viz: # Results visualization rgbimg = scene.imgs cmap = plt.get_cmap('jet') depths_max = max([d.max() for d in depths]) depths_viz = [d / depths_max for d in depths] confs_max = max([d.max() for d in confs]) confs_viz = [cmap(d / confs_max) for d in confs] assert len(rgbimg) == len(depths_viz) == len(confs) H, W = rgbimg[0].shape[:2] N = len(rgbimg) plt.figure(dpi=100, figsize=(4 * W / 100, N * H / 100)) for i in range(N): a = rgbimg[i] b = rgb(depths_viz[i]) c = rgb(confs_viz[i]) d = rgb(msk[i]) plt.subplot(N, 4, 1 + 4 * i) plt.imshow(a) plt.axis('off') plt.subplot(N, 4, 2 + 4 * i) plt.imshow(b) plt.axis('off') plt.subplot(N, 4, 3 + 4 * i) plt.imshow(c) plt.axis('off') plt.subplot(N, 4, 4 + 4 * i) plt.imshow(d) plt.axis('off') plt.tight_layout(pad=0) plt.savefig(f"{output_file_prefix}__viz.png") plt.close() if save_rerun_viz: rr.init("reconstruction", recording_id="v0.1") # rr.connect_tcp() rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True) rr.set_time_seconds("frame", 0) rr.log( "world/xyz", rr.Arrows3D( vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]], colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]], ), ) rr.set_time_seconds("frame", rerun_viz_timestamp / 30) for v in range(len(rgbimg)): h, w = scene.imshape fx, fy = scene.get_focals().cpu().numpy()[v] cx, cy = scene.get_principal_points().cpu().numpy()[v] K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) c2w = scene.get_im_poses().cpu().numpy()[v] rr.log(f"image/view-{v}/rgb", rr.Image(scene.imgs[v])) rr.log(f"image/view-{v}/depth", rr.DepthImage(depths[v], point_fill_ratio=0.2)) rr.log(f"image/view-{v}", rr.Pinhole(image_from_camera=K, width=w, height=h)) rr.log(f"image/view-{v}", rr.Transform3D(translation=c2w[:3, 3], mat3x3=c2w[:3, :3])) rr.log(f"point_cloud/duster-cleaned/view-{v}", rr.Points3D(pts, colors=col, radii=rerun_radii)) rr.log(f"point_cloud/duster-raw/view-{v}", rr.Points3D(positions=np.stack(pts3d).reshape(-1, 3), colors=np.stack(imgs).reshape(-1, 3), radii=rerun_radii)) rr_rrd_path = f"{output_file_prefix}__rerun_viz.rrd" rr.save(rr_rrd_path) print(f"Saved Rerun recording to: {os.path.abspath(rr_rrd_path)}") def get_2D_matches(output_file_prefix, scene, input_views, min_conf_thr, clean_depth, viz_matches=False): scene = deepcopy(scene) scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr))) if clean_depth: scene = scene.clean_pointcloud() # retrieve useful values from scene: imgs = scene.imgs pts3d = scene.get_pts3d() confidence_masks = scene.get_masks() pts2d_list, pts3d_list = {}, {} for view_i in range(len(input_views)): conf_i = confidence_masks[view_i].cpu().numpy() pts2d_list[view_i] = xy_grid(*imgs[view_i].shape[:2][::-1])[conf_i] # imgs[i].shape[:2] = (H, W) pts3d_list[view_i] = pts3d[view_i].detach().cpu().numpy()[conf_i] matches = {} for view_i in range(len(input_views) - 1): for view_j in range(view_i + 1, len(input_views)): # find 2D-2D matches between the two images reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(pts3d_list[view_i], pts3d_list[view_j]) assert num_matches == reciprocal_in_P2.sum() print(f'view_{view_i}-view_{view_j}: {num_matches} matches') matches_i_xy = pts2d_list[view_i][nn2_in_P1][reciprocal_in_P2] matches_j_xy = pts2d_list[view_j][reciprocal_in_P2] matches_i_xyz = pts3d_list[view_i][nn2_in_P1][reciprocal_in_P2] matches_j_xyz = pts3d_list[view_j][reciprocal_in_P2] assert len(matches_i_xy) == len(matches_j_xy) == len(matches_i_xyz) == len(matches_j_xyz) == num_matches # store the matches matches[(view_i, view_j)] = { 'matches_i_xy': matches_i_xy, 'matches_j_xy': matches_j_xy, 'matches_i_xyz': matches_i_xyz, 'matches_j_xyz': matches_j_xyz, } # visualize a few matches if viz_matches: n_viz = 18 match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int) viz_matches_im0, viz_matches_im1 = matches_i_xy[match_idx_to_viz], matches_j_xy[match_idx_to_viz] H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2] img0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) img1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) img = np.concatenate((img0, img1), axis=1) plt.figure(dpi=200) plt.imshow(img) cmap = plt.get_cmap('jet') for i in range(n_viz): (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T plt.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False) plt.savefig(f"{output_file_prefix}__matches__v{view_i}-v{view_j}.png") plt.tight_layout(pad=0) plt.close() # save the matches np.savez(f"{output_file_prefix}__matches.npz", matches=matches) def load_images(folder_or_list, size, square_ok=False, verbose=True): """ open and convert all images in a list or folder to proper input format for DUSt3R """ if isinstance(folder_or_list, str): if verbose: print(f'>> Loading images from {folder_or_list}') root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) elif isinstance(folder_or_list, list): if verbose: print(f'>> Loading a list of {len(folder_or_list)} images') root, folder_content = '', folder_or_list else: raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})') supported_images_extensions = ['.jpg', '.jpeg', '.png'] if heif_support_enabled: supported_images_extensions += ['.heic', '.heif'] supported_images_extensions = tuple(supported_images_extensions) imgs = [] for path in folder_content: if not path.lower().endswith(supported_images_extensions): continue img = exif_transpose(Image.open(os.path.join(root, path))).convert('RGB') W1, H1 = img.size if size == 224: # resize short side to 224 (then crop) img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1))) else: # resize long side to 512 img = _resize_pil_image(img, size) # W, H = img.size # cx, cy = W // 2, H // 2 # if size == 224: # half = min(cx, cy) # img = img.crop((cx - half, cy - half, cx + half, cy + half)) # else: # halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8 # if not (square_ok) and W == H: # halfh = 3 * halfw / 4 # img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh)) W2, H2 = img.size if verbose: print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}') imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32( [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs)))) assert imgs, 'no images foud at ' + root if verbose: print(f' (Found {len(imgs)} images)') return imgs, (W1, H1, W2, H2) def tensor_to_pil(img_tensor): """Convert uint8 torch tensor [3, H, W] to PIL.Image""" return Image.fromarray(img_tensor.permute(1, 2, 0).cpu().numpy()) def load_tensor_images(tensor_list, size, square_ok=False, verbose=True): """Convert torch.Tensor RGB uint8 images to DUSt3R-ready format""" imgs = [] for i, tensor in enumerate(tensor_list): if not (isinstance(tensor, torch.Tensor) and tensor.dtype == torch.uint8 and tensor.ndim == 3 and tensor.shape[ 0] == 3): raise ValueError(f"Invalid tensor at index {i}") img = tensor_to_pil(tensor) W1, H1 = img.size if size == 224: img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1))) else: img = _resize_pil_image(img, size) W2, H2 = img.size if verbose: print(f' - tensor[{i}] resolution {W1}x{H1} --> {W2}x{H2}') imgs.append(dict( img=ImgNorm(img)[None], true_shape=np.int32([img.size[::-1]]), idx=i, instance=str(i) )) if not imgs: raise ValueError('No valid images in input list.') return imgs, (W1, H1, W2, H2) def global_aligner(dust3r_output, device, **optim_kw): view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()] net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device) return net def load_known_camera_parameters_from_neus_dataset(dataset_path, input_views): fx = [] fy = [] cx = [] cy = [] extrinsics = [] for input_view in input_views: cameras_sphere_path = os.path.join(dataset_path, input_view, "cameras_sphere.npz") assert os.path.exists(cameras_sphere_path) cameras_sphere = np.load(cameras_sphere_path) world_mat_0 = cameras_sphere['world_mat_0'] out = cv2.decomposeProjectionMatrix(world_mat_0[:3, :]) K, R, t = out[:3] K = K / K[2, 2] t = t[:3].squeeze() / t[3] fx.append(K[0, 0]) fy.append(K[1, 1]) cx.append(K[0, 2]) cy.append(K[1, 2]) pose = np.eye(4) pose[:3, :3] = R.T pose[:3, 3] = t extrinsics_ = np.linalg.inv(pose) extrinsics.append(extrinsics_) fx = torch.tensor(fx).float() fy = torch.tensor(fy).float() cx = torch.tensor(cx).float() cy = torch.tensor(cy).float() extrinsics = torch.from_numpy(np.stack(extrinsics)).float() return fx, fy, cx, cy, extrinsics def run_duster( images_tensor_or_image_paths, output_path, fx, fy, cx, cy, extrinsics, model_name_or_path="../duster/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), image_size=512, skip_if_output_already_exists=True, silent=False, output_2d_matches=False, dump_exhaustive_data=False, save_ply=False, save_png_viz=False, show_debug_plots=False, save_rerun_viz=False, rerun_radii=0.01, frame_selection=None, ga_lr=0.01, ga_schedule='linear', # linear, cosine scenegraph_type="complete", # complete, swin, oneref use_known_poses_for_pairwise_pose_init=False, # True, False ga_niter=300, # from 0 to 5000, default in demo was 300 min_conf_thr=20, # from 1 to 20, step 0.1, defualt in demo was 3 mask_sky=False, # True, False, default in demo was False clean_depth=True, # True, False, default in demo was True ): # Set the random seed seed_all(72) os.makedirs(output_path, exist_ok=True) output_path = Path(output_path) # Load the model model = AsymmetricCroCo3DStereo.from_pretrained(model_name_or_path).to(device) # Load images into a torch tensor images_all = [] n_views, n_frames = None, None original_w, original_h, target_w, target_h = None, None, None, None if not isinstance(images_tensor_or_image_paths, torch.Tensor): n_views = len(images_tensor_or_image_paths) n_frames = len(images_tensor_or_image_paths[0]) for frame_idx in range(n_frames): frame_img_paths = [str(images_tensor_or_image_paths[view_idx][frame_idx]) for view_idx in range(n_views)] images, shapes = load_images(frame_img_paths, image_size, verbose=not silent) if original_w is None: original_w, original_h, target_w, target_h = shapes images_all.append(images) else: n_views, n_frames, _, original_h, original_w = images_tensor_or_image_paths.shape for frame_idx in range(n_frames): frame_imgs = [images_tensor_or_image_paths[view_idx, frame_idx] for view_idx in range(n_views)] images, shapes = load_tensor_images(frame_imgs, image_size, verbose=not silent) if target_w is None: assert (original_w, original_h) == shapes[:2] _, _, target_w, target_h = shapes images_all.append(images) # Check the input data assert len(fx) == len(fy) == len(cx) == len(cy) == len(extrinsics) == n_views assert all(extrinsics[view_idx].shape == (4, 4) for view_idx in range(n_views)) # Assume known camera parameters known_poses = extrinsics.inverse() known_focals = torch.stack([fx, fy], dim=-1) known_pp = torch.stack([cx, cy], dim=-1) patch_h, patch_w = model.patch_embed.patch_size # e.g., (16, 16) pad_h = (patch_h - (target_h % patch_h)) % patch_h pad_w = (patch_w - (target_w % patch_w)) % patch_w assert pad_h % 2 == 0, f"pad_h {pad_h} is not divisible by 2" assert pad_w % 2 == 0, f"pad_w {pad_w} is not divisible by 2" pad_top = pad_h // 2 pad_bottom = pad_h - pad_top pad_left = pad_w // 2 pad_right = pad_w - pad_left if pad_h or pad_w: for frame_images in images_all: # images_all[frame_idx] == list of dicts per view for im_dict in frame_images: # shape: [1, 3, H, W] assert im_dict["img"].shape[-2:] == (target_h, target_w) # F.pad takes (left, right, top, bottom) im_dict["img"] = F.pad(im_dict["img"], (pad_left, pad_right, pad_top, pad_bottom), mode="replicate") im_dict["true_shape"] = np.int32([[target_h + pad_h, target_w + pad_w]]) # shift principal point to the padded image coordinate system # (we padded symmetrically, so add half the padding on each axis) known_pp = known_pp.clone() known_pp[..., 0] = known_pp[..., 0] + pad_left # cx known_pp[..., 1] = known_pp[..., 1] + pad_top # cy if frame_selection is None: frame_selection = range(n_frames) for frame_idx in frame_selection: print(f"Processing frame {frame_idx:05d}/{n_frames:05d}...") if skip_if_output_already_exists and os.path.exists(output_path / f"3d_model__{frame_idx:05d}__scene.npz"): try: np.load(output_path / f"3d_model__{frame_idx:05d}__scene.npz") print(f"Skipping frame because the output file already exists.") continue except Exception as e: print(f"Output file already exists but is corrupted: {e}") # Load preprocessed input images images = images_all[frame_idx] assert (target_h + pad_h, target_w + pad_w) == images[0]['img'].shape[-2:] assert len(images) == n_views print(f"Loaded {len(images)} images. " f"Original resolution: {original_w}x{original_h}. " f"Target resolution: {target_w}x{target_h}.") # Extract encoder features for each image feats = [] for view_idx in range(n_views): with torch.no_grad(): feat, pos_enc, _ = model._encode_image(images[view_idx]["img"].to(device), images[view_idx]["true_shape"]) feats.append(feat) feats = torch.concat(feats).detach().cpu().numpy() # Run DUSt3R on the pairs pairs = make_pairs(images, scene_graph=scenegraph_type, prefilter=None, symmetrize=True) output = inference(pairs, model, device, batch_size=1, verbose=not silent) # Unpad the output if padding was applied if pad_h or pad_w: H_pad = target_h + pad_h W_pad = target_w + pad_w t, l = pad_top, pad_left b, r = t + target_h, l + target_w assert output["view1"]["img"].shape == (len(pairs), 3, H_pad, W_pad) assert output["view2"]["img"].shape == (len(pairs), 3, H_pad, W_pad) assert output["pred1"]["conf"].shape == (len(pairs), H_pad, W_pad) assert output["pred2"]["conf"].shape == (len(pairs), H_pad, W_pad) assert output["pred1"]["pts3d"].shape == (len(pairs), H_pad, W_pad, 3) assert output["pred2"]["pts3d_in_other_view"].shape == (len(pairs), H_pad, W_pad, 3) output["view1"]["img"] = output["view1"]["img"][:, :, t:b, l:r].contiguous() output["view2"]["img"] = output["view2"]["img"][:, :, t:b, l:r].contiguous() output["pred1"]["conf"] = output["pred1"]["conf"][:, t:b, l:r].contiguous() output["pred2"]["conf"] = output["pred2"]["conf"][:, t:b, l:r].contiguous() output["pred1"]["pts3d"] = output["pred1"]["pts3d"][:, t:b, l:r, :].contiguous() output["pred2"]["pts3d_in_other_view"] = output["pred2"]["pts3d_in_other_view"][:, t:b, l:r, :].contiguous() output["view1"]["true_shape"] = np.int32([[target_h, target_w]]) output["view2"]["true_shape"] = np.int32([[target_h, target_w]]) # Set the known camera parameters scene = global_aligner(output, device=device, verbose=not silent) if not np.isclose(target_w / original_w, target_h / original_h): warnings.warn(f"The aspect ratio of the input images is different from the target aspect ratio:\n" f" - rescaling factor x: {target_w}/{original_w} = {target_w / original_w}\n" f" - rescaling factor y: {target_h}/{original_h} = {target_h / original_h}") if target_w == 512: rescaling_factor = target_w / original_w elif target_h == 512: rescaling_factor = target_h / original_h else: raise ValueError(f"Unexpected target resolution: {target_w}x{target_h}") print(f"We will use the rescaling factor: {target_w}/{original_w} = {rescaling_factor}") scene.preset_focal(known_focals.clone() * rescaling_factor) scene.im_pp.requires_grad_(True) scene.preset_principal_point(known_pp.clone() * rescaling_factor) scene.preset_pose(known_poses.clone()) # scene.im_pp.requires_grad_(True) # Run global alignment to get the global pointcloud and estimated camera parameters init = 'mst' if not use_known_poses_for_pairwise_pose_init else 'known_poses' try: loss = scene.compute_global_alignment(init=init, niter=ga_niter, schedule=ga_schedule, lr=ga_lr) except Exception as e: other_init = {"mst": "known_poses", "known_poses": "mst"} print(f"Error during global alignment: {e}") print(f"Trying the other initialization method init={other_init[init]} instead of init={init}") loss = scene.compute_global_alignment(init=other_init[init], niter=ga_niter, schedule=ga_schedule, lr=ga_lr) print(f"Global alignment loss: {loss}") print(f"Poses after global alignment:") print(f"{scene.get_im_poses().cpu().tolist()},") print(f"Intrinsic after global alignment:") print(f"{scene.get_focals().cpu().tolist()}") print(f"{scene.get_principal_points().cpu().tolist()}") print() # Save the scene data, pointclouds, and camera parameters if feats is not None and (pad_h or pad_w): warnings.warn(f"The saved 'feats' won't take into account the padding (pad_h={pad_h}, pad_w={pad_w}).") get_3D_model_from_scene( output_file_prefix=output_path / f"3d_model__{frame_idx:05d}", silent=silent, scene=scene, min_conf_thr=min_conf_thr, mask_sky=mask_sky, clean_depth=clean_depth, feats=feats, dump_exhaustive_data=dump_exhaustive_data, save_ply=save_ply, save_png_viz=save_png_viz, save_rerun_viz=save_rerun_viz, rerun_radii=rerun_radii, rerun_viz_timestamp=frame_idx, ) # get_3D_model_from_scene(output_path / f"low_threshold_3d_model__{frame_idx:05d}", silent, scene, 1, mask_sky, clean_depth) # get_3D_model_from_scene(output_path / f"non_clean_3d_model__{frame_idx:05d}", silent, scene, 0, mask_sky, False) if output_2d_matches: output_file_prefix = os.path.join(output_path, f"frame_{frame_idx}") get_2D_matches(output_file_prefix, scene, image_paths, min_conf_thr, clean_depth, viz_matches=True) if show_debug_plots: from sklearn.decomposition import PCA reducer = PCA(n_components=3) fvec_flat_all = feats.reshape(-1, 1024) reducer.fit(fvec_flat_all) fvec_reduced = reducer.transform(fvec_flat_all) reducer_min = fvec_reduced.min(axis=0) reducer_max = fvec_reduced.max(axis=0) def fvec_to_rgb(fvec): fvec_reduced = reducer.transform(fvec) fvec_reduced_rescaled = (fvec_reduced - reducer_min) / (reducer_max - reducer_min) fvec_reduced_rgb = (fvec_reduced_rescaled * 255).astype(int) return fvec_reduced_rgb rgb_with_feat_list = [] for view_idx in range(n_views): fvec_flat = feats[view_idx, :, :].reshape(((target_h + pad_h) // 16) * ((target_w + 16) // 16), 1024) fvec_reduced_rgb = fvec_to_rgb(fvec_flat).reshape((target_h + pad_h) // 16, (target_w + pad_w) // 16, 3) rgb_img = ((images[view_idx]["img"][0].permute(1, 2, 0).numpy() / 2 + 0.5) * 255).astype(int) fvec_img = np.kron(fvec_reduced_rgb, np.ones((16, 16, 1))).astype(int) rgb_with_feat = np.concatenate([rgb_img, fvec_img], axis=1) rgb_with_feat_list.append(rgb_with_feat) rgb_with_feat = np.concatenate(rgb_with_feat_list, axis=0) import matplotlib.pyplot as plt; plt.figure(figsize=(rgb_with_feat.shape[1] / 100, rgb_with_feat.shape[0] / 100), dpi=100) plt.imshow(rgb_with_feat) plt.axis('off') plt.tight_layout(pad=0) plt.savefig(os.path.join(output_path, f"debug__{frame_idx:05d}__rgb_with_encoder_features.png")) # plt.show() plt.close() def main_on_neus_scene(scene_root, views_selection, **duster_kwargs): views_selection_str = ''.join(str(v) for v in views_selection) output_path = scene_root / f'duster-views-{views_selection_str}' view_paths = [scene_root / f"view_{v:02d}" for v in views_selection] frame_paths = [sorted((view_path / "rgb").glob("*.png")) for view_path in view_paths] n_frames = len(frame_paths[0]) assert n_frames > 0 assert all(len(f) == n_frames for f in frame_paths) fx, fy, cx, cy, extrinsics = [], [], [], [], [] for view_path in view_paths: camera_params_file = os.path.join(view_path, "intrinsics_extrinsics.npz") params = np.load(camera_params_file) intrinsics = params["intrinsics"] extrinsics_view = params["extrinsics"] assert intrinsics[0, 1] == 0 assert intrinsics[1, 0] == 0 assert intrinsics[2, 0] == 0 assert intrinsics[2, 1] == 0 assert intrinsics[2, 2] == 1 fx.append(intrinsics[0, 0]) fy.append(intrinsics[1, 1]) cx.append(intrinsics[0, 2]) cy.append(intrinsics[1, 2]) extrinsics.append(extrinsics_view) fx = torch.tensor(fx).float() fy = torch.tensor(fy).float() cx = torch.tensor(cx).float() cy = torch.tensor(cy).float() extrinsics = torch.from_numpy(np.stack(extrinsics)).float() print(f"Processing {output_path}") run_duster(frame_paths, output_path, fx, fy, cx, cy, extrinsics, **duster_kwargs) def main_on_kubric_scene(scene_root, views_selection, **duster_kwargs): views_selection_str = ''.join(str(v) for v in views_selection) output_path = scene_root / f'duster-views-{views_selection_str}' view_paths = [scene_root / f"view_{v:01d}" for v in views_selection] frame_paths = [sorted(view_path.glob("rgba_*.png")) for view_path in view_paths] n_frames = len(frame_paths[0]) assert n_frames > 0 assert all(len(f) == n_frames for f in frame_paths) datapoint = KubricMultiViewDataset.getitem_raw_datapoint(scene_root) fx, fy, cx, cy, extrinsics = [], [], [], [], [] for view_idx in views_selection: intrinsics = datapoint["views"][view_idx]["intrinsics"] extrinsics_view = np.eye(4) extrinsics_view[:3, :4] = datapoint["views"][view_idx]["extrinsics"][0] assert intrinsics[0, 1] == 0 assert intrinsics[1, 0] == 0 assert intrinsics[2, 0] == 0 assert intrinsics[2, 1] == 0 assert intrinsics[2, 2] == 1 fx.append(intrinsics[0, 0]) fy.append(intrinsics[1, 1]) cx.append(intrinsics[0, 2]) cy.append(intrinsics[1, 2]) extrinsics.append(extrinsics_view) fx = torch.tensor(fx).float() fy = torch.tensor(fy).float() cx = torch.tensor(cx).float() cy = torch.tensor(cy).float() extrinsics = torch.from_numpy(np.stack(extrinsics)).float() start = time.time() print(f"Processing {output_path}") run_duster(frame_paths, output_path, fx, fy, cx, cy, extrinsics, **duster_kwargs) time_elapsed = time.time() - start print(f"Time elapsed for DUST3R: {time_elapsed:.2f} seconds") def main_on_d3dgs_panoptic_scene( scene_root, views_selection, save_rerun_viz=False, rerun_radii=0.002, **duster_kwargs, ): md = json.load(open(os.path.join(scene_root, "train_meta.json"), 'r')) n_frames = len(md['fn']) # Check that the selected views are in the training set view_paths = [] for view_idx in views_selection: view_path = scene_root / "ims" / f"{view_idx}" assert view_idx in md["cam_id"][0], f"Camera {view_idx} is not in the training set" assert view_path.exists() view_paths.append(view_path) frame_paths = [sorted(view_path.glob("*.jpg")) for view_path in view_paths] assert all(len(frame_paths[v]) == n_frames for v in range(len(views_selection))) # Create the output directory views_selection_str = '-'.join(str(v) for v in views_selection) output_path = scene_root / f'duster-views-{views_selection_str}' os.makedirs(output_path, exist_ok=True) # Load the camera parameters fx, fy, cx, cy, extrinsics = [], [], [], [], [] for view_idx in views_selection: fx_current, fy_current, cx_current, cy_current, extrinsics_current = [], [], [], [], [] for t in range(n_frames): view_idx_in_array = md['cam_id'][t].index(view_idx) k = md['k'][t][view_idx_in_array] w2c = np.array(md['w2c'][t][view_idx_in_array]) fx_current.append(k[0][0]) fy_current.append(k[1][1]) cx_current.append(k[0][2]) cy_current.append(k[1][2]) extrinsics_current.append(w2c) assert all(np.equal(fx_current[0], fx_current[t]).all() for t in range(1, n_frames)) assert all(np.equal(fy_current[0], fy_current[t]).all() for t in range(1, n_frames)) assert all(np.equal(cx_current[0], cx_current[t]).all() for t in range(1, n_frames)) assert all(np.equal(cy_current[0], cy_current[t]).all() for t in range(1, n_frames)) assert all(np.equal(extrinsics_current[0], extrinsics_current[t]).all() for t in range(1, n_frames)) fx.append(fx_current[0]) fy.append(fy_current[0]) cx.append(cx_current[0]) cy.append(cy_current[0]) extrinsics.append(extrinsics_current[0]) fx = torch.tensor(fx).float() fy = torch.tensor(fy).float() cx = torch.tensor(cx).float() cy = torch.tensor(cy).float() extrinsics = torch.from_numpy(np.stack(extrinsics)).float() # Visualize the initialization point cloud used in D3DGS if save_rerun_viz: init_pt_cld = np.load(scene_root / "init_pt_cld.npz")["data"] xyz = init_pt_cld[:, :3] col = init_pt_cld[:, 3:6] seg = init_pt_cld[:, 6:7] rr.init("reconstruction", recording_id="v0.1") # rr.connect_tcp() rr.set_time_seconds("frame", 0 / 30) rr.log(f"point_cloud/sfm-full", rr.Points3D(xyz, colors=col, radii=rerun_radii)) rr.log(f"point_cloud/sfm-full-seg", rr.Points3D(xyz, colors=col * seg, radii=rerun_radii)) rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True) rr.set_time_seconds("frame", 0) rr.log( "world/xyz", rr.Arrows3D( vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]], colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]], ), ) # moge_depths = [] # moge_masks = [] for selected_view_idx, view_idx in enumerate(views_selection): rgbs = np.stack([np.array(Image.open(frame_paths[selected_view_idx][t])) for t in range(n_frames)]) rgbs = torch.from_numpy(rgbs).permute(0, 3, 1, 2).float() H, W = rgbs.shape[-2], rgbs.shape[-1] K = np.array([ [fx[selected_view_idx], 0, cx[selected_view_idx]], [0, fy[selected_view_idx], cy[selected_view_idx]], [0, 0, 1], ]) K_inv = np.linalg.inv(K) K_for_moge = np.array([ [fx[selected_view_idx] / W, 0, 0.5], [0, fy[selected_view_idx] / H, 0.5], [0, 0, 1], ]) # depths, i, _, _, mask = moge(rgbs[::10], intrinsics=K_for_moge) # moge_depths.append(depths) # moge_masks.append(mask) for t in range(0, n_frames, 10): rr.set_time_seconds("frame", t / 30) c2w = torch.linalg.inv(extrinsics[selected_view_idx]).numpy() rr.log(f"image/view-{view_idx}/rgb", rr.Image(rgbs[t].permute(1, 2, 0).numpy())) # rr.log(f"image/view-{view_idx}/depth", # rr.DepthImage(moge_depths[selected_view_idx][t // 10], point_fill_ratio=0.2)) rr.log(f"image/view-{view_idx}", rr.Pinhole(image_from_camera=K, width=W, height=H)) rr.log(f"image/view-{view_idx}", rr.Transform3D(translation=c2w[:3, 3], mat3x3=c2w[:3, :3])) # # Generate and log point cloud colored by RGB values # y, x = np.indices((H, W)) # homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T # depth_values = moge_depths[selected_view_idx][t // 10].ravel() # cam_coords = (K_inv @ homo_pixel_coords) * depth_values # cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1])))) # world_coords = (c2w @ cam_coords)[:3].T # valid_mask = (depth_values > 0) & moge_masks[selected_view_idx][t // 10].reshape(-1, ) # world_coords = world_coords[valid_mask] # rgb_colors = rgbs[t].permute(1, 2, 0).reshape(-1, 3).numpy()[valid_mask].astype(np.uint8) # rr.log(f"point_cloud/view-{view_idx}", rr.Points3D(world_coords, colors=rgb_colors, radii=rerun_radii)) rr.save(output_path / "init_pt_cld.rrd") # Run DUSt3R print(f"Processing {output_path}") run_duster(frame_paths, output_path, fx, fy, cx, cy, extrinsics, save_rerun_viz=save_rerun_viz, rerun_radii=rerun_radii, **duster_kwargs) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, required=True, help='The dataset to process') args = parser.parse_args() duster_kwargs = { "model_name_or_path": "../duster/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", "silent": False, "output_2d_matches": False, "dump_exhaustive_data": True, "save_ply": True, "save_png_viz": True, "show_debug_plots": True, } if args.dataset == "dexycb": data_root = Path('./datasets/dex-january-2025/neus_nsubsample-3/') views_selections = [ [0, 1, 2, 3], [2, 3, 4, 5], [4, 5, 6, 7], [0, 1, 2, 3, 4, 5, 6, 7], ] for scene_root in sorted(data_root.glob("*")): for views_selection in views_selections: main_on_neus_scene(scene_root, views_selection, **duster_kwargs) elif args.dataset == "kubric-val": data_root = Path('./datasets/kubric_multiview_003/test/') duster_kwargs["save_rerun_viz"] = True views_selections = [ # [0, 1], [0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7], ] for scene_root in sorted(data_root.glob("[!.]*")): for views_selection in views_selections: main_on_kubric_scene(scene_root, views_selection, **duster_kwargs) elif args.dataset == "kubric-train": # Save space by not saving all logs duster_kwargs["dump_exhaustive_data"] = False duster_kwargs["save_ply"] = False duster_kwargs["save_png_viz"] = False duster_kwargs["show_debug_plots"] = False data_root = Path('./datasets/kubric_multiview_003/train/') views_selections = [ [0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6, 7], ] # # Parallelize across a machine with 4 GPUs # total_gpus = 4 # gpu_id = int(os.environ.get("CUDA_VISIBLE_DEVICES")) # # Run, e.g., as: # # -------------- # # CUDA_VISIBLE_DEVICES=0 python scripts/estimate_depth_with_duster.py --dataset kubric-train # # CUDA_VISIBLE_DEVICES=1 python scripts/estimate_depth_with_duster.py --dataset kubric-train # # CUDA_VISIBLE_DEVICES=2 python scripts/estimate_depth_with_duster.py --dataset kubric-train # # CUDA_VISIBLE_DEVICES=3 python scripts/estimate_depth_with_duster.py --dataset kubric-train # Parallelize across 128 machines with 4 GPUs each total_gpus = 128 * 4 a = int(os.environ.get("CHUNK")) b = int(os.environ.get("CUDA_VISIBLE_DEVICES")) gpu_id = a * 4 + b # Run, e.g., as: # -------------- # CHUNK=0 CUDA_VISIBLE_DEVICES=0 python scripts/estimate_depth_with_duster.py --dataset kubric-train # CHUNK=0 CUDA_VISIBLE_DEVICES=1 python scripts/estimate_depth_with_duster.py --dataset kubric-train # CHUNK=0 CUDA_VISIBLE_DEVICES=2 python scripts/estimate_depth_with_duster.py --dataset kubric-train # CHUNK=0 CUDA_VISIBLE_DEVICES=3 python scripts/estimate_depth_with_duster.py --dataset kubric-train # CHUNK=1 CUDA_VISIBLE_DEVICES=1 python scripts/estimate_depth_with_duster.py --dataset kubric-train # ... # CHUNK=15 CUDA_VISIBLE_DEVICES=3 python scripts/estimate_depth_with_duster.py --dataset kubric-train print(f"Running on GPU {gpu_id} (out of {total_gpus})") print(f'Total scenes to process: {len(sorted(data_root.glob("[!.]*"))[gpu_id::total_gpus])}') for scene_root in sorted(data_root.glob("[!.]*"))[gpu_id::total_gpus]: for views_selection in views_selections: main_on_kubric_scene(scene_root, views_selection, **duster_kwargs) elif args.dataset == "panoptic_d3dgs": duster_kwargs["skip_if_output_already_exists"] = True duster_kwargs["save_rerun_viz"] = False duster_kwargs["frame_selection"] = None # [0] data_root = Path('./datasets/panoptic_d3dgs/') views_selections = [ # [27, 16, 14, 8, 11, 19, 11, 6, 23, 1], # 10 views [27, 16, 14, 8, 11, 19, 11, 6], # 8 views [27, 16, 14, 8], # 4 views [27, 16], # 2 views # [1, 4, 7, 11, 14, 17, 20, 23, 26, 29], # 10 views # # [5, 8, 11, 14, 17, 20, 23, 26, 29], # 9 views [1, 4, 7, 11, 14, 17, 20, 23], # 8 views # [1, 4, 7, 11, ], # 4 views - v1 [1, 7, 14, 20, ], # 4 views - v2 # # [1, 4], # 2 views - v1 # [1, 14], # 2 views - v2 ] for scene_root in sorted(data_root.glob("[!.]*")): for views_selection in views_selections: main_on_d3dgs_panoptic_scene(scene_root, views_selection, **duster_kwargs) else: raise ValueError(f"Unknown dataset: {args.dataset}") print(f"Done.") ================================================ FILE: scripts/hi4d_preprocessing.py ================================================ """ First download the dataset. You'll have to fill in an online ETH form and then wait for a few days to get a temporary access code over email. I used the following sequence of commands to download and unpack the data into the expected structure. You can probably replace the `dt=...` with your access token that you can probably find in the access URL (or otherwise in the page source of the download page that will be linked). Note that you don't need to download all the data if you don't need it, e.g., maybe you just want to download a small sample. Note also that in the commands below, I didn't delete the `*.tar.gz` files, but you can do so if you'd like. ```bash wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/LICENSE.txt' -O LICENSE.txt wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/README.md' -O README.md wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair00_1.tar.gz' -O pair00_1.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair00_2.tar.gz' -O pair00_2.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair01.tar.gz' -O pair01.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair02_1.tar.gz' -O pair02_1.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair02_2.tar.gz' -O pair02_2.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair09.tar.gz' -O pair09.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair10.tar.gz' -O pair10.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair12.tar.gz' -O pair12.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair13_1.tar.gz' -O pair13_1.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair13_2.tar.gz' -O pair13_2.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair14.tar.gz' -O pair14.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair15_1.tar.gz' -O pair15_1.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair15_2.tar.gz' -O pair15_2.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair16.tar.gz' -O pair16.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair17_1.tar.gz' -O pair17_1.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair17_2.tar.gz' -O pair17_2.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair18_1.tar.gz' -O pair18_1.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair18_2.tar.gz' -O pair18_2.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair19_1.tar.gz' -O pair19_1.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair19_2.tar.gz' -O pair19_2.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair21_1.tar.gz' -O pair21_1.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair21_2.tar.gz' -O pair21_2.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair22.tar.gz' -O pair22.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair23_1.tar.gz' -O pair23_1.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair23_2.tar.gz' -O pair23_2.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair27_1.tar.gz' -O pair27_1.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair27_2.tar.gz' -O pair27_2.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair28.tar.gz' -O pair28.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair32_1.tar.gz' -O pair32_1.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair32_2.tar.gz' -O pair32_2.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair37_1.tar.gz' -O pair37_1.tar.gz wget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair37_2.tar.gz' -O pair37_2.tar.gz mkdir -p pair00 pair01 pair02 pair09 pair10 pair12 pair13 pair14 pair15 pair16 pair17 pair18 pair19 pair21 pair22 pair23 pair27 pair28 pair32 pair37 tar -xvzf pair00_1.tar.gz -C pair00 tar -xvzf pair00_2.tar.gz -C pair00 tar -xvzf pair01.tar.gz pair01 tar -xvzf pair02_1.tar.gz -C pair02 tar -xvzf pair02_2.tar.gz -C pair02 tar -xvzf pair09.tar.gz -C pair09 tar -xvzf pair10.tar.gz -C pair10 tar -xvzf pair12.tar.gz -C pair12 tar -xvzf pair13_1.tar.gz -C pair13 tar -xvzf pair13_2.tar.gz -C pair13 tar -xvzf pair14.tar.gz -C pair14 tar -xvzf pair15_1.tar.gz -C pair15 tar -xvzf pair15_2.tar.gz -C pair15 tar -xvzf pair16.tar.gz -C pair16 tar -xvzf pair17_1.tar.gz -C pair17 tar -xvzf pair17_2.tar.gz -C pair17 tar -xvzf pair18_1.tar.gz -C pair18 tar -xvzf pair18_2.tar.gz -C pair18 tar -xvzf pair19_1.tar.gz -C pair19 tar -xvzf pair19_2.tar.gz -C pair19 tar -xvzf pair21_1.tar.gz -C pair21 tar -xvzf pair21_2.tar.gz -C pair21 tar -xvzf pair22.tar.gz -C pair22 tar -xvzf pair23_1.tar.gz -C pair23 tar -xvzf pair23_2.tar.gz -C pair23 tar -xvzf pair27_1.tar.gz -C pair27 tar -xvzf pair27_2.tar.gz -C pair27 tar -xvzf pair28.tar.gz -C pair28 tar -xvzf pair32_1.tar.gz -C pair32 tar -xvzf pair32_2.tar.gz -C pair32 tar -xvzf pair37_1.tar.gz -C pair37 tar -xvzf pair37_2.tar.gz -C pair37 # Some cleanup because the tars were not consistently structured mv pair00/pair00/* pair00/ mv pair01/pair01/* pair01/ mv pair02/pair02/* pair02/ mv pair09/pair09/* pair09/ mv pair10/pair10/* pair10/ mv pair12/pair12/* pair12/ mv pair13/pair13/* pair13/ mv pair14/pair14/* pair14/ mv pair15/pair15/* pair15/ mv pair16/pair16/* pair16/ mv pair17/pair17/* pair17/ mv pair18/pair18/* pair18/ mv pair19/pair19/* pair19/ mv pair21/pair21/* pair21/ mv pair22/pair22/* pair22/ mv pair23/pair23/* pair23/ mv pair27/pair27/* pair27/ mv pair28/pair28/* pair28/ mv pair32/pair32/* pair32/ mv pair37/pair37/* pair37/ rm -rf pair*/pair*/ ``` With the data downloaded, you can run the script: `python -m scripts.hi4d_preprocessing`. """ from mvtracker.datasets.utils import transform_scene def load_pickle(p): with open(p, "rb") as f: return pickle.load(f) import glob import os import pickle from typing import Optional, Dict, List, Tuple import cv2 import numpy as np import rerun as rr import torch import tqdm from PIL import Image from pytorch3d.io import load_objs_as_meshes from pytorch3d.renderer import ( PerspectiveCameras, MeshRasterizer, RasterizationSettings, ) from pytorch3d.structures import Meshes from scipy.spatial.transform import Rotation def save_pickle(p, data): os.makedirs(os.path.dirname(p), exist_ok=True) with open(p, "wb") as f: pickle.dump(data, f) def load_image(path): return np.array(Image.open(path)) def _safe_load_rgb_cameras(npz_path: str) -> Dict[str, np.ndarray]: """ Hi4D has a typo in docs ('intirnsics'). Support both. Returns dict with keys: ids [N], intrinsics [N,3,3], extrinsics [N,3,4], dist_coeffs [N,5] """ data = dict(np.load(npz_path)) ids = data.get("ids") intr = data.get("intrinsics", data.get("intirnsics")) extr = data.get("extrinsics") dist = data.get("dist_coeffs") assert ids is not None and intr is not None and extr is not None, \ f"Missing keys in {npz_path}. Found keys: {list(data.keys())}" return {"ids": ids, "intrinsics": intr, "extrinsics": extr, "dist_coeffs": dist} def _find_all_frames_for_action(images_root: str, cam_ids: List[int]) -> List[int]: """ Robustly infer the list of frame indices by intersecting the available frames across cams. Hi4D names images as 000XXX.jpg (zero-padded 6). """ per_cam_sets = [] for cid in cam_ids: cam_dir = os.path.join(images_root, f"{cid}") jpgs = sorted(glob.glob(os.path.join(cam_dir, "*.jpg"))) frames = set(int(os.path.splitext(os.path.basename(p))[0]) for p in jpgs) per_cam_sets.append(frames) if not per_cam_sets: return [] common = set.intersection(*per_cam_sets) if len(per_cam_sets) > 1 else per_cam_sets[0] return sorted(list(common)) def _mesh_path_for_frame(frames_dir: str, frame_idx: int) -> str: """ Hi4D meshes are 'mesh-f00XXX.obj' (5 digits). We'll format with 5 digits. """ return os.path.join(frames_dir, f"mesh-f{frame_idx:05d}.obj") def extract_hi4d_action_to_pkl( dataset_root: str, pair: str, action: str, save_pkl_path: str, downscaled_longerside: Optional[int] = None, save_rerun_viz: bool = True, stream_rerun_viz: bool = False, skip_if_output_exists: bool = False, ): """ Build a single .pkl for a (pair, action): - rgbs: dict[cam_id_str] -> [T,3,H,W] uint8 - intrs: dict[cam_id_str] -> [3,3] float32 (scaled if resized) - extrs: dict[cam_id_str] -> [3,4] float32 - depths:dict[cam_id_str] -> [T,H,W] float32 (mesh-rendered) - ego_cam_name: None """ if skip_if_output_exists and os.path.exists(save_pkl_path): print(f"Skipping {save_pkl_path} (exists).") return save_pkl_path print(f"Processing {pair}/{action} -> {save_pkl_path}") root = os.path.join(dataset_root, pair, action) frames_dir = os.path.join(root, "frames") images_dir = os.path.join(root, "images") cameras_npz = os.path.join(root, "cameras", "rgb_cameras.npz") meta_npz = os.path.join(root, "meta.npz") cams = _safe_load_rgb_cameras(cameras_npz) cam_ids: List[int] = list(map(int, cams["ids"])) # e.g., [4,16,28,40,52,64,76,88] intr_all = cams["intrinsics"].astype(np.float32) # [N,3,3] extr_all = cams["extrinsics"].astype(np.float32) # [N,3,4] meta = dict(np.load(meta_npz)) frame_ids = _find_all_frames_for_action(images_dir, cam_ids) assert len(frame_ids) > 0, f"No common frames found across cameras at {images_dir}" assert frame_ids[0] == meta["start"].item() assert frame_ids[-1] == meta["end"].item() assert len(frame_ids) == (meta["end"].item() - meta["start"].item() + 1) # Build containers rgbs: Dict[str, List[np.ndarray]] = {str(cid): [] for cid in cam_ids} depths: Dict[str, List[np.ndarray]] = {str(cid): [] for cid in cam_ids} intrs: Dict[str, np.ndarray] = {} extrs: Dict[str, np.ndarray] = {} device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Pre-load a single mesh per frame and rasterize to each camera # (This is typically faster than reloading the mesh V times.) raster_settings_cache: Dict[Tuple[int, int], RasterizationSettings] = {} for frame in tqdm.tqdm(frame_ids, desc=f"Frames {pair}/{action}"): mesh_path = _mesh_path_for_frame(frames_dir, frame) if not os.path.isfile(mesh_path): # Some sequences may use different padding; try 6 digits as fallback. alt = os.path.join(frames_dir, f"mesh-f{frame:06d}.obj") if os.path.isfile(alt): mesh_path = alt else: # Skip missing mesh frame continue # Load mesh (geometry only is enough for depth) meshes: Meshes = load_objs_as_meshes([mesh_path], device=device) # For each camera, render depth & collect RGB for i, cid in enumerate(cam_ids): cam_name = str(cid) img_path = os.path.join(images_dir, cam_name, f"{frame:06d}.jpg") if not os.path.isfile(img_path): # Skip if that particular view is missing the image for this frame continue image = load_image(img_path) h0, w0 = image.shape[:2] # Copy camera params K = intr_all[i].copy() # [3,3] E = extr_all[i].copy() # [3,4] world->cam (Hi4D) # Optional downscale (longer side) + scale intrinsics if downscaled_longerside is not None: scale = downscaled_longerside / float(max(h0, w0)) nh, nw = int(round(h0 * scale)), int(round(w0 * scale)) if (nh, nw) != (h0, w0): image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) K[:2] *= scale h, w = nh, nw else: h, w = h0, w0 # Stash static intr/extr once (raw, no global transform) if cam_name not in intrs: intrs[cam_name] = K.astype(np.float32) extrs[cam_name] = E.astype(np.float32) rgbs[cam_name].append(image) # Build PyTorch3D camera from raw E fx, fy = K[0, 0], K[1, 1] cx, cy = K[0, 2], K[1, 2] R = E[:3, :3] t = E[:3, 3] # 4D-DRESS convention: transpose + flip X/Y R = R.T R = R @ np.diag(np.array([-1.0, -1.0, 1.0], dtype=np.float32)) t = t @ np.diag(np.array([-1.0, -1.0, 1.0], dtype=np.float32)) cameras_p3d = PerspectiveCameras( focal_length=torch.tensor([[fx, fy]], dtype=torch.float32, device=device), principal_point=torch.tensor([[cx, cy]], dtype=torch.float32, device=device), R=torch.tensor(R, dtype=torch.float32, device=device).unsqueeze(0), T=torch.tensor(t, dtype=torch.float32, device=device).unsqueeze(0), image_size=torch.tensor([[h, w]], dtype=torch.float32, device=device), in_ndc=False, device=device, ) # Rasterize (no global transform on mesh here) rs_key = (h, w) if rs_key not in raster_settings_cache: raster_settings_cache[rs_key] = RasterizationSettings( image_size=(h, w), blur_radius=0.0, faces_per_pixel=1, bin_size=0, ) rasterizer = MeshRasterizer(cameras=cameras_p3d, raster_settings=raster_settings_cache[rs_key]) fragments = rasterizer(meshes) # faces_per_pixel=1 -> (1,H,W,1) -> (H,W) zbuf = fragments.zbuf[0, ..., 0].detach().cpu().numpy() zbuf = np.nan_to_num(zbuf, nan=0.0) depths[cam_name].append(zbuf.astype(np.float32)) # Stack per-camera data cam_names = sorted(rgbs.keys(), key=lambda s: int(s)) for cam_name in cam_names: if len(rgbs[cam_name]) == 0: # Camera had no valid frames (skip) del intrs[cam_name], extrs[cam_name], rgbs[cam_name], depths[cam_name] continue rgbs[cam_name] = np.stack(rgbs[cam_name]).transpose(0, 3, 1, 2).astype(np.uint8) # [T,3,H,W] depths[cam_name] = np.stack(depths[cam_name]).astype(np.float32) # [T,H,W] # Basic shape checks (use first cam as reference) kept_cams = sorted(rgbs.keys(), key=lambda s: int(s)) assert len(kept_cams) > 0, "No cameras with data." n_frames, _, h, w = rgbs[kept_cams[0]].shape for cam_name in kept_cams: assert rgbs[cam_name].shape == (n_frames, 3, h, w) assert intrs[cam_name].shape == (3, 3) assert extrs[cam_name].shape == (3, 4) assert depths[cam_name].shape == (n_frames, h, w) # Rotate the scene to have the ground at z=0 rot_x = Rotation.from_euler('x', 90, degrees=True).as_matrix() rot_y = Rotation.from_euler('y', 0, degrees=True).as_matrix() rot_z = Rotation.from_euler('z', 0, degrees=True).as_matrix() rot = torch.from_numpy(rot_z @ rot_y @ rot_x) translation = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) for cam_name in kept_cams: E = torch.from_numpy(extrs[cam_name][None, None]) # [1,1,3,4] E_tx = transform_scene(1, rot, translation, None, E, None, None, None)[1] extrs[cam_name] = E_tx[0, 0].numpy() # Save save_pickle(save_pkl_path, dict( rgbs=rgbs, intrs=intrs, extrs=extrs, depths=depths, ego_cam_name=None, )) # Visualize the data sample using rerun rerun_modes = [] if stream_rerun_viz: rerun_modes += ["stream"] if save_rerun_viz: rerun_modes += ["save"] for rerun_mode in rerun_modes: rr.init(f"3dpt", recording_id="v0.16") if rerun_mode == "stream": rr.connect_tcp() rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True) rr.set_time_seconds("frame", 0) rr.log( "world/xyz", rr.Arrows3D( vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]], colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]], ), ) mesh_vertices = meshes._verts_list[0].cpu() mesh_faces = meshes._faces_list[0].cpu() mesh_vertices = transform_scene(1, rot, translation, None, None, None, mesh_vertices[None], None)[3][0] rr.log( "mesh", rr.Mesh3D( vertex_positions=mesh_vertices.numpy().astype(np.float32), # (N, 3) triangle_indices=mesh_faces.numpy().reshape(-1, 3).astype(np.int32), # (M, 3) albedo_factor=[200, 200, 255], # Optional color ), ) fps = 30 for frame_idx in range(n_frames): rr.set_time_seconds("frame", frame_idx / fps) for cam_name in cam_names: extr = extrs[cam_name] intr = intrs[cam_name] img = rgbs[cam_name][frame_idx].transpose(1, 2, 0).astype(np.uint8) depth = depths[cam_name][frame_idx] h, w = img.shape[:2] fx, fy = intr[0, 0], intr[1, 1] cx, cy = intr[0, 2], intr[1, 2] # Camera pose T = np.eye(4) T[:3, :] = extr world_T_cam = np.linalg.inv(T) rr.log(f"{cam_name}/image", rr.Transform3D( translation=world_T_cam[:3, 3], mat3x3=world_T_cam[:3, :3], )) rr.log(f"{cam_name}/image", rr.Pinhole( image_from_camera=intr, width=w, height=h )) rr.log(f"{cam_name}/image", rr.Image(img)) rr.log(f"{cam_name}/depth", rr.Transform3D( translation=world_T_cam[:3, 3], mat3x3=world_T_cam[:3, :3], )) rr.log(f"{cam_name}/depth", rr.Pinhole( image_from_camera=intr, width=w, height=h )) rr.log(f"{cam_name}/depth", rr.DepthImage(depth, meter=1.0, colormap="viridis")) # Unproject depth to point cloud y, x = np.meshgrid(np.arange(h), np.arange(w), indexing="ij") z = depth valid = z > 0 x = x[valid] y = y[valid] z = z[valid] X = (x - cx) * z / fx Y = (y - cy) * z / fy pts_cam = np.stack([X, Y, z], axis=-1) # Transform to world R = world_T_cam[:3, :3] t = world_T_cam[:3, 3] pts_world = pts_cam @ R.T + t # Color colors = img[y, x] rr.log(f"point_cloud/{cam_name}", rr.Points3D(positions=pts_world, colors=colors)) if rerun_mode == "save": base, name = os.path.split(save_pkl_path) name_no_ext = os.path.splitext(name)[0] save_rrd_path = os.path.join(base, f"rerun__{name_no_ext}.rrd") rr.save(save_rrd_path) print(f"Saved rerun viz to {os.path.abspath(save_rrd_path)}") print(f"Done with {save_pkl_path}.") print() if __name__ == "__main__": dataset_root = "datasets/hi4d" output_root = "datasets/hi4d-processed" longside_resolution: Optional[int] = 512 if longside_resolution is not None: output_root += f"-resized-{longside_resolution}" os.makedirs(output_root, exist_ok=True) pairs = [ "pair00", "pair01", "pair02", "pair09", "pair10", "pair12", "pair13", "pair14", "pair15", "pair16", "pair17", "pair18", "pair19", "pair21", "pair22", "pair23", "pair27", "pair28", "pair32", "pair37" ] # Enumerate actions per pair automatically for pair in tqdm.tqdm(pairs, desc="Pairs"): pair_dir = os.path.join(dataset_root, pair) assert os.path.isdir(pair_dir) actions = sorted([ d for d in os.listdir(pair_dir) if os.path.isdir(os.path.join(pair_dir, d)) and not d.startswith(".") ]) for action in tqdm.tqdm(actions, desc=f"{pair} actions", leave=False): out_pkl = os.path.join(output_root, f"{pair}__{action}.pkl") extract_hi4d_action_to_pkl( dataset_root=dataset_root, pair=pair, action=action, save_pkl_path=out_pkl, downscaled_longerside=longside_resolution, save_rerun_viz=True, stream_rerun_viz=False, skip_if_output_exists=True, ) ================================================ FILE: scripts/merge_comparison_mp4s.py ================================================ """ Merge MP4 files of different methods into a single side-by-side comparison, adding a small text bar for each method using Pillow + ImageClip instead of MoviePy's TextClip (which requires ImageMagick). Usage: python merge_comparison_mp4s.py """ import os import numpy as np from PIL import Image, ImageDraw, ImageFont from moviepy.editor import ( VideoFileClip, ImageClip, clips_array, CompositeVideoClip ) def create_title_image(text, width, height=50, bg_color=(255, 255, 255)): """ Creates a PIL Image of size (width x height) with the given text, centered. Returns a NumPy array (H x W x 3). """ # Create a blank RGB image img = Image.new("RGB", (width, height), color=bg_color) draw = ImageDraw.Draw(img) # Choose a default font. If you have a TTF file, specify it here: font = ImageFont.truetype("times_new_roman.ttf", size=36) # font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", size=24) # If you don't have a TTF file handy, ImageFont.load_default() is the fallback: # font = ImageFont.load_default() text_w, text_h = draw.textsize(text, font=font) x = (width - text_w) // 2 y = (height - text_h) // 2 draw.text((x, y), text, fill=(0, 0, 0), font=font) return np.array(img) def merge_mp4s(mp4s_title_to_path_dict, merged_mp4_output_path, num_columns): """ Merges each input MP4 (which presumably has a 'first column' or 'second column' that you want to extract) into a side-by-side comparison video, arranged in multiple rows if num_columns < number_of_videos, AND places each method's title bar above its own clip. :param mp4s_title_to_path_dict: dict of {title: path_to_video} :param merged_mp4_output_path: output MP4 path :param num_columns: number of clips to display per row """ titles = list(mp4s_title_to_path_dict.keys()) raw_clips = [] # 1) Load each video and crop the relevant half-column for title in titles: path = mp4s_title_to_path_dict[title] if not os.path.exists(path): raise FileNotFoundError(f"Video file not found: {path}") clip = VideoFileClip(path) w, h = clip.size # (width, height) if "GT" in title: # Crop the first column sub_clip = clip.crop(x1=0, x2=w // 2, y1=0, y2=h) else: # Crop the second column sub_clip = clip.crop(x1=w // 2, x2=w, y1=0, y2=h) raw_clips.append((title, sub_clip)) # 2) For each sub-clip, create a small "title bar" on top # so each method has its own label above its clip. bar_height = 50 titled_clips = [] for (title, subclip) in raw_clips: # Create a bar image for the subclip width title_img_array = create_title_image(title, subclip.w, bar_height) title_iclip = ImageClip(title_img_array, duration=subclip.duration) # Shift subclip downward by bar_height subclip_shifted = subclip.set_position((0, bar_height)) # Composite them vertically: [title bar on top, subclip below] comp_h = bar_height + subclip.h comp_w = subclip.w composite = CompositeVideoClip( [title_iclip, subclip_shifted], size=(comp_w, comp_h) ) titled_clips.append(composite) # 3) Normalize all titled_clips to the same height if they differ. import math min_height = min(tc.h for tc in titled_clips) normalized_clips = [] for tc in titled_clips: if tc.h != min_height: scale = min_height / tc.h new_w = int(tc.w * scale) resized = tc.resize((new_w, min_height)) normalized_clips.append(resized) else: normalized_clips.append(tc) # 4) Arrange the normalized clips in rows of length `num_columns`. n = len(normalized_clips) n_rows = math.ceil(n / num_columns) rows = [] idx = 0 for _ in range(n_rows): row_clips = normalized_clips[idx: idx + num_columns] rows.append(row_clips) idx += num_columns # 5) Stack them using clips_array final_clip = clips_array(rows) # 6) Write to output final_clip.write_videofile( merged_mp4_output_path, fps=12, codec="libx264", threads=4 # adjust as needed ) print(f"✅ Merged video saved successfully to {merged_mp4_output_path}") if __name__ == '__main__': for selection in ["A", "B", "C"]: if selection == "A": datasets_seq = [ *[("kubric-multiview-v3-views0123-novelviews4", seq) for seq in [0, 3, 4, 5]], *[("panoptic-multiview-views1_7_14_20-novelviews24", seq) for seq in [0, 3, 4, 5]], *[("panoptic-multiview-views1_7_14_20-novelviews27", seq) for seq in [0, 3, 4, 5]], ] mp4s = { "GT": "logs/cotracker3-online/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", "Dynamic 3DGS": "logs/dynamic_3dgs/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", "Shape of Motion": "logs/shape_of_motion/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", "LocoTrack": "logs/locotrack/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", "CoTracker3": "logs/cotracker3-online/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", "DELTA": "logs/delta/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", # "SpaTracker-1": "logs/kubric_v3_duster0123/multiview-adapter-pretrained-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", "SpaTracker": "logs/kubric_v3/multiview-adapter-002/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_69799.mp4", # "SpaTracker-3": "logs/kubric_v3_duster0123/multiview-adapter-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_90799.mp4", "Triplane Baseline": "logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_99999.mp4", # "Triplane-2": "logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_99999.mp4", "Ours": "logs/kubric_v3_augs/mvtracker/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_159999.mp4", } elif selection == "B": datasets_seq = [ *[("dex-ycb-multiview-duster0123-novelviews4", seq) for seq in [0, 3, 4, 5]], *[("dex-ycb-multiview-duster0123-novelviews5", seq) for seq in [0, 3, 4, 5]], *[("dex-ycb-multiview-duster0123-novelviews6", seq) for seq in [0, 3, 4, 5]], *[("dex-ycb-multiview-duster0123-novelviews7", seq) for seq in [0, 3, 4, 5]], ] mp4s = { "GT": "logs/cotracker3-online/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", "Dynamic 3DGS": "logs/dynamic_3dgs/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", "Shape of Motion": "logs/shape_of_motion/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", "LocoTrack": "logs/locotrack/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", "CoTracker3": "logs/cotracker3-online/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", "DELTA": "logs/delta/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", # "SpaTracker-1": "logs/kubric_v3_duster0123/multiview-adapter-pretrained-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", # "SpaTracker-2": "logs/kubric_v3/multiview-adapter-002/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_69799.mp4", "SpaTracker": "logs/kubric_v3_duster0123/multiview-adapter-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_90799.mp4", # "Triplane-1": "logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_99999.mp4", "Triplane Baseline": "logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_99999.mp4", "Ours": "logs/kubric_v3_augs/mvtracker/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_159999.mp4", } elif selection == "C": datasets_seq = [ *[("dex-ycb-multiview-duster2345-novelviews7", seq) for seq in [0, 3, 4, 5]], *[("dex-ycb-multiview-duster4567-novelviews7", seq) for seq in [0, 3, 4, 5]], *[("dex-ycb-multiview-duster4567-novelviews0", seq) for seq in [0, 3, 4, 5]], ] mp4s = { "GT": "logs/cotracker3-online/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", # "Dynamic 3DGS": "logs/dynamic_3dgs/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", # "Shape of Motion": "logs/shape_of_motion/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", # "LocoTrack": "logs/locotrack/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", "CoTracker3": "logs/cotracker3-online/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", "DELTA": "logs/delta/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", # "SpaTracker-1": "logs/kubric_v3_duster0123/multiview-adapter-pretrained-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4", # "SpaTracker-2": "logs/kubric_v3/multiview-adapter-002/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_69799.mp4", "SpaTracker": "logs/kubric_v3_duster0123/multiview-adapter-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_90799.mp4", # "Triplane-1": "logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_99999.mp4", "Triplane Baseline": "logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_99999.mp4", "Ours": "logs/kubric_v3_augs/mvtracker/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_159999.mp4", } else: raise ValueError(f"Invalid selection: {selection}") for dataset, seq in datasets_seq: mp4s_title_to_path_dict = { key: path.format(dataset=dataset, seq=seq) for key, path in mp4s.items() } if not mp4s_title_to_path_dict: print(f"⚠️ Warning: No valid MP4 files found for dataset {dataset} seq {seq}. Skipping...") continue merged_mp4 = f"logs/comparison_v4__{dataset}__seq-{seq}.mp4" merge_mp4s(mp4s_title_to_path_dict, merged_mp4, num_columns=3) ================================================ FILE: scripts/panoptic_studio_preprocessing.py ================================================ """ This script will convert the Panoptic Studio subset of TAPVid-3D to multi-view 3D point tracking dataset. First, follow the instructions at https://github.com/google-deepmind/tapnet/tree/main/tapnet/tapvid3d to download the raw panoptic studio data, for example, as follows: ```bash # Set up a temporary environment conda create -n panoptic-preprocessing python=3.10.12 -y conda activate panoptic-preprocessing pip install "git+https://github.com/google-deepmind/tapnet.git#egg=tapnet[tapvid3d_eval,tapvid3d_generation]" # Download the raw data python -m tapnet.tapvid3d.annotation_generation.generate_pstudio --output_dir datasets/panoptic_studio_tapvid3d mkdir datasets/panoptic-multiview mv datasets/panoptic_studio_tapvid3d/tmp/data/* datasets/panoptic-multiview/ # If you like, you can remove the temporary environment now conda deactivate conda env remove -n panoptic-preprocessing ``` Following https://github.com/JonathonLuiten/Dynamic3DGaussians#run-visualizer-on-pretrained-models, download and unzip the pretrained Dynamic3DGS checkpoints, e.g. as follows: ```bash wget https://omnomnom.vision.rwth-aachen.de/data/Dynamic3DGaussians/output.zip -O checkpoints/output.zip unzip checkpoints/output.zip -d checkpoints/ rm checkpoints/output.zip mv checkpoints/output/pretrained checkpoints/dynamic3dgs_pretrained ``` Install the missing dependencies needed by Dynamic3DGS: ```bash conda activate 3dpt conda install -c conda-forge gcc_linux-64=11.3.0 gxx_linux-64=11.3.0 gcc=11.3.0 gxx=11.3.0 -y pip install git+https://github.com/JonathonLuiten/diff-gaussian-rasterization-w-depth.git pip install open3d==0.16.0 ``` Now you can run this script to generate the Dynamic3DGS depths and merge the TAP-Vid3D annotations: ```bash python -m scripts.panoptic_studio_preprocessing \ --dataset_root ./datasets/panoptic-multiview \ --checkpoint_root ./checkpoints/dynamic3dgs_pretrained \ --tapvid3d_root ./datasets/panoptic_studio_tapvid3d ``` The processed dataset is now stored in ./datasets/panoptic-multiview. If you'd like, you can remove the raw tapvid3d data now to save space: ```bash rm -rf ./datasets/panoptic_studio_tapvid3d ``` """ import argparse from pathlib import Path from tqdm import tqdm from mvtracker.models.core.dynamic3dgs.export_depths_from_pretrained_checkpoint import export_depth from mvtracker.models.core.dynamic3dgs.merge_tapvid3d_per_camera_annotations import merge_annotations def parse_args(): parser = argparse.ArgumentParser(description="Preprocess Panoptic Studio TAPVid-3D subset.") parser.add_argument("--dataset_root", type=Path, required=True, help="Root path to Panoptic Studio dataset (per-sequence folders).") parser.add_argument("--checkpoint_root", type=Path, required=True, help="Root path to Dynamic3DGS pretrained checkpoints (per-sequence).") parser.add_argument("--tapvid3d_root", type=Path, required=True, help="Root path to TAPVid-3D annotations for Panoptic Studio.") return parser.parse_args() if __name__ == '__main__': args = parse_args() sequences = ["basketball", "boxes", "football", "juggle", "softball", "tennis"] print("Exporting depths from pretrained checkpoints") for sequence_name in tqdm(sequences): scene_root = args.dataset_root / sequence_name output_path = scene_root / "dynamic3dgs_depth" checkpoint_path = args.checkpoint_root / sequence_name export_depth(scene_root, output_path, checkpoint_path) print("Merging TAP-Vid3D per-camera annotations.") for sequence_name in tqdm(sequences): scene_root = args.dataset_root / sequence_name checkpoint_path = args.checkpoint_root / sequence_name tapvid3d_annotation_paths = list(args.tapvid3d_root.glob(f"{sequence_name}_*.npz")) merge_annotations( scene_root, checkpoint_path, tapvid3d_annotation_paths, skip_if_output_already_exists=True, rerun_logging=True ) ================================================ FILE: scripts/plot_aj_for_varying_depth_noise_levels.py ================================================ import os import matplotlib.pyplot as plt import numpy as np import seaborn as sns # set_size from https://jwalton.info/Embed-Publication-Matplotlib-Latex/ def set_size(width, fraction=1, golden_ratio=(5 ** .5 - 1) / 2): """Set figure dimensions to avoid scaling in LaTeX. Parameters ---------- width: float Document textwidth or columnwidth in pts fraction: float, optional Fraction of the width which you wish the figure to occupy Returns ------- fig_dim: tuple Dimensions of figure in inches """ # Width of figure (in pts) fig_width_pt = width * fraction # Convert from pt to inches inches_per_pt = 1 / 72.27 # Golden ratio to set aesthetic figure height # https://disq.us/p/2940ij3 # golden_ratio = (5 ** .5 - 1) / 2 # Figure width in inches fig_width_in = fig_width_pt * inches_per_pt # Figure height in inches fig_height_in = fig_width_in * golden_ratio fig_dim = (fig_width_in, fig_height_in) return fig_dim def setup_plot(): sns.set_theme(style="whitegrid") sns.set_palette("tab10") plt.rcParams["font.family"] = "Times New Roman" plt.rcParams['font.weight'] = 'normal' def plot_aj( save_name='plot_robustness_to_depth_noise.pdf', width_in_paper_pts=237.13594, # \showthe\linewidth --> > 237.13594pt. linewidth=1.5, marker_size=5, label_font_size=9, tick_font_size=9, legend_font_size=7, dpi=400, results_dir=None, save_svg=False, ): setup_plot() fig, ax = plt.subplots(figsize=set_size(width_in_paper_pts, golden_ratio=0.3), dpi=dpi) x_labels = ['0', '1', '2', '5', '10', '20', '50', '100', '200'] x = np.arange(len(x_labels)) # x = np.array([0, 1, 2, 5, 10, 20, 50, 100, 200]) # x_labels = ['0', '1', '2', '5', '10', '20', '50', '100', '200'] results = { "Ours": [81.6, 80.7, 77.4, 69.8, 63.1, 59.3, 56.1, 54.3, 52.8], "Triplane": [75.4, 75.0, 73.7, 69.2, 63.4, 57.4, 51.5, 49.1, 47.6], "SpaTracker": [65.5, 63.8, 62.1, 58.7, 55.8, 52.6, 48.6, 45.4, 43.3], "DELTA": [57.4, 51.8, 46.2, 34.3, 23.8, 13.2, 5.0, 2.3, 1.0], } for label, y in results.items(): sns.lineplot(x=x, y=y, ax=ax, linewidth=linewidth, marker='o', markersize=marker_size, label=label) # ax.axhline(y=47.2, color=sns.color_palette("tab10")[1], linestyle='--', linewidth=1.5, label='Blind Baseline') ax.set_xticks(x) ax.set_xticklabels(x_labels) ax.set_yticks(np.arange(40, 90, 10)) ax.set_ylim([40, 83]) ax.tick_params(axis='both', which='major', labelsize=tick_font_size) ax.set_xlabel('Depth Noise (σ, in cm)', fontsize=label_font_size, fontweight='normal', labelpad=0) ax.set_ylabel('AJ', fontsize=label_font_size, fontweight='normal', labelpad=2) for spine in ax.spines.values(): spine.set_color('black') ax.grid(axis='y', color='lightgrey') ax.tick_params(axis="y", direction="in") ax.tick_params(axis="x", direction="in") legend = plt.legend( frameon=True, fancybox=False, loc=(0.675, 0.265), # loc="upper right", prop={'size': legend_font_size}, handletextpad=0.2, labelspacing=0.1, ) # legend.get_frame().set_facecolor('white') # legend.get_frame().set_edgecolor('black') plt.tight_layout(pad=0) if save_name: if results_dir: os.makedirs(results_dir, exist_ok=True) save_name = os.path.join(results_dir, save_name) plt.savefig(save_name, bbox_inches='tight', pad_inches=0) if save_svg: plt.savefig(save_name.replace('.pdf', '.svg'), bbox_inches='tight', pad_inches=0) plt.show() if __name__ == '__main__': plot_aj() ================================================ FILE: scripts/plot_aj_for_varying_n_of_views.py ================================================ import os import matplotlib.pyplot as plt import matplotlib.ticker as ticker import numpy as np import seaborn as sns # set_size from https://jwalton.info/Embed-Publication-Matplotlib-Latex/ def set_size(width, fraction=1): """Set figure dimensions to avoid scaling in LaTeX. Parameters ---------- width: float Document textwidth or columnwidth in pts fraction: float, optional Fraction of the width which you wish the figure to occupy Returns ------- fig_dim: tuple Dimensions of figure in inches """ # Width of figure (in pts) fig_width_pt = width * fraction # Convert from pt to inches inches_per_pt = 1 / 72.27 # Golden ratio to set aesthetic figure height # https://disq.us/p/2940ij3 golden_ratio = (5 ** .5 - 1) / 2 # Figure width in inches fig_width_in = fig_width_pt * inches_per_pt # Figure height in inches fig_height_in = fig_width_in * golden_ratio fig_dim = (fig_width_in, fig_height_in) return fig_dim def setup_plot(): sns.set_theme(style="whitegrid") sns.set_palette("tab10") plt.rcParams["font.family"] = "Times New Roman" plt.rcParams['font.weight'] = 'normal' def plot_aj( save_name='plot_number_of_views.pdf', width_in_paper_pts=237.13594, # \showthe\linewidth --> > 237.13594pt. linewidth=1.5, marker_size=5, label_font_size=9, tick_font_size=9, legend_font_size=7, dpi=400, results_dir=None, save_svg=False, ): setup_plot() fig, ax = plt.subplots(figsize=set_size(width_in_paper_pts), dpi=dpi) x = np.arange(1, 9) y_data = { "MVTracker (ours)": [64.0, 66.8, 73.2, 71.1, 77.4, 76.7, 77.3, 79.2], "Triplane": [44.0, 48.0, 56.0, 57.6, 63.5, 64.5, 65.5, 66.8], # "TAPIP3D": [36.6, 35.6, 40.5, 38.8, 57.7, 54.2, 55.2, 56.4], # "SpatialTrackerV2": [39.8, 39.5, 36.5, 35.5, 41.1, 37.1, 37.0, 37.7], "SpatialTracker": [60.6, 58.4, 61.8, 58.3, 63.2, 62.4, 62.9, 63.4], "CoTracker3": [28.6, 27.0, 29.5, 29.4, 39.1, 37.5, 37.1, 37.3], # "CoTracker2": [29.8, 26.4, 29.2, 28.8, 37.8, 36.2, 36.0, 36.0], "DELTA": [33.0, 34.3, 38.0, 36.5, 37.2, 35.4, 34.9, 35.7], "LocoTrack": [27.9, 26.0, 28.1, 27.8, 36.3, 34.8, 34.7, 34.9] } for label, y in y_data.items(): sns.lineplot(x=x, y=y, label=label, ax=ax, linewidth=linewidth, marker='o', markersize=marker_size) ax.set_xticks(x) ax.set_yticks(np.arange(30, 81, 10)) ax.set_ylim([25, 80]) ax.xaxis.set_major_formatter(ticker.ScalarFormatter()) ax.tick_params(axis='both', which='major', labelsize=tick_font_size) ax.set_xlabel('Number of Views', fontsize=label_font_size, fontweight='normal', labelpad=0) ax.set_ylabel('Average Jaccard (AJ)', fontsize=label_font_size, fontweight='normal', labelpad=2) for spine in ax.spines.values(): spine.set_color('black') ax.grid(axis='y', color='lightgrey') ax.tick_params(axis="y", direction="in") ax.tick_params(axis="x", direction="in") legend = plt.legend( frameon=True, fancybox=False, loc=(0.625, 0.265), prop={'size': legend_font_size}, handletextpad=0.2, labelspacing=0.1, ) # legend.get_frame().set_facecolor('white') # legend.get_frame().set_edgecolor('black') plt.tight_layout(pad=0) if save_name: if results_dir: os.makedirs(results_dir, exist_ok=True) save_name = os.path.join(results_dir, save_name) plt.savefig(save_name, bbox_inches='tight', pad_inches=0) if save_svg: plt.savefig(save_name.replace('.pdf', '.svg'), bbox_inches='tight', pad_inches=0) plt.show() if __name__ == '__main__': plot_aj() ================================================ FILE: scripts/profiling.md ================================================ # Profiling Notes This document summarizes how to run performance profiling using PyTorch’s built-in tools, and how to interpret the results. To profile one training iteration (forward + backward + optimizer step), the following snippet can be used: ```python from torch.profiler import profile, ProfilerActivity with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack=True, with_flops=True, profile_memory=True, record_shapes=True, ) as prof: # one iteration of fwd + bwd + optimize pass print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=36)) print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=36)) print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=36)) print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=36)) prof.export_chrome_trace("trace.json") breakpoint() ``` The printed summary tables produced by `.key_averages()` are already extremely informative. Sorting by `cuda_time_total` highlights which operations dominate the runtime on GPU. The `self_cuda_memory_usage` sort can reveal the main contributors to memory consumption. Enabling `record_shapes=True` further helps diagnose operations that unexpectedly receive large tensors. For more detailed inspection, the exported trace file `trace.json` can be opened in Chrome at `chrome://tracing`. This presents a flamegraph-style timeline of kernel launches, memory activity, and execution order, which can be especially helpful for understanding the global structure and scheduling behavior of the code. For a brief tutorial on how to navigate this view, see [here](https://www.youtube.com/watch?v=AhIOohJYSrw). Note that the trace file can become very large (e.g., over 1 GB) depending on how much code is profiled. This may slow down both trace generation and visualization. While `.key_averages()` provides a fast, summary-level view that is often sufficient for identifying key bottlenecks, the flamegraph timeline can be equally valuable for temporal insights. ================================================ FILE: scripts/selfcap_preprocessing.py ================================================ """ SelfCap dataset (https://zju3dv.github.io/longvolcap/) Download the dataset (but first fill in the form at https://forms.gle/MzJqZjBfyZ53fRMZ7): ```bash mkdir -p datasets/selfcap cd datasets/selfcap gdown --fuzzy https://drive.google.com/file/d/1iTr6sTVQoCtTK4FbA3lRxMrh7sC0MhzP/view?usp=share_link # LICENSE gdown --fuzzy https://drive.google.com/file/d/1cg54hE_IBsnVXuMCj44JCQEGnqU1Hr5b/view?usp=share_link # yoga-calib.tar.gz gdown --fuzzy https://drive.google.com/file/d/1l84Pna4eO9m_bql2mR8nm6VnLO80e717/view?usp=share_link # hair-calib.tar.gz gdown --fuzzy https://drive.google.com/file/d/1Desj7th500-vsyRYzRq8Xb6TtUgDPU4u/view?usp=share_link # README.md gdown --fuzzy https://drive.google.com/file/d/1Ex3OtLmz6kBbgB84MImlDLJpVE6vI3ks/view?usp=share_link # bike-release.tar.gz gdown --fuzzy https://drive.google.com/file/d/1muPLxdCm4il_X6TRVLaxx-6sYO6XYIwH/view?usp=share_link # yoga-release.tar.gz gdown --fuzzy https://drive.google.com/file/d/12mRUCpaTk1XearBq2hUIf5ZbHZw4AQAw/view?usp=share_link # dance-release.tar.gz gdown --fuzzy https://drive.google.com/file/d/1AEiQBC9CIthR97qZeZzkH2nlXXpogfxH/view?usp=share_link # hair-release.tar.gz gdown --fuzzy https://drive.google.com/file/d/1NFrHh-SxUER4jWBV0irnCcDhEmkg3WUg/view?usp=share_link # corgi-release.tar.gz gdown --fuzzy https://drive.google.com/file/d/1b9Hf3YY_usPrtddgpMe569dSqh0bEGLo/view?usp=share_link # bar-release.tar.gz tar xvf bar-release.tar.gz tar xvf bike-release.tar.gz tar xvf corgi-release.tar.gz tar xvf dance-release.tar.gz tar xvf hair-calib.tar.gz tar xvf hair-release.tar.gz tar xvf yoga-calib.tar.gz tar xvf yoga-release.tar.gz rm *.tar.gz cd - ``` Running the script: `PYTHONPATH=/local/home/frrajic/xode/duster:$PYTHONPATH python -m scripts.selfcap_preprocessing` Note that you need to set up dust3r first, see docstring of `scripts/estimate_depth_with_duster.py`. """ import concurrent.futures import json import os import pickle from typing import Optional import cv2 import numpy as np import rerun as rr from scipy.spatial.transform import Rotation as R from tqdm import tqdm from scripts.egoexo4d_preprocessing import main_estimate_duster_depth def main_preprocess_selfcap( dataset_root: str, scene_name: str, outputs_dir: str, num_cameras: Optional[int] = None, sample_cameras_sequentially: Optional[bool] = False, start_frame: Optional[int] = None, max_frames: Optional[int] = None, frames_downsampling_factor: Optional[int] = None, downscaled_longerside: Optional[int] = None, save_rerun_viz: bool = True, stream_rerun_viz: bool = False, skip_if_output_exists: bool = True, ): # Skip if output exists save_pkl_path = os.path.join(outputs_dir, f"{scene_name}.pkl") if skip_if_output_exists and os.path.exists(save_pkl_path): print(f"Skipping {save_pkl_path} since it already exists") print() return save_pkl_path else: print(f"Processing {scene_name}...") # --- Load calibration --- calib_dir = os.path.join(dataset_root, f"{scene_name}-calib", "optimized") intri_path = os.path.join(calib_dir, "intri.yml") extri_path = os.path.join(calib_dir, "extri.yml") sync_path = os.path.join(calib_dir, "sync.json") assert all(os.path.exists(p) for p in [intri_path, extri_path, sync_path]) intri_fs = cv2.FileStorage(intri_path, cv2.FILE_STORAGE_READ) extri_fs = cv2.FileStorage(extri_path, cv2.FILE_STORAGE_READ) with open(sync_path) as f: sync_data = json.load(f) # --- Load videos --- video_dir = os.path.join(dataset_root, f"{scene_name}-release", "videos") cam_names = sorted([f.replace(".mp4", "") for f in os.listdir(video_dir) if f.endswith(".mp4")]) if num_cameras is not None and num_cameras < len(cam_names): if sample_cameras_sequentially: cam_names = cam_names[:num_cameras] else: step = len(cam_names) / num_cameras cam_names = [cam_names[int(i * step)] for i in range(num_cameras)] rgbs, intrs, extrs = {}, {}, {} def load_cam_video(cam): vid_path = os.path.join(video_dir, f"{cam}.mp4") cap = cv2.VideoCapture(vid_path) fps = cap.get(cv2.CAP_PROP_FPS) offset = int(round(sync_data[cam] * fps)) frames = [] i = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break idx = i - offset i += 1 if idx < 0: continue if start_frame is not None and idx < start_frame: continue if frames_downsampling_factor and ((idx - start_frame) % frames_downsampling_factor != 0): continue if max_frames and len(frames) >= max_frames: break img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB).transpose(2, 0, 1) frames.append(img) cap.release() if not frames: return None, None, None rgb = np.stack(frames) intr = intri_fs.getNode(f"K_{cam}").mat().astype(np.float32) R = extri_fs.getNode(f"Rot_{cam}").mat().astype(np.float32) T = extri_fs.getNode(f"T_{cam}").mat().astype(np.float32).reshape(3) extr = np.concatenate([R, T[:, None]], axis=1) return cam, rgb, intr, extr # Run parallel loading with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: futures = [executor.submit(load_cam_video, cam) for cam in cam_names] for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): cam, rgb, intr, extr = future.result() if cam is None: print("Warning: camera skipped due to no usable frames.") continue rgbs[cam] = rgb intrs[cam] = intr extrs[cam] = extr intri_fs.release() extri_fs.release() # Apply a global -90° rotation around X axis to the scene rot_x = R.from_euler('x', -90, degrees=True).as_matrix() rot_y = R.from_euler('y', 0, degrees=True).as_matrix() rot_z = R.from_euler('z', 0, degrees=True).as_matrix() rot = rot_z @ rot_y @ rot_x T_rot = np.eye(4) T_rot[:3, :3] = rot for cam in extrs: extrs_square = np.eye(4, dtype=extrs[cam].dtype) extrs_square[:3, :] = extrs[cam] extrs_trans_square = np.einsum('ki,ij->kj', extrs_square, T_rot.T) extrs_trans = extrs_trans_square[..., :3, :] assert np.allclose(extrs_trans_square[..., 3, 3], np.ones_like(extrs_trans_square[..., 3, 3])) extrs[cam] = extrs_trans print(f"Loaded SelfCap scene '{scene_name}' with {len(cam_names)} cams and {rgbs[cam_names[0]].shape[0]} frames.") # Check shapes n_frames, _, h, w = rgbs[cam_names[0]].shape for cam_name in cam_names: assert rgbs[cam_name].shape == (n_frames, 3, h, w) assert intrs[cam_name].shape == (3, 3) assert extrs[cam_name].shape == (3, 4) # Save downsized version if downscaled_longerside is not None: print(f"Downscaling to longer side {downscaled_longerside}") for cam_name in tqdm(cam_names, desc="Downscaling"): _, _, h, w = rgbs[cam_name].shape scale = downscaled_longerside / max(h, w) new_h, new_w = int(h * scale), int(w * scale) resized = [] for img in rgbs[cam_name]: img = img.transpose(1, 2, 0) # CHW -> HWC img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) resized.append(img.transpose(2, 0, 1)) # HWC -> CHW rgbs[cam_name] = np.stack(resized) # scale intrinsics intrs[cam_name][:2] *= scale # Save processed output to a pickle file os.makedirs(outputs_dir, exist_ok=True) with open(save_pkl_path, "wb") as f: pickle.dump( dict( rgbs=rgbs, intrs=intrs, extrs=extrs, ego_cam_name=None, ), f, protocol=pickle.HIGHEST_PROTOCOL, ) print(f"Saved {save_pkl_path}") # Visualize the data sample using rerun rerun_modes = [] if stream_rerun_viz: rerun_modes += ["stream"] if save_rerun_viz: rerun_modes += ["save"] for rerun_mode in rerun_modes: rr.init(f"3dpt", recording_id="v0.16") if rerun_mode == "stream": rr.connect_tcp() rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True) rr.set_time_seconds("frame", 0) rr.log( "world/xyz", rr.Arrows3D( vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]], colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]], ), ) fps = 30 for frame_idx in range(min(n_frames, 30)): rr.set_time_seconds("frame", frame_idx / fps) for cam_name in cam_names: extr = extrs[cam_name] intr = intrs[cam_name] img = rgbs[cam_name][frame_idx].transpose(1, 2, 0).astype(np.uint8) # Camera pose logging E = extr if extr.shape == (3, 4) else extr[0] T = np.eye(4) T[:3, :] = E T_world_cam = np.linalg.inv(T) rr.log(f"{cam_name}/image", rr.Transform3D( translation=T_world_cam[:3, 3], mat3x3=T_world_cam[:3, :3], )) # Intrinsics and image rr.log(f"{cam_name}/image", rr.Pinhole( image_from_camera=intr, width=img.shape[1], height=img.shape[0] )) rr.log(f"{cam_name}/image", rr.Image(img)) if rerun_mode == "save": save_rrd_path = os.path.join(outputs_dir, f"rerun__{scene_name}.rrd") rr.save(save_rrd_path) print(f"Saved rerun viz to {os.path.abspath(save_rrd_path)}") return save_pkl_path if __name__ == '__main__': dataset_root = "datasets/selfcap/" outputs_dir = "datasets/selfcap-processed/" for scene_name in ["yoga", "hair"]: for num_cameras, sequential_cams, start_frame, max_frames, frames_downsampling_factor, downscaled_longerside in [ (8, False, 90, 256, 10, 512), (8, True, 90, 256, 10, 512), (8, False, 90, 2560, 10, 512), (4, False, 90, 256, 10, 512), (4, True, 90, 256, 10, 512), (16, False, 90, 256, 10, 512), (16, True, 90, 256, 10, 512), (16, True, 90, 2560, 10, 512), (8, False, 90, 256, 1, 512), (8, False, 90, 2560, 1, 512), (8, False, 90, 256, 5, 512), (8, False, 90, 256, 20, 512), (8, False, 90, 256, 30, 512), ]: # Extract rgbs, intrs, extrs from SelfCap outputs_subdir = os.path.join( outputs_dir, f"numcams-{num_cameras}-seq-{sequential_cams}_" f"startframe-{start_frame}_" f"maxframes-{max_frames}_" f"downsample-{frames_downsampling_factor}_" f"downscale-{downscaled_longerside}" ) scene_pkl = main_preprocess_selfcap( dataset_root=dataset_root, scene_name=scene_name, outputs_dir=outputs_subdir, num_cameras=num_cameras, sample_cameras_sequentially=sequential_cams, start_frame=start_frame, max_frames=max_frames, frames_downsampling_factor=frames_downsampling_factor, downscaled_longerside=downscaled_longerside, ) # Run Dust3r to estimate depths from rgbs, fix the known intrs and extrs during multi-view stereo optim depth_subdir = os.path.join(outputs_subdir, f"duster_depths__{scene_name}") main_estimate_duster_depth( pkl_scene_file=scene_pkl, depths_output_dir=depth_subdir, ) # Run VGGT to estimate depths from rgbs, align with the known extrs afterward ... ================================================ FILE: scripts/slurm/eval.sh ================================================ #!/bin/bash #SBATCH --job-name=eval-058 #SBATCH --nodes=1 #SBATCH --ntasks-per-node=1 #SBATCH --cpus-per-task=32 #SBATCH --gres=gpu:1 #SBATCH --mem=460000 #SBATCH --partition=normal #SBATCH --account=a-a03 #SBATCH --time=00:10:00 #SBATCH --dependency=singleton #SBATCH --mail-type=begin #SBATCH --mail-type=end #SBATCH --mail-user=frano.rajic@inf.ethz.ch #SBATCH --output=./logs/slurm_logs/%x-%j.out #SBATCH --array=0-85 set -x cat $0 DIR=$(realpath .) mkdir -p $DIR/runs CKPTS=( # "experiment_path=logs/eval/copycat model=copycat" # "experiment_path=logs/dynamic_3dgs model=locotrack" # "experiment_path=logs/shape_of_motion model=locotrack" # # "experiment_path=logs/eval/tapip3d model=tapip3d" # "experiment_path=logs/eval/scenetracker model=scenetracker" # "experiment_path=logs/eval/locotrack model=locotrack" # "experiment_path=logs/eval/delta model=delta" # "experiment_path=logs/eval/cotracker2_online model=cotracker2_online" # "experiment_path=logs/eval/cotracker3_online model=cotracker3_online" # # "experiment_path=logs/eval/spatracker_monocular_pretrained model=spatracker_monocular_pretrained restore_ckpt_path=checkpoints/spatracker_monocular_original-authors-ckpt.pth" # "experiment_path=logs/eval/spatracker_monocular_kubric-training model=spatracker_monocular restore_ckpt_path=checkpoints/spatracker_monocular_trained-on-kubric-depth_069800.pth" # "experiment_path=logs/eval/spatracker_monocular_duster-training model=spatracker_monocular restore_ckpt_path=checkpoints/spatracker_monocular_trained-on-duster-depth_090800.pth" # "experiment_path=logs/eval/spatracker_multiview_kubric-training model=spatracker_multiview restore_ckpt_path=checkpoints/spatracker_multiview_trained-on-kubric-depth_100000.pth model.triplane_xres=128 model.triplane_yres=128 model.triplane_zres=128" # "experiment_path=logs/eval/spatracker_multiview_duster-training model=spatracker_multiview restore_ckpt_path=checkpoints/spatracker_multiview_trained-on-duster-depth_100000.pth model.triplane_xres=256 model.triplane_yres=256 model.triplane_zres=128" # # "experiment_path=logs/eval/mvtracker-v0_kubric-training model=mvtracker restore_ckpt_path=checkpoints/mvtracker_v0_trained-on-kubric-depth_091600.pth model.updatetransformer_type=spatracker model.apply_sigmoid_to_vis=true trainer.precision=16-mixed model.fmaps_dim=384 model.hidden_size=384 model.num_heads=8" # "experiment_path=logs/eval/mvtracker-v0_duster-training model=mvtracker restore_ckpt_path=checkpoints/mvtracker_v0_trained-on-duster-depth_100000.pth model.updatetransformer_type=spatracker model.apply_sigmoid_to_vis=true trainer.precision=16-mixed model.fmaps_dim=384" # "experiment_path=logs/eval/mvtracker-iccv-march2025 model=mvtracker restore_ckpt_path=checkpoints/mvtracker_160000_march2025.pth model.updatetransformer_type=spatracker model.apply_sigmoid_to_vis=true trainer.precision=16-mixed " "experiment_path=logs/mvtracker model=mvtracker restore_ckpt_path=checkpoints/mvtracker_200000_june2025.pth" ) DATASETS=( ############################ ### ~~~ Main results ~~~ ### ############################ dex-ycb-multiview dex-ycb-multiview-duster0123 dex-ycb-multiview-duster0123cleaned panoptic-multiview-views1_7_14_20 panoptic-multiview-views27_16_14_8 panoptic-multiview-views1_4_7_11 kubric-multiview-v3-views0123 kubric-multiview-v3-duster0123 kubric-multiview-v3-duster0123cleaned tapvid2d-davis-mogewithextrinsics-256x256 ############################# ### ~~~ 2DPT Ablation ~~~ ### ############################# dex-ycb-multiview-2dpt dex-ycb-multiview-duster0123-2dpt panoptic-multiview-views1_7_14_20-2dpt panoptic-multiview-views27_16_14_8-2dpt panoptic-multiview-views1_4_7_11-2dpt kubric-multiview-v3-views0123-2dpt kubric-multiview-v3-duster0123-2dpt #################################### ### ~~~ Single-point results ~~~ ### #################################### # dex-ycb-multiview-single # dex-ycb-multiview-duster0123-single # dex-ycb-multiview-duster0123cleaned-single # panoptic-multiview-views1_7_14_20-single # panoptic-multiview-views27_16_14_8-single # panoptic-multiview-views1_4_7_11-single # kubric-multiview-v3-views0123-single # kubric-multiview-v3-duster0123-single # kubric-multiview-v3-duster0123cleaned-single # tapvid2d-davis-mogewithextrinsics-256x256-single ##################################### ### ~~~ Camera-setup Ablation ~~~ ### ##################################### panoptic-multiview-views1_7_14_20 panoptic-multiview-views27_16_14_8 panoptic-multiview-views1_4_7_11 dex-ycb-multiview-duster0123 dex-ycb-multiview-duster2345 dex-ycb-multiview-duster4567 ######################################## ### ~~~ Number-of-views Ablation ~~~ ### ######################################## kubric-multiview-v3-views0 kubric-multiview-v3-views01 kubric-multiview-v3-views012 kubric-multiview-v3-views0123 kubric-multiview-v3-views01234 kubric-multiview-v3-views012345 kubric-multiview-v3-views0123456 kubric-multiview-v3-views01234567 kubric-multiview-v3-duster0123-views0 kubric-multiview-v3-duster0123-views01 kubric-multiview-v3-duster0123-views012 kubric-multiview-v3-duster0123-views0123 kubric-multiview-v3-duster01234567-views01234 kubric-multiview-v3-duster01234567-views012345 kubric-multiview-v3-duster01234567-views0123456 kubric-multiview-v3-duster01234567-views01234567 panoptic-multiview-views1 panoptic-multiview-views1_14 panoptic-multiview-views1_7_14 panoptic-multiview-views1_7_14_20 panoptic-multiview-views1_4_7_14_20 panoptic-multiview-views1_4_7_14_17_20 panoptic-multiview-views1_4_7_11_14_17_20 panoptic-multiview-views1_4_7_11_14_17_20_23 dex-ycb-multiview-duster0123-views0 dex-ycb-multiview-duster0123-views01 dex-ycb-multiview-duster0123-views012 dex-ycb-multiview-duster0123-views0123 dex-ycb-multiview-duster01234567-views01234 dex-ycb-multiview-duster01234567-views012345 dex-ycb-multiview-duster01234567-views0123456 dex-ycb-multiview-duster01234567-views01234567 ##################################### ### ~~~ For video comparisons ~~~ ### ##################################### kubric-multiview-v3-views0123-novelviews4 panoptic-multiview-views1_7_14_20-novelviews24 panoptic-multiview-views1_7_14_20-novelviews27 dex-ycb-multiview-duster0123-novelviews4 dex-ycb-multiview-duster0123-novelviews5 dex-ycb-multiview-duster0123-novelviews6 dex-ycb-multiview-duster0123-novelviews7 dex-ycb-multiview-duster2345-novelviews7 dex-ycb-multiview-duster4567-novelviews7 dex-ycb-multiview-duster4567-novelviews0 #################################### ### ~~~ For noise experiment ~~~ ### #################################### kubric-multiview-v3-noise0cm kubric-multiview-v3-noise1cm kubric-multiview-v3-noise2cm kubric-multiview-v3-noise5cm kubric-multiview-v3-noise10cm kubric-multiview-v3-noise20cm kubric-multiview-v3-noise50cm kubric-multiview-v3-noise100cm kubric-multiview-v3-noise200cm kubric-multiview-v3-noise1000cm ) # Compute number of jobs needed NUM_CKPTS=${#CKPTS[@]} NUM_DATASETS=${#DATASETS[@]} TOTAL_JOBS=$((NUM_CKPTS * NUM_DATASETS)) # Check if SLURM_ARRAY_TASK_ID is valid if [ "$SLURM_ARRAY_TASK_ID" -ge "$TOTAL_JOBS" ]; then echo "Error: SLURM_ARRAY_TASK_ID=$SLURM_ARRAY_TASK_ID exceeds the max index $((TOTAL_JOBS-1))" exit 1 fi # Map SLURM_ARRAY_TASK_ID to checkpoint and dataset CKPT_INDEX=$((SLURM_ARRAY_TASK_ID % NUM_CKPTS)) DATASET_INDEX=$((SLURM_ARRAY_TASK_ID / NUM_CKPTS)) SELECTED_CKPT=${CKPTS[$CKPT_INDEX]} SELECTED_DATASET=${DATASETS[$DATASET_INDEX]} echo "Selected Checkpoint: $SELECTED_CKPT" echo "Selected Dataset: $SELECTED_DATASET" # Run the job with the extracted checkpoint & dataset srun -ul --container-writable --environment=my_pytorch_env numactl --membind=0-3 bash -c " source /users/fraji/venvs/spa10/bin/activate && CUDA_VISIBLE_DEVICES=0 TORCH_HOME=./checkpoints/.cache python eval.py $SELECTED_CKPT datasets.eval.names=[$SELECTED_DATASET] " ================================================ FILE: scripts/slurm/mvtracker-nodepthaugs.sh ================================================ #!/bin/bash #SBATCH --job-name=mvtracker_200000_june2025_cleandepths #SBATCH --nodes=2 #SBATCH --ntasks-per-node=4 #SBATCH --cpus-per-task=72 #SBATCH --gres=gpu:4 #SBATCH --mem=460000 #SBATCH --partition=normal #SBATCH --account=a136 #SBATCH --time=12:00:00 #SBATCH --dependency=singleton #SBATCH --mail-type=begin #SBATCH --mail-type=end #SBATCH --mail-user=frano.rajic@inf.ethz.ch #SBATCH --output=./logs/slurm_logs/%x-%j.out #SBATCH --error=./logs/slurm_logs/%x-%j.out #SBATCH --signal=USR1@60 set -euo pipefail set -x cat $0 DIR=$(realpath .) # Wrap the commands CMD=" source /users/fraji/venvs/spa10/bin/activate cd $DIR python train.py model=mvtracker \ trainer.num_steps=200000 \ trainer.eval_freq=10000 \ trainer.viz_freq=10000 \ trainer.save_ckpt_freq=500 \ trainer.lr=0.0005 \ datasets.train.traj_per_sample=2048 \ model.updatetransformer_type=cotracker2 \ reproducibility.seed=36 \ trainer.precision=bf16-mixed \ modes.do_initial_static_pretrain=false \ trainer.augment_train_iters=false \ model.apply_sigmoid_to_vis=false \ augmentations.variable_depth_type=false \ logging.log_wandb=true \ experiment_path=logs/${SLURM_JOB_NAME} " # Execute within the container srun -ul --environment=my_pytorch_env bash -c "$CMD" ================================================ FILE: scripts/slurm/mvtracker.sh ================================================ #!/bin/bash #SBATCH --job-name=mvtracker_200000_june2025 #SBATCH --nodes=2 #SBATCH --ntasks-per-node=4 #SBATCH --cpus-per-task=72 #SBATCH --gres=gpu:4 #SBATCH --mem=460000 #SBATCH --partition=normal #SBATCH --account=a136 #SBATCH --time=12:00:00 #SBATCH --dependency=singleton #SBATCH --mail-type=begin #SBATCH --mail-type=end #SBATCH --mail-user=frano.rajic@inf.ethz.ch #SBATCH --output=./logs/slurm_logs/%x-%j.out #SBATCH --error=./logs/slurm_logs/%x-%j.out #SBATCH --signal=USR1@60 set -euo pipefail set -x cat $0 DIR=$(realpath .) # Wrap the commands CMD=" source /users/fraji/venvs/spa10/bin/activate cd $DIR python train.py model=mvtracker \ trainer.num_steps=200000 \ trainer.eval_freq=10000 \ trainer.viz_freq=10000 \ trainer.save_ckpt_freq=500 \ trainer.lr=0.0005 \ datasets.train.traj_per_sample=2048 \ model.updatetransformer_type=cotracker2 \ reproducibility.seed=36 \ trainer.precision=bf16-mixed \ modes.do_initial_static_pretrain=false \ trainer.augment_train_iters=false \ model.apply_sigmoid_to_vis=false \ logging.log_wandb=true \ experiment_path=logs/${SLURM_JOB_NAME} " # Execute within the container srun -ul --environment=my_pytorch_env bash -c "$CMD" ================================================ FILE: scripts/slurm/spatracker.sh ================================================ #!/bin/bash #SBATCH --job-name=spatracker_monocular #SBATCH --nodes=8 #SBATCH --ntasks-per-node=4 #SBATCH --cpus-per-task=72 #SBATCH --gres=gpu:4 #SBATCH --mem=460000 #SBATCH --partition=normal #SBATCH --account=a-a136-1 #SBATCH --time=12:00:00 #SBATCH --dependency=singleton #SBATCH --mail-type=begin #SBATCH --mail-type=end #SBATCH --mail-user=frano.rajic@inf.ethz.ch #SBATCH --output=./logs/slurm_logs/%x-%j.out #SBATCH --error=./logs/slurm_logs/%x-%j.out #SBATCH --signal=USR1@60 set -euo pipefail set -x cat $0 DIR=$(realpath .) # Wrap the commands CMD=" source /users/fraji/venvs/spa10/bin/activate cd $DIR python train.py model=spatracker_monocular \ trainer.num_steps=200000 \ trainer.eval_freq=10000 \ trainer.viz_freq=10000 \ trainer.save_ckpt_freq=500 \ trainer.lr=0.001 \ datasets.train.traj_per_sample=512 \ reproducibility.seed=72 \ trainer.precision=bf16-mixed \ modes.do_initial_static_pretrain=true \ trainer.augment_train_iters=true \ experiment_path=logs/${SLURM_JOB_NAME} " # Execute within the container srun -ul --environment=my_pytorch_env bash -c "$CMD" ================================================ FILE: scripts/slurm/test_reproducibility.sh ================================================ #!/bin/bash #SBATCH --job-name=repro-test-mvtracker #SBATCH --nodes=2 #SBATCH --ntasks-per-node=4 #SBATCH --cpus-per-task=32 #SBATCH --gres=gpu:4 #SBATCH --mem=460000 #SBATCH --partition=normal #SBATCH --account=a-a03 #SBATCH --time=00:20:00 #SBATCH --output=./logs/slurm_logs/%x-%j.out set -euo pipefail set -x cat $0 DIR=$(realpath .) cd "$DIR" # Use job ID for run directory RUN1="logs/debug/test_repro_${SLURM_JOB_ID}_run1" RUN2="logs/debug/test_repro_${SLURM_JOB_ID}_run2" [[ ! -e "$RUN1" ]] || { echo "ERROR: $RUN1 already exists"; exit 1; } [[ ! -e "$RUN2" ]] || { echo "ERROR: $RUN2 already exists"; exit 1; } # Wrap the commands CMD=" source /users/fraji/venvs/spa10/bin/activate cd $DIR export CUBLAS_WORKSPACE_CONFIG=:4096:8 export PYTHONHASHSEED=0 # === Run 1 === python train.py +experiment=mvtracker_overfit \ datasets.eval.names=[] \ modes.tune_per_scene=false \ trainer.num_steps=10 \ reproducibility.deterministic=true \ dataset.train.num_workers=4 \ trainer.precision=32 \ experiment_path=$RUN1 # === Run 2 === python train.py +experiment=mvtracker_overfit \ datasets.eval.names=[] \ modes.tune_per_scene=false \ trainer.num_steps=10 \ reproducibility.deterministic=true \ dataset.train.num_workers=4 \ trainer.precision=32 \ experiment_path=$RUN2 " # Execute within the container srun -ul --environment=my_pytorch_env bash -c "$CMD" ================================================ FILE: scripts/slurm/triplane-128.sh ================================================ #!/bin/bash #SBATCH --job-name=spatracker_multiview_128 #SBATCH --nodes=8 #SBATCH --ntasks-per-node=4 #SBATCH --cpus-per-task=72 #SBATCH --gres=gpu:4 #SBATCH --mem=460000 #SBATCH --partition=normal #SBATCH --account=a-a136-1 #SBATCH --time=12:00:00 #SBATCH --dependency=singleton #SBATCH --mail-type=begin #SBATCH --mail-type=end #SBATCH --mail-user=frano.rajic@inf.ethz.ch #SBATCH --output=./logs/slurm_logs/%x-%j.out #SBATCH --error=./logs/slurm_logs/%x-%j.out #SBATCH --signal=USR1@60 set -euo pipefail set -x cat $0 DIR=$(realpath .) # Wrap the commands CMD=" source /users/fraji/venvs/spa10/bin/activate cd $DIR python train.py model=spatracker_multiview model.triplane_xres=128 model.triplane_yres=128 model.triplane_zres=128 \ trainer.num_steps=200000 \ trainer.eval_freq=10000 \ trainer.viz_freq=10000 \ trainer.save_ckpt_freq=500 \ trainer.lr=0.001 \ datasets.train.traj_per_sample=768 \ reproducibility.seed=36 \ trainer.precision=bf16-mixed \ modes.do_initial_static_pretrain=true \ trainer.augment_train_iters=true \ experiment_path=logs/${SLURM_JOB_NAME} " # Execute within the container srun -ul --environment=my_pytorch_env bash -c "$CMD" ================================================ FILE: scripts/slurm/triplane-256.sh ================================================ #!/bin/bash #SBATCH --job-name=spatracker_multiview_256 #SBATCH --nodes=8 #SBATCH --ntasks-per-node=4 #SBATCH --cpus-per-task=72 #SBATCH --gres=gpu:4 #SBATCH --mem=460000 #SBATCH --partition=normal #SBATCH --account=a-a136-1 #SBATCH --time=12:00:00 #SBATCH --dependency=singleton #SBATCH --mail-type=begin #SBATCH --mail-type=end #SBATCH --mail-user=frano.rajic@inf.ethz.ch #SBATCH --output=./logs/slurm_logs/%x-%j.out #SBATCH --error=./logs/slurm_logs/%x-%j.out #SBATCH --signal=USR1@60 set -euo pipefail set -x cat $0 DIR=$(realpath .) # Wrap the commands CMD=" source /users/fraji/venvs/spa10/bin/activate cd $DIR python train.py model=spatracker_multiview model.triplane_xres=256 model.triplane_yres=256 model.triplane_zres=128 \ trainer.num_steps=200000 \ trainer.eval_freq=10000 \ trainer.viz_freq=10000 \ trainer.save_ckpt_freq=500 \ trainer.lr=0.001 \ datasets.train.traj_per_sample=384 \ reproducibility.seed=36 \ trainer.precision=bf16-mixed \ modes.do_initial_static_pretrain=true \ trainer.augment_train_iters=true \ experiment_path=logs/${SLURM_JOB_NAME} " # Execute within the container srun -ul --environment=my_pytorch_env bash -c "$CMD" ================================================ FILE: scripts/summarize_eval_results.py ================================================ import os import re import warnings import pandas as pd REMAP_KUBRIC = { "Method": ("", "Method"), "average_jaccard__dynamic": ("Dynamic Points (motion > 0.1)", "Jacc."), "jaccard_0.05__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.05"), "jaccard_0.10__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.1"), "jaccard_0.20__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.2"), "jaccard_0.40__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.4"), "jaccard_0.80__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.8"), "average_pts_within_thresh__dynamic": ("Dynamic Points (motion > 0.1)", "Loc."), "pts_within_0.05__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.05"), "pts_within_0.10__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.1"), "pts_within_0.20__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.2"), "pts_within_0.40__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.4"), "pts_within_0.80__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.8"), "survival__dynamic": ("Dynamic Points (motion > 0.1)", "Surv."), "occlusion_accuracy__dynamic": ("Dynamic Points (motion > 0.1)", "OA"), "mte_visible__dynamic": ("Dynamic Points (motion > 0.1)", "MTE"), "ate_visible__dynamic": ("Dynamic Points (motion > 0.1)", "ATE"), "fde_visible__dynamic": ("Dynamic Points (motion > 0.1)", "FDE"), "n__dynamic": ("Dynamic Points (motion > 0.1)", "n"), "v__dynamic": ("Dynamic Points (motion > 0.1)", "v"), "average_jaccard__very_dynamic": ("Very Dynamic", "Jacc."), "average_pts_within_thresh__very_dynamic": ("Very Dynamic", "Loc."), "survival__very_dynamic": ("Very Dynamic", "Surv."), "occlusion_accuracy__very_dynamic": ("Very Dynamic", "OA"), "mte_visible__very_dynamic": ("Very Dynamic", "MTE"), "average_jaccard__static": ("Static Points (motion < 0.01)", "Jacc."), "average_pts_within_thresh__static": ("Static Points (motion < 0.01)", "Loc."), "survival__static": ("Static Points (motion < 0.01)", "Surv."), "occlusion_accuracy__static": ("Static Points (motion < 0.01)", "OA"), "mte_visible__static": ("Static Points (motion < 0.01)", "MTE"), "average_jaccard__any": ("Any Points", "Jacc."), "average_pts_within_thresh__any": ("Any Points", "Loc."), "survival__any": ("Any Points", "Surv."), "occlusion_accuracy__any": ("Any Points", "OA"), "mte_visible__any": ("Any Points", "MTE"), "n_iters": ("", "#iters"), } REMAP_DEXYCB_V1 = { "Method": ("", "Method"), "average_jaccard__dynamic": ("Dynamic Points (motion > 0.1)", "Jacc."), "jaccard_0.01__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.01"), "jaccard_0.02__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.02"), "jaccard_0.05__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.05"), "jaccard_0.10__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.10"), "jaccard_0.20__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.20"), "average_pts_within_thresh__dynamic": ("Dynamic Points (motion > 0.1)", "Loc."), "pts_within_0.01__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.01"), "pts_within_0.02__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.02"), "pts_within_0.05__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.05"), "pts_within_0.10__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.10"), "pts_within_0.20__dynamic": ("Dynamic Points (motion > 0.1)", "< 0.20"), "survival__dynamic": ("Dynamic Points (motion > 0.1)", "Surv."), "occlusion_accuracy__dynamic": ("Dynamic Points (motion > 0.1)", "OA"), "mte_visible__dynamic": ("Dynamic Points (motion > 0.1)", "MTE"), "ate_visible__dynamic": ("Dynamic Points (motion > 0.1)", "ATE"), "fde_visible__dynamic": ("Dynamic Points (motion > 0.1)", "FDE"), "n__dynamic": ("Dynamic Points (motion > 0.1)", "n"), "v__dynamic": ("Dynamic Points (motion > 0.1)", "v"), "average_jaccard__very_dynamic": ("Very Dynamic", "Jacc."), "average_pts_within_thresh__very_dynamic": ("Very Dynamic", "Loc."), "survival__very_dynamic": ("Very Dynamic", "Surv."), "occlusion_accuracy__very_dynamic": ("Very Dynamic", "OA"), "average_jaccard__static": ("Static Points (motion < 0.01)", "Jacc."), "average_pts_within_thresh__static": ("Static Points (motion < 0.01)", "Loc."), "survival__static": ("Static Points (motion < 0.01)", "Surv."), "occlusion_accuracy__static": ("Static Points (motion < 0.01)", "OA"), "average_jaccard__any": ("Any Points", "Jacc."), "average_pts_within_thresh__any": ("Any Points", "Loc."), "survival__any": ("Any Points", "Surv."), "occlusion_accuracy__any": ("Any Points", "OA"), "n_iters": ("", "#iters"), } # Initialize remapping dictionary with the correct order REMAP_DEXYCB_V2 = {} REMAP_DEXYCB_V2["Method"] = ("", "Method") # Define ordered point categories (dynamic first, then very dynamic, static, and any) POINT_TYPES = { "dynamic": "Dynamic Points (motion > 0.1)", "very_dynamic": "Very Dynamic", "static": "Static Points (motion < 0.01)", "any": "Any Points", "dynamic-static-mean": "Dynamic+Static Points Mean", } METRICS = { "average_jaccard": "Jacc.", "jaccard": "<{threshold}", "average_pts_within_thresh": "Loc.", "pts_within": "<{threshold}", "survival": "Surv.", "occlusion_accuracy": "OA", "occlusion_accuracy_for_vis0": "OA(v=0)", "occlusion_accuracy_for_vis1": "OA(v=1)", "mte_visible": "MTE", "ate_visible": "ATE", "fde_visible": "FDE", "n": "n", "v": "v" } THRESHOLDS = ["0.01", "0.02", "0.05", "0.10", "0.20"] for pt_key, pt_label in POINT_TYPES.items(): for metric, metric_label in METRICS.items(): if metric in ["jaccard", "pts_within"]: # Threshold-based metrics for thresh in THRESHOLDS: REMAP_DEXYCB_V2[f"{metric}_{thresh}__{pt_key}"] = (pt_label, metric_label.format(threshold=thresh)) else: # Regular metrics REMAP_DEXYCB_V2[f"{metric}__{pt_key}"] = (pt_label, metric_label) REMAP_DEXYCB_V2["n_iters"] = ("", "#iters") REMAP_TAPVID2D_INDEX_NAMES = ["Metric Definition", "Metric"] REMAP_TAPVID2D = { "Method": ("", "Method",), "average_jaccard__any": ("Our Metrics", "Jacc.",), "jaccard_1.00__any": ("Our Metrics", "< 1",), "jaccard_2.00__any": ("Our Metrics", "< 2",), "jaccard_4.00__any": ("Our Metrics", "< 4",), "jaccard_8.00__any": ("Our Metrics", "< 8",), "jaccard_16.00__any": ("Our Metrics", "< 16",), "average_pts_within_thresh__any": ("Our Metrics", "Loc.",), "pts_within_1.00__any": ("Our Metrics", "< 1",), "pts_within_2.00__any": ("Our Metrics", "< 2",), "pts_within_4.00__any": ("Our Metrics", "< 4",), "pts_within_8.00__any": ("Our Metrics", "< 8",), "pts_within_16.00__any": ("Our Metrics", "< 16",), "survival__any": ("Our Metrics", "Surv.",), "occlusion_accuracy__any": ("Our Metrics", "OA",), "occlusion_accuracy_for_vis0__any": ("Our Metrics", "OA(v=0)",), "occlusion_accuracy_for_vis1__any": ("Our Metrics", "OA(v=1)",), "mte_visible__any": ("Our Metrics", "MTE",), "ate_visible__any": ("Our Metrics", "ATE",), "fde_visible__any": ("Our Metrics", "FDE",), "n__any": ("Our Metrics", "n",), "v__any": ("Our Metrics", "v",), "tapvid2d_average_jaccard": ("TAPVid-2D Metrics", "Jacc.",), "tapvid2d_jaccard_1": ("TAPVid-2D Metrics", "< 1",), "tapvid2d_jaccard_2": ("TAPVid-2D Metrics", "< 2",), "tapvid2d_jaccard_4": ("TAPVid-2D Metrics", "< 4",), "tapvid2d_jaccard_8": ("TAPVid-2D Metrics", "< 8",), "tapvid2d_jaccard_16": ("TAPVid-2D Metrics", "< 16",), "tapvid2d_average_pts_within_thresh": ("TAPVid-2D Metrics", "Loc.",), "tapvid2d_pts_within_1": ("TAPVid-2D Metrics", "< 1",), "tapvid2d_pts_within_2": ("TAPVid-2D Metrics", "< 2",), "tapvid2d_pts_within_4": ("TAPVid-2D Metrics", "< 4",), "tapvid2d_pts_within_8": ("TAPVid-2D Metrics", "< 8",), "tapvid2d_pts_within_16": ("TAPVid-2D Metrics", "< 16",), "tapvid2d_occlusion_accuracy": ("TAPVid-2D Metrics", "OA",), "n_iters": ("", "#iters",), } REMAP_PANOPTIC = {} REMAP_PANOPTIC["Method"] = ("", "Method") for pt_key in ["any"]: pt_label = POINT_TYPES[pt_key] for metric, metric_label in METRICS.items(): if metric in ["jaccard", "pts_within"]: # Threshold-based metrics for thresh in ["0.05", "0.10", "0.20", "0.40"]: REMAP_PANOPTIC[f"{metric}_{thresh}__{pt_key}"] = (pt_label, metric_label.format(threshold=thresh)) else: # Regular metrics REMAP_PANOPTIC[f"{metric}__{pt_key}"] = (pt_label, metric_label) REMAP_PANOPTIC["n_iters"] = ("", "#iters") PARTIAL_REMAP_FOR_2DPT_ABLATION = {} for pt_key, pt_label in POINT_TYPES.items(): for metric, metric_label in METRICS.items(): if "jaccard" in metric or "occlusion" in metric: continue if metric in ["jaccard", "pts_within"]: # Threshold-based metrics for thresh in ["1.00", "2.00", "4.00", "8.00", "16.00"]: PARTIAL_REMAP_FOR_2DPT_ABLATION[f"2dpt__{metric}_{thresh}__{pt_key}"] = ( "(2DPT) " + pt_label, metric_label.format(threshold=thresh) ) else: # Regular metrics PARTIAL_REMAP_FOR_2DPT_ABLATION[f"2dpt__{metric}__{pt_key}"] = ("(2DPT) " + pt_label, metric_label) for logged_key, (pt_label, metric_label) in REMAP_TAPVID2D.items(): if "jaccard" in logged_key or "occlusion" in logged_key: continue if "tapvid2d" not in logged_key: continue PARTIAL_REMAP_FOR_2DPT_ABLATION[f"2dpt__{logged_key}"] = ("(2DPT) " + pt_label, metric_label) REMAP_2DPT_ABLATION = REMAP_KUBRIC | PARTIAL_REMAP_FOR_2DPT_ABLATION ONE_REMAP_TO_RULE_THEM_ALL = {} ONE_REMAP_TO_RULE_THEM_ALL["Method"] = ("", "Method") ONE_REMAP_TO_RULE_THEM_ALL["Dataset"] = ("", "Dataset") THRESHOLDS = ["0.01", "0.02", "0.05", "0.10", "0.20", "0.40"] for pt_key, pt_label in POINT_TYPES.items(): for metric, metric_label in METRICS.items(): if metric in ["jaccard", "pts_within"]: # Threshold-based metrics for thresh in THRESHOLDS: ONE_REMAP_TO_RULE_THEM_ALL[f"{metric}_{thresh}__{pt_key}"] = ( pt_label, metric_label.format(threshold=thresh)) else: # Regular metrics ONE_REMAP_TO_RULE_THEM_ALL[f"{metric}__{pt_key}"] = (pt_label, metric_label) ONE_REMAP_TO_RULE_THEM_ALL["n_iters"] = ("", "#iters") def find_file_with_max_steps(folder): if not os.path.isdir(folder): return None, -1 pattern = re.compile(r"step-(\d+)_metrics_avg.csv") max_steps = -1 max_file = None for filename in os.listdir(folder): m = pattern.search(filename) if m: steps = int(m.group(1)) if steps > max_steps: max_steps = steps max_file = filename return max_file, max_steps def create_table( method_name_to_csv_path, remap=REMAP_KUBRIC, remap_index_names=["Type", "Metric"], header=True, skip_missing=False, ): assert len(method_name_to_csv_path) > 0, "No CSV files provided" rows = [] order = [] for method_name, path in method_name_to_csv_path.items(): if "step-?_" in path: filename, n_iters = find_file_with_max_steps(os.path.dirname(path)) if filename is None: warnings.warn(f"No CSV files found in {os.path.dirname(path)}") continue path = os.path.join(os.path.dirname(path), filename) if not os.path.exists(path): if skip_missing: warnings.warn(f"Skipping missing file: {path}") continue raise FileNotFoundError(f"File not found: {path}") df = pd.read_csv(path, header=None, names=["Metric", "Value"]) df = df.dropna(subset=["Metric"]).reset_index(drop=True) if type(method_name) == tuple: method_name, dataset_name = method_name else: dataset_name = os.path.basename(os.path.dirname(path)).replace("eval_", "") df["Method"] = method_name match = re.search(r"step-(\d+)", path) n_iters = int(match.group(1)) if match else 0 df.loc[len(df)] = ["n_iters", n_iters, method_name] df["Metric"] = df["Metric"].str.split("/").str[-1].str.replace("model__", "") df["Dataset"] = dataset_name rows.append(df) order.append((method_name, dataset_name)) combined_df = pd.concat(rows) pivot_df = combined_df.pivot(index=["Method", "Dataset"], columns="Metric", values="Value").reset_index() pivot_df = pivot_df.set_index(["Method", "Dataset"]).reindex(order).reset_index() # Define a mapping for the new names for k in remap.keys(): if k not in pivot_df.columns: pivot_df[k] = None pivot_df = pivot_df.copy() # To avoid "DataFrame is highly fragmented" warning pivot_df = pivot_df[remap.keys()] multi_index = pd.MultiIndex.from_tuples( tuples=[remap[col] for col in pivot_df.columns], names=remap_index_names, ) pivot_df.columns = multi_index return pivot_df, pivot_df.to_csv(index=False, header=header) def kubric_single_point(): print("Kubric single-point evaluation results:") print("================================") df, csv_str = create_table({ # ls logs/kubric_v3/*/eval_kubric-multiview-v3-single/step-*_kubric-multiview-v3-single_metrics_avg.csv | cat "SpaTracker (pretrained)": "logs/kubric_v3/multiview-adapter-pretrained-004/eval_kubric-multiview-v3-single/step--1_kubric-multiview-v3-single_metrics_avg.csv", "SpaTracker (single-view baseline)": "logs/kubric_v3/multiview-adapter-002/eval_kubric-multiview-v3-single/step-69799_kubric-multiview-v3-single_metrics_avg.csv", "Multi-view-V1 (ours)": "logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_kubric-multiview-v3-single/step-99999_kubric-multiview-v3-single_metrics_avg.csv", "Multi-view-V2 (ours)": "logs/kubric_v3/multiview-v2-002--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-single/step-91599_kubric-multiview-v3-single_metrics_avg.csv", }) print(csv_str) def kubric_before_gt0123(): print("Kubric multi-point evaluation results:") print("================================") df, csv_str = create_table({ # ls logs/kubric_v3/*/eval_kubric-multiview-v3/step-*_kubric-multiview-v3_metrics_avg.csv | cat "CopyCat (No motion baseline)": "logs/copycat/eval_kubric-multiview-v3/step--1_kubric-multiview-v3_metrics_avg.csv", "SpaTracker (pretrained)": "logs/kubric_v3/multiview-adapter-pretrained-004/eval_kubric-multiview-v3/step--1_kubric-multiview-v3_metrics_avg.csv", "SpaTracker (single-view baseline)": "logs/kubric_v3/multiview-adapter-002/eval_kubric-multiview-v3/step-69799_kubric-multiview-v3_metrics_avg.csv", "Multi-view-V1 (ours)": "logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_kubric-multiview-v3/step-99999_kubric-multiview-v3_metrics_avg.csv", "Multi-view-V2 (ours)": "logs/kubric_v3/multiview-v2-002--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3/step-91599_kubric-multiview-v3_metrics_avg.csv", "Multi-view-V1 (ours) (128; trained on D4)": "logs/kubric_v3_duster0123/multiview-v1-with-128-triplane-001/eval_kubric-multiview-v3/step-100000_kubric-multiview-v3_metrics_avg.csv", "Multi-view-V1 (ours) (256; trained on D4)": "logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_kubric-multiview-v3/step-100000_kubric-multiview-v3_metrics_avg.csv", "Multi-view-V2 (ours) (finetuned on D4c)": "logs/kubric_v3_duster0123/multiview-v2-pretrained-cleaned-003--lr-2.5e-4--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3/step-10000_kubric-multiview-v3_metrics_avg.csv", "Multi-view-V2 (ours) (trained on D4)": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3/step-99999_kubric-multiview-v3_metrics_avg.csv", }) print(csv_str) def kubric(): print("Kubric multi-point evaluation results:") print("================================") df, csv_str = create_table({ "CopyCat (No motion baseline)": "logs/copycat/eval_kubric-multiview-v3-gt0123/step--1_kubric-multiview-v3-gt0123_metrics_avg.csv", "SpaTracker (pretrained)": "logs/kubric_v3/multiview-adapter-pretrained-004/eval_kubric-multiview-v3-gt0123/step--1_kubric-multiview-v3-gt0123_metrics_avg.csv", "SpaTracker (single-view baseline)": "logs/kubric_v3/multiview-adapter-002/eval_kubric-multiview-v3-gt0123/step-69799_kubric-multiview-v3-gt0123_metrics_avg.csv", "Multi-view-V1 (ours)": "logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_kubric-multiview-v3-gt0123/step-99999_kubric-multiview-v3-gt0123_metrics_avg.csv", "Multi-view-V2 (ours)": "logs/kubric_v3/multiview-v2-002--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-gt0123/step-91599_kubric-multiview-v3-gt0123_metrics_avg.csv", "Multi-view-V1 (ours) (128; trained on D4)": "logs/kubric_v3_duster0123/multiview-v1-with-128-triplane-001/eval_kubric-multiview-v3-gt0123/step-99999_kubric-multiview-v3-gt0123_metrics_avg.csv", "Multi-view-V1 (ours) (256; trained on D4)": "logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_kubric-multiview-v3-gt0123/step-99999_kubric-multiview-v3-gt0123_metrics_avg.csv", "Multi-view-V2 (ours) (finetuned on D4c)": "logs/kubric_v3_duster0123/multiview-v2-pretrained-cleaned-003--lr-2.5e-4--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-gt0123/step-9999_kubric-multiview-v3-gt0123_metrics_avg.csv", "Multi-view-V2 (ours) (trained on D4)": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-gt0123/step-99999_kubric-multiview-v3-gt0123_metrics_avg.csv", "Multi-view-V2 (ours) (trained on D4c)": "logs/kubric_v3_duster0123/multiview-v2-cleaned-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-gt0123/step-99999_metrics_avg.csv", "Multi-view-V3 (ours) (finetuned^2 on D4c;s=4)": "logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-gt0123/step-9999_kubric-multiview-v3-gt0123_metrics_avg.csv", }) print(csv_str) def kubric_duster(): print("Kubric multi-point evaluation results, Duster0123:") print("================================") df, csv_str = create_table({ # ls logs/kubric_v3/*/eval_kubric-multiview-v3-duster0123/step-*_kubric-multiview-v3-duster0123_metrics_avg.csv | cat "CopyCat (No motion baseline)": "logs/copycat/eval_kubric-multiview-v3-duster0123/step--1_kubric-multiview-v3-duster0123_metrics_avg.csv", "SpaTracker (pretrained)": "logs/kubric_v3/multiview-adapter-pretrained-004/eval_kubric-multiview-v3-duster0123/step--1_kubric-multiview-v3-duster0123_metrics_avg.csv", "SpaTracker (single-view baseline)": "logs/kubric_v3/multiview-adapter-002/eval_kubric-multiview-v3-duster0123/step-69799_kubric-multiview-v3-duster0123_metrics_avg.csv", "Multi-view-V1 (ours)": "logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_kubric-multiview-v3-duster0123/step-99999_kubric-multiview-v3-duster0123_metrics_avg.csv", "Multi-view-V2 (ours)": "logs/kubric_v3/multiview-v2-002--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123/step-91599_kubric-multiview-v3-duster0123_metrics_avg.csv", "SpaTracker (single-view baseline) (trained on D4)": "logs/kubric_v3_duster0123/multiview-adapter-001/eval_kubric-multiview-v3-duster0123/step-90000_kubric-multiview-v3-duster0123_metrics_avg.csv", "Multi-view-V1 (ours) (128; trained on D4)": "logs/kubric_v3_duster0123/multiview-v1-with-128-triplane-001/eval_kubric-multiview-v3-duster0123/step-100000_kubric-multiview-v3-duster0123_metrics_avg.csv", "Multi-view-V1 (ours) (256; trained on D4)": "logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_kubric-multiview-v3-duster0123/step-100000_kubric-multiview-v3-duster0123_metrics_avg.csv", "Multi-view-V2 (ours) (finetuned on D4)": "logs/kubric_v3_duster0123/multiview-v2-pretrained-cleaned-003--lr-2.5e-4--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123/step-10000_kubric-multiview-v3-duster0123_metrics_avg.csv", "Multi-view-V2 (ours) (trained on D4)": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123/step-100000_kubric-multiview-v3-duster0123_metrics_avg.csv", # "Multi-view-V2 (ours) (trained on D4c)": "logs/kubric_v3_duster0123/multiview-v2-cleaned-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123/step-70000_kubric-multiview-v3-duster0123_metrics_avg.csv", # "Multi-view-V3 (ours) (finetuned^2 on D4c;s=4)": "logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123/step-10000_kubric-multiview-v3-duster0123_metrics_avg.csv", }) print(csv_str) df, csv_str = create_table({ # "SpaTracker (single-view baseline) (trained on D4)": "logs/kubric_v3_duster0123/multiview-adapter-001/eval_kubric-multiview-v3-duster0123cleaned/step-90000_kubric-multiview-v3-duster0123cleaned_metrics_avg.csv", # "Multi-view-V1 (ours) (128; trained on D4)": "logs/kubric_v3_duster0123/multiview-v1-with-128-triplane-001/eval_kubric-multiview-v3-duster0123cleaned/step-100000_kubric-multiview-v3-duster0123cleaned_metrics_avg.csv", # "Multi-view-V1 (ours) (256; trained on D4)": "logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_kubric-multiview-v3-duster0123cleaned/step-100000_kubric-multiview-v3-duster0123cleaned_metrics_avg.csv", # "Multi-view-V2 (ours) (finetuned on D4)": "logs/kubric_v3_duster0123/multiview-v2-pretrained-cleaned-003--lr-2.5e-4--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123cleaned/step-10000_kubric-multiview-v3-duster0123cleaned_metrics_avg.csv", # "Multi-view-V2 (ours) (trained on D4)": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123cleaned/step-100000_kubric-multiview-v3-duster0123cleaned_metrics_avg.csv", "Multi-view-V2 (ours) (trained on D4c)": "logs/kubric_v3_duster0123/multiview-v2-cleaned-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123cleaned/step-70000_kubric-multiview-v3-duster0123cleaned_metrics_avg.csv", "Multi-view-V3 (ours) (finetuned^2 on D4c;s=4)": "logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-10000_kubric-multiview-v3-duster0123cleaned_metrics_avg.csv", # "Multi-view-V3 (ours) (trained on D4c;s=4)": "TBD", # "Multi-view-V3 (ours) (trained on D4c;s=16)": "TBD", }) def mv3_kubric_duster_transformed(): print("Kubric transformed, Duster0123, Multi-view-V3 (ours) (finetuned^2 on D4c;s=4):") print("================================") df, csv_str = create_table({ "Kubric (no world space transformations)": f"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.00_no_transform.csv", "Kubric (translated by z-10)": f"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.01_translate_z-10.csv", "Kubric (translated by x+4, y+4)": f"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.02a_translate_x+4_y+4.csv", "Kubric (translated by x+10, y+10)": f"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.02b_translate_x+10_y+10.csv", "Kubric (rotated x+90)": f"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.03_rotate_x+90.csv", "Kubric (rotated y+90)": f"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.04_rotate_y+90.csv", "Kubric (rotated z+90)": f"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.05_rotate_z+90.csv", "Kubric (scaled down 2x)": f"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.06_scale_down_2x.csv", "Kubric (scaled down 8x)": f"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.07_scale_down_8x.csv", "Kubric (scaled up 2x)": f"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.08_scale_up_2x.csv", "Kubric (scaled up 8x)": f"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.09_scale_up_8x.csv", }) "" print(csv_str) def mv3_kubric_nviews(): print("eval_kubric-multiview-v3-views..., Multi-view-V2 (ours) (trained on D4):") print("================================") df, csv_str = create_table({ "1": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views0/step-99999_metrics_avg.csv", "2": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views01/step-99999_metrics_avg.csv", "3": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views012/step-99999_metrics_avg.csv", "4": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views0123/step-99999_metrics_avg.csv", "5": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views01234/step-99999_metrics_avg.csv", "6": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views012345/step-99999_metrics_avg.csv", "7": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views0123456/step-99999_metrics_avg.csv", "8": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views01234567/step-99999_metrics_avg.csv", "9": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views012345678/step-99999_metrics_avg.csv", "10": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views0123456789/step-99999_metrics_avg.csv", }) print(csv_str) def mv3_kubric_duster_nviews(): print("eval_kubric-multiview-v3-duster0123-views..., Multi-view-V2 (ours) (trained on D4):") print("================================") df, csv_str = create_table({ "1": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123-views0/step-99999_metrics_avg.csv", "2": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123-views01/step-99999_metrics_avg.csv", "3": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123-views012/step-99999_metrics_avg.csv", "4": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123-views0123/step-99999_metrics_avg.csv", "5": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster01234567-views01234/step-99999_metrics_avg.csv", "6": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster01234567-views012345/step-99999_metrics_avg.csv", "7": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster01234567-views0123456/step-99999_metrics_avg.csv", "8": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster01234567-views01234567/step-99999_metrics_avg.csv", }) print(csv_str) def kubric_nviews(): print("=" * 80) print("=" * 80) print("=" * 80) method_name_to_csv_path_template = { "CopyCat (No motion baseline),{}": "logs/copycat/eval_{}/step--1_metrics_avg.csv", "SpaTracker (pretrained),{}": "logs/kubric_v3/multiview-adapter-pretrained-004/eval_{}/step--1_metrics_avg.csv", "SpaTracker (single-view baseline),{}": "logs/kubric_v3/multiview-adapter-002/eval_{}/step-69799_metrics_avg.csv", "Multi-view-V1 (ours),{}": "logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_{}/step-99999_metrics_avg.csv", # "Multi-view-V2 (ours),{}": "logs/kubric_v3/multiview-v2-002--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_{}/step-91599_metrics_avg.csv", "SpaTracker (single-view baseline) (trained on D4),{}": "logs/kubric_v3_duster0123/multiview-adapter-001/eval_{}/step-logs/kubric_v3_duster0123/multiview-adapter-001_metrics_avg.csv", # "Multi-view-V1 (ours) (128; trained on D4),{}": "logs/kubric_v3_duster0123/multiview-v1-with-128-triplane-001/eval_{}/step-99999_metrics_avg.csv", "Multi-view-V1 (ours) (256; trained on D4),{}": "logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_{}/step-99999_metrics_avg.csv", # "Multi-view-V2 (ours) (finetuned on D4c),{}": "logs/kubric_v3_duster0123/multiview-v2-pretrained-cleaned-003--lr-2.5e-4--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_{}/step-9999_metrics_avg.csv", "Multi-view-V2 (ours) (trained on D4),{}": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_{}/step-99999_metrics_avg.csv", # "Multi-view-V2 (ours) (trained on D4c),{}": "logs/kubric_v3_duster0123/multiview-v2-cleaned-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_{}/step-99999_metrics_avg.csv", # "Multi-view-V3 (ours) (finetuned^2 on D4c;s=4),{}": "logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_{}/step-9999_metrics_avg.csv", } method_name_to_csv_path_per_dataset = {} for dataset_prefix in [ "kubric-multiview-v3-views", "kubric-multiview-v3-duster0123-views", "kubric-multiview-v3-duster01234567-views", "kubric-multiview-v3-duster0123cleaned-views", "kubric-multiview-v3-duster01234567cleaned-views", ]: method_name_to_csv_path_per_dataset[dataset_prefix] = {} for method_name_template, csv_path_template in method_name_to_csv_path_template.items(): for n in range(8): if ("-duster0123-" in dataset_prefix or "-duster0123cleaned-" in dataset_prefix) and n > 4: continue if ("-duster01234567-" in dataset_prefix or "-duster01234567cleaned-" in dataset_prefix) and n < 5: continue dataset = dataset_prefix + "".join(str(i) for i in range(n + 1)) method_name = method_name_template.format(n + 1) csv_path = csv_path_template.format(dataset) assert method_name not in method_name_to_csv_path_per_dataset[ dataset_prefix], f"Duplicate method name: {method_name}" method_name_to_csv_path_per_dataset[dataset_prefix][method_name] = csv_path for dataset_prefix, method_name_to_csv_path in method_name_to_csv_path_per_dataset.items(): print(method_name_to_csv_path) print(f"Kubric multi-point evaluation results, {dataset_prefix}:") print("================================") df, csv_str = create_table(method_name_to_csv_path) print(csv_str) MODELS = { "copycat": { "name": "CopyCat (No motion baseline)", "csv": "logs/copycat/eval_{dataset}/step--1_metrics_avg.csv", }, "cotracker3": { "name": "CoTracker3 Offline (x)", "csv": "logs/cotracker3/eval_{dataset}/step--1_metrics_avg.csv", }, "cotracker3offline": { "name": "CoTracker3 Offline", "csv": "logs/cotracker3-offline/eval_{dataset}/step--1_metrics_avg.csv", }, "cotracker3online": { "name": "CoTracker3 Online", "csv": "logs/cotracker3-online/eval_{dataset}/step--1_metrics_avg.csv", }, "cotracker2offline": { "name": "CoTracker2 Offline", "csv": "logs/cotracker2-offline/eval_{dataset}/step--1_metrics_avg.csv", }, "cotracker2online": { "name": "CoTracker2 Online", "csv": "logs/cotracker2-online/eval_{dataset}/step--1_metrics_avg.csv", }, "cotracker1offline": { "name": "CoTracker1 Offline", "csv": "logs/cotracker1-offline/eval_{dataset}/step--1_metrics_avg.csv", }, "cotracker1online": { "name": "CoTracker1 Online", "csv": "logs/cotracker1-online/eval_{dataset}/step--1_metrics_avg.csv", }, "delta": { "name": "DELTA", "csv": "logs/delta/eval_{dataset}/step--1_metrics_avg.csv", }, "locotrack": { "name": "LocoTrack", "csv": "logs/locotrack/eval_{dataset}/step--1_metrics_avg.csv", }, "scenetracker": { "name": "SceneTracker", "csv": "logs/scenetracker/eval_{dataset}/step--1_metrics_avg.csv", }, "spatracker-pretrained": { "name": "SpaTracker (pretrained)", "csv": "logs/kubric_v3_duster0123/multiview-adapter-pretrained-001/eval_{dataset}/step--1_metrics_avg.csv", }, "spatracker": { "name": "SpaTracker (single-view baseline)", "csv": "logs/kubric_v3/multiview-adapter-002/eval_{dataset}/step-69799_metrics_avg.csv", }, "mv1": { "name": "Multi-view-V1 (ours)", "csv": "logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_{dataset}/step-99999_metrics_avg.csv", }, "mv2": { "name": "Multi-view-V2 (ours)", "csv": "logs/kubric_v3/multiview-v2-002--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_{dataset}/step-91599_metrics_avg.csv", }, "spatracker-d4": { "name": "SpaTracker (single-view baseline) (trained on D4)", "csv": "logs/kubric_v3_duster0123/multiview-adapter-001/eval_{dataset}/step-90799_metrics_avg.csv", }, "mv1-d4": { "name": "Multi-view-V1 (ours) (256; trained on D4)", "csv": "logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_{dataset}/step-99999_metrics_avg.csv", }, "mv2-d4": { "name": "Multi-view-V2 (ours) (trained on D4)", "csv": "logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_{dataset}/step-99999_metrics_avg.csv", }, "mv3-d4c": { "name": "Multi-view-V3 (ours) (finetuned^2 on D4c;s=4)", "csv": "logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_{dataset}/step-9999_metrics_avg.csv", }, "mv4-a07": { "name": "Multi-view-V4 (ours) (A07)", "csv": "logs/kubric_v3_augs/multiview-v4-A07.augs_4.002/eval_{dataset}/step-25599_metrics_avg.csv", }, "mv4-b01": { "name": "Multi-view-V4 (ours) (B01)", "csv": "logs/kubric_v3_augs/multiview-v4-B01.vary_n_views.004/eval_{dataset}/step-199999_metrics_avg.csv", }, "mv4-b02": { "name": "Multi-view-V4 (ours) (B02)", "csv": "logs/kubric_v3_augs/multiview-v4-B02.vary_depth_type.002a/eval_{dataset}/step-199999_metrics_avg.csv", }, "mv4-b03": { "name": "Multi-view-V4 (ours) (B03)", "csv": "logs/kubric_v3_augs/multiview-v4-B03.vary_both.004/eval_{dataset}/step-?_metrics_avg.csv", }, "mv4-b03-paper": { "name": "Multi-view-V4 (ours) (B03 paper ckpt)", "csv": "logs/kubric_v3_augs/multiview-v4-B03.vary_both.004/eval_{dataset}/step-153999_metrics_avg.csv", }, # # "C01.001.0" : { # # "name": "Ablation (C01) – Offset 1 AddXYZ 0 K 16 P 4", # # "csv": "logs/kubric_v3_augs/ablate-correlation.001.0_K-16_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", # # }, # "C01.001.1" : { # "name": "Ablation (C01) – Offset 0 AddXYZ 0", # "csv": "logs/kubric_v3_augs/ablate-correlation.001.1_K-16_FMAP-128_PYR-4_KNN-remove_offset/eval_{dataset}/step-?_metrics_avg.csv", # }, # "C01.001.2" : { # "name": "Ablation (C01) – Offset 1 AddXYZ 1", # "csv": "logs/kubric_v3_augs/ablate-correlation.001.2_K-16_FMAP-128_PYR-4_KNN-add_neighbor_xyz/eval_{dataset}/step-?_metrics_avg.csv", # }, # # "C01.001.3" : { # # "name": "Ablation (C01) – Offset 0 AddXYZ 1", # # "csv": "logs/kubric_v3_augs/ablate-correlation.001.3_K-16_FMAP-128_PYR-4_KNN-remove_offset_and_add_neighbor_xyz/eval_{dataset}/step-?_metrics_avg.csv", # # }, # "C01.001.4" : { # "name": "Ablation (C01) – K 1", # "csv": "logs/kubric_v3_augs/ablate-correlation.001.4_K-1_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", # }, # "C01.001.5" : { # "name": "Ablation (C01) – K 4", # "csv": "logs/kubric_v3_augs/ablate-correlation.001.5_K-4_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", # }, # "C01.001.6" : { # "name": "Ablation (C01) – K 8", # "csv": "logs/kubric_v3_augs/ablate-correlation.001.6_K-8_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", # }, # # "C01.001.7" : { # # "name": "Ablation (C01) – K 32", # # "csv": "logs/kubric_v3_augs/ablate-correlation.001.7_K-32_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", # # }, # # "C01.001.8" : { # # "name": "Ablation (C01) – K 64", # # "csv": "logs/kubric_v3_augs/ablate-correlation.001.8_K-64_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", # # }, # "C01.001.9" : { # "name": "Ablation (C01) – P 1", # "csv": "logs/kubric_v3_augs/ablate-correlation.001.9_K-16_FMAP-128_PYR-1_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", # }, # "C01.001.10" : { # "name": "Ablation (C01) – P 2", # "csv": "logs/kubric_v3_augs/ablate-correlation.001.10_K-16_FMAP-128_PYR-2_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", # }, # # "C01.001.11" : { # # "name": "Ablation (C01) – P 6", # # "csv": "logs/kubric_v3_augs/ablate-correlation.001.11_K-16_FMAP-128_PYR-6_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", # # }, "C02.001.0": { "name": "Ablation (C02) – Offset 1 AddXYZ 0 K 16 P 4", "csv": "logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.0_K-16_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", }, "C02.001.1": { "name": "Ablation (C02) – Offset 0 AddXYZ 0", "csv": "logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.1_K-16_FMAP-128_PYR-4_KNN-remove_offset/eval_{dataset}/step-?_metrics_avg.csv", }, "C02.001.2": { "name": "Ablation (C02) – Offset 1 AddXYZ 1", "csv": "logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.2_K-16_FMAP-128_PYR-4_KNN-add_neighbor_xyz/eval_{dataset}/step-?_metrics_avg.csv", }, # "C02.001.3" : { # "name": "Ablation (C02) – Offset 0 AddXYZ 1", # "csv": "logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.3_K-16_FMAP-128_PYR-4_KNN-remove_offset_and_add_neighbor_xyz/eval_{dataset}/step-?_metrics_avg.csv", # }, "C02.001.4": { "name": "Ablation (C02) – K 1", "csv": "logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.4_K-1_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", }, "C02.001.5": { "name": "Ablation (C02) – K 4", "csv": "logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.5_K-4_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", }, "C02.001.6": { "name": "Ablation (C02) – K 8", "csv": "logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.6_K-8_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", }, # "C02.002.7" : { # "name": "Ablation (C02) – K 32", # "csv": "logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.002.7_K-32_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", # }, "C02.001.8": { "name": "Ablation (C02) – K 64", "csv": "logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.8_K-64_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", }, "C02.001.9": { "name": "Ablation (C02) – P 1", "csv": "logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.9_K-16_FMAP-128_PYR-1_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", }, "C02.001.10": { "name": "Ablation (C02) – P 2", "csv": "logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.10_K-16_FMAP-128_PYR-2_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", }, "C02.001.11": { "name": "Ablation (C02) – P 6", "csv": "logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.11_K-16_FMAP-128_PYR-6_KNN-default/eval_{dataset}/step-?_metrics_avg.csv", }, "shape-of-motion": { "name": "Shape of Motion (MV)", "csv": "logs/shape_of_motion/eval_{dataset}/step--1_metrics_avg.csv", }, # June 2025 "mvtracker-march": { "name": "MV-Tracker (ours; March 2025)", "csv": "logs/eval/mvtracker-iccv-march2025/eval_{dataset}/step--1_metrics_avg.csv", }, "mvtracker-june": { "name": "MV-Tracker (ours; June 2025)", "csv": "logs/eval/mvtracker-june2025/eval_{dataset}/step--1_metrics_avg.csv", }, } def tavid2d_davis(): print("TAPVid-2D DAVIS:") print("================") models_to_report = [ "copycat", "locotrack", "scenetracker", "delta", "cotracker1online", "cotracker2online", "cotracker3online", "cotracker1offline", "cotracker2offline", "cotracker3offline", "spatracker-pretrained", "spatracker", "spatracker-d4", "mv1-d4", "mv2-d4", "mv4-b01", "mv4-b02", "mv4-b03", ] assert all(m in MODELS for m in models_to_report) for resolution in [ "-256x256", # "", ]: for depth_estimator in [ # "zoedepth", # "moge", "mogewithextrinsics", ]: df, csv_str = create_table({ MODELS[m]["name"]: MODELS[m]["csv"].format(dataset=f"tapvid2d-davis-{depth_estimator}{resolution}") for m in models_to_report }, remap=REMAP_TAPVID2D, remap_index_names=REMAP_TAPVID2D_INDEX_NAMES) print(f"Resolution: {resolution}, Depth estimator: {depth_estimator}") print(csv_str) print() def dexycb(): print("DexYCB evaluation results:") print("==========================") for models_to_report, depths in [ (["copycat", "locotrack", "scenetracker", "delta", "cotracker1online", "cotracker2online", "cotracker3online", "cotracker1offline", "cotracker2offline", "cotracker3offline", "spatracker-pretrained", "spatracker", "mv1", "mv2", "spatracker-d4", "mv1-d4", "mv2-d4", "mv3-d4c", "mv4-b01", "mv4-b02", "mv4-b03"], ""), (["copycat", "locotrack", "scenetracker", "delta", "cotracker1online", "cotracker2online", "cotracker3online", "cotracker1offline", "cotracker2offline", "cotracker3offline", "spatracker-pretrained", "spatracker", "mv1", "mv2", "spatracker-d4", "mv1-d4", "mv2-d4", "mv4-b01", "mv4-b02", "mv4-b03", "shape-of-motion", "mv4-b03-paper"], "-duster0123"), (["locotrack", "scenetracker", "delta", "cotracker1online", "cotracker2online", "cotracker3online", "cotracker1offline", "cotracker2offline", "cotracker3offline", "mv3-d4c", "mv4-b01", "mv4-b02", "mv4-b03"], "-duster0123cleaned"), ]: assert all(m in MODELS for m in models_to_report) # for remove_hand in ["", "-removehand"]: for remove_hand in [""]: df, csv_str = create_table({ MODELS[m]["name"]: MODELS[m]["csv"].format(dataset=f"dex-ycb-multiview{depths}{remove_hand}") for m in models_to_report }, remap=REMAP_DEXYCB_V2) print(f"Depths: {depths} Remove hand: {remove_hand}") print(csv_str) print() def kubric_refactored(): print("Kubric evaluation results:") print("==========================") for models_to_report, depths in [ (["copycat", "locotrack", "scenetracker", "delta", "cotracker1online", "cotracker2online", "cotracker3online", "cotracker1offline", "cotracker2offline", "cotracker3offline", "spatracker-pretrained", "spatracker", "mv1", "mv2", "spatracker-d4", "mv1-d4", "mv2-d4", "mv4-b01", "mv4-b02", "mv4-b03", "shape-of-motion", "mv4-b03-paper"], "-views0123"), (["copycat", "locotrack", "scenetracker", "delta", "cotracker1online", "cotracker2online", "cotracker3online", "cotracker1offline", "cotracker2offline", "cotracker3offline", # "spatracker-pretrained", # "spatracker", "mv1", "mv2", # "spatracker-d4", "mv1-d4", "mv2-d4", "mv4-b01", "mv4-b02", "mv4-b03"], "-duster0123"), (["spatracker", "spatracker-d4", ], "-duster0123-views0123"), (["copycat", "locotrack", "scenetracker", "delta", "cotracker1online", "cotracker2online", "cotracker3online", "cotracker1offline", "cotracker2offline", "cotracker3offline", # "spatracker-pretrained", # "spatracker", "mv1", "mv2", # "spatracker-d4", "mv1-d4", "mv2-d4", "mv4-b01", "mv4-b02", "mv4-b03"], "-duster0123cleaned"), (["spatracker-d4"], "-duster0123cleaned-views0123"), ]: assert all(m in MODELS for m in models_to_report) df, csv_str = create_table({ MODELS[m]["name"]: MODELS[m]["csv"].format(dataset=f"kubric-multiview-v3{depths}") for m in models_to_report }, remap=REMAP_KUBRIC) print(f"Depths: {depths}") print(csv_str) print() def panoptic(): print("Panoptic Studio evaluation results:") print("===================================") models_to_report = [ "copycat", "locotrack", "scenetracker", "delta", "cotracker1online", "cotracker2online", "cotracker3online", "cotracker1offline", "cotracker2offline", "cotracker3offline", "spatracker-pretrained", "spatracker", "mv1", "mv2", "spatracker-d4", "mv1-d4", "mv2-d4", "mv4-b01", "mv4-b02", "mv4-b03", "shape-of-motion" ] assert all(m in MODELS for m in models_to_report) for views in ["-views1_7_14_20", "-views27_16_14_8", "-views1_4_7_11"]: df, csv_str = create_table({ MODELS[m]["name"]: MODELS[m]["csv"].format(dataset=f"panoptic-multiview{views}") for m in models_to_report }, remap=REMAP_PANOPTIC) print(f"*** Views: {views} ***") print(csv_str) print() def kubric_single(): print("Kubric single-point evaluation results:") print("==========================") for models_to_report, depths in [ (["copycat", "cotracker3", "spatracker-pretrained", "spatracker", "mv1", "mv2", "spatracker-d4", "mv1-d4", "mv2-d4"], "-views0123"), (["copycat", "cotracker3", # "spatracker-pretrained", # "spatracker", "mv1", "mv2", # "spatracker-d4", "mv1-d4", "mv2-d4", ], "-duster0123"), (["cotracker3", "spatracker-pretrained", "spatracker", "mv1", "mv2", "spatracker-d4", "mv1-d4", "mv2-d4", ], "-duster0123cleaned"), ]: assert all(m in MODELS for m in models_to_report) df, csv_str = create_table({ MODELS[m]["name"]: MODELS[m]["csv"].format(dataset=f"kubric-multiview-v3{depths}") for m in models_to_report }, remap=REMAP_KUBRIC) print(f"Depths: {depths}") print(csv_str) print() def dexycb_single(): print("DexYCB single-point evaluation results:") print("==========================") for models_to_report, depths in [ (["copycat", "cotracker3", "spatracker-pretrained", "spatracker", "mv1", "mv2", "spatracker-d4", "mv1-d4", "mv2-d4"], ""), (["copycat", "cotracker3", "spatracker-pretrained", "spatracker", "mv1", "mv2", "spatracker-d4", "mv1-d4", "mv2-d4", ], "-duster0123"), (["copycat", "cotracker3", # "spatracker-pretrained", # "spatracker", "mv1", "mv2", # "spatracker-d4", "mv1-d4", "mv2-d4", ], "-duster0123cleaned"), ]: assert all(m in MODELS for m in models_to_report) df, csv_str = create_table({ MODELS[m]["name"]: MODELS[m]["csv"].format(dataset=f"dex-ycb-multiview{depths}-single") for m in models_to_report }, remap=REMAP_DEXYCB_V2) print(f"Depths: {depths}") print(csv_str) print() def panoptic_single(): print("Panoptic Studio single-point evaluation results:") print("================================================") models_to_report = [ "copycat", "cotracker3", "spatracker-pretrained", "spatracker", "mv1", "mv2", "spatracker-d4", "mv1-d4", "mv2-d4", ] assert all(m in MODELS for m in models_to_report) for views in [ # "-views27_16_14_8", # "-views1_4_7_11", "-views1_7_14_20", ]: df, csv_str = create_table({ MODELS[m]["name"]: MODELS[m]["csv"].format(dataset=f"panoptic-multiview{views}-single") for m in models_to_report }, remap=REMAP_PANOPTIC) print(f"*** Views: {views} ***") print(csv_str) print() MODEL_KEYS_ABLATION = [ "copycat", "locotrack", "scenetracker", "delta", "cotracker1online", "cotracker2online", "cotracker3online", "cotracker1offline", "cotracker2offline", "cotracker3offline", "spatracker-pretrained", "spatracker", "mv1", "mv2", "spatracker-d4", "mv1-d4", "mv2-d4", "mv4-b01", "mv4-b02", "mv4-b03", ] def ablation_2dpt(): datasets = [ "kubric-multiview-v3-views0123-2dpt", "kubric-multiview-v3-duster0123-2dpt", "dex-ycb-multiview-2dpt", "dex-ycb-multiview-duster0123-2dpt", "panoptic-multiview-views1_7_14_20-2dpt", "panoptic-multiview-views27_16_14_8-2dpt", "panoptic-multiview-views1_4_7_11-2dpt", ] models_to_report = MODEL_KEYS_ABLATION assert all(m in MODELS for m in models_to_report) for dataset in datasets: df, csv_str = create_table({ MODELS[m]["name"]: MODELS[m]["csv"].format(dataset=dataset) for m in models_to_report }, remap=REMAP_KUBRIC | PARTIAL_REMAP_FOR_2DPT_ABLATION, header=dataset == datasets[0]) print(f"DATASET: {dataset}") print(csv_str) print() def one_to_rule_them_all(models, datasets, separate_datasets=True, **create_table_kwargs): assert all(m in MODELS for m in models) if not separate_datasets: df, csv_str = create_table({ (MODELS[m]["name"], dataset): MODELS[m]["csv"].format(dataset=dataset) for m in models for dataset in datasets }, remap=ONE_REMAP_TO_RULE_THEM_ALL, header=True, **create_table_kwargs) print(csv_str) print() else: for dataset in datasets: df, csv_str = create_table({ MODELS[m]["name"]: MODELS[m]["csv"].format(dataset=dataset) for m in models }, remap=ONE_REMAP_TO_RULE_THEM_ALL, header=dataset == datasets[0], **create_table_kwargs) print(f"DATASET: {dataset}") print(csv_str) print() def ablation_model_params(): datasets = [ "kubric-multiview-v3-views0123", "kubric-multiview-v3-duster0123", "dex-ycb-multiview", "dex-ycb-multiview-duster0123", "panoptic-multiview-views1_7_14_20", "panoptic-multiview-views27_16_14_8", "panoptic-multiview-views1_4_7_11", ] models = [m for m in MODELS if m.startswith("C01") or m.startswith("C02")] one_to_rule_them_all(models, datasets) def ablation_camera_setups(): datasets = [ "panoptic-multiview-views1_7_14_20", "panoptic-multiview-views27_16_14_8", "panoptic-multiview-views1_4_7_11", "dex-ycb-multiview-duster0123", "dex-ycb-multiview-duster2345", "dex-ycb-multiview-duster4567", ] one_to_rule_them_all(MODEL_KEYS_ABLATION, datasets) def ablation_num_views(separate_datasets): datasets = [ "kubric-multiview-v3-views0", "kubric-multiview-v3-views01", "kubric-multiview-v3-views012", "kubric-multiview-v3-views0123", "kubric-multiview-v3-views01234", "kubric-multiview-v3-views012345", "kubric-multiview-v3-views0123456", "kubric-multiview-v3-views01234567", "kubric-multiview-v3-duster0123-views0", "kubric-multiview-v3-duster0123-views01", "kubric-multiview-v3-duster0123-views012", "kubric-multiview-v3-duster0123-views0123", "kubric-multiview-v3-duster01234567-views01234", "kubric-multiview-v3-duster01234567-views012345", "kubric-multiview-v3-duster01234567-views0123456", "kubric-multiview-v3-duster01234567-views01234567", "panoptic-multiview-views1", "panoptic-multiview-views1_14", "panoptic-multiview-views1_7_14", "panoptic-multiview-views1_7_14_20", "panoptic-multiview-views1_4_7_14_20", "panoptic-multiview-views1_4_7_14_17_20", "panoptic-multiview-views1_4_7_11_14_17_20", "panoptic-multiview-views1_4_7_11_14_17_20_23", "dex-ycb-multiview-duster0123-views0", "dex-ycb-multiview-duster0123-views01", "dex-ycb-multiview-duster0123-views012", "dex-ycb-multiview-duster0123-views0123", "dex-ycb-multiview-duster01234567-views01234", "dex-ycb-multiview-duster01234567-views012345", "dex-ycb-multiview-duster01234567-views0123456", "dex-ycb-multiview-duster01234567-views01234567", ] one_to_rule_them_all(MODEL_KEYS_ABLATION, datasets, separate_datasets=separate_datasets, skip_missing=True) if __name__ == '__main__': # kubric_single_point() # kubric_before_gt0123() # kubric() # kubric_duster() # mv3_kubric_duster_transformed() # mv3_kubric_nviews() # mv3_kubric_duster_nviews() # kubric_nviews() # tavid2d_davis() # dexycb() # kubric_refactored() # panoptic() # kubric_single() # dexycb_single() # panoptic_single() # ablation_model_params() # ablation_2dpt() # ablation_camera_setups() # ablation_num_views(separate_datasets=False) # ablation_num_views(separate_datasets=True) ######################################### # print("Dirty results:") # print("==========================") # df, csv_str = create_table({ # "CoTracker3 Online": "logs/eval/cotracker3_online/eval_tapvid2d-davis-megasam-256x256/step--1_metrics_avg.csv", # "MV-Tracker + MoGe": "logs/mvtracker-may/eval_tapvid2d-davis-moge-256x256/step--1_metrics_avg.csv", # "MV-Tracker + MoGe-with-extrinsics": "logs/mvtracker-may/eval_tapvid2d-davis-mogewithextrinsics-256x256/step--1_metrics_avg.csv", # "MV-Tracker + ZoeDepth": "logs/mvtracker-may/eval_tapvid2d-davis-zoedepth-256x256/step--1_metrics_avg.csv", # "MV-Tracker + MegaSAM": "logs/mvtracker-may/eval_tapvid2d-davis-megasam-256x256/step--1_metrics_avg.csv", # }, remap=REMAP_TAPVID2D, remap_index_names=REMAP_TAPVID2D_INDEX_NAMES) # print(csv_str) # # print("Depth + Gaussian noise") # print("==========================") # df, csv_str = create_table({ # f"{model};{noise}": f"{model}/eval_kubric-multiview-v3-noise{noise}/step--1_metrics_avg.csv" # for model in [ # "logs/eval/delta", # "logs/eval/spatracker_monocular_pretrained", # "logs/eval/spatracker_monocular_kubric-training", # "logs/eval/spatracker_monocular_duster-training", # "logs/eval/spatracker_multiview_kubric-training", # # "logs/eval/spatracker_multiview_duster-training", # # "logs/mvtracker-noise2", # "logs/eval/spatracker_multiview_duster-training-noise3", # "logs/mvtracker-noise3", # ] # for noise in ["0cm", "1cm", "2cm", "5cm", "10cm", "20cm", "50cm", "100cm", "200cm", "1000cm"] # }, remap=ONE_REMAP_TO_RULE_THEM_ALL, remap_index_names=REMAP_TAPVID2D_INDEX_NAMES) # print(csv_str) ######################################### print("Final full-scale model re-training (June 2025)") print("==========================") datasets = [ "kubric-multiview-v3-views0123", "kubric-multiview-v3-duster0123", "dex-ycb-multiview", "dex-ycb-multiview-duster0123", "panoptic-multiview-views1_7_14_20", "panoptic-multiview-views27_16_14_8", "panoptic-multiview-views1_4_7_11", "tapvid2d-davis-mogewithextrinsics-256x256", "tapvid2d-davis-megasam-256x256", ] models = ["mvtracker-march", "mvtracker-june"] one_to_rule_them_all(models, datasets)